Skip to content

Commit

Permalink
deprecate sax.nn
Browse files Browse the repository at this point in the history
  • Loading branch information
flaport committed Jun 16, 2024
1 parent 15be025 commit 07b71e4
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 365 deletions.
5 changes: 1 addition & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ dev = [

[tool.setuptools.packages.find]
where = ["."]
include = ["sax", "sax.nn", "sax.backends"]
include = ["sax", "sax.backends"]
exclude = []
namespaces = true

Expand All @@ -86,6 +86,3 @@ reportPrivateImportUsage = false

[tool.flake8]
ignore = ['F722']

[tool.coverage.run]
omit = "sax/nn/*"
7 changes: 7 additions & 0 deletions sax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from . import utils as utils
from .circuit import circuit as circuit
from .circuit import get_required_circuit_models as get_required_circuit_models
from .loss import huber_loss as huber_loss
from .loss import l2_reg as l2_reg
from .loss import mse as mse
from .models import get_models as get_models
from .models import passthru as passthru
from .multimode import multimode as multimode
Expand Down Expand Up @@ -70,8 +73,10 @@
from .saxtypes import sdict as sdict
from .saxtypes import validate_model as validate_model
from .utils import block_diag as block_diag
from .utils import cartesian_product as cartesian_product
from .utils import clean_string as clean_string
from .utils import copy_settings as copy_settings
from .utils import denormalize as denormalize
from .utils import flatten_dict as flatten_dict
from .utils import get_inputs_outputs as get_inputs_outputs
from .utils import get_port_combinations as get_port_combinations
Expand All @@ -80,6 +85,8 @@
from .utils import grouped_interp as grouped_interp
from .utils import merge_dicts as merge_dicts
from .utils import mode_combinations as mode_combinations
from .utils import normalization as normalization
from .utils import normalize as normalize
from .utils import reciprocal as reciprocal
from .utils import rename_params as rename_params
from .utils import rename_ports as rename_ports
Expand Down
9 changes: 4 additions & 5 deletions sax/nn/loss.py → sax/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import jax.numpy as jnp

from ..saxtypes import ComplexArrayND
from .saxtypes import ComplexArrayND


def mse(x: ComplexArrayND, y: ComplexArrayND) -> float:
Expand All @@ -16,9 +16,8 @@ def mse(x: ComplexArrayND, y: ComplexArrayND) -> float:

def huber_loss(x: ComplexArrayND, y: ComplexArrayND, delta: float = 0.5) -> float:
"""huber loss"""
return cast(
float, ((delta**2) * ((1.0 + (abs(x - y) / delta) ** 2) ** 0.5 - 1.0)).mean()
)
loss = ((delta**2) * ((1.0 + (abs(x - y) / delta) ** 2) ** 0.5 - 1.0)).mean()
return cast(float, loss)


def l2_reg(weights: Dict[str, ComplexArrayND]) -> float:
Expand All @@ -28,4 +27,4 @@ def l2_reg(weights: Dict[str, ComplexArrayND]) -> float:
for w in (v for k, v in weights.items() if k[0] in ("w", "b")):
numel = numel + w.size
loss = loss + (jnp.abs(w) ** 2).sum()
return loss / numel
return cast(float, loss / numel)
21 changes: 0 additions & 21 deletions sax/nn/__init__.py

This file was deleted.

81 changes: 0 additions & 81 deletions sax/nn/core.py

This file was deleted.

197 changes: 0 additions & 197 deletions sax/nn/io.py

This file was deleted.

Loading

0 comments on commit 07b71e4

Please sign in to comment.