diff --git a/pyproject.toml b/pyproject.toml index edbfe82..21fca9b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,7 +72,7 @@ dev = [ [tool.setuptools.packages.find] where = ["."] -include = ["sax", "sax.nn", "sax.backends"] +include = ["sax", "sax.backends"] exclude = [] namespaces = true @@ -86,6 +86,3 @@ reportPrivateImportUsage = false [tool.flake8] ignore = ['F722'] - -[tool.coverage.run] -omit = "sax/nn/*" diff --git a/sax/__init__.py b/sax/__init__.py index 6c66ddd..cbfab8d 100644 --- a/sax/__init__.py +++ b/sax/__init__.py @@ -23,6 +23,9 @@ from . import utils as utils from .circuit import circuit as circuit from .circuit import get_required_circuit_models as get_required_circuit_models +from .loss import huber_loss as huber_loss +from .loss import l2_reg as l2_reg +from .loss import mse as mse from .models import get_models as get_models from .models import passthru as passthru from .multimode import multimode as multimode @@ -70,8 +73,10 @@ from .saxtypes import sdict as sdict from .saxtypes import validate_model as validate_model from .utils import block_diag as block_diag +from .utils import cartesian_product as cartesian_product from .utils import clean_string as clean_string from .utils import copy_settings as copy_settings +from .utils import denormalize as denormalize from .utils import flatten_dict as flatten_dict from .utils import get_inputs_outputs as get_inputs_outputs from .utils import get_port_combinations as get_port_combinations @@ -80,6 +85,8 @@ from .utils import grouped_interp as grouped_interp from .utils import merge_dicts as merge_dicts from .utils import mode_combinations as mode_combinations +from .utils import normalization as normalization +from .utils import normalize as normalize from .utils import reciprocal as reciprocal from .utils import rename_params as rename_params from .utils import rename_ports as rename_ports diff --git a/sax/nn/loss.py b/sax/loss.py similarity index 77% rename from sax/nn/loss.py rename to sax/loss.py index 53644cc..b166626 100644 --- a/sax/nn/loss.py +++ b/sax/loss.py @@ -6,7 +6,7 @@ import jax.numpy as jnp -from ..saxtypes import ComplexArrayND +from .saxtypes import ComplexArrayND def mse(x: ComplexArrayND, y: ComplexArrayND) -> float: @@ -16,9 +16,8 @@ def mse(x: ComplexArrayND, y: ComplexArrayND) -> float: def huber_loss(x: ComplexArrayND, y: ComplexArrayND, delta: float = 0.5) -> float: """huber loss""" - return cast( - float, ((delta**2) * ((1.0 + (abs(x - y) / delta) ** 2) ** 0.5 - 1.0)).mean() - ) + loss = ((delta**2) * ((1.0 + (abs(x - y) / delta) ** 2) ** 0.5 - 1.0)).mean() + return cast(float, loss) def l2_reg(weights: Dict[str, ComplexArrayND]) -> float: @@ -28,4 +27,4 @@ def l2_reg(weights: Dict[str, ComplexArrayND]) -> float: for w in (v for k, v in weights.items() if k[0] in ("w", "b")): numel = numel + w.size loss = loss + (jnp.abs(w) ** 2).sum() - return loss / numel + return cast(float, loss / numel) diff --git a/sax/nn/__init__.py b/sax/nn/__init__.py deleted file mode 100644 index 258f3c0..0000000 --- a/sax/nn/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -""" SAX Neural Network Module """ - -from __future__ import annotations - -from .core import dense as dense -from .core import generate_dense_weights as generate_dense_weights -from .core import preprocess as preprocess -from .io import get_available_sizes as get_available_sizes -from .io import get_dense_weights_path as get_dense_weights_path -from .io import get_norm_path as get_norm_path -from .io import load_nn_dense as load_nn_dense -from .io import load_nn_weights_json as load_nn_weights_json -from .io import save_nn_weights_json as save_nn_weights_json -from .loss import huber_loss as huber_loss -from .loss import l2_reg as l2_reg -from .loss import mse as mse -from .utils import cartesian_product as cartesian_product -from .utils import denormalize as denormalize -from .utils import get_df_columns as get_df_columns -from .utils import get_normalization as get_normalization -from .utils import normalize as normalize diff --git a/sax/nn/core.py b/sax/nn/core.py deleted file mode 100644 index 2243b0a..0000000 --- a/sax/nn/core.py +++ /dev/null @@ -1,81 +0,0 @@ -""" SAX Neural Network Core Utils """ - -from __future__ import annotations - -from typing import Callable, Dict, Optional, Tuple, Union - -import jax -import jax.numpy as jnp - -from ..saxtypes import Array, ComplexArrayND -from .utils import denormalize, normalize - - -def preprocess(*params: ComplexArrayND) -> ComplexArrayND: - """preprocess parameters - - > Note: (1) all arguments are first casted into the same shape. (2) then pairs - of arguments are divided into each other to create relative arguments. (3) all - arguments are then stacked into one big tensor - """ - x = jnp.stack(jnp.broadcast_arrays(*params), -1) - assert isinstance(x, jnp.ndarray) - to_concatenate = [x] - for i in range(1, x.shape[-1]): - _x = jnp.roll(x, shift=i, axis=-1) - to_concatenate.append(x / _x) - to_concatenate.append(_x / x) - x = jnp.concatenate(to_concatenate, -1) - assert isinstance(x, jnp.ndarray) - return x - - -def dense( - weights: Dict[str, Array], - *params: ComplexArrayND, - x_norm: Tuple[float, float] = (0.0, 1.0), - y_norm: Tuple[float, float] = (0.0, 1.0), - preprocess: Callable = preprocess, - activation: Callable = jax.nn.leaky_relu, -) -> ComplexArrayND: - """simple dense neural network""" - x_mean, x_std = x_norm - y_mean, y_std = y_norm - x = preprocess(*params) - x = normalize(x, mean=x_mean, std=x_std) - for i in range(len([w for w in weights if w.startswith("w")])): - x = activation(x @ weights[f"w{i}"] + weights.get(f"b{i}", 0.0)) - y = denormalize(x, mean=y_mean, std=y_std) - return y - - -def generate_dense_weights( - key: Union[int, Array], - sizes: Tuple[int, ...], - input_names: Optional[Tuple[str, ...]] = None, - output_names: Optional[Tuple[str, ...]] = None, - preprocess=preprocess, -) -> Dict[str, ComplexArrayND]: - """Generate the weights for a dense neural network""" - - if isinstance(key, int): - random_key = jax.random.PRNGKey(key) - else: - random_key = key - - sizes = tuple(s for s in sizes) - if input_names: - arr = preprocess(*jnp.ones(len(input_names))) - assert isinstance(arr, jnp.ndarray) - sizes = (arr.shape[-1],) + sizes - if output_names: - sizes = sizes + (len(output_names),) - - keys = jax.random.split(random_key, 2 * len(sizes)) - rand = jax.nn.initializers.lecun_normal() - weights = {} - for i, (m, n) in enumerate(zip(sizes[:-1], sizes[1:])): - weights[f"w{i}"] = rand(keys[2 * i], (m, n)) - weights[f"b{i}"] = rand(keys[2 * i + 1], (1, n)).ravel() - - return weights diff --git a/sax/nn/io.py b/sax/nn/io.py deleted file mode 100644 index 435c187..0000000 --- a/sax/nn/io.py +++ /dev/null @@ -1,197 +0,0 @@ -""" SAX Neural Network I/O Utilities """ - -from __future__ import annotations - -import json -import os -import re -from typing import Callable, Dict, List, Optional, Tuple - -import jax.numpy as jnp - -from ..saxtypes import ComplexArrayND -from .core import dense, preprocess -from .utils import norm - - -def load_nn_weights_json(path: str) -> Dict[str, ComplexArrayND]: - """Load json weights from given path""" - path = os.path.abspath(os.path.expanduser(path)) - weights = {} - if os.path.exists(path): - with open(path, "r") as file: - for k, v in json.load(file).items(): - _v = jnp.array(v, dtype=float) - assert isinstance(_v, jnp.ndarray) - weights[k] = _v - return weights - - -def save_nn_weights_json(weights: Dict[str, ComplexArrayND], path: str): - """Save json weights to given path""" - path = os.path.abspath(os.path.expanduser(path)) - os.makedirs(os.path.dirname(path), exist_ok=True) - with open(path, "w") as file: - _weights = {} - for k, v in weights.items(): - v = jnp.atleast_1d(jnp.array(v)) - assert isinstance(v, jnp.ndarray) - _weights[k] = v.tolist() - json.dump(_weights, file) - - -def get_available_sizes( - dirpath: str, - prefix: str, - input_names: Tuple[str, ...], - output_names: Tuple[str, ...], -) -> List[Tuple[int, ...]]: - """Get all available json weight hidden sizes given filename parameters - - > Note: this function does NOT return the input size and the output size - of the neural network. ONLY the hidden sizes are reported. The input - and output sizes can easily be derived from `input_names` (after - preprocessing) and `output_names`. - """ - all_weightfiles = os.listdir(dirpath) - possible_weightfiles = ( - s for s in all_weightfiles if s.endswith(f"-{'-'.join(output_names)}.json") - ) - possible_weightfiles = ( - s - for s in possible_weightfiles - if s.startswith(f"{prefix}-{'-'.join(input_names)}") - ) - possible_weightfiles = (re.sub("[^0-9x]", "", s) for s in possible_weightfiles) - possible_weightfiles = (re.sub("^x*", "", s) for s in possible_weightfiles) - possible_weightfiles = (re.sub("x[^0-9]*$", "", s) for s in possible_weightfiles) - possible_hidden_sizes = (s.strip() for s in possible_weightfiles if s.strip()) - possible_hidden_sizes = ( - tuple(hs.strip() for hs in s.split("x") if hs.strip()) - for s in possible_hidden_sizes - ) - possible_hidden_sizes = ( - tuple(int(hs) for hs in s[1:-1]) for s in possible_hidden_sizes if len(s) > 2 - ) - possible_hidden_sizes = sorted( - possible_hidden_sizes, key=lambda hs: (len(hs), max(hs)) - ) - return possible_hidden_sizes - - -def get_dense_weights_path( - *sizes: int, - input_names: Optional[Tuple[str, ...]] = None, - output_names: Optional[Tuple[str, ...]] = None, - dirpath: str = "weights", - prefix: str = "dense", - preprocess=preprocess, -): - """Create the SAX conventional path for a given weight dictionary""" - if input_names: - num_inputs = preprocess(*jnp.ones(len(input_names))).shape[0] - sizes = (num_inputs,) + sizes - if output_names: - sizes = sizes + (len(output_names),) - path = os.path.abspath(os.path.join(dirpath, prefix)) - if input_names: - path = f"{path}-{'-'.join(input_names)}" - if sizes: - path = f"{path}-{'x'.join(str(s) for s in sizes)}" - if output_names: - path = f"{path}-{'-'.join(output_names)}" - return f"{path}.json" - - -def get_norm_path( - *shape: int, - input_names: Optional[Tuple[str, ...]] = None, - output_names: Optional[Tuple[str, ...]] = None, - dirpath: str = "norms", - prefix: str = "norm", - preprocess=preprocess, -): - """Create the SAX conventional path for the normalization constants""" - if input_names and output_names: - raise ValueError( - "To get the norm name, one can only specify " - "`input_names` OR `output_names`." - ) - if input_names: - num_inputs = preprocess(*jnp.ones(len(input_names))).shape[0] - shape = (num_inputs,) + shape - if output_names: - shape = shape + (len(output_names),) - path = os.path.abspath(os.path.join(dirpath, prefix)) - if input_names: - path = f"{path}-{'-'.join(input_names)}" - if shape: - path = f"{path}-{'x'.join(str(s) for s in shape)}" - if output_names: - path = f"{path}-{'-'.join(output_names)}" - return f"{path}.json" - - -class _PartialDense: - def __init__(self, weights, x_norm, y_norm, input_names, output_names): - self.weights = weights - self.x_norm = x_norm - self.y_norm = y_norm - self.input_names = input_names - self.output_names = output_names - - def __call__(self, *params: ComplexArrayND) -> ComplexArrayND: - return dense(self.weights, *params, x_norm=self.x_norm, y_norm=self.y_norm) - - def __repr__(self): - return ( - f"{self.__class__.__name__}" - f"{repr(self.input_names)}" - f"->{repr(self.output_names)}" - ) - - -def load_nn_dense( - *sizes: int, - input_names: Optional[Tuple[str, ...]] = None, - output_names: Optional[Tuple[str, ...]] = None, - weightprefix="dense", - weightdirpath="weights", - normdirpath="norms", - normprefix="norm", - preprocess=preprocess, -) -> Callable: - """Load a pre-trained dense model""" - weights_path = get_dense_weights_path( - *sizes, - input_names=input_names, - output_names=output_names, - prefix=weightprefix, - dirpath=weightdirpath, - preprocess=preprocess, - ) - if not os.path.exists(weights_path): - raise ValueError("Cannot find weights path for given parameters") - x_norm_path = get_norm_path( - input_names=input_names, - prefix=normprefix, - dirpath=normdirpath, - preprocess=preprocess, - ) - if not os.path.exists(x_norm_path): - raise ValueError("Cannot find normalization for input parameters") - y_norm_path = get_norm_path( - output_names=output_names, - prefix=normprefix, - dirpath=normdirpath, - preprocess=preprocess, - ) - if not os.path.exists(x_norm_path): - raise ValueError("Cannot find normalization for output parameters") - weights = load_nn_weights_json(weights_path) - x_norm_dict = load_nn_weights_json(x_norm_path) - y_norm_dict = load_nn_weights_json(y_norm_path) - x_norm = norm(x_norm_dict["mean"], x_norm_dict["std"]) - y_norm = norm(y_norm_dict["mean"], y_norm_dict["std"]) - partial_dense = _PartialDense(weights, x_norm, y_norm, input_names, output_names) - return partial_dense diff --git a/sax/nn/utils.py b/sax/nn/utils.py deleted file mode 100644 index c427258..0000000 --- a/sax/nn/utils.py +++ /dev/null @@ -1,56 +0,0 @@ -""" Sax Neural Network Default Utilities """ - -from __future__ import annotations - -from collections import namedtuple -from typing import Tuple - -import jax.numpy as jnp -import pandas as pd - -from ..saxtypes import ComplexArrayND - - -def cartesian_product(*arrays: ComplexArrayND) -> ComplexArrayND: - """calculate the n-dimensional cartesian product of an arbitrary number of arrays""" - ixarrays = jnp.ix_(*arrays) - barrays = jnp.broadcast_arrays(*ixarrays) - sarrays = jnp.stack(barrays, -1) - assert isinstance(sarrays, jnp.ndarray) - product = sarrays.reshape(-1, sarrays.shape[-1]) - assert isinstance(product, jnp.ndarray) - return product - - -def denormalize( - x: ComplexArrayND, mean: float = 0.0, std: float = 1.0 -) -> ComplexArrayND: - """denormalize an array with a given mean and standard deviation""" - return x * std + mean - - -norm = namedtuple("norm", ("mean", "std")) - - -def get_normalization(x: ComplexArrayND): - """Get mean and standard deviation for a given array""" - if isinstance(x, (complex, float)): - return x, 0.0 - return norm(x.mean(0), x.std(0)) - - -def get_df_columns(df: pd.DataFrame, *names: str) -> Tuple[ComplexArrayND, ...]: - """Get certain columns from a pandas DataFrame as jax.numpy arrays""" - tup = namedtuple("params", names) - params_list = [] - for name in names: - column_np = df[name].values - column_jnp = jnp.array(column_np) - assert isinstance(column_jnp, jnp.ndarray) - params_list.append(column_jnp.ravel()) - return tup(*params_list) - - -def normalize(x: ComplexArrayND, mean: float = 0.0, std: float = 1.0) -> ComplexArrayND: - """normalize an array with a given mean and standard deviation""" - return (x - mean) / std diff --git a/sax/utils.py b/sax/utils.py index 931a0a6..51515e8 100644 --- a/sax/utils.py +++ b/sax/utils.py @@ -7,7 +7,18 @@ import warnings from functools import lru_cache, partial, wraps from hashlib import md5 -from typing import Any, Callable, Dict, Iterable, Iterator, Tuple, Union, cast, overload +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + NamedTuple, + Tuple, + Union, + cast, + overload, +) import jax import jax.numpy as jnp @@ -598,3 +609,38 @@ def _numpyfy(obj: Any): return np.asarray(obj) else: return {k: _numpyfy(v) for k, v in obj.items()} + + +class Normalization(NamedTuple): + mean: ComplexArrayND + std: ComplexArrayND + + +def normalization(x: ComplexArrayND, axis=None): + if axis is None: + return Normalization(x.mean(), x.std()) + else: + return Normalization(x.mean(axis), x.std(axis)) + + +def cartesian_product(*arrays: ComplexArrayND) -> ComplexArrayND: + """calculate the n-dimensional cartesian product of an arbitrary number of arrays""" + ixarrays = jnp.ix_(*arrays) + barrays = jnp.broadcast_arrays(*ixarrays) + sarrays = jnp.stack(barrays, -1) + assert isinstance(sarrays, jnp.ndarray) + product = sarrays.reshape(-1, sarrays.shape[-1]) + assert isinstance(product, jnp.ndarray) + return product + + +def normalize(x: ComplexArrayND, normalization: Normalization) -> Tuple[ComplexArrayND]: + """normalize an array with a given mean and standard deviation""" + mean, std = normalization + return (x - mean) / std + + +def denormalize(x: ComplexArrayND, normalization: Normalization) -> ComplexArrayND: + """denormalize an array with a given mean and standard deviation""" + mean, std = normalization + return x * std + mean