Skip to content

Commit

Permalink
fix pydantic v2 transition
Browse files Browse the repository at this point in the history
  • Loading branch information
flaport committed Jun 16, 2024
1 parent 07b71e4 commit 84947c9
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 82 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# [Changelog](https://keepachangelog.com/en/1.0.0/)

## [Unreleased](https://github.com/flaport/sax/compare/0.10.2...main)
## [0.13.0](https://github.com/flaport/sax/compare/0.12.2...0.13.0)

- Deprecate `sax.nn`.
- Remove support for pydantic v1.
- Add support for gdsfactory 8 netlists (i.e. `nets` in stead of `connections`)

## [0.10.2](https://github.com/flaport/sax/compare/0.10.1...0.10.2)

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ dependencies = [
"klujax>=0.2.5",
"matplotlib",
"natsort",
"typing-extensions>=4.10.0",
"networkx",
"numpy",
"orjson",
Expand Down
36 changes: 18 additions & 18 deletions sax/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
class CircuitInfo(NamedTuple):
"""Information about the circuit function you created."""

dag: nx.Graph
dag: nx.DiGraph[str]
models: Dict[str, Model]


Expand All @@ -60,9 +60,9 @@ def circuit(
# TODO: do the following two steps *after* recursive netlist parsing.
netlist = remove_unused_instances(netlist)
netlist, instance_models = _extract_instance_models(netlist)
recnet: RecursiveNetlist = _validate_net(netlist)
dependency_dag: nx.DiGraph[str] = _validate_dag(_create_dag(recnet, models))

recnet: RecursiveNetlist = _validate_net(netlist) # type: ignore
dependency_dag: nx.Graph = _validate_dag(_create_dag(recnet, models))
models = _validate_models({**(models or {}), **instance_models}, dependency_dag)
backend = _validate_circuit_backend(backend)

Expand All @@ -75,7 +75,7 @@ def circuit(
new_models[model_name] = models[model_name]
continue

flatnet = recnet.__root__[model_name]
flatnet = recnet.root[model_name]
current_models |= new_models
new_models = {}

Expand All @@ -96,15 +96,15 @@ def circuit(
def _create_dag(
netlist: RecursiveNetlist,
models: Optional[Dict[str, Any]] = None,
):
) -> nx.DiGraph[str]:
if models is None:
models = {}
assert isinstance(models, dict)

all_models = {}
g = nx.DiGraph()

for model_name, subnetlist in netlist.model_dump()["__root__"].items():
for model_name, subnetlist in netlist.model_dump().items():
if model_name not in all_models:
all_models[model_name] = models.get(model_name, subnetlist)
g.add_node(model_name)
Expand All @@ -122,11 +122,11 @@ def _create_dag(
g.add_edge(model_name, component)

# we only need the nodes that depend on the parent...
parent_node = next(iter(netlist.__root__.keys()))
parent_node = next(iter(netlist.root.keys()))
nodes = [parent_node, *nx.descendants(g, parent_node)]
g = nx.induced_subgraph(g, nodes)

return g
return g # type: ignore


def draw_dag(dag, with_labels=True, **kwargs):
Expand Down Expand Up @@ -331,11 +331,9 @@ def _enforce_return_type(model, return_type):
return stype_func(model)


def _ensure_recursive_netlist_dict(netlist):
def _ensure_recursive_netlist_dict(netlist: Any) -> RecursiveNetlistDict:
if not isinstance(netlist, dict):
netlist = netlist.model_dump()
if "__root__" in netlist:
netlist = netlist["__root__"]
if "instances" in netlist:
netlist = {"top_level": netlist}
netlist = {**netlist}
Expand Down Expand Up @@ -379,14 +377,16 @@ def _validate_circuit_backend(backend):
return backend


def _validate_net(netlist: Union[Netlist, RecursiveNetlist]) -> RecursiveNetlist:
def _validate_net(
netlist: Union[Netlist, RecursiveNetlist, NetlistDict, RecursiveNetlistDict]
) -> RecursiveNetlist:
if isinstance(netlist, dict):
try:
netlist = Netlist.parse_obj(netlist)
netlist = Netlist.model_validate(netlist)
except ValidationError:
netlist = RecursiveNetlist.parse_obj(netlist)
elif isinstance(netlist, Netlist):
netlist = RecursiveNetlist(__root__={"top_level": netlist})
netlist = RecursiveNetlist.model_validate(netlist)
if isinstance(netlist, Netlist):
netlist = RecursiveNetlist(root={"top_level": netlist})
return netlist


Expand Down Expand Up @@ -419,13 +419,13 @@ def get_required_circuit_models(
# TODO: do the following two steps *after* recursive netlist parsing.
netlist = remove_unused_instances(netlist)
netlist, _ = _extract_instance_models(netlist)
recnet: RecursiveNetlist = _validate_net(netlist) # type: ignore
recnet: RecursiveNetlist = _validate_net(netlist)

missing_models = {}
missing_model_names = []
g = nx.DiGraph()

for model_name, subnetlist in recnet.model_dump()["__root__"].items():
for model_name, subnetlist in recnet.model_dump().items():
if model_name not in missing_models:
missing_models[model_name] = models.get(model_name, subnetlist)
g.add_node(model_name)
Expand Down
130 changes: 67 additions & 63 deletions sax/netlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,28 @@
from copy import deepcopy
from enum import Enum
from functools import lru_cache
from typing import Any, Dict, Optional, TypedDict, Union
from typing import Any, TypedDict

import black
import networkx as nx
import numpy as np
import yaml
from natsort import natsorted
from pydantic import BaseModel, ConfigDict, Field, ValidationError, validator
from pydantic import (
BaseModel,
BeforeValidator,
ConfigDict,
Field,
RootModel,
ValidationError,
field_validator,
)
from typing_extensions import Annotated

from .utils import clean_string, hash_dict


def netlist(dic: Dict) -> RecursiveNetlist:
def netlist(dic: dict) -> RecursiveNetlist:
"""return a netlist from a given dictionary"""
if isinstance(dic, RecursiveNetlist):
return dic
Expand All @@ -38,7 +47,7 @@ class _BaseModel(BaseModel): # type: ignore
model_config = ConfigDict(
extra="ignore",
frozen=True,
json_encoders = {np.ndarray: lambda arr: np.round(arr, 12).tolist()},
json_encoders={np.ndarray: lambda arr: np.round(arr, 12).tolist()},
)

def __repr__(self):
Expand All @@ -54,11 +63,10 @@ def __hash__(self):


class Component(_BaseModel):
component: Union[str, Dict[str, Any]] = Field(..., title="Component")
settings: Optional[Dict[str, Any]] = Field(None, title="Settings")
info: Optional[Dict[str, Any]] = Field(None, title="Info")
component: str
settings: dict[str, Any] = Field(default_factory=dict)

@validator("component")
@field_validator("component")
def validate_component_name(cls, value):
if "," in value:
raise ValueError(
Expand All @@ -82,53 +90,33 @@ class PortEnum(Enum):


class Placement(_BaseModel):
x: Optional[Union[str, float]] = Field(0, title="X")
y: Optional[Union[str, float]] = Field(0, title="Y")
xmin: Optional[Union[str, float]] = Field(None, title="Xmin")
ymin: Optional[Union[str, float]] = Field(None, title="Ymin")
xmax: Optional[Union[str, float]] = Field(None, title="Xmax")
ymax: Optional[Union[str, float]] = Field(None, title="Ymax")
dx: Optional[float] = Field(0, title="Dx")
dy: Optional[float] = Field(0, title="Dy")
port: Optional[Union[str, PortEnum]] = Field(None, title="Port")
rotation: Optional[int] = Field(0, title="Rotation")
mirror: Optional[bool] = Field(False, title="Mirror")
x: str | float = 0.0
y: str | float = 0.0
dx: str | float = 0.0
dy: str | float = 0.0
rotation: float = 0.0
mirror: bool = False
xmin: str | float | None = None
xmax: str | float | None = None
ymin: str | float | None = None
ymax: str | float | None = None
port: str | PortEnum | None = None


class Route(_BaseModel):
links: Dict[str, str] = Field(..., title="Links", default_factory=dict)
settings: Optional[Dict[str, Any]] = Field(None, title="Settings", default_factory=dict)
routing_strategy: Optional[str] = Field(None, title="Routing Strategy", default_factory=dict)
def _str_to_component(s: Any) -> Component:
if isinstance(s, str):
return Component(component=s)
return Component.model_validate(s)


CoercingComponent = Annotated[Component | str, BeforeValidator(_str_to_component)]


class Netlist(_BaseModel):
instances: Dict[str, Component] = Field(..., title="Instances", default_factory=dict)
connections: Optional[Dict[str, str]] = Field(None, title="Connections", default_factory=dict)
ports: Optional[Dict[str, str]] = Field(None, title="Ports", default_factory=dict)
placements: Optional[Dict[str, Placement]] = Field(None, title="Placements", default_factory=dict)

# these were removed (irrelevant for SAX):

# routes: Optional[Dict[str, Route]] = Field(None, title='Routes')
# name: Optional[str] = Field(None, title='Name')
# info: Optional[Dict[str, Any]] = Field(None, title='Info')
# settings: Optional[Dict[str, Any]] = Field(None, title='Settings')
# pdk: Optional[str] = Field(None, title='Pdk')

# these are extra additions:

@validator("instances", pre=True)
def coerce_different_type_instance_into_component_model(cls, instances):
new_instances = {}
for k, v in instances.items():
if isinstance(v, str):
v = {
"component": v,
"settings": {},
}
new_instances[k] = v

return new_instances
instances: dict[str, CoercingComponent] = Field(default_factory=dict)
connections: dict[str, str] = Field(default_factory=dict)
ports: dict[str, str] = Field(default_factory=dict)
placements: dict[str, Placement] = Field(default_factory=dict)

@staticmethod
def clean_instance_string(value):
Expand All @@ -138,11 +126,11 @@ def clean_instance_string(value):
)
return clean_string(value)

@validator("instances")
@field_validator("instances")
def validate_instance_names(cls, instances):
return {cls.clean_instance_string(k): v for k, v in instances.items()}

@validator("placements")
@field_validator("placements")
def validate_placement_names(cls, placements):
if placements is not None:
return {cls.clean_instance_string(k): v for k, v in placements.items()}
Expand All @@ -154,32 +142,48 @@ def clean_connection_string(cls, value):
comp = cls.clean_instance_string(",".join(comp))
return f"{comp},{port}"

@validator("connections")
@field_validator("connections")
def validate_connection_names(cls, connections):
return {
cls.clean_connection_string(k): cls.clean_connection_string(v)
for k, v in connections.items()
}

@validator("ports")
@field_validator("ports")
def validate_port_names(cls, ports):
return {
cls.clean_instance_string(k): cls.clean_connection_string(v)
for k, v in ports.items()
}


class RecursiveNetlist(_BaseModel):
__root__: Dict[str, Netlist]
class RecursiveNetlist(RootModel):
root: dict[str, Netlist]

model_config = ConfigDict(
frozen=True,
json_encoders={np.ndarray: lambda arr: np.round(arr, 12).tolist()},
)

def __repr__(self):
s = super().__repr__()
s = black.format_str(s, mode=black.Mode())
return s

def __str__(self):
return self.__repr__()

def __hash__(self):
return hash_dict(self.model_dump())


class NetlistDict(TypedDict):
instances: Dict
connections: Dict[str, str]
ports: Dict[str, str]
instances: dict
connections: dict[str, str]
ports: dict[str, str]


RecursiveNetlistDict = Dict[str, NetlistDict]
RecursiveNetlistDict = dict[str, NetlistDict]


@lru_cache()
Expand All @@ -197,7 +201,7 @@ def _clean_string(path: str) -> str:
return clean_string(re.sub(ext, "", os.path.split(path)[-1]))

# the circuit we're interested in should come first:
netlists: Dict[str, Netlist] = {_clean_string(pic_path): Netlist()}
netlists: dict[str, Netlist] = {_clean_string(pic_path): Netlist()}

for filename in os.listdir(folder_path):
path = os.path.join(folder_path, filename)
Expand All @@ -222,7 +226,7 @@ def get_netlist_instances_by_prefix(
Returns:
A list of all instances with the given prefix.
"""
recursive_netlist_root = recursive_netlist.model_dump()["__root__"]
recursive_netlist_root = recursive_netlist.model_dump()
result = []
for key in recursive_netlist_root.keys():
if key.startswith(prefix):
Expand All @@ -247,7 +251,7 @@ def get_component_instances(
A dictionary of all instances of the given component.
"""
instance_names = []
recursive_netlist_root = recursive_netlist.model_dump()["__root__"]
recursive_netlist_root = recursive_netlist.model_dump()

# Should only be one in a netlist-to-digraph. Can always be very specified.
top_level_prefixes = get_netlist_instances_by_prefix(
Expand Down
15 changes: 15 additions & 0 deletions tests/test_netlist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from sax.netlist import Netlist, RecursiveNetlist


def test_empty_netlist():
assert Netlist()


def test_coercing_instances():
assert Netlist.model_validate({"instances": {"mmi": "mmi"}})


def test_recursive_netlist():
net = Netlist.model_validate({"instances": {"mmi": "mmi"}})
recnet = RecursiveNetlist.model_validate({"net": net})
assert recnet

0 comments on commit 84947c9

Please sign in to comment.