From 84947c93c5661b245b4bd3ea9a9b018979e7de2a Mon Sep 17 00:00:00 2001 From: flaport Date: Sun, 16 Jun 2024 20:47:38 +0200 Subject: [PATCH] fix pydantic v2 transition --- CHANGELOG.md | 6 +- pyproject.toml | 1 + sax/circuit.py | 36 ++++++------ sax/netlist.py | 130 ++++++++++++++++++++++-------------------- tests/test_netlist.py | 15 +++++ 5 files changed, 106 insertions(+), 82 deletions(-) create mode 100644 tests/test_netlist.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 93e1207..41f1da9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,10 @@ # [Changelog](https://keepachangelog.com/en/1.0.0/) -## [Unreleased](https://github.com/flaport/sax/compare/0.10.2...main) +## [0.13.0](https://github.com/flaport/sax/compare/0.12.2...0.13.0) + +- Deprecate `sax.nn`. +- Remove support for pydantic v1. +- Add support for gdsfactory 8 netlists (i.e. `nets` in stead of `connections`) ## [0.10.2](https://github.com/flaport/sax/compare/0.10.1...0.10.2) diff --git a/pyproject.toml b/pyproject.toml index 21fca9b..b820148 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,7 @@ dependencies = [ "klujax>=0.2.5", "matplotlib", "natsort", + "typing-extensions>=4.10.0", "networkx", "numpy", "orjson", diff --git a/sax/circuit.py b/sax/circuit.py index f080528..14379c8 100644 --- a/sax/circuit.py +++ b/sax/circuit.py @@ -34,7 +34,7 @@ class CircuitInfo(NamedTuple): """Information about the circuit function you created.""" - dag: nx.Graph + dag: nx.DiGraph[str] models: Dict[str, Model] @@ -60,9 +60,9 @@ def circuit( # TODO: do the following two steps *after* recursive netlist parsing. netlist = remove_unused_instances(netlist) netlist, instance_models = _extract_instance_models(netlist) + recnet: RecursiveNetlist = _validate_net(netlist) + dependency_dag: nx.DiGraph[str] = _validate_dag(_create_dag(recnet, models)) - recnet: RecursiveNetlist = _validate_net(netlist) # type: ignore - dependency_dag: nx.Graph = _validate_dag(_create_dag(recnet, models)) models = _validate_models({**(models or {}), **instance_models}, dependency_dag) backend = _validate_circuit_backend(backend) @@ -75,7 +75,7 @@ def circuit( new_models[model_name] = models[model_name] continue - flatnet = recnet.__root__[model_name] + flatnet = recnet.root[model_name] current_models |= new_models new_models = {} @@ -96,7 +96,7 @@ def circuit( def _create_dag( netlist: RecursiveNetlist, models: Optional[Dict[str, Any]] = None, -): +) -> nx.DiGraph[str]: if models is None: models = {} assert isinstance(models, dict) @@ -104,7 +104,7 @@ def _create_dag( all_models = {} g = nx.DiGraph() - for model_name, subnetlist in netlist.model_dump()["__root__"].items(): + for model_name, subnetlist in netlist.model_dump().items(): if model_name not in all_models: all_models[model_name] = models.get(model_name, subnetlist) g.add_node(model_name) @@ -122,11 +122,11 @@ def _create_dag( g.add_edge(model_name, component) # we only need the nodes that depend on the parent... - parent_node = next(iter(netlist.__root__.keys())) + parent_node = next(iter(netlist.root.keys())) nodes = [parent_node, *nx.descendants(g, parent_node)] g = nx.induced_subgraph(g, nodes) - return g + return g # type: ignore def draw_dag(dag, with_labels=True, **kwargs): @@ -331,11 +331,9 @@ def _enforce_return_type(model, return_type): return stype_func(model) -def _ensure_recursive_netlist_dict(netlist): +def _ensure_recursive_netlist_dict(netlist: Any) -> RecursiveNetlistDict: if not isinstance(netlist, dict): netlist = netlist.model_dump() - if "__root__" in netlist: - netlist = netlist["__root__"] if "instances" in netlist: netlist = {"top_level": netlist} netlist = {**netlist} @@ -379,14 +377,16 @@ def _validate_circuit_backend(backend): return backend -def _validate_net(netlist: Union[Netlist, RecursiveNetlist]) -> RecursiveNetlist: +def _validate_net( + netlist: Union[Netlist, RecursiveNetlist, NetlistDict, RecursiveNetlistDict] +) -> RecursiveNetlist: if isinstance(netlist, dict): try: - netlist = Netlist.parse_obj(netlist) + netlist = Netlist.model_validate(netlist) except ValidationError: - netlist = RecursiveNetlist.parse_obj(netlist) - elif isinstance(netlist, Netlist): - netlist = RecursiveNetlist(__root__={"top_level": netlist}) + netlist = RecursiveNetlist.model_validate(netlist) + if isinstance(netlist, Netlist): + netlist = RecursiveNetlist(root={"top_level": netlist}) return netlist @@ -419,13 +419,13 @@ def get_required_circuit_models( # TODO: do the following two steps *after* recursive netlist parsing. netlist = remove_unused_instances(netlist) netlist, _ = _extract_instance_models(netlist) - recnet: RecursiveNetlist = _validate_net(netlist) # type: ignore + recnet: RecursiveNetlist = _validate_net(netlist) missing_models = {} missing_model_names = [] g = nx.DiGraph() - for model_name, subnetlist in recnet.model_dump()["__root__"].items(): + for model_name, subnetlist in recnet.model_dump().items(): if model_name not in missing_models: missing_models[model_name] = models.get(model_name, subnetlist) g.add_node(model_name) diff --git a/sax/netlist.py b/sax/netlist.py index dea4562..9c4ba75 100644 --- a/sax/netlist.py +++ b/sax/netlist.py @@ -8,19 +8,28 @@ from copy import deepcopy from enum import Enum from functools import lru_cache -from typing import Any, Dict, Optional, TypedDict, Union +from typing import Any, TypedDict import black import networkx as nx import numpy as np import yaml from natsort import natsorted -from pydantic import BaseModel, ConfigDict, Field, ValidationError, validator +from pydantic import ( + BaseModel, + BeforeValidator, + ConfigDict, + Field, + RootModel, + ValidationError, + field_validator, +) +from typing_extensions import Annotated from .utils import clean_string, hash_dict -def netlist(dic: Dict) -> RecursiveNetlist: +def netlist(dic: dict) -> RecursiveNetlist: """return a netlist from a given dictionary""" if isinstance(dic, RecursiveNetlist): return dic @@ -38,7 +47,7 @@ class _BaseModel(BaseModel): # type: ignore model_config = ConfigDict( extra="ignore", frozen=True, - json_encoders = {np.ndarray: lambda arr: np.round(arr, 12).tolist()}, + json_encoders={np.ndarray: lambda arr: np.round(arr, 12).tolist()}, ) def __repr__(self): @@ -54,11 +63,10 @@ def __hash__(self): class Component(_BaseModel): - component: Union[str, Dict[str, Any]] = Field(..., title="Component") - settings: Optional[Dict[str, Any]] = Field(None, title="Settings") - info: Optional[Dict[str, Any]] = Field(None, title="Info") + component: str + settings: dict[str, Any] = Field(default_factory=dict) - @validator("component") + @field_validator("component") def validate_component_name(cls, value): if "," in value: raise ValueError( @@ -82,53 +90,33 @@ class PortEnum(Enum): class Placement(_BaseModel): - x: Optional[Union[str, float]] = Field(0, title="X") - y: Optional[Union[str, float]] = Field(0, title="Y") - xmin: Optional[Union[str, float]] = Field(None, title="Xmin") - ymin: Optional[Union[str, float]] = Field(None, title="Ymin") - xmax: Optional[Union[str, float]] = Field(None, title="Xmax") - ymax: Optional[Union[str, float]] = Field(None, title="Ymax") - dx: Optional[float] = Field(0, title="Dx") - dy: Optional[float] = Field(0, title="Dy") - port: Optional[Union[str, PortEnum]] = Field(None, title="Port") - rotation: Optional[int] = Field(0, title="Rotation") - mirror: Optional[bool] = Field(False, title="Mirror") + x: str | float = 0.0 + y: str | float = 0.0 + dx: str | float = 0.0 + dy: str | float = 0.0 + rotation: float = 0.0 + mirror: bool = False + xmin: str | float | None = None + xmax: str | float | None = None + ymin: str | float | None = None + ymax: str | float | None = None + port: str | PortEnum | None = None -class Route(_BaseModel): - links: Dict[str, str] = Field(..., title="Links", default_factory=dict) - settings: Optional[Dict[str, Any]] = Field(None, title="Settings", default_factory=dict) - routing_strategy: Optional[str] = Field(None, title="Routing Strategy", default_factory=dict) +def _str_to_component(s: Any) -> Component: + if isinstance(s, str): + return Component(component=s) + return Component.model_validate(s) + + +CoercingComponent = Annotated[Component | str, BeforeValidator(_str_to_component)] class Netlist(_BaseModel): - instances: Dict[str, Component] = Field(..., title="Instances", default_factory=dict) - connections: Optional[Dict[str, str]] = Field(None, title="Connections", default_factory=dict) - ports: Optional[Dict[str, str]] = Field(None, title="Ports", default_factory=dict) - placements: Optional[Dict[str, Placement]] = Field(None, title="Placements", default_factory=dict) - - # these were removed (irrelevant for SAX): - - # routes: Optional[Dict[str, Route]] = Field(None, title='Routes') - # name: Optional[str] = Field(None, title='Name') - # info: Optional[Dict[str, Any]] = Field(None, title='Info') - # settings: Optional[Dict[str, Any]] = Field(None, title='Settings') - # pdk: Optional[str] = Field(None, title='Pdk') - - # these are extra additions: - - @validator("instances", pre=True) - def coerce_different_type_instance_into_component_model(cls, instances): - new_instances = {} - for k, v in instances.items(): - if isinstance(v, str): - v = { - "component": v, - "settings": {}, - } - new_instances[k] = v - - return new_instances + instances: dict[str, CoercingComponent] = Field(default_factory=dict) + connections: dict[str, str] = Field(default_factory=dict) + ports: dict[str, str] = Field(default_factory=dict) + placements: dict[str, Placement] = Field(default_factory=dict) @staticmethod def clean_instance_string(value): @@ -138,11 +126,11 @@ def clean_instance_string(value): ) return clean_string(value) - @validator("instances") + @field_validator("instances") def validate_instance_names(cls, instances): return {cls.clean_instance_string(k): v for k, v in instances.items()} - @validator("placements") + @field_validator("placements") def validate_placement_names(cls, placements): if placements is not None: return {cls.clean_instance_string(k): v for k, v in placements.items()} @@ -154,14 +142,14 @@ def clean_connection_string(cls, value): comp = cls.clean_instance_string(",".join(comp)) return f"{comp},{port}" - @validator("connections") + @field_validator("connections") def validate_connection_names(cls, connections): return { cls.clean_connection_string(k): cls.clean_connection_string(v) for k, v in connections.items() } - @validator("ports") + @field_validator("ports") def validate_port_names(cls, ports): return { cls.clean_instance_string(k): cls.clean_connection_string(v) @@ -169,17 +157,33 @@ def validate_port_names(cls, ports): } -class RecursiveNetlist(_BaseModel): - __root__: Dict[str, Netlist] +class RecursiveNetlist(RootModel): + root: dict[str, Netlist] + + model_config = ConfigDict( + frozen=True, + json_encoders={np.ndarray: lambda arr: np.round(arr, 12).tolist()}, + ) + + def __repr__(self): + s = super().__repr__() + s = black.format_str(s, mode=black.Mode()) + return s + + def __str__(self): + return self.__repr__() + + def __hash__(self): + return hash_dict(self.model_dump()) class NetlistDict(TypedDict): - instances: Dict - connections: Dict[str, str] - ports: Dict[str, str] + instances: dict + connections: dict[str, str] + ports: dict[str, str] -RecursiveNetlistDict = Dict[str, NetlistDict] +RecursiveNetlistDict = dict[str, NetlistDict] @lru_cache() @@ -197,7 +201,7 @@ def _clean_string(path: str) -> str: return clean_string(re.sub(ext, "", os.path.split(path)[-1])) # the circuit we're interested in should come first: - netlists: Dict[str, Netlist] = {_clean_string(pic_path): Netlist()} + netlists: dict[str, Netlist] = {_clean_string(pic_path): Netlist()} for filename in os.listdir(folder_path): path = os.path.join(folder_path, filename) @@ -222,7 +226,7 @@ def get_netlist_instances_by_prefix( Returns: A list of all instances with the given prefix. """ - recursive_netlist_root = recursive_netlist.model_dump()["__root__"] + recursive_netlist_root = recursive_netlist.model_dump() result = [] for key in recursive_netlist_root.keys(): if key.startswith(prefix): @@ -247,7 +251,7 @@ def get_component_instances( A dictionary of all instances of the given component. """ instance_names = [] - recursive_netlist_root = recursive_netlist.model_dump()["__root__"] + recursive_netlist_root = recursive_netlist.model_dump() # Should only be one in a netlist-to-digraph. Can always be very specified. top_level_prefixes = get_netlist_instances_by_prefix( diff --git a/tests/test_netlist.py b/tests/test_netlist.py new file mode 100644 index 0000000..2ae3f71 --- /dev/null +++ b/tests/test_netlist.py @@ -0,0 +1,15 @@ +from sax.netlist import Netlist, RecursiveNetlist + + +def test_empty_netlist(): + assert Netlist() + + +def test_coercing_instances(): + assert Netlist.model_validate({"instances": {"mmi": "mmi"}}) + + +def test_recursive_netlist(): + net = Netlist.model_validate({"instances": {"mmi": "mmi"}}) + recnet = RecursiveNetlist.model_validate({"net": net}) + assert recnet