From f8f50889bb6d30abb0e4214df7361d687d82950a Mon Sep 17 00:00:00 2001 From: Tyler Morrow Date: Thu, 15 Aug 2024 06:40:45 -0600 Subject: [PATCH 1/4] Refactor for Python and dependency upgrades. --- .github/workflows/build-and-run-tests.yml | 10 +- .github/workflows/run-examples.yml | 10 +- README.md | 11 +- examples/data/synthesis/mix_seeds.py | 6 +- examples/modeling/arad.py | 4 +- examples/modeling/arad_latent_prediction.py | 7 +- examples/modeling/classifier_comparison.py | 4 +- .../modeling/label_proportion_estimation.py | 1 - examples/modeling/multi_event_classifier.py | 74 --- .../modeling/neural_network_classifier.py | 2 +- examples/visualization/confusion_matrix.py | 2 +- pyproject.toml | 51 +- riid/data/converters/__init__.py | 25 +- riid/data/sampleset.py | 23 +- riid/data/synthetic/seed.py | 9 +- riid/losses/__init__.py | 19 +- riid/metrics.py | 38 +- riid/models/__init__.py | 4 +- riid/models/bayes.py | 148 +++--- riid/models/layers.py | 82 +++ riid/models/neural_nets/__init__.py | 467 +++++------------- riid/models/neural_nets/arad.py | 148 +++--- tests/anomaly_tests.py | 7 +- tests/model_tests.py | 19 +- tests/seedmixer_tests.py | 29 +- tests/visualize_tests.py | 3 +- 26 files changed, 509 insertions(+), 694 deletions(-) delete mode 100644 examples/modeling/multi_event_classifier.py create mode 100644 riid/models/layers.py diff --git a/.github/workflows/build-and-run-tests.yml b/.github/workflows/build-and-run-tests.yml index b1c8e2a4..96633e40 100644 --- a/.github/workflows/build-and-run-tests.yml +++ b/.github/workflows/build-and-run-tests.yml @@ -7,8 +7,8 @@ jobs: build: strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] - os: [ubuntu-latest, windows-latest, macos-13] + python-version: ["3.9", "3.10", "3.11", "3.12"] + os: [ubuntu-latest, windows-latest, macos-latest] runs-on: ${{ matrix.os }} steps: - name: Checkout @@ -19,6 +19,12 @@ jobs: python-version: ${{ matrix.python-version }} cache: "pip" cache-dependency-path: "**/pyproject.toml" + - name: Install HDF5 (macOS only) + if: runner.os == 'macOS' + run: brew install hdf5 + - name: Set HDF5_DIR environment variable (macOS only) + if: runner.os == 'macOS' + run: echo "HDF5_DIR=$(brew --prefix hdf5)" >> $GITHUB_ENV - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/run-examples.yml b/.github/workflows/run-examples.yml index 295143f1..ddf06b6e 100644 --- a/.github/workflows/run-examples.yml +++ b/.github/workflows/run-examples.yml @@ -7,8 +7,8 @@ jobs: build: strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] - os: [ubuntu-latest, windows-latest, macos-13] + python-version: ["3.9", "3.10", "3.11", "3.12"] + os: [ubuntu-latest, windows-latest, macos-latest] runs-on: ${{ matrix.os }} steps: - name: Checkout @@ -19,6 +19,12 @@ jobs: python-version: ${{ matrix.python-version }} cache: "pip" cache-dependency-path: "**/pyproject.toml" + - name: Install HDF5 (macOS only) + if: runner.os == 'macOS' + run: brew install hdf5 + - name: Set HDF5_DIR environment variable (macOS only) + if: runner.os == 'macOS' + run: echo "HDF5_DIR=$(brew --prefix hdf5)" >> $GITHUB_ENV - name: Install dependencies run: | python -m pip install --upgrade pip setuptools wheel diff --git a/README.md b/README.md index 7d7c285e..8f01a50c 100644 --- a/README.md +++ b/README.md @@ -11,23 +11,22 @@ PyRIID is a Python package providing modeling and data synthesis utilities for m Requirements: -- Python version: 3.8 to 3.10 +- Python version: 3.9 to 3.12 + - Note: we recommended the highest Python version you can manage as anecdotally, we have noticed that everything just tends to get faster. - Operating systems: Windows, Mac, or Ubuntu -A virtual environment is recommended. - Tests and examples are run via Actions on many combinations of Python version and operating system. You can verify support for your platform by checking the workflow files. ### For Use -To use the latest version on PyPI (note: changes are slower to appear here), run: +To use the latest version on PyPI, run: ```sh pip install riid ``` -**For the latest features, run:** +Note that changes are slower to appear on PyPI, so for the latest features, run:** ```sh pip install git+https://github.com/sandialabs/pyriid.git@main @@ -86,7 +85,7 @@ Full copyright details can be found [here](https://github.com/sandialabs/PyRIID/ ## Acknowledgements **Thank you** to the U.S. Department of Energy, National Nuclear Security Administration, -Office of Defense Nuclear Nonproliferation Research and Development (DNN R&D) for funding that has led to version `2.x`. +Office of Defense Nuclear Nonproliferation Research and Development (DNN R&D) for funding that has led to versions `2.0` and `2.1`. Additionally, **thank you** to the following individuals who have provided invaluable subject-matter expertise: diff --git a/examples/data/synthesis/mix_seeds.py b/examples/data/synthesis/mix_seeds.py index 907dfc5a..911fc56d 100644 --- a/examples/data/synthesis/mix_seeds.py +++ b/examples/data/synthesis/mix_seeds.py @@ -2,14 +2,16 @@ # Under the terms of Contract DE-NA0003525 with NTESS, # the U.S. Government retains certain rights in this software. """This example demonstrates how to generate synthetic gamma spectra from seeds.""" +import numpy as np from riid.data.synthetic import get_dummy_seeds from riid.data.synthetic.seed import SeedMixer fg_seeds_ss, bg_seeds_ss = get_dummy_seeds().split_fg_and_bg() -mixed_fg_seeds_ss = SeedMixer(fg_seeds_ss, mixture_size=2)\ +rng = np.random.default_rng(3) +mixed_fg_seeds_ss = SeedMixer(fg_seeds_ss, mixture_size=2, rng=rng)\ .generate(n_samples=10) -mixed_bg_seeds_ss = SeedMixer(bg_seeds_ss, mixture_size=3)\ +mixed_bg_seeds_ss = SeedMixer(bg_seeds_ss, mixture_size=3, rng=rng)\ .generate(n_samples=10) print(mixed_fg_seeds_ss) diff --git a/examples/modeling/arad.py b/examples/modeling/arad.py index 846975fe..f15d2e4e 100644 --- a/examples/modeling/arad.py +++ b/examples/modeling/arad.py @@ -14,7 +14,7 @@ # Config rng = np.random.default_rng(42) OOD_QUANTILE = 0.99 -VERBOSE = True +VERBOSE = False # Some of the following parameters are set low because this example runs on GitHub Actions and # we don't want it taking a bunch of time. # When running this locally, change the values per their corresponding comment, otherwise @@ -54,7 +54,7 @@ arad.predict(gross_train_ss) ood_threshold = np.quantile(gross_train_ss.info.recon_error, OOD_QUANTILE) - reconstructions = arad.predict(test_ss, verbose=True) + reconstructions = arad.predict(test_ss, verbose=VERBOSE) ood = test_ss.info.recon_error.values > ood_threshold false_positive_rate = ood.mean() mean_recon_error = test_ss.info.recon_error.values.mean() diff --git a/examples/modeling/arad_latent_prediction.py b/examples/modeling/arad_latent_prediction.py index fed2d805..9b233954 100644 --- a/examples/modeling/arad_latent_prediction.py +++ b/examples/modeling/arad_latent_prediction.py @@ -5,16 +5,17 @@ from an ARAD latent space. """ import numpy as np +from keras.api.metrics import Accuracy, CategoricalCrossentropy from sklearn.metrics import f1_score, mean_squared_error from riid.data.synthetic import get_dummy_seeds from riid.data.synthetic.seed import SeedMixer from riid.data.synthetic.static import StaticSynthesizer -from riid.models.neural_nets.arad import ARADv2, ARADLatentPredictor +from riid.models.neural_nets.arad import ARADLatentPredictor, ARADv2 # Config rng = np.random.default_rng(42) -VERBOSE = True +VERBOSE = False # Some of the following parameters are set low because this example runs on GitHub Actions and # we don't want it taking a bunch of time. # When running this locally, change the values per their corresponding comment, otherwise @@ -66,7 +67,7 @@ print("Training Classifier") arad_classifier = ARADLatentPredictor( loss="categorical_crossentropy", - metrics=("accuracy", "categorical_crossentropy"), + metrics=[Accuracy(), CategoricalCrossentropy()], final_activation="softmax" ) arad_classifier.fit( diff --git a/examples/modeling/classifier_comparison.py b/examples/modeling/classifier_comparison.py index 7b1b4e12..4fce43e7 100644 --- a/examples/modeling/classifier_comparison.py +++ b/examples/modeling/classifier_comparison.py @@ -36,8 +36,8 @@ train_fg_ss, _ = static_synth.generate(fg_seeds_ss, mixed_bg_seed_ss, verbose=False) train_fg_ss.normalize() -model_nn = MLPClassifier(hidden_layers=(5,)) -model_nn.fit(train_fg_ss, epochs=10, patience=5, verbose=1) +model_nn = MLPClassifier() +model_nn.fit(train_fg_ss, epochs=10, patience=5) # Create PB model model_pb = PoissonBayesClassifier() diff --git a/examples/modeling/label_proportion_estimation.py b/examples/modeling/label_proportion_estimation.py index 8f3a4ef5..b4e5e4c4 100644 --- a/examples/modeling/label_proportion_estimation.py +++ b/examples/modeling/label_proportion_estimation.py @@ -55,7 +55,6 @@ batch_size=10, epochs=2, validation_split=0.2, - verbose=True, bg_cps=300 ) diff --git a/examples/modeling/multi_event_classifier.py b/examples/modeling/multi_event_classifier.py deleted file mode 100644 index eac23d44..00000000 --- a/examples/modeling/multi_event_classifier.py +++ /dev/null @@ -1,74 +0,0 @@ -# Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). -# Under the terms of Contract DE-NA0003525 with NTESS, -# the U.S. Government retains certain rights in this software. -"""This example demonstrates how to use the MLP classifier.""" -from copy import deepcopy as copy - -from sklearn.metrics import f1_score - -from riid.data.synthetic import get_dummy_seeds -from riid.data.synthetic.static import StaticSynthesizer -from riid.models.neural_nets import MLPClassifier, MultiEventClassifier - -fg_seeds_ss, bg_seeds_ss = get_dummy_seeds().split_fg_and_bg() -static_syn = StaticSynthesizer( - # log10 sampling samples lower SNR values more frequently. - # This makes the SampleSet overall "harder" to classify. - snr_function="log10", - samples_per_seed=50, - return_fg=True, - return_gross=True, -) - -# Generate some training data -fg_ss, gross_ss = static_syn.generate(fg_seeds_ss, bg_seeds_ss) -bg_ss = gross_ss - fg_ss -bg_ss.normalize() -gross_ss.normalize() - -# Train two single event classifiers -model1 = MLPClassifier() -model1.fit(gross_ss, bg_ss=bg_ss, verbose=1, epochs=50, patience=20) - -model2 = MLPClassifier() -model2.fit(gross_ss, bg_ss=bg_ss, verbose=1, epochs=50, patience=20) - -# Generate two sample sets (with same sources but predictions from different models) -train2a_fg_ss, train2a_ss = static_syn.generate(fg_seeds_ss, bg_seeds_ss) -train2a_bg_ss = train2a_ss - train2a_fg_ss -train2b_ss = copy(train2a_ss) -train2b_bg_ss = copy(train2a_bg_ss) - -model1.predict(train2a_ss, train2a_bg_ss) -model2.predict(train2b_ss, train2b_bg_ss) - -# Train MultiEvent model -mec = MultiEventClassifier() -mec.fit( - [train2a_ss, train2b_ss], - train2a_ss.sources.groupby(axis=1, level="Isotope", sort=False).sum(), - epochs=50 -) - -# Make predictions on multi model - -multi_preds = mec.predict([train2a_ss, train2b_ss]) - -# Compare performance for single event models and multi-event model -m1_f1_score = f1_score(train2a_ss.get_predictions(), - train2a_ss.get_labels(), - average="weighted") -m2_f1_score = f1_score(train2b_ss.get_predictions(), - train2b_ss.get_labels(), - average="weighted") - -multi_f1_score = f1_score(multi_preds.values.argmax(axis=1), - train2a_ss.get_source_contributions().argmax(axis=1), - average="weighted") - -results_str = ( - f"M1 F1 Score: {m1_f1_score:.2f}\n" - f"M2 F1 Score: {m2_f1_score:.2f}\n" - f"M12 F1 Score: {multi_f1_score:.2f}" -) -print(results_str) diff --git a/examples/modeling/neural_network_classifier.py b/examples/modeling/neural_network_classifier.py index 8ab7d55f..277a2dc5 100644 --- a/examples/modeling/neural_network_classifier.py +++ b/examples/modeling/neural_network_classifier.py @@ -24,7 +24,7 @@ train_ss.normalize() model = MLPClassifier() -model.fit(train_ss, epochs=10, patience=5, verbose=1) +model.fit(train_ss, epochs=10, patience=5) # Generate some test data static_synth.samples_per_seed = 50 diff --git a/examples/visualization/confusion_matrix.py b/examples/visualization/confusion_matrix.py index 704f4d1b..b2e25e72 100644 --- a/examples/visualization/confusion_matrix.py +++ b/examples/visualization/confusion_matrix.py @@ -26,7 +26,7 @@ .generate(fg_seeds_ss, mixed_bg_seed_ss) train_ss.normalize() -model = MLPClassifier(hidden_layers=(8,)) +model = MLPClassifier() model.fit(train_ss, verbose=0, epochs=50) # Generate some test data diff --git a/pyproject.toml b/pyproject.toml index a7af43bc..c20c480e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ namespaces = false [project] name = "riid" description = "Machine learning-based models and utilities for radioisotope identification" -version = "2.1.0" +version = "2.2.0" maintainers = [ {name="Tyler Morrow", email="tmorro@sandia.gov"}, ] @@ -41,44 +41,43 @@ classifiers = [ 'Topic :: Software Development', 'Topic :: Software Development :: Libraries', 'Topic :: Software Development :: Libraries :: Python Modules', + 'Programming Language :: Python', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3 :: Only', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + 'Programming Language :: Python :: 3.12', ] keywords = ["pyriid", "riid", "machine learning", "radioisotope identification", "gamma spectrum"] -requires-python = ">=3.8,<3.11" +requires-python = ">=3.9,<3.13" dependencies = [ - "jsonschema ==4.17.*", - "matplotlib ==3.7.*", - "pyyaml ==6.0.*", - "seaborn ==0.12.*", - "tf2onnx ==1.14.*", - "onnx ==1.14.1", - "tqdm ==4.65.*", - "numpy ==1.23.*", - "pandas ==2.0.*", - "parmap ==1.6.*", - "pythonnet ==3.0.*; platform_system == 'Windows'", - "tables ==3.8.*", - "tensorflow ==2.12.*", - "tensorflow-io ==0.27.*", - "tensorflow-model-optimization ==0.7.*", - "tensorflow-probability ==0.20.*", - "typeguard ==2.7.*", - "scikit-learn ==1.2.*", + "jsonschema ==4.23.*", # 3.8 - 3.13 + "matplotlib ==3.9.*", # 3.9 - 3.12 + "numpy ==1.26.*", # 3.9 - 3.12, also to be limited by onnx 1.16.2 + "pandas ==2.2.*", # >= 3.9 + "pythonnet ==3.0.3; platform_system == 'Windows'", # 3.7 - 3.12 + "pyyaml ==6.0.*", # >= 3.6 + "tables ==3.9.*", # >= 3.9 + "scikit-learn ==1.5.*", # 3.9 - 3.12 + "scipy ==1.13.*", # >= 3.10 + "seaborn ==0.13.*", # >= 3.8 + "tensorflow ==2.16.*", # 3.9 - 3.12 + "tensorflow-model-optimization ==0.8.*", # 3.7 - 3.12 + "onnx ==1.16.1", # 3.7 - 3.10 + "tf2onnx ==1.16.1", # 3.7 - 3.10 + "tqdm ==4.66.*", # >= 3.7 + "typeguard ==4.3.*", # 3.9 - 3.12 ] [project.optional-dependencies] dev = [ - "flake8", - "flake8-quotes", "coverage", "ipykernel", + "flake8", + "flake8-quotes", "tabulate", ] [project.urls] -repository = "https://github.com/sandialabs/PyRIID" +Documentation = "https://sandialabs.github.io/PyRIID" +Repository = "https://github.com/sandialabs/PyRIID" diff --git a/riid/data/converters/__init__.py b/riid/data/converters/__init__.py index 2959f299..c956e827 100644 --- a/riid/data/converters/__init__.py +++ b/riid/data/converters/__init__.py @@ -6,7 +6,7 @@ from pathlib import Path from typing import Callable -import parmap as pm +from joblib import Parallel, delayed def _validate_and_create_output_dir(output_dir: str): @@ -17,8 +17,8 @@ def _validate_and_create_output_dir(output_dir: str): def convert_directory(input_dir_path: str, conversion_func: Callable, file_ext: str, - pm_processes: int = 8, pm_chunksize: int = 1, **kwargs): - """Convert and save every file in a specified directory in parallel. + n_jobs: int = 8, **kwargs): + """Convert and save every file in a specified directory. Conversion functions can be found in sub-modules: @@ -32,16 +32,14 @@ def convert_directory(input_dir_path: str, conversion_func: Callable, file_ext: convert_directory(...) ``` - Consider setting `pm_processes` to `multiprocessing.cpu_count()`; - unfortunately, `pm_chunksize` requires some experimentation to fully optimize. + Tip: for max utilization, considering setting `n_jobs` to `multiprocessing.cpu_count()`. Args: input_dir_path: directory path containing the input files conversion_func: function used to convert a data file to a `SampleSet` file_ext: file extension to read in for conversion - pm_processes: parmap parameter to set the # of processes - pm_chunksize: parmap parameter to set the chunksize - kwargs: keyword args passed to underlying conversion_func operations + n_jobs: `joblib.Parallel` parameter to set the # of jobs + kwargs: additional keyword args passed to conversion_func """ input_path = Path(input_dir_path) if not input_path.exists() or not input_path.is_dir(): @@ -50,13 +48,6 @@ def convert_directory(input_dir_path: str, conversion_func: Callable, file_ext: input_file_paths = sorted(glob.glob(f"{input_dir_path}/*.{file_ext}")) - x = pm.map( - conversion_func, - input_file_paths, - **kwargs, - pm_processes=pm_processes, - pm_chunksize=pm_chunksize, - pm_parallel=True, - pm_pbar=True, + Parallel(n_jobs, verbose=10)( + delayed(conversion_func)(path, **kwargs) for path in input_file_paths ) - return x diff --git a/riid/data/sampleset.py b/riid/data/sampleset.py index 9e520428..47dd0332 100644 --- a/riid/data/sampleset.py +++ b/riid/data/sampleset.py @@ -13,7 +13,7 @@ import random import re import warnings -from datetime import datetime +from datetime import datetime, timezone from enum import Enum from typing import Callable, Iterable, Tuple, Union @@ -759,8 +759,9 @@ def concat(self, ss_list: list): sort=False ) self._sources = self._sources.where(pd.notnull(self._sources), 0) + existing_info_df = self._info if not self._info.empty else None self._info = pd.concat( - [self._info] + [ss.info for ss in ss_list], + [existing_info_df] + [ss.info for ss in ss_list], ignore_index=True, sort=False ) @@ -1015,7 +1016,8 @@ def get_confidences(self, fg_seeds_ss: SampleSet, bg_seed_ss: SampleSet = None, self.spectra.values, **confidence_func_kwargs ) - return np.array(confidences) + confidences = np.array(confidences) + return confidences def _get_spectral_distances(self, distance_func=distance.jensenshannon) -> np.array: n_samples = self.n_samples @@ -1111,7 +1113,7 @@ def get_source_contributions(self, target_level="Isotope") -> np.ndarray: Returns: Array containing the ground truth contributions for each sample """ - collapsed_sources = self.sources.groupby(axis=1, level=target_level).sum() + collapsed_sources = self.sources.T.groupby(target_level).sum().T sources_values = np.nan_to_num(collapsed_sources) return sources_values @@ -1465,7 +1467,8 @@ def _dict_to_bulleted_list(data_dict: dict, level=0, indent=4, bullet="-") -> st def _get_utc_timestamp(): - ts = datetime.utcnow().isoformat(sep=" ", timespec="seconds") + now_utc = datetime.now(timezone.utc) + ts = now_utc.isoformat(sep=" ", timespec="seconds") return ts @@ -1496,10 +1499,10 @@ def _get_row_labels(df: pd.DataFrame, target_level: str = "Isotope", max_only: b """ if max_only: if level_aggregation == "sum": - values = df.groupby(axis=1, level=target_level).sum() + values = df.T.groupby(target_level).sum().T labels = values.idxmax(axis=1) elif level_aggregation == "mean": - values = df.groupby(axis=1, level=target_level).mean() + values = df.T.groupby(target_level).mean().T labels = values.idxmax(axis=1) else: levels_to_drop = [ @@ -1513,9 +1516,9 @@ def _get_row_labels(df: pd.DataFrame, target_level: str = "Isotope", max_only: b labels = [f"{x} ({y:.2f})" for x, y in zip(labels, values)] else: # Much slower if level_aggregation == "sum": - values = df.groupby(axis=1, level=target_level).sum() + values = df.T.groupby(target_level).sum().T elif level_aggregation == "mean": - values = df.groupby(axis=1, level=target_level).mean() + values = df.T.groupby(target_level).mean().T else: values = df mask = values.ge(min_value).values @@ -1698,7 +1701,7 @@ def _ss_to_pcf_dict(ss: SampleSet, verbose=False) -> dict: info_timestamps = ss.info.timestamp.fillna("") info_tags = ss.info.tag.fillna("") - info_zero_fill = ss.info.fillna(0).astype(float, errors="ignore") + info_zero_fill = ss.info.infer_objects(copy=False).astype(float, errors="ignore") for i in sample_range: title = isotopes[i] if not isotopes.empty else NO_ISOTOPE description = ss.info.description.fillna("").iloc[i] diff --git a/riid/data/synthetic/seed.py b/riid/data/synthetic/seed.py index 08e19215..9fd1b2c4 100644 --- a/riid/data/synthetic/seed.py +++ b/riid/data/synthetic/seed.py @@ -292,10 +292,13 @@ def __call__(self, n_samples: int, max_batch_size: int = 100, ) sources_df = pd.DataFrame([r], columns=sources_cols) batch_sources_dfs.append(sources_df) - empty_sources_df = pd.DataFrame([], columns=self.seeds_ss.sources.columns) - batch_ss.sources = pd\ - .concat([empty_sources_df] + batch_sources_dfs)\ + sources_df = pd\ + .concat(batch_sources_dfs)\ .fillna(0.0) + batch_ss.sources = sources_df.reindex( + columns=self.seeds_ss.sources.columns, + fill_value=0.0 + ) n_samples_produced += batch_size diff --git a/riid/losses/__init__.py b/riid/losses/__init__.py index ae18ccc8..26ade493 100644 --- a/riid/losses/__init__.py +++ b/riid/losses/__init__.py @@ -4,7 +4,7 @@ """This module contains custom loss functions.""" import numpy as np import tensorflow as tf -from keras import backend as K +from keras.api import ops def negative_log_f1(y_true: np.ndarray, y_pred: np.ndarray): @@ -18,13 +18,13 @@ def negative_log_f1(y_true: np.ndarray, y_pred: np.ndarray): Custom loss score on a log scale """ diff = y_true - y_pred - negs = K.clip(diff, -1.0, 0.0) - false_positive = -K.sum(negs, axis=-1) + negs = ops.clip(diff, -1.0, 0.0) + false_positive = -ops.sum(negs, axis=-1) true_positive = 1.0 - false_positive lower_clip = 1e-20 - true_positive = K.clip(true_positive, lower_clip, 1.0) + true_positive = ops.clip(true_positive, lower_clip, 1.0) - return -K.mean(K.log(true_positive)) + return -ops.mean(ops.log(true_positive)) def negative_f1(y_true, y_pred): @@ -38,13 +38,13 @@ def negative_f1(y_true, y_pred): Custom loss score """ diff = y_true - y_pred - negs = K.clip(diff, -1.0, 0.0) - false_positive = -K.sum(negs, axis=-1) + negs = ops.clip(diff, -1.0, 0.0) + false_positive = -ops.sum(negs, axis=-1) true_positive = 1.0 - false_positive lower_clip = 1e-20 - true_positive = K.clip(true_positive, lower_clip, 1.0) + true_positive = ops.clip(true_positive, lower_clip, 1.0) - return -K.mean(true_positive) + return -ops.mean(true_positive) def build_keras_semisupervised_loss_func(supervised_loss_func, @@ -54,6 +54,7 @@ def build_keras_semisupervised_loss_func(supervised_loss_func, normalize: bool = False, normalize_scaler: float = 1.0, normalize_func=tf.math.tanh): + @tf.keras.utils.register_keras_serializable(package="Addons") def _semisupervised_loss_func(data, y_pred): """ Args: diff --git a/riid/metrics.py b/riid/metrics.py index 8b5c48dd..aa677448 100644 --- a/riid/metrics.py +++ b/riid/metrics.py @@ -20,14 +20,14 @@ def multi_f1(y_true: np.ndarray, y_pred: np.ndarray) -> float: Returns: Multi F1-score value(s) """ - from keras import backend as K + from keras.api import ops diff = y_true - y_pred - negs = K.clip(diff, -1.0, 0.0) - false_positive = -K.sum(negs, axis=-1) + negs = ops.clip(diff, -1.0, 0.0) + false_positive = -ops.sum(negs, axis=-1) true_positive = 1.0 - false_positive - return K.mean(true_positive) + return ops.mean(true_positive) def single_f1(y_true: np.ndarray, y_pred: np.ndarray): @@ -43,23 +43,23 @@ def single_f1(y_true: np.ndarray, y_pred: np.ndarray): F1-score value(s) """ import tensorflow as tf - from keras import backend as K + from keras.api import ops - a = tf.dtypes.cast(y_true == K.max(y_true, axis=1)[:, None], tf.float32) - b = tf.dtypes.cast(y_pred == K.max(y_pred, axis=1)[:, None], tf.float32) + a = tf.dtypes.cast(y_true == ops.max(y_true, axis=1)[:, None], tf.float32) + b = tf.dtypes.cast(y_pred == ops.max(y_pred, axis=1)[:, None], tf.float32) - TP_mat = tf.dtypes.cast(K.all(tf.stack([a, b]), axis=0), tf.float32) - FP_mat = tf.dtypes.cast(K.all(tf.stack([a != b, b == 1]), axis=0), tf.float32) - FN_mat = tf.dtypes.cast(K.all(tf.stack([a != b, a == 1]), axis=0), tf.float32) + TP_mat = tf.dtypes.cast(ops.all(tf.stack([a, b]), axis=0), tf.float32) + FP_mat = tf.dtypes.cast(ops.all(tf.stack([a != b, b == 1]), axis=0), tf.float32) + FN_mat = tf.dtypes.cast(ops.all(tf.stack([a != b, a == 1]), axis=0), tf.float32) - TPs = K.sum(TP_mat, axis=0) - FPs = K.sum(FP_mat, axis=0) - FNs = K.sum(FN_mat, axis=0) + TPs = ops.sum(TP_mat, axis=0) + FPs = ops.sum(FP_mat, axis=0) + FNs = ops.sum(FN_mat, axis=0) - F1s = 2 * TPs / (2*TPs + FNs + FPs + tf.fill(tf.shape(TPs), K.epsilon())) + F1s = 2 * TPs / (2*TPs + FNs + FPs + tf.fill(tf.shape(TPs), tf.keras.backend.epsilon())) - support = K.sum(a, axis=0) - f1 = K.sum(F1s * support) / K.sum(support) + support = ops.sum(a, axis=0) + f1 = ops.sum(F1s * support) / ops.sum(support) return f1 @@ -122,10 +122,10 @@ def precision_recall_curve(ss: SampleSet, smooth: bool = True, multiclass: bool https://jonathan-hui.medium.com/map-mean-average-precision-for-object-detection-45c121a31173) """ - y_true = ss.sources.groupby(axis=1, level=target_level, sort=False).sum() + y_true = ss.sources.T.groupby(target_level, sort=False).sum().T if minimum_contribution is not None: y_true = (y_true > minimum_contribution).astype(int) - y_pred = ss.prediction_probas.groupby(axis=1, level=target_level, sort=False).sum() + y_pred = ss.prediction_probas.T.groupby(target_level, sort=False).sum().T # switch from pandas to numpy labels = y_true.columns @@ -250,6 +250,6 @@ def build_keras_semisupervised_metric_func(keras_metric_func, activation_func, n_labels): def metric_func(y_true, y_pred): return keras_metric_func(y_true[:, :n_labels], activation_func(y_pred)) - metric_func.__name__ = keras_metric_func.__name__ + metric_func.__name__ = keras_metric_func.__class__.__name__ return metric_func diff --git a/riid/models/__init__.py b/riid/models/__init__.py index f2c29971..0825b251 100644 --- a/riid/models/__init__.py +++ b/riid/models/__init__.py @@ -11,8 +11,8 @@ import numpy as np import tensorflow as tf import tf2onnx -from keras.models import Model -from keras.utils import get_custom_objects +from keras.api.models import Model +from keras.api.utils import get_custom_objects import riid from riid.data.labeling import label_to_index_element diff --git a/riid/models/bayes.py b/riid/models/bayes.py index 758b0229..6850cbb1 100644 --- a/riid/models/bayes.py +++ b/riid/models/bayes.py @@ -5,10 +5,14 @@ import numpy as np import pandas as pd import tensorflow as tf -import tensorflow_probability as tfp +from keras.api.layers import Add, Input, Multiply, Subtract +from keras.api.models import Model from riid.data.sampleset import SampleSet from riid.models import PyRIIDModel +from riid.models.layers import (ClipByValueLayer, DivideLayer, ExpandDimsLayer, + PoissonLogProbabilityLayer, ReduceMaxLayer, + ReduceSumLayer, SeedLayer) class PoissonBayesClassifier(PyRIIDModel): @@ -29,6 +33,14 @@ class PoissonBayesClassifier(PyRIIDModel): def __init__(self): super().__init__() + self._update_custom_objects("ReduceSumLayer", ReduceSumLayer) + self._update_custom_objects("ReduceMaxLayer", ReduceMaxLayer) + self._update_custom_objects("DivideLayer", DivideLayer) + self._update_custom_objects("ExpandDimsLayer", ExpandDimsLayer) + self._update_custom_objects("ClipByValueLayer", ClipByValueLayer) + self._update_custom_objects("PoissonLogProbabilityLayer", PoissonLogProbabilityLayer) + self._update_custom_objects("SeedLayer", SeedLayer) + def fit(self, seeds_ss: SampleSet): """Construct a TF-based implementation of a poisson-bayes classifier in terms of the given seeds. @@ -50,74 +62,98 @@ def fit(self, seeds_ss: SampleSet): msg = "Argument 'seeds_ss' can't contain any spectra with zero total counts." raise ZeroTotalCountsError(msg) - self._seeds = tf.constant(tf.convert_to_tensor( + self._seeds = tf.convert_to_tensor( seeds_ss.spectra.values, dtype=tf.float32 - )) - - # Inputs - gross_spectrum_input = tf.keras.layers.Input( - shape=seeds_ss.n_channels, - name="gross_spectrum" - ) - gross_live_time_input = tf.keras.layers.Input( - shape=(), - name="gross_live_time" - ) - bg_spectrum_input = tf.keras.layers.Input( - shape=seeds_ss.n_channels, - name="bg_spectrum" - ) - bg_live_time_input = tf.keras.layers.Input( - shape=(), - name="bg_live_time" - ) - - # Compute expected_seed_spectrums - gross_total_counts = tf.reduce_sum(gross_spectrum_input, axis=1) - bg_total_counts = tf.reduce_sum(bg_spectrum_input, axis=1) - bg_count_rate = tf.divide(bg_total_counts, bg_live_time_input) - expected_bg_counts = tf.multiply(bg_count_rate, gross_live_time_input) - expected_fg_counts = tf.subtract(gross_total_counts, expected_bg_counts) - normalized_bg_spectrum = tf.divide( - bg_spectrum_input, - tf.expand_dims(bg_total_counts, axis=1) - ) - expected_bg_spectrum = tf.multiply( - normalized_bg_spectrum, - tf.expand_dims(expected_bg_counts, axis=1) - ) - expected_fg_spectrum = tf.multiply( - self._seeds, - tf.expand_dims(tf.expand_dims( - expected_fg_counts, - axis=-1 - ), axis=-1) - ) - max_value = tf.math.reduce_max(expected_fg_spectrum) - expected_fg_spectrum = tf.clip_by_value(expected_fg_spectrum, 1e-8, max_value) - expected_gross_spectrum = tf.add( - expected_fg_spectrum, - tf.expand_dims(expected_bg_spectrum, axis=1) ) - poisson_dist = tfp.distributions.Poisson(expected_gross_spectrum) - all_probas = poisson_dist.log_prob( - tf.expand_dims(gross_spectrum_input, axis=1) - ) - prediction_probas = tf.math.reduce_sum(all_probas, axis=2) - + # Inputs + gross_spectrum_input = Input(shape=(seeds_ss.n_channels,), + name="gross_spectrum") + gross_live_time_input = Input(shape=(), + name="gross_live_time") + bg_spectrum_input = Input(shape=(seeds_ss.n_channels,), + name="bg_spectrum") + bg_live_time_input = Input(shape=(), + name="bg_live_time") model_inputs = ( gross_spectrum_input, gross_live_time_input, bg_spectrum_input, bg_live_time_input, ) - self.model = tf.keras.Model(model_inputs, prediction_probas) + + # Input statistics + gross_total_counts = ReduceSumLayer(name="gross_total_counts")(gross_spectrum_input, axis=1) + bg_total_counts = ReduceSumLayer(name="bg_total_counts")(bg_spectrum_input, axis=1) + bg_count_rate = DivideLayer(name="bg_count_rate")([bg_total_counts, bg_live_time_input]) + + gross_spectrum_input_expanded = ExpandDimsLayer( + name="gross_spectrum_input_expanded" + )(gross_spectrum_input, axis=1) + bg_total_counts_expanded = ExpandDimsLayer( + name="bg_total_counts_expanded" + )(bg_total_counts, axis=1) + + # Expectations + seed_layer = SeedLayer(self._seeds)(model_inputs) + seed_layer_expanded = ExpandDimsLayer()(seed_layer, axis=0) + expected_bg_counts = Multiply( + trainable=False, + name="expected_bg_counts" + )([bg_count_rate, gross_live_time_input]) + expected_bg_counts_expanded = ExpandDimsLayer( + name="expected_bg_counts_expanded" + )(expected_bg_counts, axis=1) + normalized_bg_spectrum = DivideLayer( + name="normalized_bg_spectrum" + )([bg_spectrum_input, bg_total_counts_expanded]) + expected_bg_spectrum = Multiply( + trainable=False, + name="expected_bg_spectrum" + )([normalized_bg_spectrum, expected_bg_counts_expanded]) + expected_fg_counts = Subtract( + trainable=False, + name="expected_fg_counts" + )([gross_total_counts, expected_bg_counts]) + expected_fg_counts_expanded = ExpandDimsLayer( + name="expected_fg_counts_expanded" + )(expected_fg_counts, axis=-1) + expected_fg_counts_expanded2 = ExpandDimsLayer( + name="expected_fg_counts_expanded2" + )(expected_fg_counts_expanded, axis=-1) + expected_fg_spectrum = Multiply( + trainable=False, + name="expected_fg_spectrum" + )([seed_layer_expanded, expected_fg_counts_expanded2]) + max_fg_value = ReduceMaxLayer( + name="max_fg_value" + )(expected_fg_spectrum) + expected_fg_spectrum = ClipByValueLayer( + name="clip_expected_fg_spectrum" + )(expected_fg_spectrum, clip_value_min=1e-8, clip_value_max=max_fg_value) + expected_bg_spectrum_expanded = ExpandDimsLayer( + name="expected_bg_spectrum_expanded" + )(expected_bg_spectrum, axis=1) + expected_gross_spectrum = Add( + trainable=False, + name="expected_gross_spectrum" + )([expected_fg_spectrum, expected_bg_spectrum_expanded]) + + # Compute probabilities + log_probabilities = PoissonLogProbabilityLayer( + name="log_probabilities" + )([expected_gross_spectrum, gross_spectrum_input_expanded]) + summed_log_probabilities = ReduceSumLayer( + name="summed_log_probabilities" + )(log_probabilities, axis=2) + + # Assemble model + self.model = Model(model_inputs, summed_log_probabilities) self.model.compile() self.target_level = "Seed" - sources_df = seeds_ss.sources.groupby(axis=1, level=self.target_level, sort=False).sum() + sources_df = seeds_ss.sources.T.groupby(self.target_level, sort=False).sum().T self.model_outputs = sources_df.columns.values.tolist() def predict(self, gross_ss: SampleSet, bg_ss: SampleSet, diff --git a/riid/models/layers.py b/riid/models/layers.py new file mode 100644 index 00000000..c250e09f --- /dev/null +++ b/riid/models/layers.py @@ -0,0 +1,82 @@ +# Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). +# Under the terms of Contract DE-NA0003525 with NTESS, +# the U.S. Government retains certain rights in this software. +"""This module contains custom Keras layers.""" +import tensorflow as tf +from keras.api.layers import Layer + + +class ReduceSumLayer(Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def call(self, x, axis): + return tf.reduce_sum(x, axis=axis) + + +class ReduceMaxLayer(Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def call(self, x): + return tf.reduce_max(x) + + +class DivideLayer(Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def call(self, x): + return tf.divide(x[0], x[1]) + + +class ExpandDimsLayer(Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def call(self, x, axis): + return tf.expand_dims(x, axis=axis) + + +class ClipByValueLayer(Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def call(self, x, clip_value_min, clip_value_max): + return tf.clip_by_value(x, clip_value_min=clip_value_min, clip_value_max=clip_value_max) + + +class PoissonLogProbabilityLayer(Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def call(self, x): + exp, value = x + log_probas = tf.math.xlogy(value, exp) - exp - tf.math.lgamma(value + 1) + return log_probas + + +class SeedLayer(Layer): + def __init__(self, seeds, **kwargs): + super(SeedLayer, self).__init__(**kwargs) + self.seeds = tf.convert_to_tensor(seeds) + + def get_config(self): + config = super().get_config() + config.update({ + "seeds": self.seeds.numpy().tolist(), + }) + return config + + def call(self, inputs): + return self.seeds + + +class L1NormLayer(Layer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def call(self, inputs): + sums = tf.reduce_sum(inputs, axis=-1) + l1_norm = inputs / tf.reshape(sums, (-1, 1)) + return l1_norm diff --git a/riid/models/neural_nets/__init__.py b/riid/models/neural_nets/__init__.py index e46c700c..42443f88 100644 --- a/riid/models/neural_nets/__init__.py +++ b/riid/models/neural_nets/__init__.py @@ -2,88 +2,90 @@ # Under the terms of Contract DE-NA0003525 with NTESS, # the U.S. Government retains certain rights in this software. """This module contains neural network-based classifiers and regressors.""" -from typing import Any, List - +import keras import numpy as np import pandas as pd import tensorflow as tf -from keras.callbacks import EarlyStopping -from keras.layers import Dense, Dropout -from keras.optimizers import Adam -from keras.regularizers import L1L2, l1, l2 +from keras.api.activations import sigmoid, softmax +from keras.api.callbacks import EarlyStopping +from keras.api.layers import Dense, Dropout, Input +from keras.api.losses import CategoricalCrossentropy, MeanSquaredError +from keras.api.metrics import F1Score, Precision, Recall +from keras.api.models import Model +from keras.api.optimizers import Adam +from keras.api.regularizers import L1L2, l1, l2 +from keras.api.utils import split_dataset from scipy.interpolate import UnivariateSpline -from riid.data.sampleset import SampleSet +from riid.data.sampleset import SampleSet, SpectraType from riid.losses import (build_keras_semisupervised_loss_func, chi_squared_diff, jensen_shannon_divergence, normal_nll_diff, poisson_nll_diff, reconstruction_error, sse_diff, weighted_sse_diff) from riid.losses.sparsemax import SparsemaxLoss, sparsemax -from riid.metrics import (build_keras_semisupervised_metric_func, multi_f1, - single_f1) +from riid.metrics import build_keras_semisupervised_metric_func from riid.models import ModelInput, PyRIIDModel +from riid.models.layers import L1NormLayer class MLPClassifier(PyRIIDModel): """Multi-layer perceptron classifier.""" - def __init__(self, hidden_layers: tuple = (512,), activation: str = "relu", - loss: str = "categorical_crossentropy", - optimizer: Any = Adam(learning_rate=0.01, clipnorm=0.001), - metrics: tuple = ("accuracy", "categorical_crossentropy", multi_f1, single_f1), - l2_alpha: float = 1e-4, activity_regularizer=l1(0), dropout: float = 0.0, - learning_rate: float = 0.01, final_activation: str = "softmax"): + def __init__(self, activation=None, loss=None, optimizer=None, + metrics=None, l2_alpha: float = 1e-4, + activity_regularizer=None, final_activation=None): """ Args: - hidden_layers: tuple defining the number and size of dense layers activation: activate function to use for each dense layer loss: loss function to use for training optimizer: tensorflow optimizer or optimizer name to use for training metrics: list of metrics to be evaluating during training l2_alpha: alpha value for the L2 regularization of each dense layer activity_regularizer: regularizer function applied each dense layer output - dropout: amount of dropout to apply to each dense layer - learning_rate: learning rate to use for an Adam optimizer + final_activation: final activation function to apply to model output """ super().__init__() - self.hidden_layers = hidden_layers self.activation = activation - self.final_activation = final_activation self.loss = loss - if optimizer == "adam": - self.optimizer = Adam(learning_rate=learning_rate) - else: - self.optimizer = optimizer self.optimizer = optimizer + self.final_activation = final_activation self.metrics = metrics self.l2_alpha = l2_alpha self.activity_regularizer = activity_regularizer - self.dropout = dropout + self.final_activation = final_activation + + if self.activation is None: + self.activation = "relu" + if self.loss is None: + self.loss = CategoricalCrossentropy() + if optimizer is None: + self.optimizer = Adam(learning_rate=0.01, clipnorm=0.001) + if self.metrics is None: + self.metrics = [F1Score(), Precision(), Recall()] + if self.activity_regularizer is None: + self.activity_regularizer = l1(0.0) + if self.final_activation is None: + self.final_activation = "softmax" self.model = None + self._predict_fn = None - def fit(self, ss: SampleSet, bg_ss: SampleSet = None, - ss_input_type: ModelInput = ModelInput.GrossSpectrum, - bg_ss_input_type: ModelInput = ModelInput.BackgroundSpectrum, - batch_size: int = 200, epochs: int = 20, - validation_split: float = 0.2, callbacks=None, val_ss: SampleSet = None, - val_bg_ss: SampleSet = None, patience: int = 15, es_monitor: str = "val_loss", + def fit(self, ss: SampleSet, batch_size: int = 200, epochs: int = 20, + validation_split: float = 0.2, callbacks=None, + patience: int = 15, es_monitor: str = "val_loss", es_mode: str = "min", es_verbose=0, target_level="Isotope", verbose: bool = False): """Fit a model to the given `SampleSet`(s). Args: ss: `SampleSet` of `n` spectra where `n` >= 1 and the spectra are either foreground (AKA, "net") or gross. - bg_ss: `SampleSet` of `n` spectra where `n` >= 1 and the spectra are background batch_size: number of samples per gradient update epochs: maximum number of training iterations validation_split: percentage of the training data to use as validation data callbacks: list of callbacks to be passed to the TensorFlow `Model.fit()` method - val_ss: validation set to be used instead of taking a portion of the training data - val_bg_ss: validation set to be used as background for `val_ss` - patience: number of epochs to wait for `tf.keras.callbacks.EarlyStopping` - es_monitor: quantity to be monitored for `tf.keras.callbacks.EarlyStopping` - es_mode: mode for `tf.keras.callbacks.EarlyStopping` - es_verbose: verbosity level for `tf.keras.callbacks.EarlyStopping` + patience: number of epochs to wait for `EarlyStopping` object + es_monitor: quantity to be monitored for `EarlyStopping` object + es_mode: mode for `EarlyStopping` object + es_verbose: verbosity level for `EarlyStopping` object target_level: `SampleSet.sources` column level to use verbose: whether to show detailed model training output @@ -96,75 +98,44 @@ def fit(self, ss: SampleSet, bg_ss: SampleSet = None, if ss.n_samples <= 0: raise ValueError("No spectr[a|um] provided!") - x_train = ss.get_samples().astype(float) - source_contributions_df = ss.sources.groupby(axis=1, level=target_level, sort=False).sum() - y_train = source_contributions_df.values.astype(float) - if bg_ss: - x_bg_train = bg_ss.get_samples().astype(float) - - if val_ss: - if val_bg_ss: - val_data = ( - [val_ss.get_samples().astype(float), val_bg_ss.get_samples().astype(float)], - val_ss.get_source_contributions().astype(float), - ) - else: - val_data = ( - val_ss.get_samples().astype(float), - val_ss.get_source_contributions().astype(float), - ) - validation_split = None + if ss.spectra_type == SpectraType.Gross: + self.model_inputs = (ModelInput.GrossSpectrum,) + elif ss.spectra_type == SpectraType.Foreground: + self.model_inputs = (ModelInput.ForegroundSpectrum,) + elif ss.spectra_type == SpectraType.Background: + self.model_inputs = (ModelInput.BackgroundSpectrum,) else: - val_data = None - row_order = np.arange(x_train.shape[0]) - np.random.shuffle(row_order) - # Enforce random validation split through shuffling - x_train = x_train[row_order] - y_train = y_train[row_order] - - if bg_ss: - x_bg_train = x_bg_train[row_order] + raise ValueError(f"{ss.spectra_type} is not supported in this model.") + + X = ss.get_samples() + source_contributions_df = ss.sources.T.groupby(target_level, sort=False).sum().T + model_outputs = source_contributions_df.columns.values.tolist() + Y = source_contributions_df.values + + spectra_tensor = tf.convert_to_tensor(X, dtype=tf.float32) + labels_tensor = tf.convert_to_tensor(Y, dtype=tf.float32) + training_dataset = tf.data.Dataset.from_tensor_slices((spectra_tensor, labels_tensor)) + training_dataset, validation_dataset = split_dataset( + training_dataset, + left_size=validation_split, + shuffle=True + ) + training_dataset = training_dataset.batch(batch_size=batch_size) + validation_dataset = validation_dataset.batch(batch_size=batch_size) if not self.model: - spectra_input = tf.keras.layers.Input( - x_train.shape[1], - name=ss_input_type.name - ) - inputs = [spectra_input] - - if bg_ss: - background_spectra_input = tf.keras.layers.Input( - x_bg_train.shape[1], - name=bg_ss_input_type.name - ) - inputs.append(background_spectra_input) - - if len(inputs) > 1: - x = tf.keras.layers.Concatenate()(inputs) - else: - x = inputs[0] - - for layer, nodes in enumerate(self.hidden_layers): - if layer == 0: - x = Dense( - nodes, - activation=self.activation, - activity_regularizer=self.activity_regularizer, - kernel_regularizer=l2(self.l2_alpha), - )(x) - else: - x = Dense( - nodes, - activation=self.activation, - activity_regularizer=self.activity_regularizer, - kernel_regularizer=l2(self.l2_alpha), - )(x) - if self.dropout > 0: - x = Dropout(self.dropout)(x) - - output = Dense(y_train.shape[1], activation=self.final_activation)(x) - self.model = tf.keras.models.Model(inputs, output) - self.model.compile(loss=self.loss, optimizer=self.optimizer, metrics=self.metrics) + inputs = Input(shape=(X.shape[1],), name="Spectrum") + dense_layer_size = X.shape[1] // 2 + dense_layer = Dense( + dense_layer_size, + activation=self.activation, + activity_regularizer=self.activity_regularizer, + kernel_regularizer=l2(self.l2_alpha), + )(inputs) + outputs = Dense(Y.shape[1], activation=self.final_activation)(dense_layer) + self.model = Model(inputs, outputs) + self.model.compile(loss=self.loss, optimizer=self.optimizer, + metrics=self.metrics) es = EarlyStopping( monitor=es_monitor, @@ -173,41 +144,39 @@ def fit(self, ss: SampleSet, bg_ss: SampleSet = None, restore_best_weights=True, mode=es_mode, ) - if callbacks: callbacks.append(es) else: callbacks = [es] - if bg_ss: - X_data = [x_train, x_bg_train] - self.model_inputs = (ss_input_type, bg_ss_input_type) - else: - X_data = x_train - self.model_inputs = (ss_input_type,) - history = self.model.fit( - X_data, - y_train, + training_dataset, epochs=epochs, verbose=verbose, - validation_split=validation_split, - validation_data=val_data, + validation_data=validation_dataset, callbacks=callbacks, - shuffle=True, - batch_size=batch_size, - ) + ) # Update model information self._update_info( target_level=target_level, - model_outputs=source_contributions_df.columns.values.tolist(), + model_outputs=model_outputs, normalization=ss.spectra_state, ) + # Define the predict function with tf.function and input_signature + self._predict_fn = tf.function( + self._predict, + # input_signature=[tf.TensorSpec(shape=[None, X.shape[1]], dtype=tf.float32)] + experimental_relax_shapes=True + ) + return history - def predict(self, ss: SampleSet, bg_ss: SampleSet = None, verbose=False): + def _predict(self, input_tensor): + return self.model(input_tensor, training=False) + + def predict(self, ss: SampleSet, bg_ss: SampleSet = None): """Classify the spectra in the provided `SampleSet`(s). Results are stored inside the first SampleSet's prediction-related properties. @@ -223,7 +192,8 @@ def predict(self, ss: SampleSet, bg_ss: SampleSet = None, verbose=False): else: X = x_test - results = self.model.predict(X, verbose=verbose) + spectra_tensor = tf.convert_to_tensor(X, dtype=tf.float32) + results = self._predict_fn(spectra_tensor) col_level_idx = SampleSet.SOURCES_MULTI_INDEX_NAMES.index(self.target_level) col_level_subset = SampleSet.SOURCES_MULTI_INDEX_NAMES[:col_level_idx+1] @@ -238,194 +208,6 @@ def predict(self, ss: SampleSet, bg_ss: SampleSet = None, verbose=False): ss.classified_by = self.model_id -class MultiEventClassifier(PyRIIDModel): - """Classifier for spectra from multiple detectors observing the same event.""" - def __init__(self, hidden_layers: tuple = (512,), activation: str = "relu", - loss: str = "categorical_crossentropy", - optimizer: Any = Adam(learning_rate=0.01, clipnorm=0.001), - metrics: list = ["accuracy", "categorical_crossentropy", multi_f1, single_f1], - l2_alpha: float = 1e-4, activity_regularizer: tf.keras.regularizers = l1(0), - dropout: float = 0.0, learning_rate: float = 0.01): - """ - Args: - hidden_layers: tuple containing the number and size of dense layers - activation: activate function to use for each dense layer - loss: string name of the loss function to use for training - optimizer: string name of the optimizer to use for training - metrics: list of metrics to be evaluating during training - l2_alpha: alpha value for the L2 regularization of each dense layer - activity_regularizer: regularizer function applied each dense layer output - dropout: amount of dropout to apply to each dense layer - learning_rate: learning rate to use for an Adam optimizer - """ - super().__init__() - - self.hidden_layers = hidden_layers - self.activation = activation - self.loss = loss - if optimizer == "adam": - self.optimizer = Adam(learning_rate=learning_rate) - else: - self.optimizer = optimizer - self.metrics = metrics - self.l2_alpha = l2_alpha - self.activity_regularizer = activity_regularizer - self.dropout = dropout - self.model = None - - def fit(self, list_of_ss: List[SampleSet], target_contributions: pd.DataFrame, - batch_size: int = 200, epochs: int = 20, - validation_split: float = 0.2, callbacks: list = None, - val_model_ss_list: SampleSet = None, - val_model_target_contributions: pd.DataFrame = None, - patience: int = 15, es_monitor: str = "val_loss", es_mode: str = "min", - es_verbose: bool = False, target_level="Isotope", verbose: bool = False): - """Fit a model to the given SampleSet(s). - - Args: - list_of_ss: list of `SampleSet`s which have prediction_probas populated from - single-event classifiers - target_contributions: DataFrame of the contributions for each - observation. Column titles are the desired label strings. - batch_size: number of samples per gradient update - epochs: maximum number of training iterations - validation_split: percentage of the training data to use as validation data - callbacks: list of callbacks to be passed to TensorFlow Model.fit() method - val_model_ss_list: validation set to be used instead of taking a portion of the - training data - val_model_target_contributions: target contributions to the model for each sample - patience: number of epochs to wait for `tf.keras.callbacks.EarlyStopping` object - es_monitor: quantity to be monitored for `tf.keras.callbacks.EarlyStopping` object - es_mode: mode for `tf.keras.callbacks.EarlyStopping` object - es_verbose: verbosity level for `tf.keras.callbacks.EarlyStopping` object - target_level: source level to target for model output - verbose: whether to show detailed training output - - Returns: - `tf.History` object - - Raises: - `ValueError` when no predictions are provided with `list_of_ss` input - """ - if len(list_of_ss) <= 0: - raise ValueError("No model predictions provided!") - - x_train = [ss.prediction_probas.values for ss in list_of_ss] - y_train = target_contributions.values - - if val_model_ss_list and val_model_target_contributions: - val_data = ( - [ss.prediction_probas.values for ss in val_model_ss_list], - val_model_target_contributions.values, - ) - validation_split = None - else: - val_data = None - row_order = np.arange(x_train[0].shape[0]) - np.random.shuffle(row_order) - # Enforce random validation split through shuffling - x_train = [i[row_order] for i in x_train] - y_train = y_train[row_order] - - if not self.model: - inputs = [] - for ss in list_of_ss: - input_from_single_event_model = tf.keras.layers.Input( - ss.prediction_probas.shape[1], - name=ss.classified_by - ) - inputs.append(input_from_single_event_model) - - if len(inputs) > 1: - x = tf.keras.layers.Concatenate()(inputs) - else: - x = inputs[0] - - for layer, nodes in enumerate(self.hidden_layers): - if layer == 0: - x = Dense( - nodes, - activation=self.activation, - activity_regularizer=self.activity_regularizer, - kernel_regularizer=l2(self.l2_alpha), - )(x) - else: - x = Dense( - nodes, - activation=self.activation, - activity_regularizer=self.activity_regularizer, - kernel_regularizer=l2(self.l2_alpha), - )(x) - if self.dropout > 0: - x = Dropout(self.dropout)(x) - - output = Dense(y_train.shape[1], activation="softmax")(x) - self.model = tf.keras.models.Model(inputs, output) - self.model.compile(loss=self.loss, optimizer=self.optimizer, metrics=self.metrics) - - es = EarlyStopping( - monitor=es_monitor, - patience=patience, - verbose=es_verbose, - restore_best_weights=True, - mode=es_mode, - ) - - if callbacks: - callbacks.append(es) - else: - callbacks = [es] - - history = self.model.fit( - x_train, - y_train, - epochs=epochs, - verbose=verbose, - validation_split=validation_split, - validation_data=val_data, - callbacks=callbacks, - shuffle=True, - batch_size=batch_size, - ) - - # Initialize model info, update output/input information - self._update_info( - target_level=target_level, - model_outputs=target_contributions.columns.values.tolist(), - model_inputs=tuple( - [(ss.classified_by, ss.prediction_probas.shape[1]) for ss in list_of_ss] - ), - normalization=tuple( - [(ss.classified_by, ss.spectra_state) for ss in list_of_ss] - ), - ) - - return history - - def predict(self, list_of_ss: List[SampleSet], verbose=False) -> pd.DataFrame: - """Classify the spectra in the provided `SampleSet`(s) based on each one's results. - - Args: - list_of_ss: list of `SampleSet`s which had predictions made by single-event models - - Returns: - `DataFrame` of predicted results for the `Sampleset`(s) - """ - X = [ss.prediction_probas for ss in list_of_ss] - results = self.model.predict(X, verbose=verbose) - - col_level_idx = SampleSet.SOURCES_MULTI_INDEX_NAMES.index(self.target_level) - col_level_subset = SampleSet.SOURCES_MULTI_INDEX_NAMES[:col_level_idx+1] - results_df = pd.DataFrame( - data=results, - columns=pd.MultiIndex.from_tuples( - self.get_model_outputs_as_label_tuples(), - names=col_level_subset - ) - ) - return results_df - - class LabelProportionEstimator(PyRIIDModel): """Regressor for predicting label proportions that uses a semi-supervised loss. @@ -450,19 +232,19 @@ class LabelProportionEstimator(PyRIIDModel): sparsemax, ), "categorical_crossentropy": ( - tf.keras.losses.CategoricalCrossentropy, + CategoricalCrossentropy, { "from_logits": True, "reduction": tf.keras.losses.Reduction.NONE, }, - tf.keras.activations.softmax, + softmax, ), "mse": ( - tf.keras.losses.MeanSquaredError, + MeanSquaredError, { "reduction": tf.keras.losses.Reduction.NONE, }, - tf.keras.activations.sigmoid, + sigmoid, ) } INFO_KEYS = ( @@ -495,7 +277,7 @@ class LabelProportionEstimator(PyRIIDModel): ) def __init__(self, hidden_layers: tuple = (256,), sup_loss="sparsemax", unsup_loss="sse", - metrics=("mae", "categorical_crossentropy",), beta=0.9, source_dict=None, + metrics: list = ["mae", "categorical_crossentropy"], beta=0.9, source_dict=None, optimizer="adam", optimizer_kwargs=None, learning_rate: float = 1e-3, hidden_layer_activation: str = "mish", kernel_l1_regularization: float = 0.0, kernel_l2_regularization: float = 0.0, @@ -551,7 +333,7 @@ def __init__(self, hidden_layers: tuple = (256,), sup_loss="sparsemax", unsup_lo self.optimizer = optimizer if isinstance(optimizer, str): - self.optimizer = tf.keras.optimizers.get(optimizer) + self.optimizer = keras.optimizers.get(optimizer) if optimizer_kwargs is not None: for key, value in optimizer_kwargs.items(): setattr(self.optimizer, key, value) @@ -603,11 +385,11 @@ def _get_unsup_loss_func(self, loss_func_str): return self.UNSUPERVISED_LOSS_FUNCS[loss_func_str] def _initialize_model(self, input_size, output_size): - spectra_input = tf.keras.layers.Input(input_size, name="input_spectrum") + spectra_input = Input(input_size, name="input_spectrum") spectra_norm = L1NormLayer(name="normalized_input_spectrum")(spectra_input) x = spectra_norm for layer, nodes in enumerate(self.hidden_layers): - x = tf.keras.layers.Dense( + x = Dense( nodes, activation=self.hidden_layer_activation, kernel_regularizer=L1L2( @@ -626,17 +408,14 @@ def _initialize_model(self, input_size, output_size): )(x) if self.dropout > 0: - x = tf.keras.layers.Dropout(self.dropout)(x) - output = tf.keras.layers.Dense( + x = Dropout(self.dropout)(x) + output = Dense( output_size, activation="linear", name="output" )(x) - self.model = tf.keras.models.Model( - inputs=[spectra_input], - outputs=[output] - ) + self.model = Model(inputs=[spectra_input], outputs=[output]) def _get_info_as_dict(self): info_dict = {} @@ -703,10 +482,10 @@ def fit(self, seeds_ss: SampleSet, ss: SampleSet, bg_cps: int = 300, is_gross: b epochs: maximum number of training iterations validation_split: proportion of training data to use as validation data callbacks: list of callbacks to be passed to TensorFlow Model.fit() method - patience: number of epochs to wait for tf.keras.callbacks.EarlyStopping object - es_monitor: quantity to be monitored for tf.keras.callbacks.EarlyStopping object - es_mode: mode for tf.keras.callbacks.EarlyStopping object - es_verbose: verbosity level for tf.keras.callbacks.EarlyStopping object + patience: number of epochs to wait for `EarlyStopping` object + es_monitor: quantity to be monitored for `EarlyStopping` object + es_mode: mode for `EarlyStopping` object + es_verbose: verbosity level for `EarlyStopping` object es_min_delta: minimum change to count as an improvement for early stopping normalize_sup_loss: whether to normalize the supervised loss term normalize_func: normalization function used for supervised loss term @@ -715,7 +494,7 @@ def fit(self, seeds_ss: SampleSet, ss: SampleSet, bg_cps: int = 300, is_gross: b verbose: whether model training output is printed to the terminal """ spectra = ss.get_samples().astype(float) - sources_df = ss.sources.groupby(axis=1, level=target_level, sort=False).sum() + sources_df = ss.sources.T.groupby(target_level, sort=False).sum().T sources = sources_df.values.astype(float) self.sources_columns = sources_df.columns @@ -734,7 +513,7 @@ def fit(self, seeds_ss: SampleSet, ss: SampleSet, bg_cps: int = 300, is_gross: b if verbose: print("Initializing model...") self._initialize_model( - ss.n_channels, + (ss.n_channels,), sources.shape[1], ) elif verbose: @@ -771,7 +550,7 @@ def fit(self, seeds_ss: SampleSet, ss: SampleSet, bg_cps: int = 300, is_gross: b ) else: semisup_metrics.append( - build_keras_semisupervised_loss_func( + build_keras_semisupervised_metric_func( each, self.activation, sources.shape[1] @@ -813,7 +592,7 @@ def fit(self, seeds_ss: SampleSet, ss: SampleSet, bg_cps: int = 300, is_gross: b if verbose: print("Finding OOD detection threshold function...") - train_logits = self.model.predict(spectra) + train_logits = self.model.predict(spectra, verbose=0) train_lpes = self.activation(tf.convert_to_tensor(train_logits, dtype=tf.float32)) self.spline_recon_errors = reconstruction_error( tf.convert_to_tensor(spectra, dtype=tf.float32), @@ -878,31 +657,11 @@ def predict(self, ss: SampleSet, bg_cps: int = 300, is_gross: bool = False, verb ss.info["recon_error"] = recon_errors -class L1NormLayer(tf.keras.layers.Layer): - """Keras layer applying an L1 norm (dividing by total counts) to input data. - """ - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def call(self, inputs): - """This is where the layer's logic lives. - - Args: - inputs: input tensor, or dict/list/tuple of input tensors. - - Returns: - A tensor or list/tuple of tensors. - """ - sums = tf.reduce_sum(inputs, axis=-1) - l1_norm = inputs / tf.reshape(sums, (-1, 1)) - return l1_norm - - def _get_reordered_spectra(old_spectra_df: pd.DataFrame, old_sources_df: pd.DataFrame, new_sources_columns, target_level) -> pd.DataFrame: collapsed_sources_df = old_sources_df\ - .groupby(axis=1, level=target_level)\ - .sum() + .T.groupby(target_level)\ + .sum().T reordered_spectra_df = old_spectra_df.iloc[ collapsed_sources_df[ new_sources_columns diff --git a/riid/models/neural_nets/arad.py b/riid/models/neural_nets/arad.py index 25f4941f..3679bfe6 100644 --- a/riid/models/neural_nets/arad.py +++ b/riid/models/neural_nets/arad.py @@ -4,22 +4,27 @@ """This module contains implementations of the ARAD deep learning architecture.""" from typing import List +import keras import pandas as pd import tensorflow as tf -from keras.activations import sigmoid, softplus -from keras.callbacks import EarlyStopping, ReduceLROnPlateau -from keras.initializers import GlorotNormal, HeNormal -from keras.layers import (BatchNormalization, Concatenate, Conv1D, - Conv1DTranspose, Dense, Dropout, Flatten, Input, - MaxPool1D, Reshape, UpSampling1D) -from keras.models import Model -from keras.regularizers import L1L2, L2 +from keras.api.activations import sigmoid, softplus +from keras.api.callbacks import EarlyStopping, ReduceLROnPlateau +from keras.api.initializers import GlorotNormal, HeNormal +from keras.api.layers import (BatchNormalization, Concatenate, Conv1D, + Conv1DTranspose, Dense, Dropout, Flatten, Input, + MaxPool1D, Reshape, UpSampling1D) +from keras.api.losses import kl_divergence, log_cosh +from keras.api.metrics import MeanSquaredError +from keras.api.models import Model +from keras.api.optimizers import Adam, Nadam +from keras.api.regularizers import L1L2, L2 from scipy.spatial.distance import jensenshannon from scipy.stats import entropy from riid.data.sampleset import SampleSet, SpectraState -from riid.losses import jensen_shannon_distance, mish +from riid.losses import mish from riid.models import PyRIIDModel +from riid.models.bayes import ExpandDimsLayer class ARADv1TF(Model): @@ -31,15 +36,15 @@ class ARADv1TF(Model): - Ghawaly Jr, James M. "A Datacentric Algorithm for Gamma-ray Radiation Anomaly Detection in Unknown Background Environments." (2020). """ - def __init__(self, latent_dim: int = 5): + def __init__(self, latent_dim: int = 5, **kwargs): """ Args: latent_dim: dimension of internal latent represention. 5 was the final one in the paper, but 4 to 8 were found to work well. """ - super().__init__() + super().__init__(**kwargs) - input_size = (128, 1) + input_size = (128,) # Encoder b1_config = ( (5, 1, 32), @@ -63,9 +68,10 @@ def __init__(self, latent_dim: int = 5): (1, 1, 2), ) encoder_input = Input(shape=input_size, name="encoder_input") - b1 = self._get_branch(encoder_input, b1_config, 0.1, "softplus", "B1", 5) - b2 = self._get_branch(encoder_input, b2_config, 0.1, "softplus", "B2", 5) - b3 = self._get_branch(encoder_input, b3_config, 0.1, "softplus", "B3", 5) + expanded_encoder_input = ExpandDimsLayer()(encoder_input, axis=-1) + b1 = self._get_branch(expanded_encoder_input, b1_config, 0.1, "softplus", "B1", 5) + b2 = self._get_branch(expanded_encoder_input, b2_config, 0.1, "softplus", "B2", 5) + b3 = self._get_branch(expanded_encoder_input, b3_config, 0.1, "softplus", "B3", 5) x = Concatenate(axis=1)([b1, b2, b3]) x = Reshape((15,), name="reshape")(x) latent_space = Dense( @@ -92,27 +98,11 @@ def __init__(self, latent_dim: int = 5): decoded_spectrum = decoder(encoded_spectrum) autoencoder = Model(encoder_input, decoded_spectrum, name="autoencoder") - def logcosh_with_kld_penalty(input_spectrum, decoded_spectrum, - latent_space, sparsities, - penalty_weight=0.5): - squeezed_input = tf.squeeze(input_spectrum) - logcosh_loss = tf.keras.losses.log_cosh(squeezed_input, decoded_spectrum) - kld_loss = tf.keras.losses.kld(sparsities, latent_space) - loss = logcosh_loss + penalty_weight * kld_loss - return loss - - sparsities = [0.001] * latent_dim - loss_func = logcosh_with_kld_penalty( - encoder_input, - decoded_spectrum, - latent_space, - sparsities, - ) - autoencoder.add_loss(loss_func) - self.encoder = encoder self.decoder = decoder self.autoencoder = autoencoder + self.sparsities = [0.001] * latent_dim + self.penalty_weight = 0.5 def _get_branch(self, input_layer, config, dropout_rate, activation, branch_name, dense_units): x = input_layer @@ -132,7 +122,15 @@ def _get_branch(self, input_layer, config, dropout_rate, activation, branch_name return x def call(self, x): - decoded = self.autoencoder(x) + encoded = self.encoder(x) + decoded = self.decoder(encoded) + + # Compute loss + logcosh_loss = log_cosh(x, decoded) + kld_loss = kl_divergence(self.sparsities, encoded) + loss = logcosh_loss + self.penalty_weight * kld_loss + self.add_loss(loss) + return decoded @@ -145,15 +143,15 @@ class ARADv2TF(Model): - Ghawaly Jr, James M., et al. "Characterization of the Autoencoder Radiation Anomaly Detection (ARAD) model." Engineering Applications of Artificial Intelligence 111 (2022): 104761. """ - def __init__(self, latent_dim: int = 8): + def __init__(self, latent_dim: int = 8, **kwargs): """ Args: latent_dim: dimension of internal latent represention. 5 was the final one in the paper, but 4 to 8 were found to work well. """ - super().__init__() + super().__init__(**kwargs) - input_size = (128, 1) + input_size = (128,) # Encoder config = ( (7, 1, 8, 2), @@ -163,7 +161,8 @@ def __init__(self, latent_dim: int = 8): (3, 1, 8, 2), ) encoder_input = Input(shape=input_size, name="encoder_input") - x = encoder_input + expanded_encoder_input = ExpandDimsLayer()(encoder_input, axis=-1) + x = expanded_encoder_input for i, (kernel_size, strides, filters, max_pool_size) in enumerate(config, start=1): conv_name = f"conv{i}" x = Conv1D( @@ -194,7 +193,7 @@ def __init__(self, latent_dim: int = 8): encoder = Model(encoder_input, encoder_output, name="encoder") # Decoder - decoder_input = Input(shape=latent_dim, name="decoder_input") + decoder_input = Input(shape=(latent_dim,), name="decoder_input") x = Dense(units=32, name="D2", activation=mish)(decoder_input) x = BatchNormalization(name="D2_batch_norm")(x) x = Reshape((4, 8), name="reshape")(x) @@ -242,7 +241,25 @@ def __init__(self, latent_dim: int = 8): self.autoencoder = autoencoder def call(self, x): - decoded = self.autoencoder(x) + encoded = self.encoder(x) + decoded = self.decoder(encoded) + + # Compute loss + p_sum = tf.reduce_sum(x, axis=-1) + p_norm = tf.divide( + x, + tf.reshape(p_sum, (-1, 1)) + ) + q_sum = tf.reduce_sum(decoded, axis=-1) + q_norm = tf.divide( + decoded, + tf.reshape(q_sum, (-1, 1)) + ) + m = (p_norm + q_norm) / 2 + js_divergence = (kl_divergence(p_norm, m) + kl_divergence(q_norm, m)) / 2 + loss = tf.math.sqrt(js_divergence) + self.add_loss(loss) + return decoded @@ -278,16 +295,12 @@ def fit(self, ss: SampleSet, epochs: int = 300, validation_split=0.2, x = ss.get_samples().astype(float) - optimizer = tf.keras.optimizers.Nadam( - learning_rate=1e-4 - ) + optimizer = Nadam(learning_rate=1e-4) if not self.model: self.model = ARADv1TF() - self.model.compile( - loss=None, - optimizer=optimizer - ) + + self.model.compile(optimizer=optimizer) callbacks = [ EarlyStopping( @@ -375,17 +388,12 @@ def fit(self, ss: SampleSet, epochs: int = 300, validation_split=0.2, x = ss.get_samples().astype(float) - optimizer = tf.keras.optimizers.Adam( - learning_rate=0.01, - epsilon=0.05 - ) + optimizer = Adam(learning_rate=0.01, epsilon=0.05) if not self.model: self.model = ARADv2TF() - self.model.compile( - loss=jensen_shannon_distance, - optimizer=optimizer - ) + + self.model.compile(optimizer=optimizer) callbacks = [ EarlyStopping( @@ -448,7 +456,7 @@ class ARADLatentPredictor(PyRIIDModel): def __init__(self, hidden_layers: tuple = (8, 4,), hidden_activation: str = "relu", final_activation: str = "linear", loss: str = "mse", optimizer="adam", optimizer_kwargs=None, - learning_rate: float = 1e-3, metrics: tuple = ("mse", ), + learning_rate: float = 1e-3, metrics=None, kernel_l1_regularization: float = 0.0, kernel_l2_regularization: float = 0.0, bias_l1_regularization: float = 0.0, bias_l2_regularization: float = 0.0, activity_l1_regularization: float = 0.0, activity_l2_regularization: float = 0.0, @@ -479,12 +487,14 @@ def __init__(self, hidden_layers: tuple = (8, 4,), self.loss = loss self.optimizer = optimizer if isinstance(optimizer, str): - self.optimizer = tf.keras.optimizers.get(optimizer) + self.optimizer = keras.optimizers.get(optimizer) if optimizer_kwargs is not None: for key, value in optimizer_kwargs.items(): setattr(self.optimizer, key, value) self.optimizer.learning_rate = learning_rate self.metrics = metrics + if self.metrics is None: + self.metrics = [MeanSquaredError()] self.kernel_l1_regularization = kernel_l1_regularization self.kernel_l2_regularization = kernel_l2_regularization self.bias_l1_regularization = bias_l1_regularization @@ -498,15 +508,15 @@ def __init__(self, hidden_layers: tuple = (8, 4,), def _initialize_model(self, arad: Model, output_size: int): """Build Keras MLP model. """ - encoder = arad.get_layer("encoder") - encoder_input = encoder.get_layer(index=0).input - encoder_output = encoder.get_layer(index=-1).output - encoder_output_shape = encoder_output.shape + encoder: Model = arad.get_layer("encoder") + encoder_input = encoder.input + encoder_output = encoder.output + encoder_output_shape = encoder_output.shape[-1] - predictor_input = Input(shape=encoder_output_shape, name="predictor_input") + predictor_input = Input(shape=(encoder_output_shape,), name="inner_predictor_input") x = predictor_input for layer, nodes in enumerate(self.hidden_layers): - x = tf.keras.layers.Dense( + x = Dense( nodes, activation=self.hidden_activation, kernel_regularizer=L1L2( @@ -521,19 +531,19 @@ def _initialize_model(self, arad: Model, output_size: int): l1=self.activity_l1_regularization, l2=self.activity_l2_regularization ), - name=f"dense_{layer}" + name=f"inner_predictor_dense_{layer}" )(x) if self.dropout > 0: - x = tf.keras.layers.Dropout(self.dropout)(x) - predictor_output = tf.keras.layers.Dense( + x = Dropout(self.dropout)(x) + predictor_output = Dense( output_size, activation=self.final_activation, - name="output" + name="inner_predictor_output" )(x) - predictor = Model(predictor_input, predictor_output, name="predictor") + inner_predictor = Model(predictor_input, predictor_output, name="inner_predictor") encoded_spectrum = encoder(encoder_input) - predictions = predictor(encoded_spectrum) + predictions = inner_predictor(encoded_spectrum) self.model = Model(encoder_input, predictions, name="predictor") # Freeze the layers corresponding to the autoencoder # Note: setting trainable to False is recursive to sub-layers per TF docs: diff --git a/tests/anomaly_tests.py b/tests/anomaly_tests.py index 617384df..35143a1d 100644 --- a/tests/anomaly_tests.py +++ b/tests/anomaly_tests.py @@ -19,13 +19,14 @@ def setUp(self): pass def test_event_detector(self): - np.random.seed(42) + random_state = 42 + rng = np.random.default_rng(random_state) SAMPLE_INTERVAL = 0.5 BG_RATE = 300 seeds_ss = get_dummy_seeds(100) fg_seeds_ss, bg_seeds_ss = seeds_ss.split_fg_and_bg() - mixed_bg_seeds_ss = SeedMixer(bg_seeds_ss, mixture_size=3)\ + mixed_bg_seeds_ss = SeedMixer(bg_seeds_ss, mixture_size=3, rng=rng)\ .generate(1) events = PassbySynthesizer(events_per_seed=1, sample_interval=SAMPLE_INTERVAL, @@ -34,7 +35,7 @@ def test_event_detector(self): dwell_time_function_args=(20, 20), snr_function_args=(20, 20), return_gross=True, - rng=np.random.default_rng(42))\ + rng=rng)\ .generate(fg_seeds_ss, mixed_bg_seeds_ss, verbose=False) _, gross_events = list(zip(*events)) diff --git a/tests/model_tests.py b/tests/model_tests.py index 39357f3f..ea430abc 100644 --- a/tests/model_tests.py +++ b/tests/model_tests.py @@ -15,8 +15,7 @@ from riid.models import PyRIIDModel from riid.models.bayes import (NegativeSpectrumError, PoissonBayesClassifier, ZeroTotalCountsError) -from riid.models.neural_nets import (LabelProportionEstimator, MLPClassifier, - MultiEventClassifier) +from riid.models.neural_nets import (LabelProportionEstimator, MLPClassifier) from riid.models.neural_nets.arad import ARADLatentPredictor, ARADv1, ARADv2 @@ -86,7 +85,7 @@ def test_pb_predict(self): # Get test samples gss = StaticSynthesizer( - samples_per_seed=1, + samples_per_seed=2, live_time_function_args=(4, 4), snr_function_args=(10, 10), rng=np.random.default_rng(42), @@ -99,7 +98,7 @@ def test_pb_predict(self): # Predict pb_model.predict(test_gross_ss, test_bg_ss) - truth_labels = fg_seeds_ss.get_labels() + truth_labels = test_fg_ss.get_labels() predicted_labels = test_gross_ss.get_predictions() assert (truth_labels == predicted_labels).all() @@ -110,18 +109,6 @@ def test_mlp_fit_save_load_predict(self): _test_model_fit_save_load_predict(self, MLPClassifier, self.test_ss, self.train_ss, epochs=1) - def test_mec_fit_save_load_predict(self): - test_copy_ss = self.test_ss[:] - test_copy_ss.prediction_probas = test_copy_ss.sources - _test_model_fit_save_load_predict( - self, - MultiEventClassifier, - [test_copy_ss], - [self.train_ss], - self.train_ss.sources.groupby(axis=1, level="Isotope", sort=False).sum(), - epochs=1 - ) - def test_lpe_fit_save_load_predict(self): _test_model_fit_save_load_predict(self, LabelProportionEstimator, self.test_ss, self.fg_seeds_ss, self.train_ss, epochs=1) diff --git a/tests/seedmixer_tests.py b/tests/seedmixer_tests.py index e56950e2..e986f891 100644 --- a/tests/seedmixer_tests.py +++ b/tests/seedmixer_tests.py @@ -15,8 +15,10 @@ class TestSeedMixer(unittest.TestCase): """Test seed mixing functionality of SampleSet. """ - def setUp(self): - np.random.seed(42) + @classmethod + def setUpClass(self): + random_state = 42 + self.rng = np.random.default_rng(random_state) self.ss, _ = get_dummy_seeds().split_fg_and_bg() self.ss.normalize() @@ -71,24 +73,24 @@ def test_mixture_pdf(self): self.assertAlmostEqual(self.three_mix_seeds_ss.spectra.values[sample, :].sum(), 1.0) def test_spectrum_construction_3seeds_2mix(self): - _, bg_seeds_ss = get_dummy_seeds(n_channels=16).split_fg_and_bg() - mixed_bg_ss = SeedMixer(bg_seeds_ss, mixture_size=2).generate(100) + _, bg_seeds_ss = get_dummy_seeds(n_channels=16, rng=self.rng).split_fg_and_bg() + mixed_bg_ss = SeedMixer(bg_seeds_ss, mixture_size=2, rng=self.rng).generate(100) spectral_distances = _get_spectral_distances(bg_seeds_ss, mixed_bg_ss) - self.assertTrue(all(spectral_distances == 0)) + self.assertTrue(np.isclose(spectral_distances, 0.0).all()) def test_spectrum_construction_3seeds_3mix(self): - _, bg_seeds_ss = get_dummy_seeds(n_channels=16).split_fg_and_bg() - mixed_bg_ss = SeedMixer(bg_seeds_ss, mixture_size=3).generate(100) + _, bg_seeds_ss = get_dummy_seeds(n_channels=16, rng=self.rng).split_fg_and_bg() + mixed_bg_ss = SeedMixer(bg_seeds_ss, mixture_size=3, rng=self.rng).generate(100) spectral_distances = _get_spectral_distances(bg_seeds_ss, mixed_bg_ss) - self.assertTrue(all(spectral_distances == 0)) + self.assertTrue(np.isclose(spectral_distances, 0.0).all()) def test_spectrum_construction_2seeds_2mix(self): - _, bg_seeds_ss = get_dummy_seeds(n_channels=16).split_fg_and_bg( + _, bg_seeds_ss = get_dummy_seeds(n_channels=16, rng=self.rng).split_fg_and_bg( bg_seed_names=SampleSet.DEFAULT_BG_SEED_NAMES[1:3] ) - mixed_bg_ss = SeedMixer(bg_seeds_ss, mixture_size=2).generate(100) + mixed_bg_ss = SeedMixer(bg_seeds_ss, mixture_size=2, rng=self.rng).generate(100) spectral_distances = _get_spectral_distances(bg_seeds_ss, mixed_bg_ss) - self.assertTrue(all(spectral_distances == 0)) + self.assertTrue(np.isclose(spectral_distances, 0.0).all()) def test_spectrum_construction_2seeds_2mix_error(self): _, bg_seeds_ss = get_dummy_seeds(n_channels=16).split_fg_and_bg( @@ -104,10 +106,11 @@ def _get_spectral_distances(seeds_ss, mixed_ss): mixed_ss.sources.values.T ).T spectral_distances = jensenshannon( - recon_spectra.astype(np.float32), - mixed_ss.spectra.values.astype(np.float32), + recon_spectra, + mixed_ss.spectra.values, axis=1 ) + spectral_distances = np.nan_to_num(spectral_distances, nan=0.0) return spectral_distances diff --git a/tests/visualize_tests.py b/tests/visualize_tests.py index e9e83fb3..151e9f48 100644 --- a/tests/visualize_tests.py +++ b/tests/visualize_tests.py @@ -23,7 +23,8 @@ class TestVisualize(unittest.TestCase): """Testing plot functions in the visualize module.""" - def setUp(self): + @classmethod + def setUpClass(self): """Test setup.""" self.fg_seeds_ss, self.bg_seeds_ss = get_dummy_seeds().split_fg_and_bg() self.mixed_bg_seed_ss = SeedMixer(self.bg_seeds_ss, mixture_size=3).generate(10) From e9ab4c1988a262690752eb3b856658b61072348b Mon Sep 17 00:00:00 2001 From: Tyler Morrow Date: Thu, 15 Aug 2024 09:03:41 -0600 Subject: [PATCH 2/4] Move common imports to top-level; fix model saving bugs. --- examples/courses/Primer 1/Primer1.ipynb | 68 +- examples/data/conversion/pcf_to_ss.py | 3 +- examples/data/difficulty_score.py | 4 +- .../data/preprocessing/energy_calibration.py | 4 +- examples/data/synthesis/mix_seeds.py | 4 +- examples/data/synthesis/synthesize_passbys.py | 3 +- examples/data/synthesis/synthesize_seeds.py | 2 +- .../synthesis/synthesize_seeds_advanced.py | 2 +- examples/data/synthesis/synthesize_spectra.py | 4 +- examples/modeling/anomaly_detection.py | 4 +- examples/modeling/arad.py | 6 +- examples/modeling/arad_latent_prediction.py | 6 +- examples/modeling/classifier_comparison.py | 7 +- .../modeling/label_proportion_estimation.py | 6 +- .../modeling/neural_network_classifier.py | 6 +- examples/run_examples.py | 1 + examples/visualization/confusion_matrix.py | 6 +- examples/visualization/distance_matrix.py | 2 +- .../plot_sampleset_compare_to.py | 4 +- examples/visualization/plot_spectra.py | 2 +- riid/__init__.py | 12 +- riid/data/converters/aipt.py | 3 +- riid/data/converters/topcoder.py | 3 +- riid/data/synthetic/__init__.py | 374 +--------- riid/data/synthetic/base.py | 290 ++++++++ riid/data/synthetic/passby.py | 4 +- riid/data/synthetic/seed.py | 88 ++- riid/data/synthetic/static.py | 4 +- riid/gadras/api.py | 12 +- riid/metrics.py | 2 +- riid/models/__init__.py | 245 +------ riid/models/base.py | 256 +++++++ riid/models/bayes.py | 4 +- riid/models/neural_nets/__init__.py | 669 +----------------- riid/models/neural_nets/arad.py | 6 +- riid/models/neural_nets/basic.py | 197 ++++++ riid/models/neural_nets/lpe.py | 489 +++++++++++++ riid/visualize.py | 2 +- tests/anomaly_tests.py | 4 +- tests/data_tests.py | 6 +- tests/gadras_tests.py | 3 +- tests/model_tests.py | 15 +- tests/sampleset_tests.py | 4 +- tests/seedmixer_tests.py | 4 +- tests/staticsynth_tests.py | 9 +- tests/visualize_tests.py | 6 +- 46 files changed, 1442 insertions(+), 1413 deletions(-) create mode 100644 riid/data/synthetic/base.py create mode 100644 riid/models/base.py create mode 100644 riid/models/neural_nets/basic.py create mode 100644 riid/models/neural_nets/lpe.py diff --git a/examples/courses/Primer 1/Primer1.ipynb b/examples/courses/Primer 1/Primer1.ipynb index ffe4714d..459e3ef8 100755 --- a/examples/courses/Primer 1/Primer1.ipynb +++ b/examples/courses/Primer 1/Primer1.ipynb @@ -155,7 +155,7 @@ "from riid.gadras.api import GADRAS_API_SEEMINGLY_AVAILABLE\n", "\n", "if GADRAS_API_SEEMINGLY_AVAILABLE:\n", - " from riid.data.synthetic.seed import SeedSynthesizer\n", + " from riid import SeedSynthesizer\n", " seed_syn = SeedSynthesizer()\n", " # The YAML file defining the seed synthesis specification is ultimately parsed into a dictionary.\n", " # You can also load it yourself and pass in the dictionary instead - this is useful for varying detector parameters!\n", @@ -163,7 +163,7 @@ "else:\n", " # If you don't have Windows with GADRAS installed, this will use the dummy seeds below which are not actual gamma spectra.\n", " # Another option would be to load a seeds file obtained elsewhere.\n", - " from riid.data.synthetic import get_dummy_seeds\n", + " from riid import get_dummy_seeds\n", " seeds_ss = get_dummy_seeds()" ] }, @@ -251,7 +251,7 @@ "outputs": [], "source": [ "\"\"\"Seed mixing\"\"\"\n", - "from riid.data.synthetic.seed import SeedMixer\n", + "from riid import SeedMixer\n", "\n", "mixed_bg_seeds_ss = SeedMixer(\n", " bg_seeds_ss,\n", @@ -278,7 +278,7 @@ "outputs": [], "source": [ "\"\"\"Combining SampleSets\"\"\"\n", - "from riid.data.sampleset import SampleSet\n", + "from riid import SampleSet\n", "\n", "combined_ss = SampleSet()\n", "combined_ss.concat([fg_seeds_ss, mixed_bg_seeds_ss])\n", @@ -318,14 +318,14 @@ "outputs": [], "source": [ "\"\"\"Static Synthesis\"\"\"\n", - "from riid.data.synthetic.static import StaticSynthesizer\n", + "from riid import StaticSynthesizer\n", "\n", "static_syn = StaticSynthesizer(\n", " samples_per_seed=100,\n", " bg_cps=300,\n", " live_time_function=\"uniform\",\n", " live_time_function_args=(0.25, 8),\n", - " snr_function=\"uniform\",\n", + " snr_function=\"log10\",\n", " snr_function_args=(0.1, 100),\n", " apply_poisson_noise=True,\n", " return_fg=True,\n", @@ -348,8 +348,7 @@ "outputs": [], "source": [ "\"\"\"Normalization\"\"\"\n", - "gross_ss.normalize()\n", - "bg_ss.normalize()" + "fg_ss.normalize()" ] }, { @@ -371,16 +370,10 @@ "outputs": [], "source": [ "\"\"\"Model fitting\"\"\"\n", - "from riid.models.neural_nets import MLPClassifier\n", - "from riid.metrics import single_f1\n", + "from riid.models import MLPClassifier\n", "\n", - "model = MLPClassifier(\n", - " hidden_layers=(256,),\n", - " learning_rate=4e-3,\n", - " metrics=[single_f1]\n", - ")\n", - "\n", - "history = model.fit(gross_ss, bg_ss, epochs=25, patience=5, verbose=True)" + "model = MLPClassifier()\n", + "history = model.fit(fg_ss, epochs=10, verbose=True)" ] }, { @@ -402,12 +395,9 @@ "outputs": [], "source": [ "\"\"\"Generate some in-distribution data the model has not seen.\"\"\"\n", - "test_bg_ss, test_gross_ss = static_syn.generate(fg_seeds_ss, bg_seeds_ss)\n", - "test_bg_ss.normalize()\n", - "test_gross_ss.normalize()\n", - "# Adjust ground truth\n", - "#test_gross_ss.sources.drop(test_bg_ss.sources.columns, axis=1, inplace=True)\n", - "#test_gross_ss.normalize_sources()" + "test_fg_ss, test_gross_ss = static_syn.generate(fg_seeds_ss, bg_seeds_ss)\n", + "test_fg_ss.normalize()\n", + "test_gross_ss.normalize()" ] }, { @@ -417,7 +407,7 @@ "outputs": [], "source": [ "\"\"\"Use the model!\"\"\"\n", - "model.predict(test_gross_ss, test_bg_ss) # Saved in your SampleSet containing non-background sources (the gross spectra)" + "model.predict(test_fg_ss) # Results are saved in the SampleSet's prediction_probas DataFrame" ] }, { @@ -429,8 +419,8 @@ "\"\"\"Calculate performance metric\"\"\"\n", "from sklearn.metrics import f1_score\n", "\n", - "labels = test_gross_ss.get_labels()\n", - "predictions = test_gross_ss.get_predictions()\n", + "labels = test_fg_ss.get_labels()\n", + "predictions = test_fg_ss.get_predictions()\n", "f1_score(labels, predictions, average=\"micro\")" ] }, @@ -443,7 +433,7 @@ "\"\"\"Confusion Matrix\"\"\"\n", "from riid.visualize import confusion_matrix\n", "\n", - "_ = confusion_matrix(test_gross_ss)" + "_ = confusion_matrix(test_fg_ss)" ] }, { @@ -455,7 +445,7 @@ "\"\"\"SNR vs. Model Score\"\"\"\n", "from riid.visualize import plot_snr_vs_score\n", "\n", - "_ = plot_snr_vs_score(test_gross_ss, xscale=\"log\")" + "_ = plot_snr_vs_score(test_fg_ss, xscale=\"log\")" ] }, { @@ -465,13 +455,23 @@ "outputs": [], "source": [ "\"\"\"Save model\"\"\"\n", - "import os\n", + "from pathlib import Path\n", + "\n", + "\n", + "def _delete_if_exists(path: Path):\n", + " if path.exists():\n", + " path.unlink()\n", "\n", - "model_path = \"./model.h5\"\n", - "if os.path.exists(model_path):\n", - " os.remove(model_path)\n", + "model_path_json = Path(\"./model.json\")\n", + "model_path_tflite = model_path_json.with_suffix(\".tflite\")\n", + "model_path_onnx = model_path_json.with_suffix(\".onnx\")\n", + "_delete_if_exists(model_path_json)\n", + "_delete_if_exists(model_path_tflite)\n", + "_delete_if_exists(model_path_onnx)\n", "\n", - "model.save(model_path)" + "model.save(str(model_path_json))\n", + "model.to_tflite(str(model_path_tflite))\n", + "model.to_onnx(str(model_path_onnx))" ] }, { @@ -571,7 +571,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.12.4" }, "orig_nbformat": 4, "vscode": { diff --git a/examples/data/conversion/pcf_to_ss.py b/examples/data/conversion/pcf_to_ss.py index 0a5b6d65..e94e0148 100644 --- a/examples/data/conversion/pcf_to_ss.py +++ b/examples/data/conversion/pcf_to_ss.py @@ -12,10 +12,9 @@ import os from pathlib import Path -from riid import SAMPLESET_HDF_FILE_EXTENSION +from riid import SAMPLESET_HDF_FILE_EXTENSION, read_pcf from riid.data.converters import (_validate_and_create_output_dir, convert_directory) -from riid.data.sampleset import read_pcf def convert_and_save(input_file_path: str, output_dir: str = None, diff --git a/examples/data/difficulty_score.py b/examples/data/difficulty_score.py index 103d9e1b..6f838ea6 100644 --- a/examples/data/difficulty_score.py +++ b/examples/data/difficulty_score.py @@ -2,9 +2,7 @@ # Under the terms of Contract DE-NA0003525 with NTESS, # the U.S. Government retains certain rights in this software. """This example demonstrates how to compute the difficulty of a given SampleSet.""" -from riid.data.synthetic import get_dummy_seeds -from riid.data.synthetic.seed import SeedMixer -from riid.data.synthetic.static import StaticSynthesizer +from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds fg_seeds_ss, bg_seeds_ss = get_dummy_seeds().split_fg_and_bg() mixed_bg_seed_ss = SeedMixer(bg_seeds_ss, mixture_size=3)\ diff --git a/examples/data/preprocessing/energy_calibration.py b/examples/data/preprocessing/energy_calibration.py index 25061fea..03cefc28 100644 --- a/examples/data/preprocessing/energy_calibration.py +++ b/examples/data/preprocessing/energy_calibration.py @@ -8,9 +8,7 @@ import matplotlib.pyplot as plt import numpy as np -from riid.data.synthetic import get_dummy_seeds -from riid.data.synthetic.seed import SeedMixer -from riid.data.synthetic.static import StaticSynthesizer +from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds SYNTHETIC_DATA_CONFIG = { "samples_per_seed": 10, diff --git a/examples/data/synthesis/mix_seeds.py b/examples/data/synthesis/mix_seeds.py index 911fc56d..a2a7bbac 100644 --- a/examples/data/synthesis/mix_seeds.py +++ b/examples/data/synthesis/mix_seeds.py @@ -3,8 +3,8 @@ # the U.S. Government retains certain rights in this software. """This example demonstrates how to generate synthetic gamma spectra from seeds.""" import numpy as np -from riid.data.synthetic import get_dummy_seeds -from riid.data.synthetic.seed import SeedMixer + +from riid import SeedMixer, get_dummy_seeds fg_seeds_ss, bg_seeds_ss = get_dummy_seeds().split_fg_and_bg() diff --git a/examples/data/synthesis/synthesize_passbys.py b/examples/data/synthesis/synthesize_passbys.py index 4c9ca023..7b2bbb3e 100644 --- a/examples/data/synthesis/synthesize_passbys.py +++ b/examples/data/synthesis/synthesize_passbys.py @@ -7,8 +7,7 @@ import matplotlib.pyplot as plt import numpy as np -from riid.data.synthetic import get_dummy_seeds -from riid.data.synthetic.passby import PassbySynthesizer +from riid import PassbySynthesizer, get_dummy_seeds if len(sys.argv) == 2: import matplotlib diff --git a/examples/data/synthesis/synthesize_seeds.py b/examples/data/synthesis/synthesize_seeds.py index ab837dca..a33eabfc 100644 --- a/examples/data/synthesis/synthesize_seeds.py +++ b/examples/data/synthesis/synthesize_seeds.py @@ -4,7 +4,7 @@ """This example demonstrates how to generate synthetic seeds from GADRAS.""" import yaml -from riid.data.synthetic.seed import SeedSynthesizer +from riid import SeedSynthesizer seed_synth_config = """ --- diff --git a/examples/data/synthesis/synthesize_seeds_advanced.py b/examples/data/synthesis/synthesize_seeds_advanced.py index 390ac72b..d259b242 100644 --- a/examples/data/synthesis/synthesize_seeds_advanced.py +++ b/examples/data/synthesis/synthesize_seeds_advanced.py @@ -5,7 +5,7 @@ configuration expansion features.""" import yaml -from riid.data.synthetic.seed import SeedSynthesizer +from riid import SeedSynthesizer seed_synth_config = """ --- diff --git a/examples/data/synthesis/synthesize_spectra.py b/examples/data/synthesis/synthesize_spectra.py index 95530244..ce150123 100644 --- a/examples/data/synthesis/synthesize_spectra.py +++ b/examples/data/synthesis/synthesize_spectra.py @@ -2,9 +2,7 @@ # Under the terms of Contract DE-NA0003525 with NTESS, # the U.S. Government retains certain rights in this software. """This example demonstrates how to generate synthetic gamma spectra from seeds.""" -from riid.data.synthetic import get_dummy_seeds -from riid.data.synthetic.seed import SeedMixer -from riid.data.synthetic.static import StaticSynthesizer +from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds SYNTHETIC_DATA_CONFIG = { "samples_per_seed": 10000, diff --git a/examples/modeling/anomaly_detection.py b/examples/modeling/anomaly_detection.py index e3b27a13..503470e2 100644 --- a/examples/modeling/anomaly_detection.py +++ b/examples/modeling/anomaly_detection.py @@ -10,10 +10,8 @@ import numpy as np from matplotlib import cm +from riid import PassbySynthesizer, SeedMixer, get_dummy_seeds from riid.anomaly import PoissonNChannelEventDetector -from riid.data.synthetic import get_dummy_seeds -from riid.data.synthetic.passby import PassbySynthesizer -from riid.data.synthetic.seed import SeedMixer if len(sys.argv) == 2: import matplotlib diff --git a/examples/modeling/arad.py b/examples/modeling/arad.py index f15d2e4e..120b0d07 100644 --- a/examples/modeling/arad.py +++ b/examples/modeling/arad.py @@ -6,10 +6,8 @@ import numpy as np import pandas as pd -from riid.data.synthetic import get_dummy_seeds -from riid.data.synthetic.seed import SeedMixer -from riid.data.synthetic.static import StaticSynthesizer -from riid.models.neural_nets.arad import ARADv1, ARADv2 +from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds +from riid.models import ARADv1, ARADv2 # Config rng = np.random.default_rng(42) diff --git a/examples/modeling/arad_latent_prediction.py b/examples/modeling/arad_latent_prediction.py index 9b233954..4d3b199b 100644 --- a/examples/modeling/arad_latent_prediction.py +++ b/examples/modeling/arad_latent_prediction.py @@ -8,10 +8,8 @@ from keras.api.metrics import Accuracy, CategoricalCrossentropy from sklearn.metrics import f1_score, mean_squared_error -from riid.data.synthetic import get_dummy_seeds -from riid.data.synthetic.seed import SeedMixer -from riid.data.synthetic.static import StaticSynthesizer -from riid.models.neural_nets.arad import ARADLatentPredictor, ARADv2 +from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds +from riid.models import ARADLatentPredictor, ARADv2 # Config rng = np.random.default_rng(42) diff --git a/examples/modeling/classifier_comparison.py b/examples/modeling/classifier_comparison.py index 4fce43e7..d51af336 100644 --- a/examples/modeling/classifier_comparison.py +++ b/examples/modeling/classifier_comparison.py @@ -7,12 +7,9 @@ import matplotlib.pyplot as plt from sklearn.metrics import f1_score -from riid.data.synthetic import get_dummy_seeds -from riid.data.synthetic.seed import SeedMixer -from riid.data.synthetic.static import StaticSynthesizer -from riid.models.bayes import PoissonBayesClassifier +from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds from riid.metrics import precision_recall_curve -from riid.models.neural_nets import MLPClassifier +from riid.models import MLPClassifier, PoissonBayesClassifier from riid.visualize import plot_precision_recall if len(sys.argv) == 2: diff --git a/examples/modeling/label_proportion_estimation.py b/examples/modeling/label_proportion_estimation.py index b4e5e4c4..1cfe30c7 100644 --- a/examples/modeling/label_proportion_estimation.py +++ b/examples/modeling/label_proportion_estimation.py @@ -4,10 +4,8 @@ from sklearn.metrics import mean_absolute_error -from riid.data.synthetic import get_dummy_seeds -from riid.data.synthetic.seed import SeedMixer -from riid.data.synthetic.static import StaticSynthesizer -from riid.models.neural_nets import LabelProportionEstimator +from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds +from riid.models import LabelProportionEstimator # Generate some mixture training data. fg_seeds_ss, bg_seeds_ss = get_dummy_seeds().split_fg_and_bg() diff --git a/examples/modeling/neural_network_classifier.py b/examples/modeling/neural_network_classifier.py index 277a2dc5..831230c2 100644 --- a/examples/modeling/neural_network_classifier.py +++ b/examples/modeling/neural_network_classifier.py @@ -5,10 +5,8 @@ import numpy as np from sklearn.metrics import f1_score -from riid.data.synthetic import get_dummy_seeds -from riid.data.synthetic.seed import SeedMixer -from riid.data.synthetic.static import StaticSynthesizer -from riid.models.neural_nets import MLPClassifier +from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds +from riid.models import MLPClassifier # Generate some training data fg_seeds_ss, bg_seeds_ss = get_dummy_seeds().split_fg_and_bg() diff --git a/examples/run_examples.py b/examples/run_examples.py index aaa29c0f..1424e7de 100644 --- a/examples/run_examples.py +++ b/examples/run_examples.py @@ -6,6 +6,7 @@ import subprocess import sys from pathlib import Path + import pandas as pd from tabulate import tabulate diff --git a/examples/visualization/confusion_matrix.py b/examples/visualization/confusion_matrix.py index b2e25e72..27dd2705 100644 --- a/examples/visualization/confusion_matrix.py +++ b/examples/visualization/confusion_matrix.py @@ -4,10 +4,8 @@ """This example demonstrates how to obtain confusion matrices.""" import sys -from riid.data.synthetic import get_dummy_seeds -from riid.data.synthetic.seed import SeedMixer -from riid.data.synthetic.static import StaticSynthesizer -from riid.models.neural_nets import MLPClassifier +from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds +from riid.models import MLPClassifier from riid.visualize import confusion_matrix if len(sys.argv) == 2: diff --git a/examples/visualization/distance_matrix.py b/examples/visualization/distance_matrix.py index 5f229f25..49e5af33 100644 --- a/examples/visualization/distance_matrix.py +++ b/examples/visualization/distance_matrix.py @@ -9,7 +9,7 @@ import matplotlib.pyplot as plt import seaborn as sns -from riid.data.synthetic import get_dummy_seeds +from riid import get_dummy_seeds if len(sys.argv) == 2: import matplotlib diff --git a/examples/visualization/plot_sampleset_compare_to.py b/examples/visualization/plot_sampleset_compare_to.py index 9ccf94bb..5a99f05d 100644 --- a/examples/visualization/plot_sampleset_compare_to.py +++ b/examples/visualization/plot_sampleset_compare_to.py @@ -4,9 +4,7 @@ """This example demonstrates how to compare sample sets.""" import sys -from riid.data.synthetic import get_dummy_seeds -from riid.data.synthetic.seed import SeedMixer -from riid.data.synthetic.static import StaticSynthesizer +from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds from riid.visualize import plot_ss_comparison if len(sys.argv) == 2: diff --git a/examples/visualization/plot_spectra.py b/examples/visualization/plot_spectra.py index d0f1e1e7..04f71bd1 100644 --- a/examples/visualization/plot_spectra.py +++ b/examples/visualization/plot_spectra.py @@ -4,7 +4,7 @@ """This example demonstrates how to plot gamma spectra.""" import sys -from riid.data.synthetic import get_dummy_seeds +from riid import get_dummy_seeds from riid.visualize import plot_spectra if len(sys.argv) == 2: diff --git a/riid/__init__.py b/riid/__init__.py index d19e9880..44fe4e4b 100644 --- a/riid/__init__.py +++ b/riid/__init__.py @@ -7,9 +7,15 @@ import logging import os import sys - from importlib.metadata import version +from riid.data.sampleset import (SampleSet, SpectraState, SpectraType, + read_hdf, read_json, read_pcf) +from riid.data.synthetic.passby import PassbySynthesizer +from riid.data.synthetic.seed import (SeedMixer, SeedSynthesizer, + get_dummy_seeds) +from riid.data.synthetic.static import StaticSynthesizer + HANDLER = logging.StreamHandler(sys.stdout) logging.root.addHandler(HANDLER) logging.root.setLevel(logging.DEBUG) @@ -31,3 +37,7 @@ "riid.data.synthetic.passby.PassbySynthesizer._generate_single_passby": True, "riid.data.sampleset.SampleSet._channels_to_energies": True, } + +__all__ = ["SampleSet", "SpectraState", "SpectraType", + "read_hdf", "read_json", "read_pcf", "get_dummy_seeds", + "PassbySynthesizer", "SeedSynthesizer", "StaticSynthesizer", "SeedMixer"] diff --git a/riid/data/converters/aipt.py b/riid/data/converters/aipt.py index d890a440..06f7c91a 100644 --- a/riid/data/converters/aipt.py +++ b/riid/data/converters/aipt.py @@ -10,9 +10,8 @@ import pandas as pd -from riid import SAMPLESET_HDF_FILE_EXTENSION +from riid import SAMPLESET_HDF_FILE_EXTENSION, SampleSet from riid.data.converters import _validate_and_create_output_dir -from riid.data.sampleset import SampleSet ELEMENT_IDS_PER_FILE = [0, 1, 2, 3] DEFAULT_ECAL = [ diff --git a/riid/data/converters/topcoder.py b/riid/data/converters/topcoder.py index e75b07a7..fffd8ada 100644 --- a/riid/data/converters/topcoder.py +++ b/riid/data/converters/topcoder.py @@ -13,10 +13,9 @@ import numpy as np import pandas as pd -from riid import SAMPLESET_HDF_FILE_EXTENSION +from riid import SAMPLESET_HDF_FILE_EXTENSION, SampleSet from riid.data.converters import _validate_and_create_output_dir from riid.data.labeling import label_to_index_element -from riid.data.sampleset import SampleSet SOURCE_ID_TO_LABEL = { 0: "Background", diff --git a/riid/data/synthetic/__init__.py b/riid/data/synthetic/__init__.py index 015c491d..bc07729c 100644 --- a/riid/data/synthetic/__init__.py +++ b/riid/data/synthetic/__init__.py @@ -1,373 +1,9 @@ # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). # Under the terms of Contract DE-NA0003525 with NTESS, # the U.S. Government retains certain rights in this software. -"""This modules contains utilities for synthesizing gamma spectra.""" -from collections import Counter -from typing import Any +"""This module contains utilities for synthesizing gamma spectra.""" +# The following imports are left to not break previous imports; remove in v3 +from riid.data.synthetic.base import Synthesizer, get_distribution_values +from riid.data.synthetic.seed import get_dummy_seeds -import numpy as np -import pandas as pd -from numpy.random import Generator - -from riid.data import get_expected_spectra -from riid.data.sampleset import (SampleSet, SpectraState, SpectraType, - _get_utc_timestamp) - - -class Synthesizer(): - """Base class for synthesizers.""" - - SYNTHETIC_STR = "synthetic" - SUPPORTED_SAMPLING_FUNCTIONS = ["uniform", "log10", "discrete", "list"] - - def __init__(self, bg_cps: float = 300.0, long_bg_live_time: float = 120.0, - apply_poisson_noise: bool = True, - normalize_sources: bool = True, - return_fg: bool = True, - return_gross: bool = False, - rng: Generator = np.random.default_rng()): - """ - Args: - bg_cps: constant rate of gammas from background - long_bg_live_time: live time on which to base background subtractions - apply_poisson_noise: whether to apply Poisson noise to spectra - normalize_sources: whether to normalize ground truth proportions to sum to 1 - return_fg: whether to compute and return background subtracted spectra - return_gross: whether to return gross spectra (always computed) - rng: NumPy random number generator, useful for experiment repeatability - """ - self.bg_cps = bg_cps - self.long_bg_live_time = long_bg_live_time - self.apply_poisson_noise = apply_poisson_noise - self.normalize_sources = normalize_sources - self.return_fg = return_fg - self.return_gross = return_gross - self._rng = rng - self._synthesis_start_dt = None - self._n_samples_synthesized = 0 - - def __str__(self): - output = "SynthesizerConfig" - for k, v in sorted(vars(self).items()): - output += " {}: {}".format(k, str(v)) - return output - - def _reset_progress(self): - self._n_samples_synthesized = 0 - self._synthesis_start_dt = _get_utc_timestamp() - - def _report_progress(self, n_samples_expected, batch_name): - percent_complete = 100 * self._n_samples_synthesized / n_samples_expected - msg = ( - f"Synthesizing ... {percent_complete:.0f}% " - f"(currently on {batch_name}" - ) - MAX_MSG_LEN = 80 - msg = (msg[:MAX_MSG_LEN] + "...") if len(msg) > MAX_MSG_LEN else msg - msg += ")" - print("\033[K" + msg, end="\r") - - def _report_completion(self, delay): - summary = ( - f"Synthesis complete!\n" - f"Generated {self._n_samples_synthesized} samples in ~{delay:.2f}s " - f"(~{(self._n_samples_synthesized / delay):.2f} samples/sec)." - ) - print("\033[K" + summary) - - def _verify_n_samples_synthesized(self, actual: int, expected: int): - assert expected == actual, ( - f"{actual} generated, but {expected} were expected. " - "Be sure to remove any columns from your seeds' sources DataFrame that " - "contain all zeroes.") - - def _get_batch(self, fg_seed, fg_sources, bg_seed, bg_sources, ecal, - lt_targets, snr_targets, rt_targets=None, distance_cm=None): - if not (self.return_fg or self.return_gross): - raise ValueError("Computing to return nothing.") - - bg_counts_expected = lt_targets * self.bg_cps - fg_counts_expected = snr_targets * np.sqrt(bg_counts_expected) - - fg_spectra = get_expected_spectra(fg_seed.values, fg_counts_expected) - bg_spectra = get_expected_spectra(bg_seed.values, bg_counts_expected) - - long_bg_counts_expected = self.long_bg_live_time * self.bg_cps - long_bg_spectrum_expected = bg_seed.values * long_bg_counts_expected - - gross_spectra = None - long_bg_spectra = None - fg_counts = 0 - bg_counts = 0 - long_bg_counts = 0 - fg_ss = None - gross_ss = None - - # Spectra - if self.apply_poisson_noise: - gross_spectra = self._rng.poisson(fg_spectra + bg_spectra) - if self.return_fg: - long_bg_spectrum = self._rng.poisson(long_bg_spectrum_expected) - long_bg_seed = long_bg_spectrum / long_bg_spectrum.sum() - long_bg_spectra = get_expected_spectra(long_bg_seed, bg_counts_expected) - fg_spectra = gross_spectra - long_bg_spectra - else: - gross_spectra = fg_spectra + bg_spectra - if self.return_fg: - long_bg_spectra = bg_spectra - fg_spectra = gross_spectra - long_bg_spectra - - # Counts - fg_counts = fg_spectra.sum(axis=1, dtype=float) - if self.return_fg: - long_bg_counts = long_bg_spectra.sum(axis=1, dtype=float) - if self.return_gross: - bg_counts = bg_spectra.sum(axis=1, dtype=float) - - # Sample sets - if self.return_fg: - snrs = fg_counts / np.sqrt(long_bg_counts.clip(1)) - fg_ss = get_fg_sample_set(fg_spectra, fg_sources, ecal, lt_targets, - snrs=snrs, total_counts=fg_counts, - real_times=rt_targets, distance_cm=distance_cm, - timestamps=self._synthesis_start_dt) - self._n_samples_synthesized += fg_ss.n_samples - if self.return_gross: - tiled_fg_sources = _tile_sources_and_scale( - fg_sources, - gross_spectra.shape[0], - fg_counts, - ) - tiled_bg_sources = _tile_sources_and_scale( - bg_sources, - gross_spectra.shape[0], - bg_counts, - ) - gross_sources = get_merged_sources_samplewise(tiled_fg_sources, tiled_bg_sources) - gross_counts = gross_spectra.sum(axis=1) - snrs = fg_counts / np.sqrt(bg_counts.clip(1)) - gross_ss = get_gross_sample_set(gross_spectra, gross_sources, ecal, - lt_targets, snrs, gross_counts, - real_times=rt_targets, distance_cm=distance_cm, - timestamps=self._synthesis_start_dt) - self._n_samples_synthesized += gross_ss.n_samples - - return fg_ss, gross_ss - - -def get_sample_set(spectra, sources, ecal, live_times, snrs, total_counts=None, - real_times=None, distance_cm=None, timestamps=None, - descriptions=None) -> SampleSet: - n_samples = spectra.shape[0] - - ss = SampleSet() - ss.spectra_state = SpectraState.Counts - ss.spectra = pd.DataFrame(spectra) - ss.sources = sources - ss.info.description = np.full(n_samples, "") # Ensures the length of info equal n_samples - if descriptions: - ss.info.description = descriptions - ss.info.snr = snrs - ss.info.timestamp = timestamps - ss.info.total_counts = total_counts if total_counts is not None else spectra.sum(axis=1) - ss.info.ecal_order_0 = ecal[0] - ss.info.ecal_order_1 = ecal[1] - ss.info.ecal_order_2 = ecal[2] - ss.info.ecal_order_3 = ecal[3] - ss.info.ecal_low_e = ecal[4] - ss.info.live_time = live_times - ss.info.real_time = real_times if real_times is not None else live_times - ss.info.distance_cm = distance_cm - ss.info.occupancy_flag = 0 - ss.info.tag = " " # TODO: test if this can be empty string - - return ss - - -def _tile_sources_and_scale(sources, n_samples, scalars) -> pd.DataFrame: - tiled_sources = pd.DataFrame( - np.tile(sources.values, (n_samples, 1)), - columns=sources.index - ) - # Multiplying normalized source values by spectrum counts. - # This is REQUIRED for properly merging sources DataFrames later when synthesizing - # multiple isotopes. - tiled_sources = tiled_sources.multiply(scalars, axis="index") - return tiled_sources - - -def get_fg_sample_set(spectra, sources, ecal, live_times, snrs, total_counts, - real_times=None, distance_cm=None, timestamps=None, - descriptions=None) -> SampleSet: - tiled_sources = _tile_sources_and_scale( - sources, - spectra.shape[0], - spectra.sum(axis=1) - ) - ss = get_sample_set( - spectra=spectra, - sources=tiled_sources, - ecal=ecal, - live_times=live_times, - snrs=snrs, - total_counts=total_counts, - real_times=real_times, - distance_cm=distance_cm, - timestamps=timestamps, - descriptions=descriptions - ) - ss.spectra_type = SpectraType.Foreground - return ss - - -def get_gross_sample_set(spectra, sources, ecal, live_times, snrs, total_counts, - real_times=None, distance_cm=None, timestamps=None, - descriptions=None) -> SampleSet: - ss = get_sample_set( - spectra=spectra, - sources=sources, - ecal=ecal, - live_times=live_times, - snrs=snrs, - total_counts=total_counts, - real_times=real_times, - distance_cm=distance_cm, - timestamps=timestamps, - descriptions=descriptions - ) - ss.spectra_type = SpectraType.Gross - return ss - - -def get_distribution_values(function: str, function_args: Any, n_values: int, - rng: Generator = np.random.default_rng()): - """Randomly sample a list of values based one of many distributions. - - Args: - function: name of the distribution function - function_args: argument or collection of arguments to be - passed to the function, if any. - n_values: size of the distribution - rng: NumPy random number generator, useful for experiment repeatability - - Returns: - Value or collection of sampled values - - Raises: - `ValueError` when an unsupported function type is provided - """ - values = None - if function == "uniform": - values = rng.uniform(*function_args, size=n_values) - elif function == "log10": - log10_args = tuple(map(np.log10, function_args)) - values = np.power(10, rng.uniform(*log10_args, size=n_values)) - elif function == "discrete": - values = rng.choice(function_args, size=n_values) - elif function == "list": - values = np.array(function_args) - else: - raise ValueError(f"{function} function not supported for sampling.") - - return values - - -def get_merged_sources_samplewise(sources1: pd.DataFrame, sources2: pd.DataFrame) -> pd.DataFrame: - merged_sources_df = sources1.add(sources2, axis=1, fill_value=0) - return merged_sources_df - - -def get_samples_per_seed(columns: pd.MultiIndex, min_samples_per_seed: int, balance_level: int): - level_values = columns.get_level_values(level=balance_level) - level_value_to_n_seeds = Counter(level_values) - unique_level_values = list(level_value_to_n_seeds.keys()) - occurences = np.array(list(level_value_to_n_seeds.values())) - max_samples_per_level_value = occurences.max() * min_samples_per_seed - samples_per_level_value = np.ceil(max_samples_per_level_value / occurences).astype(int) - lv_to_samples_per_seed = {k: v for (k, v) in zip(unique_level_values, samples_per_level_value)} - total_samples_expected = sum([x * y for x, y in zip(occurences, samples_per_level_value)]) - - return lv_to_samples_per_seed, total_samples_expected - - -def get_dummy_seeds(n_channels: int = 512, live_time: float = 600.0, - count_rate: float = 1000.0, normalize: bool = True, - rng: Generator = np.random.default_rng()) -> SampleSet: - """Get a random, dummy `SampleSet` of ideal seeds. - - WARNING: the spectra returned by this function each contain one gaussian peak that does - not overlap with the peaks of other spectra. Such data is about as *ideal* as one - could hope to be working with and does not represent anything real. - Therefore, **do not** use this data for any purpose other than testing, debugging, or - examples where code, not results, is being demonstrated. Any use in scientific studies - does not make sense. - - Args: - n_channels: number of channels in the spectra DataFrame - live_time: collection time on which to base seeds - (higher creates a less noisy shape) - count_rate: count rate on which to base seeds - (higher creates a less noisy shape) - normalize: whether to apply an L1-norm to the spectra - rng: NumPy random number generator, useful for experiment repeatability - - Returns: - `SampleSet` with randomly generated spectra - """ - ss = SampleSet() - ss.measured_or_synthetic = "synthetic" - ss.spectra_state = SpectraState.Counts - ss.spectra_type = SpectraType.BackgroundForeground - ss.synthesis_info = { - "subtract_background": True, - } - sources = [ - ("Industrial", "Am241", "Unshielded Am241"), - ("Industrial", "Ba133", "Unshielded Ba133"), - ("NORM", "K40", "PotassiumInSoil"), - ("NORM", "K40", "Moderately Shielded K40"), - ("NORM", "Ra226", "UraniumInSoil"), - ("NORM", "Th232", "ThoriumInSoil"), - ("SNM", "U238", "Unshielded U238"), - ("SNM", "Pu239", "Unshielded Pu239"), - ("SNM", "Pu239", "Moderately Shielded Pu239"), - ("SNM", "Pu239", "Heavily Shielded Pu239"), - ] - n_sources = len(sources) - n_fg_sources = n_sources - sources_cols = pd.MultiIndex.from_tuples( - sources, - names=SampleSet.SOURCES_MULTI_INDEX_NAMES - ) - sources_data = np.identity(n_sources) - ss.sources = pd.DataFrame(data=sources_data, columns=sources_cols) - - histograms = [] - N_FG_COUNTS = int(count_rate * live_time) - fg_std = np.sqrt(n_channels / n_sources) - channels_per_sources = n_channels / n_fg_sources - for i in range(n_fg_sources): - mu = i * channels_per_sources + channels_per_sources / 2 - counts = rng.normal(mu, fg_std, size=N_FG_COUNTS) - fg_histogram, _ = np.histogram(counts, bins=n_channels, range=(0, n_channels)) - histograms.append(fg_histogram) - histograms = np.array(histograms) - - ss.spectra = pd.DataFrame(data=histograms) - - ss.info.total_counts = ss.spectra.sum(axis=1) - ss.info.live_time = live_time - ss.info.real_time = live_time - ss.info.snr = None - ss.info.ecal_order_0 = 0 - ss.info.ecal_order_1 = 3000 - ss.info.ecal_order_2 = 100 - ss.info.ecal_order_3 = 0 - ss.info.ecal_low_e = 0 - ss.info.description = "" - ss.update_timestamp() - - if normalize: - ss.normalize() - - return ss +__all__ = ["get_dummy_seeds", "Synthesizer", "get_distribution_values"] diff --git a/riid/data/synthetic/base.py b/riid/data/synthetic/base.py new file mode 100644 index 00000000..9209c7f3 --- /dev/null +++ b/riid/data/synthetic/base.py @@ -0,0 +1,290 @@ +# Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). +# Under the terms of Contract DE-NA0003525 with NTESS, +# the U.S. Government retains certain rights in this software. +"""This module contains utilities for synthesizing gamma spectra.""" +from collections import Counter +from typing import Any + +import numpy as np +import pandas as pd +from numpy.random import Generator + +from riid.data import get_expected_spectra +from riid.data.sampleset import (SampleSet, SpectraState, SpectraType, + _get_utc_timestamp) + + +class Synthesizer(): + """Base class for synthesizers.""" + + SYNTHETIC_STR = "synthetic" + SUPPORTED_SAMPLING_FUNCTIONS = ["uniform", "log10", "discrete", "list"] + + def __init__(self, bg_cps: float = 300.0, long_bg_live_time: float = 120.0, + apply_poisson_noise: bool = True, + normalize_sources: bool = True, + return_fg: bool = True, + return_gross: bool = False, + rng: Generator = np.random.default_rng()): + """ + Args: + bg_cps: constant rate of gammas from background + long_bg_live_time: live time on which to base background subtractions + apply_poisson_noise: whether to apply Poisson noise to spectra + normalize_sources: whether to normalize ground truth proportions to sum to 1 + return_fg: whether to compute and return background subtracted spectra + return_gross: whether to return gross spectra (always computed) + rng: NumPy random number generator, useful for experiment repeatability + """ + self.bg_cps = bg_cps + self.long_bg_live_time = long_bg_live_time + self.apply_poisson_noise = apply_poisson_noise + self.normalize_sources = normalize_sources + self.return_fg = return_fg + self.return_gross = return_gross + self._rng = rng + self._synthesis_start_dt = None + self._n_samples_synthesized = 0 + + def __str__(self): + output = "SynthesizerConfig" + for k, v in sorted(vars(self).items()): + output += " {}: {}".format(k, str(v)) + return output + + def _reset_progress(self): + self._n_samples_synthesized = 0 + self._synthesis_start_dt = _get_utc_timestamp() + + def _report_progress(self, n_samples_expected, batch_name): + percent_complete = 100 * self._n_samples_synthesized / n_samples_expected + msg = ( + f"Synthesizing ... {percent_complete:.0f}% " + f"(currently on {batch_name}" + ) + MAX_MSG_LEN = 80 + msg = (msg[:MAX_MSG_LEN] + "...") if len(msg) > MAX_MSG_LEN else msg + msg += ")" + print("\033[K" + msg, end="\r") + + def _report_completion(self, delay): + summary = ( + f"Synthesis complete!\n" + f"Generated {self._n_samples_synthesized} samples in ~{delay:.2f}s " + f"(~{(self._n_samples_synthesized / delay):.2f} samples/sec)." + ) + print("\033[K" + summary) + + def _verify_n_samples_synthesized(self, actual: int, expected: int): + assert expected == actual, ( + f"{actual} generated, but {expected} were expected. " + "Be sure to remove any columns from your seeds' sources DataFrame that " + "contain all zeroes.") + + def _get_batch(self, fg_seed, fg_sources, bg_seed, bg_sources, ecal, + lt_targets, snr_targets, rt_targets=None, distance_cm=None): + if not (self.return_fg or self.return_gross): + raise ValueError("Computing to return nothing.") + + bg_counts_expected = lt_targets * self.bg_cps + fg_counts_expected = snr_targets * np.sqrt(bg_counts_expected) + + fg_spectra = get_expected_spectra(fg_seed.values, fg_counts_expected) + bg_spectra = get_expected_spectra(bg_seed.values, bg_counts_expected) + + long_bg_counts_expected = self.long_bg_live_time * self.bg_cps + long_bg_spectrum_expected = bg_seed.values * long_bg_counts_expected + + gross_spectra = None + long_bg_spectra = None + fg_counts = 0 + bg_counts = 0 + long_bg_counts = 0 + fg_ss = None + gross_ss = None + + # Spectra + if self.apply_poisson_noise: + gross_spectra = self._rng.poisson(fg_spectra + bg_spectra) + if self.return_fg: + long_bg_spectrum = self._rng.poisson(long_bg_spectrum_expected) + long_bg_seed = long_bg_spectrum / long_bg_spectrum.sum() + long_bg_spectra = get_expected_spectra(long_bg_seed, bg_counts_expected) + fg_spectra = gross_spectra - long_bg_spectra + else: + gross_spectra = fg_spectra + bg_spectra + if self.return_fg: + long_bg_spectra = bg_spectra + fg_spectra = gross_spectra - long_bg_spectra + + # Counts + fg_counts = fg_spectra.sum(axis=1, dtype=float) + if self.return_fg: + long_bg_counts = long_bg_spectra.sum(axis=1, dtype=float) + if self.return_gross: + bg_counts = bg_spectra.sum(axis=1, dtype=float) + + # Sample sets + if self.return_fg: + snrs = fg_counts / np.sqrt(long_bg_counts.clip(1)) + fg_ss = get_fg_sample_set(fg_spectra, fg_sources, ecal, lt_targets, + snrs=snrs, total_counts=fg_counts, + real_times=rt_targets, distance_cm=distance_cm, + timestamps=self._synthesis_start_dt) + self._n_samples_synthesized += fg_ss.n_samples + if self.return_gross: + tiled_fg_sources = _tile_sources_and_scale( + fg_sources, + gross_spectra.shape[0], + fg_counts, + ) + tiled_bg_sources = _tile_sources_and_scale( + bg_sources, + gross_spectra.shape[0], + bg_counts, + ) + gross_sources = get_merged_sources_samplewise(tiled_fg_sources, tiled_bg_sources) + gross_counts = gross_spectra.sum(axis=1) + snrs = fg_counts / np.sqrt(bg_counts.clip(1)) + gross_ss = get_gross_sample_set(gross_spectra, gross_sources, ecal, + lt_targets, snrs, gross_counts, + real_times=rt_targets, distance_cm=distance_cm, + timestamps=self._synthesis_start_dt) + self._n_samples_synthesized += gross_ss.n_samples + + return fg_ss, gross_ss + + +def get_sample_set(spectra, sources, ecal, live_times, snrs, total_counts=None, + real_times=None, distance_cm=None, timestamps=None, + descriptions=None) -> SampleSet: + n_samples = spectra.shape[0] + + ss = SampleSet() + ss.spectra_state = SpectraState.Counts + ss.spectra = pd.DataFrame(spectra) + ss.sources = sources + ss.info.description = np.full(n_samples, "") # Ensures the length of info equal n_samples + if descriptions: + ss.info.description = descriptions + ss.info.snr = snrs + ss.info.timestamp = timestamps + ss.info.total_counts = total_counts if total_counts is not None else spectra.sum(axis=1) + ss.info.ecal_order_0 = ecal[0] + ss.info.ecal_order_1 = ecal[1] + ss.info.ecal_order_2 = ecal[2] + ss.info.ecal_order_3 = ecal[3] + ss.info.ecal_low_e = ecal[4] + ss.info.live_time = live_times + ss.info.real_time = real_times if real_times is not None else live_times + ss.info.distance_cm = distance_cm + ss.info.occupancy_flag = 0 + ss.info.tag = " " # TODO: test if this can be empty string + + return ss + + +def _tile_sources_and_scale(sources, n_samples, scalars) -> pd.DataFrame: + tiled_sources = pd.DataFrame( + np.tile(sources.values, (n_samples, 1)), + columns=sources.index + ) + # Multiplying normalized source values by spectrum counts. + # This is REQUIRED for properly merging sources DataFrames later when synthesizing + # multiple isotopes. + tiled_sources = tiled_sources.multiply(scalars, axis="index") + return tiled_sources + + +def get_fg_sample_set(spectra, sources, ecal, live_times, snrs, total_counts, + real_times=None, distance_cm=None, timestamps=None, + descriptions=None) -> SampleSet: + tiled_sources = _tile_sources_and_scale( + sources, + spectra.shape[0], + spectra.sum(axis=1) + ) + ss = get_sample_set( + spectra=spectra, + sources=tiled_sources, + ecal=ecal, + live_times=live_times, + snrs=snrs, + total_counts=total_counts, + real_times=real_times, + distance_cm=distance_cm, + timestamps=timestamps, + descriptions=descriptions + ) + ss.spectra_type = SpectraType.Foreground + return ss + + +def get_gross_sample_set(spectra, sources, ecal, live_times, snrs, total_counts, + real_times=None, distance_cm=None, timestamps=None, + descriptions=None) -> SampleSet: + ss = get_sample_set( + spectra=spectra, + sources=sources, + ecal=ecal, + live_times=live_times, + snrs=snrs, + total_counts=total_counts, + real_times=real_times, + distance_cm=distance_cm, + timestamps=timestamps, + descriptions=descriptions + ) + ss.spectra_type = SpectraType.Gross + return ss + + +def get_distribution_values(function: str, function_args: Any, n_values: int, + rng: Generator = np.random.default_rng()): + """Randomly sample a list of values based one of many distributions. + + Args: + function: name of the distribution function + function_args: argument or collection of arguments to be + passed to the function, if any. + n_values: size of the distribution + rng: NumPy random number generator, useful for experiment repeatability + + Returns: + Value or collection of sampled values + + Raises: + `ValueError` when an unsupported function type is provided + """ + values = None + if function == "uniform": + values = rng.uniform(*function_args, size=n_values) + elif function == "log10": + log10_args = tuple(map(np.log10, function_args)) + values = np.power(10, rng.uniform(*log10_args, size=n_values)) + elif function == "discrete": + values = rng.choice(function_args, size=n_values) + elif function == "list": + values = np.array(function_args) + else: + raise ValueError(f"{function} function not supported for sampling.") + + return values + + +def get_merged_sources_samplewise(sources1: pd.DataFrame, sources2: pd.DataFrame) -> pd.DataFrame: + merged_sources_df = sources1.add(sources2, axis=1, fill_value=0) + return merged_sources_df + + +def get_samples_per_seed(columns: pd.MultiIndex, min_samples_per_seed: int, balance_level: int): + level_values = columns.get_level_values(level=balance_level) + level_value_to_n_seeds = Counter(level_values) + unique_level_values = list(level_value_to_n_seeds.keys()) + occurences = np.array(list(level_value_to_n_seeds.values())) + max_samples_per_level_value = occurences.max() * min_samples_per_seed + samples_per_level_value = np.ceil(max_samples_per_level_value / occurences).astype(int) + lv_to_samples_per_seed = {k: v for (k, v) in zip(unique_level_values, samples_per_level_value)} + total_samples_expected = sum([x * y for x, y in zip(occurences, samples_per_level_value)]) + + return lv_to_samples_per_seed, total_samples_expected diff --git a/riid/data/synthetic/passby.py b/riid/data/synthetic/passby.py index 12b1a68f..84234618 100644 --- a/riid/data/synthetic/passby.py +++ b/riid/data/synthetic/passby.py @@ -11,8 +11,8 @@ import pandas as pd from numpy.random import Generator -from riid.data.sampleset import SampleSet -from riid.data.synthetic import Synthesizer, get_distribution_values +from riid import SampleSet +from riid.data.synthetic.base import Synthesizer, get_distribution_values class PassbySynthesizer(Synthesizer): diff --git a/riid/data/synthetic/seed.py b/riid/data/synthetic/seed.py index 9fd1b2c4..02f70412 100644 --- a/riid/data/synthetic/seed.py +++ b/riid/data/synthetic/seed.py @@ -1,7 +1,7 @@ # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). # Under the terms of Contract DE-NA0003525 with NTESS, # the U.S. Government retains certain rights in this software. -"""This modules contains utilities for generating synthetic gamma spectrum templates from GADRAS.""" +"""This module contains utilities for generating synthetic gamma spectrum templates from GADRAS.""" import os from contextlib import contextmanager from copy import deepcopy @@ -12,7 +12,8 @@ import yaml from numpy.random import Generator -from riid.data.sampleset import SampleSet, _get_utc_timestamp, read_pcf +from riid import SampleSet, SpectraState, SpectraType, read_pcf +from riid.data.sampleset import _get_utc_timestamp from riid.gadras.api import (DETECTOR_PARAMS, GADRAS_ASSEMBLY_PATH, INJECT_PARAMS, SourceInjector, get_gadras_api, get_inject_setups, validate_inject_config) @@ -405,3 +406,86 @@ def get_choices(choices_so_far: list, options: list, options_probas: np.array, n_choices_remaining -= 1 return get_choices(choices_so_far, options, options_probas, restricted_pairs, n_choices_remaining, rng) + + +def get_dummy_seeds(n_channels: int = 512, live_time: float = 600.0, + count_rate: float = 1000.0, normalize: bool = True, + rng: Generator = np.random.default_rng()) -> SampleSet: + """Get a random, dummy `SampleSet` of ideal seeds. + + WARNING: the spectra returned by this function each contain one gaussian peak that does + not overlap with the peaks of other spectra. Such data is about as *ideal* as one + could hope to be working with and does not represent anything real. + Therefore, **do not** use this data for any purpose other than testing, debugging, or + examples where code, not results, is being demonstrated. Any use in scientific studies + does not make sense. + + Args: + n_channels: number of channels in the spectra DataFrame + live_time: collection time on which to base seeds + (higher creates a less noisy shape) + count_rate: count rate on which to base seeds + (higher creates a less noisy shape) + normalize: whether to apply an L1-norm to the spectra + rng: NumPy random number generator, useful for experiment repeatability + + Returns: + `SampleSet` with randomly generated spectra + """ + ss = SampleSet() + ss.measured_or_synthetic = "synthetic" + ss.spectra_state = SpectraState.Counts + ss.spectra_type = SpectraType.BackgroundForeground + ss.synthesis_info = { + "subtract_background": True, + } + sources = [ + ("Industrial", "Am241", "Unshielded Am241"), + ("Industrial", "Ba133", "Unshielded Ba133"), + ("NORM", "K40", "PotassiumInSoil"), + ("NORM", "K40", "Moderately Shielded K40"), + ("NORM", "Ra226", "UraniumInSoil"), + ("NORM", "Th232", "ThoriumInSoil"), + ("SNM", "U238", "Unshielded U238"), + ("SNM", "Pu239", "Unshielded Pu239"), + ("SNM", "Pu239", "Moderately Shielded Pu239"), + ("SNM", "Pu239", "Heavily Shielded Pu239"), + ] + n_sources = len(sources) + n_fg_sources = n_sources + sources_cols = pd.MultiIndex.from_tuples( + sources, + names=SampleSet.SOURCES_MULTI_INDEX_NAMES + ) + sources_data = np.identity(n_sources) + ss.sources = pd.DataFrame(data=sources_data, columns=sources_cols) + + histograms = [] + N_FG_COUNTS = int(count_rate * live_time) + fg_std = np.sqrt(n_channels / n_sources) + channels_per_sources = n_channels / n_fg_sources + for i in range(n_fg_sources): + mu = i * channels_per_sources + channels_per_sources / 2 + counts = rng.normal(mu, fg_std, size=N_FG_COUNTS) + fg_histogram, _ = np.histogram(counts, bins=n_channels, range=(0, n_channels)) + histograms.append(fg_histogram) + histograms = np.array(histograms) + + ss.spectra = pd.DataFrame(data=histograms) + + ss.info.total_counts = ss.spectra.sum(axis=1) + ss.info.live_time = live_time + ss.info.real_time = live_time + ss.info.snr = None + ss.info.ecal_order_0 = 0 + ss.info.ecal_order_1 = 3000 + ss.info.ecal_order_2 = 100 + ss.info.ecal_order_3 = 0 + ss.info.ecal_low_e = 0 + ss.info.description = "" + ss.update_timestamp() + + if normalize: + ss.normalize() + + return ss diff --git a/riid/data/synthetic/static.py b/riid/data/synthetic/static.py index f0fd82dd..cdf0aed0 100644 --- a/riid/data/synthetic/static.py +++ b/riid/data/synthetic/static.py @@ -10,8 +10,8 @@ import numpy as np from numpy.random import Generator -from riid.data.sampleset import SampleSet, SpectraState, SpectraType -from riid.data.synthetic import Synthesizer, get_distribution_values +from riid import SampleSet, SpectraState, SpectraType +from riid.data.synthetic.base import Synthesizer, get_distribution_values class StaticSynthesizer(Synthesizer): diff --git a/riid/gadras/api.py b/riid/gadras/api.py index 187dc95a..c7af0e94 100644 --- a/riid/gadras/api.py +++ b/riid/gadras/api.py @@ -2,20 +2,20 @@ # Under the terms of Contract DE-NA0003525 with NTESS, # the U.S. Government retains certain rights in this software. """This module contains utilities for working with the GADRAS API.""" +import copy +import itertools import json import os import sys -import numpy as np -from numpy.random import Generator from typing import List -import copy -import itertools +import numpy as np import tqdm from jsonschema import validate +from numpy.random import Generator -from riid.data.sampleset import SampleSet, read_pcf -from riid.data.synthetic import get_distribution_values +from riid import SampleSet, read_pcf +from riid.data.synthetic.base import get_distribution_values GADRAS_API_SEEMINGLY_AVAILABLE = False GADRAS_DIR_ENV_VAR_KEY = "GADRAS_DIR" diff --git a/riid/metrics.py b/riid/metrics.py index aa677448..81b5a9c9 100644 --- a/riid/metrics.py +++ b/riid/metrics.py @@ -5,7 +5,7 @@ import numpy as np import sklearn -from riid.data.sampleset import SampleSet +from riid import SampleSet def multi_f1(y_true: np.ndarray, y_pred: np.ndarray) -> float: diff --git a/riid/models/__init__.py b/riid/models/__init__.py index 0825b251..3a73b7a0 100644 --- a/riid/models/__init__.py +++ b/riid/models/__init__.py @@ -1,243 +1,10 @@ # Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). # Under the terms of Contract DE-NA0003525 with NTESS, # the U.S. Government retains certain rights in this software. -"""This module contains functionality shared across all PyRIID models.""" -import json -import os -import uuid -from abc import abstractmethod -from enum import Enum +"""This module contains PyRIID models.""" +from riid.models.bayes import PoissonBayesClassifier +from riid.models.neural_nets import LabelProportionEstimator, MLPClassifier +from riid.models.neural_nets.arad import ARADLatentPredictor, ARADv1, ARADv2 -import numpy as np -import tensorflow as tf -import tf2onnx -from keras.api.models import Model -from keras.api.utils import get_custom_objects - -import riid -from riid.data.labeling import label_to_index_element -from riid.data.sampleset import SampleSet, SpectraState -from riid.losses import mish -from riid.metrics import multi_f1, single_f1 - -get_custom_objects().update({ - "multi_f1": multi_f1, - "single_f1": single_f1, - "mish": mish, -}) - - -class ModelInput(int, Enum): - """Enumerates the potential input sources for a model.""" - GrossSpectrum = 0 - BackgroundSpectrum = 1 - ForegroundSpectrum = 2 - - -class PyRIIDModel: - """Base class for PyRIID models.""" - - def __init__(self, *args, **kwargs): - self._info = {} - self._temp_file_path = "temp_model.json" - self._custom_objects = {} - self._initialize_info() - - @property - def seeds(self): - return self._info["seeds"] - - @seeds.setter - def seeds(self, value): - self._info["seeds"] = value - - @property - def info(self): - return self._info - - @info.setter - def info(self, value): - self._info = value - - @property - def target_level(self): - return self._info["target_level"] - - @target_level.setter - def target_level(self, value): - if value in SampleSet.SOURCES_MULTI_INDEX_NAMES: - self._info["target_level"] = value - else: - msg = ( - f"Target level '{value}' is invalid. " - f"Acceptable levels: {SampleSet.SOURCES_MULTI_INDEX_NAMES}" - ) - raise ValueError(msg) - - @property - def model(self) -> Model: - return self._model - - @model.setter - def model(self, value: Model): - self._model = value - - @property - def model_id(self): - return self._info["model_id"] - - @model_id.setter - def model_id(self, value): - self._info["model_id"] = value - - @property - def model_inputs(self): - return self._info["model_inputs"] - - @model_inputs.setter - def model_inputs(self, value): - self._info["model_inputs"] = value - - @property - def model_outputs(self): - return self._info["model_outputs"] - - @model_outputs.setter - def model_outputs(self, value): - self._info["model_outputs"] = value - - def get_model_outputs_as_label_tuples(self): - return [ - label_to_index_element(v, self.target_level) for v in self.model_outputs - ] - - def _get_model_dict(self) -> dict: - model_json = self.model.to_json() - model_dict = json.loads(model_json) - model_weights = self.model.get_weights() - model_dict = { - "info": self._info, - "model": model_dict, - "weights": model_weights, - } - return model_dict - - def _get_model_str(self) -> str: - model_dict = self._get_model_dict() - model_str = json.dumps(model_dict, indent=4, cls=PyRIIDModelJsonEncoder) - return model_str - - def _initialize_info(self): - init_info = { - "model_id": str(uuid.uuid4()), - "model_type": self.__class__.__name__, - "normalization": SpectraState.Unknown, - "pyriid_version": riid.__version__, - } - self._update_info(**init_info) - - def _update_info(self, **kwargs): - self._info.update(kwargs) - - def _update_custom_objects(self, key, value): - self._custom_objects.update({key: value}) - - def load(self, model_path: str): - """Load the model from a path. - - Args: - model_path: path from which to load the model. - """ - if not os.path.exists(model_path): - raise ValueError("Model file does not exist.") - - with open(model_path) as fin: - model = json.load(fin) - - model_str = json.dumps(model["model"]) - self.model = tf.keras.models.model_from_json(model_str, custom_objects=self._custom_objects) - self.model.set_weights([np.array(x) for x in model["weights"]]) - self.info = model["info"] - - def save(self, model_path: str, overwrite=False): - """Save the model to a path. - - Args: - model_path: path at which to save the model. - overwrite: whether to overwrite an existing file if it already exists. - - Raises: - `ValueError` when the given path already exists - """ - if os.path.exists(model_path) and not overwrite: - raise ValueError("Model file already exists.") - - model_str = self._get_model_str() - with open(model_path, "w") as fout: - fout.write(model_str) - - def to_onnx(self, model_path: str = None, **tf2onnx_kwargs: dict): - """Convert the model to an ONNX model. - - Args: - model_path: path at which to save the model - tf2onnx_kwargs: additional kwargs to pass to the conversion - """ - if not model_path.endswith(riid.ONNX_MODEL_FILE_EXTENSION): - raise ValueError(f"ONNX file path must end with {riid.ONNX_MODEL_FILE_EXTENSION}") - if os.path.exists(model_path): - raise ValueError("Model file already exists.") - - tf2onnx.convert.from_keras( - self.model, - output_path=model_path, - **tf2onnx_kwargs - ) - - def to_tflite(self, model_path: str, quantize: bool = False, prune: bool = False): - """Convert the model to a TFLite model and optionally applying quantization or pruning. - - Args: - model_path: file path at which to save the model - quantize: whether to apply quantization - prune: whether to apply pruning - """ - if not model_path.endswith(riid.TFLITE_MODEL_FILE_EXTENSION): - raise ValueError(f"TFLite file path must end with {riid.TFLITE_MODEL_FILE_EXTENSION}") - if os.path.exists(model_path): - raise ValueError("Model file already exists.") - - optimizations = [] - if quantize: - optimizations.append(tf.lite.Optimize.DEFAULT) - if prune: - optimizations.append(tf.lite.Optimize.EXPERIMENTAL_SPARSITY) - - converter = tf.lite.TFLiteConverter.from_keras_model(self.model) - converter.optimizations = optimizations - tflite_model = converter.convert() - - with open(model_path, "wb") as fout: - fout.write(tflite_model) - - @abstractmethod - def fit(self): - pass - - @abstractmethod - def predict(self): - pass - - -class PyRIIDModelJsonEncoder(json.JSONEncoder): - """Custom JSON encoder for saving models. - """ - def default(self, o): - """Converts certain types to JSON-compatible types. - """ - if isinstance(o, np.ndarray): - return o.tolist() - elif isinstance(o, np.float32): - return o.astype(float) - - return super().default(o) +__all__ = ["PoissonBayesClassifier", "LabelProportionEstimator", "MLPClassifier", + "ARADLatentPredictor", "ARADv1", "ARADv2"] diff --git a/riid/models/base.py b/riid/models/base.py new file mode 100644 index 00000000..8f4e0329 --- /dev/null +++ b/riid/models/base.py @@ -0,0 +1,256 @@ +# Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). +# Under the terms of Contract DE-NA0003525 with NTESS, +# the U.S. Government retains certain rights in this software. +"""This module contains functionality shared across all PyRIID models.""" +import json +import os +from pathlib import Path +import uuid +from abc import abstractmethod +from enum import Enum + +import numpy as np +import tensorflow as tf +import tf2onnx +from keras.api.models import Model +from keras.api.utils import get_custom_objects + +import riid +from riid import SampleSet, SpectraState +from riid.data.labeling import label_to_index_element +from riid.losses import mish +from riid.metrics import multi_f1, single_f1 + +get_custom_objects().update({ + "multi_f1": multi_f1, + "single_f1": single_f1, + "mish": mish, +}) + + +class ModelInput(int, Enum): + """Enumerates the potential input sources for a model.""" + GrossSpectrum = 0 + BackgroundSpectrum = 1 + ForegroundSpectrum = 2 + + +class PyRIIDModel: + """Base class for PyRIID models.""" + + def __init__(self, *args, **kwargs): + self._info = {} + self._temp_file_path = "temp_model.json" + self._custom_objects = {} + self._initialize_info() + + @property + def seeds(self): + return self._info["seeds"] + + @seeds.setter + def seeds(self, value): + self._info["seeds"] = value + + @property + def info(self): + return self._info + + @info.setter + def info(self, value): + self._info = value + + @property + def target_level(self): + return self._info["target_level"] + + @target_level.setter + def target_level(self, value): + if value in SampleSet.SOURCES_MULTI_INDEX_NAMES: + self._info["target_level"] = value + else: + msg = ( + f"Target level '{value}' is invalid. " + f"Acceptable levels: {SampleSet.SOURCES_MULTI_INDEX_NAMES}" + ) + raise ValueError(msg) + + @property + def model(self) -> Model: + return self._model + + @model.setter + def model(self, value: Model): + self._model = value + + @property + def model_id(self): + return self._info["model_id"] + + @model_id.setter + def model_id(self, value): + self._info["model_id"] = value + + @property + def model_inputs(self): + return self._info["model_inputs"] + + @model_inputs.setter + def model_inputs(self, value): + self._info["model_inputs"] = value + + @property + def model_outputs(self): + return self._info["model_outputs"] + + @model_outputs.setter + def model_outputs(self, value): + self._info["model_outputs"] = value + + def get_model_outputs_as_label_tuples(self): + return [ + label_to_index_element(v, self.target_level) for v in self.model_outputs + ] + + def _get_model_dict(self) -> dict: + model_json = self.model.to_json() + model_dict = json.loads(model_json) + model_weights = self.model.get_weights() + model_dict = { + "info": self._info, + "model": model_dict, + "weights": model_weights, + } + return model_dict + + def _get_model_str(self) -> str: + model_dict = self._get_model_dict() + model_str = json.dumps(model_dict, indent=4, cls=PyRIIDModelJsonEncoder) + return model_str + + def _initialize_info(self): + init_info = { + "model_id": str(uuid.uuid4()), + "model_type": self.__class__.__name__, + "normalization": SpectraState.Unknown, + "pyriid_version": riid.__version__, + } + self._update_info(**init_info) + + def _update_info(self, **kwargs): + self._info.update(kwargs) + + def _update_custom_objects(self, key, value): + self._custom_objects.update({key: value}) + + def load(self, model_path: str): + """Load the model from a path. + + Args: + model_path: path from which to load the model. + """ + if not os.path.exists(model_path): + raise ValueError("Model file does not exist.") + + with open(model_path) as fin: + model = json.load(fin) + + model_str = json.dumps(model["model"]) + self.model = tf.keras.models.model_from_json(model_str, custom_objects=self._custom_objects) + self.model.set_weights([np.array(x) for x in model["weights"]]) + self.info = model["info"] + + def save(self, model_path: str, overwrite=False): + """Save the model to a path. + + Args: + model_path: path at which to save the model. + overwrite: whether to overwrite an existing file if it already exists. + + Raises: + `ValueError` when the given path already exists + """ + if os.path.exists(model_path) and not overwrite: + raise ValueError("Model file already exists.") + + model_str = self._get_model_str() + with open(model_path, "w") as fout: + fout.write(model_str) + + def to_onnx(self, model_path, **tf2onnx_kwargs: dict): + """Convert the model to an ONNX model. + + Args: + model_path: path at which to save the model + tf2onnx_kwargs: additional kwargs to pass to the conversion + """ + model_path = Path(model_path) + if not str(model_path).endswith(riid.ONNX_MODEL_FILE_EXTENSION): + raise ValueError(f"ONNX file path must end with {riid.ONNX_MODEL_FILE_EXTENSION}") + if model_path.exists(): + raise ValueError("Model file already exists.") + + tf2onnx.convert.from_keras( + self.model, + input_signature=[ + tf.TensorSpec( + shape=input_tensor.shape, + dtype=input_tensor.dtype, + name=input_tensor.name + ) + for input_tensor in self.model.inputs + ], + output_path=str(model_path), + **tf2onnx_kwargs + ) + + def to_tflite(self, model_path, quantize: bool = False, prune: bool = False): + """Convert the model to a TFLite model and optionally applying quantization or pruning. + + Args: + model_path: file path at which to save the model + quantize: whether to apply quantization + prune: whether to apply pruning + """ + model_path = Path(model_path) + if not str(model_path).endswith(riid.TFLITE_MODEL_FILE_EXTENSION): + raise ValueError(f"TFLite file path must end with {riid.TFLITE_MODEL_FILE_EXTENSION}") + if model_path.exists(): + raise ValueError("Model file already exists.") + + optimizations = [] + if quantize: + optimizations.append(tf.lite.Optimize.DEFAULT) + if prune: + optimizations.append(tf.lite.Optimize.EXPERIMENTAL_SPARSITY) + + saved_model_dir = model_path.stem + self.model.export(saved_model_dir) + converter = tf.lite.TFLiteConverter.from_saved_model(str(saved_model_dir)) + converter.optimizations = optimizations + tflite_model = converter.convert() + + with open(model_path, "wb") as fout: + fout.write(tflite_model) + + @abstractmethod + def fit(self): + pass + + @abstractmethod + def predict(self): + pass + + +class PyRIIDModelJsonEncoder(json.JSONEncoder): + """Custom JSON encoder for saving models. + """ + def default(self, o): + """Converts certain types to JSON-compatible types. + """ + if isinstance(o, np.ndarray): + return o.tolist() + elif isinstance(o, np.float32): + return o.astype(float) + + return super().default(o) diff --git a/riid/models/bayes.py b/riid/models/bayes.py index 6850cbb1..762d38c8 100644 --- a/riid/models/bayes.py +++ b/riid/models/bayes.py @@ -8,8 +8,8 @@ from keras.api.layers import Add, Input, Multiply, Subtract from keras.api.models import Model -from riid.data.sampleset import SampleSet -from riid.models import PyRIIDModel +from riid import SampleSet +from riid.models.base import PyRIIDModel from riid.models.layers import (ClipByValueLayer, DivideLayer, ExpandDimsLayer, PoissonLogProbabilityLayer, ReduceMaxLayer, ReduceSumLayer, SeedLayer) diff --git a/riid/models/neural_nets/__init__.py b/riid/models/neural_nets/__init__.py index 42443f88..1dd06420 100644 --- a/riid/models/neural_nets/__init__.py +++ b/riid/models/neural_nets/__init__.py @@ -2,670 +2,7 @@ # Under the terms of Contract DE-NA0003525 with NTESS, # the U.S. Government retains certain rights in this software. """This module contains neural network-based classifiers and regressors.""" -import keras -import numpy as np -import pandas as pd -import tensorflow as tf -from keras.api.activations import sigmoid, softmax -from keras.api.callbacks import EarlyStopping -from keras.api.layers import Dense, Dropout, Input -from keras.api.losses import CategoricalCrossentropy, MeanSquaredError -from keras.api.metrics import F1Score, Precision, Recall -from keras.api.models import Model -from keras.api.optimizers import Adam -from keras.api.regularizers import L1L2, l1, l2 -from keras.api.utils import split_dataset -from scipy.interpolate import UnivariateSpline +from riid.models.neural_nets.basic import MLPClassifier +from riid.models.neural_nets.lpe import LabelProportionEstimator -from riid.data.sampleset import SampleSet, SpectraType -from riid.losses import (build_keras_semisupervised_loss_func, - chi_squared_diff, jensen_shannon_divergence, - normal_nll_diff, poisson_nll_diff, - reconstruction_error, sse_diff, weighted_sse_diff) -from riid.losses.sparsemax import SparsemaxLoss, sparsemax -from riid.metrics import build_keras_semisupervised_metric_func -from riid.models import ModelInput, PyRIIDModel -from riid.models.layers import L1NormLayer - - -class MLPClassifier(PyRIIDModel): - """Multi-layer perceptron classifier.""" - def __init__(self, activation=None, loss=None, optimizer=None, - metrics=None, l2_alpha: float = 1e-4, - activity_regularizer=None, final_activation=None): - """ - Args: - activation: activate function to use for each dense layer - loss: loss function to use for training - optimizer: tensorflow optimizer or optimizer name to use for training - metrics: list of metrics to be evaluating during training - l2_alpha: alpha value for the L2 regularization of each dense layer - activity_regularizer: regularizer function applied each dense layer output - final_activation: final activation function to apply to model output - """ - super().__init__() - - self.activation = activation - self.loss = loss - self.optimizer = optimizer - self.final_activation = final_activation - self.metrics = metrics - self.l2_alpha = l2_alpha - self.activity_regularizer = activity_regularizer - self.final_activation = final_activation - - if self.activation is None: - self.activation = "relu" - if self.loss is None: - self.loss = CategoricalCrossentropy() - if optimizer is None: - self.optimizer = Adam(learning_rate=0.01, clipnorm=0.001) - if self.metrics is None: - self.metrics = [F1Score(), Precision(), Recall()] - if self.activity_regularizer is None: - self.activity_regularizer = l1(0.0) - if self.final_activation is None: - self.final_activation = "softmax" - self.model = None - self._predict_fn = None - - def fit(self, ss: SampleSet, batch_size: int = 200, epochs: int = 20, - validation_split: float = 0.2, callbacks=None, - patience: int = 15, es_monitor: str = "val_loss", - es_mode: str = "min", es_verbose=0, target_level="Isotope", verbose: bool = False): - """Fit a model to the given `SampleSet`(s). - - Args: - ss: `SampleSet` of `n` spectra where `n` >= 1 and the spectra are either - foreground (AKA, "net") or gross. - batch_size: number of samples per gradient update - epochs: maximum number of training iterations - validation_split: percentage of the training data to use as validation data - callbacks: list of callbacks to be passed to the TensorFlow `Model.fit()` method - patience: number of epochs to wait for `EarlyStopping` object - es_monitor: quantity to be monitored for `EarlyStopping` object - es_mode: mode for `EarlyStopping` object - es_verbose: verbosity level for `EarlyStopping` object - target_level: `SampleSet.sources` column level to use - verbose: whether to show detailed model training output - - Returns: - `tf.History` object. - - Raises: - `ValueError` when no spectra are provided as input - """ - if ss.n_samples <= 0: - raise ValueError("No spectr[a|um] provided!") - - if ss.spectra_type == SpectraType.Gross: - self.model_inputs = (ModelInput.GrossSpectrum,) - elif ss.spectra_type == SpectraType.Foreground: - self.model_inputs = (ModelInput.ForegroundSpectrum,) - elif ss.spectra_type == SpectraType.Background: - self.model_inputs = (ModelInput.BackgroundSpectrum,) - else: - raise ValueError(f"{ss.spectra_type} is not supported in this model.") - - X = ss.get_samples() - source_contributions_df = ss.sources.T.groupby(target_level, sort=False).sum().T - model_outputs = source_contributions_df.columns.values.tolist() - Y = source_contributions_df.values - - spectra_tensor = tf.convert_to_tensor(X, dtype=tf.float32) - labels_tensor = tf.convert_to_tensor(Y, dtype=tf.float32) - training_dataset = tf.data.Dataset.from_tensor_slices((spectra_tensor, labels_tensor)) - training_dataset, validation_dataset = split_dataset( - training_dataset, - left_size=validation_split, - shuffle=True - ) - training_dataset = training_dataset.batch(batch_size=batch_size) - validation_dataset = validation_dataset.batch(batch_size=batch_size) - - if not self.model: - inputs = Input(shape=(X.shape[1],), name="Spectrum") - dense_layer_size = X.shape[1] // 2 - dense_layer = Dense( - dense_layer_size, - activation=self.activation, - activity_regularizer=self.activity_regularizer, - kernel_regularizer=l2(self.l2_alpha), - )(inputs) - outputs = Dense(Y.shape[1], activation=self.final_activation)(dense_layer) - self.model = Model(inputs, outputs) - self.model.compile(loss=self.loss, optimizer=self.optimizer, - metrics=self.metrics) - - es = EarlyStopping( - monitor=es_monitor, - patience=patience, - verbose=es_verbose, - restore_best_weights=True, - mode=es_mode, - ) - if callbacks: - callbacks.append(es) - else: - callbacks = [es] - - history = self.model.fit( - training_dataset, - epochs=epochs, - verbose=verbose, - validation_data=validation_dataset, - callbacks=callbacks, - ) - - # Update model information - self._update_info( - target_level=target_level, - model_outputs=model_outputs, - normalization=ss.spectra_state, - ) - - # Define the predict function with tf.function and input_signature - self._predict_fn = tf.function( - self._predict, - # input_signature=[tf.TensorSpec(shape=[None, X.shape[1]], dtype=tf.float32)] - experimental_relax_shapes=True - ) - - return history - - def _predict(self, input_tensor): - return self.model(input_tensor, training=False) - - def predict(self, ss: SampleSet, bg_ss: SampleSet = None): - """Classify the spectra in the provided `SampleSet`(s). - - Results are stored inside the first SampleSet's prediction-related properties. - - Args: - ss: `SampleSet` of `n` spectra where `n` >= 1 and the spectra are either - foreground (AKA, "net") or gross - bg_ss: `SampleSet` of `n` spectra where `n` >= 1 and the spectra are background - """ - x_test = ss.get_samples().astype(float) - if bg_ss: - X = [x_test, bg_ss.get_samples().astype(float)] - else: - X = x_test - - spectra_tensor = tf.convert_to_tensor(X, dtype=tf.float32) - results = self._predict_fn(spectra_tensor) - - col_level_idx = SampleSet.SOURCES_MULTI_INDEX_NAMES.index(self.target_level) - col_level_subset = SampleSet.SOURCES_MULTI_INDEX_NAMES[:col_level_idx+1] - ss.prediction_probas = pd.DataFrame( - data=results, - columns=pd.MultiIndex.from_tuples( - self.get_model_outputs_as_label_tuples(), - names=col_level_subset - ) - ) - - ss.classified_by = self.model_id - - -class LabelProportionEstimator(PyRIIDModel): - """Regressor for predicting label proportions that uses a semi-supervised loss. - - Optionally, a U-spline-based out-of-distribution detection model can be fit to target a desired - false positive rate. - """ - UNSUPERVISED_LOSS_FUNCS = { - "poisson_nll": poisson_nll_diff, - "normal_nll": normal_nll_diff, - "sse": sse_diff, - "weighted_sse": weighted_sse_diff, - "jsd": jensen_shannon_divergence, - "chi_squared": chi_squared_diff - } - SUPERVISED_LOSS_FUNCS = { - "sparsemax": ( - SparsemaxLoss, - { - "from_logits": True, - "reduction": tf.keras.losses.Reduction.NONE, - }, - sparsemax, - ), - "categorical_crossentropy": ( - CategoricalCrossentropy, - { - "from_logits": True, - "reduction": tf.keras.losses.Reduction.NONE, - }, - softmax, - ), - "mse": ( - MeanSquaredError, - { - "reduction": tf.keras.losses.Reduction.NONE, - }, - sigmoid, - ) - } - INFO_KEYS = ( - # model architecture - "hidden_layers", - "learning_rate", - "epsilon", - "sup_loss", - "unsup_loss", - "metrics", - "beta", - "hidden_layer_activation", - "kernel_l1_regularization", - "kernel_l2_regularization", - "bias_l1_regularization", - "bias_l2_regularization", - "activity_l1_regularization", - "activity_l2_regularization", - "dropout", - "ood_fp_rate", - "fit_spline", - "spline_bins", - "spline_k", - "spline_s", - # dictionaries - "source_dict", - # populated when loading model - "spline_snrs", - "spline_recon_errors", - ) - - def __init__(self, hidden_layers: tuple = (256,), sup_loss="sparsemax", unsup_loss="sse", - metrics: list = ["mae", "categorical_crossentropy"], beta=0.9, source_dict=None, - optimizer="adam", optimizer_kwargs=None, learning_rate: float = 1e-3, - hidden_layer_activation: str = "mish", - kernel_l1_regularization: float = 0.0, kernel_l2_regularization: float = 0.0, - bias_l1_regularization: float = 0.0, bias_l2_regularization: float = 0.0, - activity_l1_regularization: float = 0.0, activity_l2_regularization: float = 0.0, - dropout: float = 0.0, ood_fp_rate: float = 0.05, - fit_spline: bool = True, spline_bins: int = 15, spline_k: int = 3, - spline_s: int = 0, spline_snrs=None, spline_recon_errors=None): - """ - Args: - hidden_layers: tuple defining the number and size of dense layers - sup_loss: supervised loss function to use for training - unsup_loss: unsupervised loss function to use for training the - foreground branch of the network (options: "sse", "poisson_nll", - "normal_nll", "weighted_sse", "jsd", or "chi_squared") - metrics: list of metrics to be evaluating during training - beta: tradeoff parameter between the supervised and unsupervised foreground loss - source_dict: 2D array of pure, long-collect foreground spectra - optimizer: tensorflow optimizer or optimizer name to use for training - optimizer_kwargs: kwargs for optimizer - learning_rate: learning rate for the optimizer - hidden_layer_activation: activation function to use for each dense layer - kernel_l1_regularization: l1 regularization value for the kernel regularizer - kernel_l2_regularization: l2 regularization value for the kernel regularizer - bias_l1_regularization: l1 regularization value for the bias regularizer - bias_l2_regularization: l2 regularization value for the bias regularizer - activity_l1_regularization: l1 regularization value for the activity regularizer - activity_l2_regularization: l2 regularization value for the activity regularizer - dropout: amount of dropout to apply to each dense layer - ood_fp_rate: false positive rate used to determine threshold for - out-of-distribution (OOD) detection - fit_spline: whether or not to fit UnivariateSpline for OOD threshold function - spline_bins: number of bins used when fitting the UnivariateSpline threshold - function for OOD detection - spline_k: degree of smoothing for the UnivariateSpline - spline_s: positive smoothing factor used to choose the number of knots in the - UnivariateSpline (s=0 forces the spline through all the datapoints, equivalent to - InterpolatedUnivariateSpline) - spline_snrs: SNRs from training used as the x-values to fit the UnivariateSpline - spline_recon_errors: reconstruction errors from training used as the y-values to - fit the UnivariateSpline - """ - super().__init__() - - self.hidden_layers = hidden_layers - self.sup_loss = sup_loss - self.unsup_loss = unsup_loss - self.sup_loss_func, self.activation = self._get_sup_loss_func( - sup_loss, - prefix="sup" - ) - self.sup_loss_func_name = self.sup_loss_func.name - - self.optimizer = optimizer - if isinstance(optimizer, str): - self.optimizer = keras.optimizers.get(optimizer) - if optimizer_kwargs is not None: - for key, value in optimizer_kwargs.items(): - setattr(self.optimizer, key, value) - self.optimizer.learning_rate = learning_rate - - self.unsup_loss_func = self._get_unsup_loss_func(unsup_loss) - self.unsup_loss_func_name = f"unsup_{unsup_loss}_loss" - self.metrics = metrics - self.beta = beta - self.source_dict = source_dict - self.semisup_loss_func_name = "semisup_loss" - self.hidden_layer_activation = hidden_layer_activation - self.kernel_l1_regularization = kernel_l1_regularization - self.kernel_l2_regularization = kernel_l2_regularization - self.bias_l1_regularization = bias_l1_regularization - self.bias_l2_regularization = bias_l2_regularization - self.activity_l1_regularization = activity_l1_regularization - self.activity_l2_regularization = activity_l2_regularization - self.dropout = dropout - self.ood_fp_rate = ood_fp_rate - self.fit_spline = fit_spline - self.spline_bins = spline_bins - self.spline_k = spline_k - self.spline_s = spline_s - self.spline_snrs = spline_snrs - self.spline_recon_errors = spline_recon_errors - self.model = None - - self._update_custom_objects("L1NormLayer", L1NormLayer) - - @property - def source_dict(self) -> dict: - return self.info["source_dict"] - - @source_dict.setter - def source_dict(self, value: dict): - self.info["source_dict"] = value - - def _get_sup_loss_func(self, loss_func_str, prefix): - if loss_func_str not in self.SUPERVISED_LOSS_FUNCS: - raise KeyError(f"'{loss_func_str}' is not a supported supervised loss function.") - func, kwargs, activation = self.SUPERVISED_LOSS_FUNCS[loss_func_str] - loss_func_name = f"{prefix}_{loss_func_str}_loss" - return func(name=loss_func_name, **kwargs), activation - - def _get_unsup_loss_func(self, loss_func_str): - if loss_func_str not in self.UNSUPERVISED_LOSS_FUNCS: - raise KeyError(f"'{loss_func_str}' is not a supported unsupervised loss function.") - return self.UNSUPERVISED_LOSS_FUNCS[loss_func_str] - - def _initialize_model(self, input_size, output_size): - spectra_input = Input(input_size, name="input_spectrum") - spectra_norm = L1NormLayer(name="normalized_input_spectrum")(spectra_input) - x = spectra_norm - for layer, nodes in enumerate(self.hidden_layers): - x = Dense( - nodes, - activation=self.hidden_layer_activation, - kernel_regularizer=L1L2( - l1=self.kernel_l1_regularization, - l2=self.kernel_l2_regularization - ), - bias_regularizer=L1L2( - l1=self.bias_l1_regularization, - l2=self.bias_l2_regularization - ), - activity_regularizer=L1L2( - l1=self.activity_l1_regularization, - l2=self.activity_l2_regularization - ), - name=f"dense_{layer}" - )(x) - - if self.dropout > 0: - x = Dropout(self.dropout)(x) - output = Dense( - output_size, - activation="linear", - name="output" - )(x) - - self.model = Model(inputs=[spectra_input], outputs=[output]) - - def _get_info_as_dict(self): - info_dict = {} - for k, v in vars(self).items(): - if k not in self.INFO_KEYS: - continue - if isinstance(v, np.ndarray): - info_dict[k] = v.tolist() - else: - info_dict[k] = v - return info_dict - - def _get_spline_threshold_func(self): - return UnivariateSpline( - self.info["avg_snrs"], - self.info["thresholds"], - k=self.spline_k, - s=self.spline_s - ) - - def _fit_spline_threshold_func(self): - out = pd.qcut( - np.array(self.spline_snrs), - self.spline_bins, - labels=False, - ) - thresholds = [ - np.quantile(np.array(self.spline_recon_errors)[out == int(i)], 1-self.ood_fp_rate) - for i in range(self.spline_bins) - ] - avg_snrs = [ - np.mean(np.array(self.spline_snrs)[out == int(i)]) for i in range(self.spline_bins) - ] - self._update_info( - avg_snrs=avg_snrs, - thresholds=thresholds, - spline_k=self.spline_k, - spline_s=self.spline_s, - ) - - def _get_snrs(self, ss: SampleSet, bg_cps: float, is_gross: bool) -> np.ndarray: - fg_counts = ss.info.total_counts.values.astype("float64") - bg_counts = ss.info.live_time.values * bg_cps - if is_gross: - fg_counts = fg_counts - bg_counts - snrs = fg_counts / np.sqrt(bg_counts) - return snrs - - def fit(self, seeds_ss: SampleSet, ss: SampleSet, bg_cps: int = 300, is_gross: bool = False, - batch_size: int = 10, epochs: int = 20, validation_split: float = 0.2, - callbacks=None, patience: int = 15, es_monitor: str = "val_loss", - es_mode: str = "min", es_verbose=0, es_min_delta: float = 0.0, - normalize_sup_loss: bool = True, normalize_func=tf.math.tanh, - normalize_scaler: float = 1.0, target_level="Isotope", verbose: bool = False): - """Fit a model to the given SampleSet(s). - - Args: - seeds_ss: `SampleSet` of pure, long-collect spectra - ss: `SampleSet` of `n` gross or foreground spectra where `n` >= 1 - bg_cps: background rate assumption used for calculating SNR in spline function - using in OOD detection - is_gross: whether `ss` contains gross spectra - batch_size: number of samples per gradient update - epochs: maximum number of training iterations - validation_split: proportion of training data to use as validation data - callbacks: list of callbacks to be passed to TensorFlow Model.fit() method - patience: number of epochs to wait for `EarlyStopping` object - es_monitor: quantity to be monitored for `EarlyStopping` object - es_mode: mode for `EarlyStopping` object - es_verbose: verbosity level for `EarlyStopping` object - es_min_delta: minimum change to count as an improvement for early stopping - normalize_sup_loss: whether to normalize the supervised loss term - normalize_func: normalization function used for supervised loss term - normalize_scaler: scalar that sets the steepness of the normalization function - target_level: source level to target for model output - verbose: whether model training output is printed to the terminal - """ - spectra = ss.get_samples().astype(float) - sources_df = ss.sources.T.groupby(target_level, sort=False).sum().T - sources = sources_df.values.astype(float) - self.sources_columns = sources_df.columns - - if verbose: - print("Building dictionary...") - - if self.source_dict is None: - self.source_dict = _get_reordered_spectra( - seeds_ss.spectra, - seeds_ss.sources, - self.sources_columns, - target_level=target_level - ).values - - if not self.model: - if verbose: - print("Initializing model...") - self._initialize_model( - (ss.n_channels,), - sources.shape[1], - ) - elif verbose: - print("Model already initialized.") - - if verbose: - print("Building loss functions...") - - self.semisup_loss_func = build_keras_semisupervised_loss_func( - self.sup_loss_func, - self.unsup_loss_func, - self.source_dict, - self.beta, - self.activation, - n_labels=sources.shape[1], - normalize=normalize_sup_loss, - normalize_func=normalize_func, - normalize_scaler=normalize_scaler - ) - - semisup_metrics = None - if self.metrics: - if verbose: - print("Building metric functions...") - semisup_metrics = [] - for each in self.metrics: - if isinstance(each, str): - semisup_metrics.append( - build_keras_semisupervised_metric_func( - tf.keras.metrics.get(each), - self.activation, - sources.shape[1] - ) - ) - else: - semisup_metrics.append( - build_keras_semisupervised_metric_func( - each, - self.activation, - sources.shape[1] - ) - ) - - self.model.compile( - loss=self.semisup_loss_func, - optimizer=self.optimizer, - metrics=semisup_metrics - ) - - es = EarlyStopping( - monitor=es_monitor, - patience=patience, - verbose=es_verbose, - restore_best_weights=True, - mode=es_mode, - min_delta=es_min_delta, - ) - - if callbacks: - callbacks.append(es) - else: - callbacks = [es] - - history = self.model.fit( - spectra, - np.append(sources, spectra, axis=1), - epochs=epochs, - verbose=verbose, - validation_split=validation_split, - callbacks=callbacks, - shuffle=True, - batch_size=batch_size - ) - - if self.fit_spline: - if verbose: - print("Finding OOD detection threshold function...") - - train_logits = self.model.predict(spectra, verbose=0) - train_lpes = self.activation(tf.convert_to_tensor(train_logits, dtype=tf.float32)) - self.spline_recon_errors = reconstruction_error( - tf.convert_to_tensor(spectra, dtype=tf.float32), - train_lpes, - self.source_dict, - self.unsup_loss_func - ).numpy() - self.spline_snrs = self._get_snrs(ss, bg_cps, is_gross) - self._fit_spline_threshold_func() - - info = self._get_info_as_dict() - self._update_info( - target_level=target_level, - model_outputs=sources_df.columns.values.tolist(), - normalization=ss.spectra_state, - **info, - ) - - return history - - def predict(self, ss: SampleSet, bg_cps: int = 300, is_gross: bool = False, verbose=False): - """Estimate the proportions of counts present in each sample of the provided SampleSet. - - Results are stored inside the SampleSet's prediction_probas property. - - Args: - ss: `SampleSet` of `n` foreground or gross spectra where `n` >= 1 - bg_cps: background rate used for estimating sample SNRs. - If background rate varies to a significant degree, split up sampleset - by SNR and make multiple calls to this method. - is_gross: whether `ss` contains gross spectra - """ - test_spectra = ss.get_samples().astype(float) - - logits = self.model.predict(test_spectra, verbose=verbose) - lpes = self.activation(tf.convert_to_tensor(logits, dtype=tf.float32)) - - col_level_idx = SampleSet.SOURCES_MULTI_INDEX_NAMES.index(self.target_level) - col_level_subset = SampleSet.SOURCES_MULTI_INDEX_NAMES[:col_level_idx+1] - ss.prediction_probas = pd.DataFrame( - data=lpes, - columns=pd.MultiIndex.from_tuples( - self.get_model_outputs_as_label_tuples(), - names=col_level_subset - ) - ) - - # Fill in unsupervised losses - recon_errors = reconstruction_error( - tf.convert_to_tensor(test_spectra, dtype=tf.float32), - lpes, - self.source_dict, - self.unsup_loss_func - ).numpy() - - if self.fit_spline: - snrs = self._get_snrs(ss, bg_cps, is_gross) - thresholds = self._get_spline_threshold_func()(snrs) - is_ood = recon_errors > thresholds - ss.info["ood"] = is_ood - - ss.info["recon_error"] = recon_errors - - -def _get_reordered_spectra(old_spectra_df: pd.DataFrame, old_sources_df: pd.DataFrame, - new_sources_columns, target_level) -> pd.DataFrame: - collapsed_sources_df = old_sources_df\ - .T.groupby(target_level)\ - .sum().T - reordered_spectra_df = old_spectra_df.iloc[ - collapsed_sources_df[ - new_sources_columns - ].idxmax() - ].reset_index(drop=True) - - return reordered_spectra_df +__all__ = ["LabelProportionEstimator", "MLPClassifier"] diff --git a/riid/models/neural_nets/arad.py b/riid/models/neural_nets/arad.py index 3679bfe6..0fd6b321 100644 --- a/riid/models/neural_nets/arad.py +++ b/riid/models/neural_nets/arad.py @@ -21,10 +21,10 @@ from scipy.spatial.distance import jensenshannon from scipy.stats import entropy -from riid.data.sampleset import SampleSet, SpectraState +from riid import SampleSet, SpectraState from riid.losses import mish -from riid.models import PyRIIDModel -from riid.models.bayes import ExpandDimsLayer +from riid.models.base import PyRIIDModel +from riid.models.layers import ExpandDimsLayer class ARADv1TF(Model): diff --git a/riid/models/neural_nets/basic.py b/riid/models/neural_nets/basic.py new file mode 100644 index 00000000..81db0962 --- /dev/null +++ b/riid/models/neural_nets/basic.py @@ -0,0 +1,197 @@ +# Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). +# Under the terms of Contract DE-NA0003525 with NTESS, +# the U.S. Government retains certain rights in this software. +"""This module contains a simple neural network.""" +import pandas as pd +import tensorflow as tf +from keras.api.callbacks import EarlyStopping +from keras.api.layers import Dense, Input +from keras.api.losses import CategoricalCrossentropy +from keras.api.metrics import F1Score, Precision, Recall +from keras.api.models import Model +from keras.api.optimizers import Adam +from keras.api.regularizers import l1, l2 +from keras.api.utils import split_dataset + +from riid import SampleSet, SpectraType +from riid.models.base import ModelInput, PyRIIDModel + + +class MLPClassifier(PyRIIDModel): + """Multi-layer perceptron classifier.""" + def __init__(self, activation=None, loss=None, optimizer=None, + metrics=None, l2_alpha: float = 1e-4, + activity_regularizer=None, final_activation=None): + """ + Args: + activation: activate function to use for each dense layer + loss: loss function to use for training + optimizer: tensorflow optimizer or optimizer name to use for training + metrics: list of metrics to be evaluating during training + l2_alpha: alpha value for the L2 regularization of each dense layer + activity_regularizer: regularizer function applied each dense layer output + final_activation: final activation function to apply to model output + """ + super().__init__() + + self.activation = activation + self.loss = loss + self.optimizer = optimizer + self.final_activation = final_activation + self.metrics = metrics + self.l2_alpha = l2_alpha + self.activity_regularizer = activity_regularizer + self.final_activation = final_activation + + if self.activation is None: + self.activation = "relu" + if self.loss is None: + self.loss = CategoricalCrossentropy() + if optimizer is None: + self.optimizer = Adam(learning_rate=0.01, clipnorm=0.001) + if self.metrics is None: + self.metrics = [F1Score(), Precision(), Recall()] + if self.activity_regularizer is None: + self.activity_regularizer = l1(0.0) + if self.final_activation is None: + self.final_activation = "softmax" + self.model = None + self._predict_fn = None + + def fit(self, ss: SampleSet, batch_size: int = 200, epochs: int = 20, + validation_split: float = 0.2, callbacks=None, + patience: int = 15, es_monitor: str = "val_loss", + es_mode: str = "min", es_verbose=0, target_level="Isotope", verbose: bool = False): + """Fit a model to the given `SampleSet`(s). + + Args: + ss: `SampleSet` of `n` spectra where `n` >= 1 and the spectra are either + foreground (AKA, "net") or gross. + batch_size: number of samples per gradient update + epochs: maximum number of training iterations + validation_split: percentage of the training data to use as validation data + callbacks: list of callbacks to be passed to the TensorFlow `Model.fit()` method + patience: number of epochs to wait for `EarlyStopping` object + es_monitor: quantity to be monitored for `EarlyStopping` object + es_mode: mode for `EarlyStopping` object + es_verbose: verbosity level for `EarlyStopping` object + target_level: `SampleSet.sources` column level to use + verbose: whether to show detailed model training output + + Returns: + `tf.History` object. + + Raises: + `ValueError` when no spectra are provided as input + """ + if ss.n_samples <= 0: + raise ValueError("No spectr[a|um] provided!") + + if ss.spectra_type == SpectraType.Gross: + self.model_inputs = (ModelInput.GrossSpectrum,) + elif ss.spectra_type == SpectraType.Foreground: + self.model_inputs = (ModelInput.ForegroundSpectrum,) + elif ss.spectra_type == SpectraType.Background: + self.model_inputs = (ModelInput.BackgroundSpectrum,) + else: + raise ValueError(f"{ss.spectra_type} is not supported in this model.") + + X = ss.get_samples() + source_contributions_df = ss.sources.T.groupby(target_level, sort=False).sum().T + model_outputs = source_contributions_df.columns.values.tolist() + Y = source_contributions_df.values + + spectra_tensor = tf.convert_to_tensor(X, dtype=tf.float32) + labels_tensor = tf.convert_to_tensor(Y, dtype=tf.float32) + training_dataset = tf.data.Dataset.from_tensor_slices((spectra_tensor, labels_tensor)) + training_dataset, validation_dataset = split_dataset( + training_dataset, + left_size=validation_split, + shuffle=True + ) + training_dataset = training_dataset.batch(batch_size=batch_size) + validation_dataset = validation_dataset.batch(batch_size=batch_size) + + if not self.model: + inputs = Input(shape=(X.shape[1],), name="Spectrum") + dense_layer_size = X.shape[1] // 2 + dense_layer = Dense( + dense_layer_size, + activation=self.activation, + activity_regularizer=self.activity_regularizer, + kernel_regularizer=l2(self.l2_alpha), + )(inputs) + outputs = Dense(Y.shape[1], activation=self.final_activation)(dense_layer) + self.model = Model(inputs, outputs) + self.model.compile(loss=self.loss, optimizer=self.optimizer, + metrics=self.metrics) + + es = EarlyStopping( + monitor=es_monitor, + patience=patience, + verbose=es_verbose, + restore_best_weights=True, + mode=es_mode, + ) + if callbacks: + callbacks.append(es) + else: + callbacks = [es] + + history = self.model.fit( + training_dataset, + epochs=epochs, + verbose=verbose, + validation_data=validation_dataset, + callbacks=callbacks, + ) + + # Update model information + self._update_info( + target_level=target_level, + model_outputs=model_outputs, + normalization=ss.spectra_state, + ) + + # Define the predict function with tf.function and input_signature + self._predict_fn = tf.function( + self._predict, + # input_signature=[tf.TensorSpec(shape=[None, X.shape[1]], dtype=tf.float32)] + experimental_relax_shapes=True + ) + + return history + + def _predict(self, input_tensor): + return self.model(input_tensor, training=False) + + def predict(self, ss: SampleSet, bg_ss: SampleSet = None): + """Classify the spectra in the provided `SampleSet`(s). + + Results are stored inside the first SampleSet's prediction-related properties. + + Args: + ss: `SampleSet` of `n` spectra where `n` >= 1 and the spectra are either + foreground (AKA, "net") or gross + bg_ss: `SampleSet` of `n` spectra where `n` >= 1 and the spectra are background + """ + x_test = ss.get_samples().astype(float) + if bg_ss: + X = [x_test, bg_ss.get_samples().astype(float)] + else: + X = x_test + + spectra_tensor = tf.convert_to_tensor(X, dtype=tf.float32) + results = self._predict_fn(spectra_tensor) + + col_level_idx = SampleSet.SOURCES_MULTI_INDEX_NAMES.index(self.target_level) + col_level_subset = SampleSet.SOURCES_MULTI_INDEX_NAMES[:col_level_idx+1] + ss.prediction_probas = pd.DataFrame( + data=results, + columns=pd.MultiIndex.from_tuples( + self.get_model_outputs_as_label_tuples(), + names=col_level_subset + ) + ) + + ss.classified_by = self.model_id diff --git a/riid/models/neural_nets/lpe.py b/riid/models/neural_nets/lpe.py new file mode 100644 index 00000000..346d5b3f --- /dev/null +++ b/riid/models/neural_nets/lpe.py @@ -0,0 +1,489 @@ +# Copyright 2021 National Technology & Engineering Solutions of Sandia, LLC (NTESS). +# Under the terms of Contract DE-NA0003525 with NTESS, +# the U.S. Government retains certain rights in this software. +"""This module contains the label proportion estimator.""" + +import keras +import numpy as np +import pandas as pd +import tensorflow as tf +from keras.api.activations import sigmoid, softmax +from keras.api.callbacks import EarlyStopping +from keras.api.layers import Dense, Dropout, Input +from keras.api.losses import CategoricalCrossentropy, MeanSquaredError +from keras.api.models import Model +from keras.api.regularizers import L1L2 +from scipy.interpolate import UnivariateSpline + +from riid import SampleSet +from riid.losses import (build_keras_semisupervised_loss_func, + chi_squared_diff, jensen_shannon_divergence, + normal_nll_diff, poisson_nll_diff, + reconstruction_error, sse_diff, weighted_sse_diff) +from riid.losses.sparsemax import SparsemaxLoss, sparsemax +from riid.metrics import build_keras_semisupervised_metric_func +from riid.models.base import PyRIIDModel +from riid.models.layers import L1NormLayer + + +class LabelProportionEstimator(PyRIIDModel): + """Regressor for predicting label proportions that uses a semi-supervised loss. + + Optionally, a U-spline-based out-of-distribution detection model can be fit to target a desired + false positive rate. + """ + UNSUPERVISED_LOSS_FUNCS = { + "poisson_nll": poisson_nll_diff, + "normal_nll": normal_nll_diff, + "sse": sse_diff, + "weighted_sse": weighted_sse_diff, + "jsd": jensen_shannon_divergence, + "chi_squared": chi_squared_diff + } + SUPERVISED_LOSS_FUNCS = { + "sparsemax": ( + SparsemaxLoss, + { + "from_logits": True, + "reduction": tf.keras.losses.Reduction.NONE, + }, + sparsemax, + ), + "categorical_crossentropy": ( + CategoricalCrossentropy, + { + "from_logits": True, + "reduction": tf.keras.losses.Reduction.NONE, + }, + softmax, + ), + "mse": ( + MeanSquaredError, + { + "reduction": tf.keras.losses.Reduction.NONE, + }, + sigmoid, + ) + } + INFO_KEYS = ( + # model architecture + "hidden_layers", + "learning_rate", + "epsilon", + "sup_loss", + "unsup_loss", + "metrics", + "beta", + "hidden_layer_activation", + "kernel_l1_regularization", + "kernel_l2_regularization", + "bias_l1_regularization", + "bias_l2_regularization", + "activity_l1_regularization", + "activity_l2_regularization", + "dropout", + "ood_fp_rate", + "fit_spline", + "spline_bins", + "spline_k", + "spline_s", + # dictionaries + "source_dict", + # populated when loading model + "spline_snrs", + "spline_recon_errors", + ) + + def __init__(self, hidden_layers: tuple = (256,), sup_loss="sparsemax", unsup_loss="sse", + metrics: list = ["mae", "categorical_crossentropy"], beta=0.9, source_dict=None, + optimizer="adam", optimizer_kwargs=None, learning_rate: float = 1e-3, + hidden_layer_activation: str = "mish", + kernel_l1_regularization: float = 0.0, kernel_l2_regularization: float = 0.0, + bias_l1_regularization: float = 0.0, bias_l2_regularization: float = 0.0, + activity_l1_regularization: float = 0.0, activity_l2_regularization: float = 0.0, + dropout: float = 0.0, ood_fp_rate: float = 0.05, + fit_spline: bool = True, spline_bins: int = 15, spline_k: int = 3, + spline_s: int = 0, spline_snrs=None, spline_recon_errors=None): + """ + Args: + hidden_layers: tuple defining the number and size of dense layers + sup_loss: supervised loss function to use for training + unsup_loss: unsupervised loss function to use for training the + foreground branch of the network (options: "sse", "poisson_nll", + "normal_nll", "weighted_sse", "jsd", or "chi_squared") + metrics: list of metrics to be evaluating during training + beta: tradeoff parameter between the supervised and unsupervised foreground loss + source_dict: 2D array of pure, long-collect foreground spectra + optimizer: tensorflow optimizer or optimizer name to use for training + optimizer_kwargs: kwargs for optimizer + learning_rate: learning rate for the optimizer + hidden_layer_activation: activation function to use for each dense layer + kernel_l1_regularization: l1 regularization value for the kernel regularizer + kernel_l2_regularization: l2 regularization value for the kernel regularizer + bias_l1_regularization: l1 regularization value for the bias regularizer + bias_l2_regularization: l2 regularization value for the bias regularizer + activity_l1_regularization: l1 regularization value for the activity regularizer + activity_l2_regularization: l2 regularization value for the activity regularizer + dropout: amount of dropout to apply to each dense layer + ood_fp_rate: false positive rate used to determine threshold for + out-of-distribution (OOD) detection + fit_spline: whether or not to fit UnivariateSpline for OOD threshold function + spline_bins: number of bins used when fitting the UnivariateSpline threshold + function for OOD detection + spline_k: degree of smoothing for the UnivariateSpline + spline_s: positive smoothing factor used to choose the number of knots in the + UnivariateSpline (s=0 forces the spline through all the datapoints, equivalent to + InterpolatedUnivariateSpline) + spline_snrs: SNRs from training used as the x-values to fit the UnivariateSpline + spline_recon_errors: reconstruction errors from training used as the y-values to + fit the UnivariateSpline + """ + super().__init__() + + self.hidden_layers = hidden_layers + self.sup_loss = sup_loss + self.unsup_loss = unsup_loss + self.sup_loss_func, self.activation = self._get_sup_loss_func( + sup_loss, + prefix="sup" + ) + self.sup_loss_func_name = self.sup_loss_func.name + + self.optimizer = optimizer + if isinstance(optimizer, str): + self.optimizer = keras.optimizers.get(optimizer) + if optimizer_kwargs is not None: + for key, value in optimizer_kwargs.items(): + setattr(self.optimizer, key, value) + self.optimizer.learning_rate = learning_rate + + self.unsup_loss_func = self._get_unsup_loss_func(unsup_loss) + self.unsup_loss_func_name = f"unsup_{unsup_loss}_loss" + self.metrics = metrics + self.beta = beta + self.source_dict = source_dict + self.semisup_loss_func_name = "semisup_loss" + self.hidden_layer_activation = hidden_layer_activation + self.kernel_l1_regularization = kernel_l1_regularization + self.kernel_l2_regularization = kernel_l2_regularization + self.bias_l1_regularization = bias_l1_regularization + self.bias_l2_regularization = bias_l2_regularization + self.activity_l1_regularization = activity_l1_regularization + self.activity_l2_regularization = activity_l2_regularization + self.dropout = dropout + self.ood_fp_rate = ood_fp_rate + self.fit_spline = fit_spline + self.spline_bins = spline_bins + self.spline_k = spline_k + self.spline_s = spline_s + self.spline_snrs = spline_snrs + self.spline_recon_errors = spline_recon_errors + self.model = None + + self._update_custom_objects("L1NormLayer", L1NormLayer) + + @property + def source_dict(self) -> dict: + return self.info["source_dict"] + + @source_dict.setter + def source_dict(self, value: dict): + self.info["source_dict"] = value + + def _get_sup_loss_func(self, loss_func_str, prefix): + if loss_func_str not in self.SUPERVISED_LOSS_FUNCS: + raise KeyError(f"'{loss_func_str}' is not a supported supervised loss function.") + func, kwargs, activation = self.SUPERVISED_LOSS_FUNCS[loss_func_str] + loss_func_name = f"{prefix}_{loss_func_str}_loss" + return func(name=loss_func_name, **kwargs), activation + + def _get_unsup_loss_func(self, loss_func_str): + if loss_func_str not in self.UNSUPERVISED_LOSS_FUNCS: + raise KeyError(f"'{loss_func_str}' is not a supported unsupervised loss function.") + return self.UNSUPERVISED_LOSS_FUNCS[loss_func_str] + + def _initialize_model(self, input_size, output_size): + spectra_input = Input(input_size, name="input_spectrum") + spectra_norm = L1NormLayer(name="normalized_input_spectrum")(spectra_input) + x = spectra_norm + for layer, nodes in enumerate(self.hidden_layers): + x = Dense( + nodes, + activation=self.hidden_layer_activation, + kernel_regularizer=L1L2( + l1=self.kernel_l1_regularization, + l2=self.kernel_l2_regularization + ), + bias_regularizer=L1L2( + l1=self.bias_l1_regularization, + l2=self.bias_l2_regularization + ), + activity_regularizer=L1L2( + l1=self.activity_l1_regularization, + l2=self.activity_l2_regularization + ), + name=f"dense_{layer}" + )(x) + + if self.dropout > 0: + x = Dropout(self.dropout)(x) + output = Dense( + output_size, + activation="linear", + name="output" + )(x) + + self.model = Model(inputs=[spectra_input], outputs=[output]) + + def _get_info_as_dict(self): + info_dict = {} + for k, v in vars(self).items(): + if k not in self.INFO_KEYS: + continue + if isinstance(v, np.ndarray): + info_dict[k] = v.tolist() + else: + info_dict[k] = v + return info_dict + + def _get_spline_threshold_func(self): + return UnivariateSpline( + self.info["avg_snrs"], + self.info["thresholds"], + k=self.spline_k, + s=self.spline_s + ) + + def _fit_spline_threshold_func(self): + out = pd.qcut( + np.array(self.spline_snrs), + self.spline_bins, + labels=False, + ) + thresholds = [ + np.quantile(np.array(self.spline_recon_errors)[out == int(i)], 1-self.ood_fp_rate) + for i in range(self.spline_bins) + ] + avg_snrs = [ + np.mean(np.array(self.spline_snrs)[out == int(i)]) for i in range(self.spline_bins) + ] + self._update_info( + avg_snrs=avg_snrs, + thresholds=thresholds, + spline_k=self.spline_k, + spline_s=self.spline_s, + ) + + def _get_snrs(self, ss: SampleSet, bg_cps: float, is_gross: bool) -> np.ndarray: + fg_counts = ss.info.total_counts.values.astype("float64") + bg_counts = ss.info.live_time.values * bg_cps + if is_gross: + fg_counts = fg_counts - bg_counts + snrs = fg_counts / np.sqrt(bg_counts) + return snrs + + def fit(self, seeds_ss: SampleSet, ss: SampleSet, bg_cps: int = 300, is_gross: bool = False, + batch_size: int = 10, epochs: int = 20, validation_split: float = 0.2, + callbacks=None, patience: int = 15, es_monitor: str = "val_loss", + es_mode: str = "min", es_verbose=0, es_min_delta: float = 0.0, + normalize_sup_loss: bool = True, normalize_func=tf.math.tanh, + normalize_scaler: float = 1.0, target_level="Isotope", verbose: bool = False): + """Fit a model to the given SampleSet(s). + + Args: + seeds_ss: `SampleSet` of pure, long-collect spectra + ss: `SampleSet` of `n` gross or foreground spectra where `n` >= 1 + bg_cps: background rate assumption used for calculating SNR in spline function + using in OOD detection + is_gross: whether `ss` contains gross spectra + batch_size: number of samples per gradient update + epochs: maximum number of training iterations + validation_split: proportion of training data to use as validation data + callbacks: list of callbacks to be passed to TensorFlow Model.fit() method + patience: number of epochs to wait for `EarlyStopping` object + es_monitor: quantity to be monitored for `EarlyStopping` object + es_mode: mode for `EarlyStopping` object + es_verbose: verbosity level for `EarlyStopping` object + es_min_delta: minimum change to count as an improvement for early stopping + normalize_sup_loss: whether to normalize the supervised loss term + normalize_func: normalization function used for supervised loss term + normalize_scaler: scalar that sets the steepness of the normalization function + target_level: source level to target for model output + verbose: whether model training output is printed to the terminal + """ + spectra = ss.get_samples().astype(float) + sources_df = ss.sources.T.groupby(target_level, sort=False).sum().T + sources = sources_df.values.astype(float) + self.sources_columns = sources_df.columns + + if verbose: + print("Building dictionary...") + + if self.source_dict is None: + self.source_dict = _get_reordered_spectra( + seeds_ss.spectra, + seeds_ss.sources, + self.sources_columns, + target_level=target_level + ).values + + if not self.model: + if verbose: + print("Initializing model...") + self._initialize_model( + (ss.n_channels,), + sources.shape[1], + ) + elif verbose: + print("Model already initialized.") + + if verbose: + print("Building loss functions...") + + self.semisup_loss_func = build_keras_semisupervised_loss_func( + self.sup_loss_func, + self.unsup_loss_func, + self.source_dict, + self.beta, + self.activation, + n_labels=sources.shape[1], + normalize=normalize_sup_loss, + normalize_func=normalize_func, + normalize_scaler=normalize_scaler + ) + + semisup_metrics = None + if self.metrics: + if verbose: + print("Building metric functions...") + semisup_metrics = [] + for each in self.metrics: + if isinstance(each, str): + semisup_metrics.append( + build_keras_semisupervised_metric_func( + tf.keras.metrics.get(each), + self.activation, + sources.shape[1] + ) + ) + else: + semisup_metrics.append( + build_keras_semisupervised_metric_func( + each, + self.activation, + sources.shape[1] + ) + ) + + self.model.compile( + loss=self.semisup_loss_func, + optimizer=self.optimizer, + metrics=semisup_metrics + ) + + es = EarlyStopping( + monitor=es_monitor, + patience=patience, + verbose=es_verbose, + restore_best_weights=True, + mode=es_mode, + min_delta=es_min_delta, + ) + + if callbacks: + callbacks.append(es) + else: + callbacks = [es] + + history = self.model.fit( + spectra, + np.append(sources, spectra, axis=1), + epochs=epochs, + verbose=verbose, + validation_split=validation_split, + callbacks=callbacks, + shuffle=True, + batch_size=batch_size + ) + + if self.fit_spline: + if verbose: + print("Finding OOD detection threshold function...") + + train_logits = self.model.predict(spectra, verbose=0) + train_lpes = self.activation(tf.convert_to_tensor(train_logits, dtype=tf.float32)) + self.spline_recon_errors = reconstruction_error( + tf.convert_to_tensor(spectra, dtype=tf.float32), + train_lpes, + self.source_dict, + self.unsup_loss_func + ).numpy() + self.spline_snrs = self._get_snrs(ss, bg_cps, is_gross) + self._fit_spline_threshold_func() + + info = self._get_info_as_dict() + self._update_info( + target_level=target_level, + model_outputs=sources_df.columns.values.tolist(), + normalization=ss.spectra_state, + **info, + ) + + return history + + def predict(self, ss: SampleSet, bg_cps: int = 300, is_gross: bool = False, verbose=False): + """Estimate the proportions of counts present in each sample of the provided SampleSet. + + Results are stored inside the SampleSet's prediction_probas property. + + Args: + ss: `SampleSet` of `n` foreground or gross spectra where `n` >= 1 + bg_cps: background rate used for estimating sample SNRs. + If background rate varies to a significant degree, split up sampleset + by SNR and make multiple calls to this method. + is_gross: whether `ss` contains gross spectra + """ + test_spectra = ss.get_samples().astype(float) + + logits = self.model.predict(test_spectra, verbose=verbose) + lpes = self.activation(tf.convert_to_tensor(logits, dtype=tf.float32)) + + col_level_idx = SampleSet.SOURCES_MULTI_INDEX_NAMES.index(self.target_level) + col_level_subset = SampleSet.SOURCES_MULTI_INDEX_NAMES[:col_level_idx+1] + ss.prediction_probas = pd.DataFrame( + data=lpes, + columns=pd.MultiIndex.from_tuples( + self.get_model_outputs_as_label_tuples(), + names=col_level_subset + ) + ) + + # Fill in unsupervised losses + recon_errors = reconstruction_error( + tf.convert_to_tensor(test_spectra, dtype=tf.float32), + lpes, + self.source_dict, + self.unsup_loss_func + ).numpy() + + if self.fit_spline: + snrs = self._get_snrs(ss, bg_cps, is_gross) + thresholds = self._get_spline_threshold_func()(snrs) + is_ood = recon_errors > thresholds + ss.info["ood"] = is_ood + + ss.info["recon_error"] = recon_errors + + +def _get_reordered_spectra(old_spectra_df: pd.DataFrame, old_sources_df: pd.DataFrame, + new_sources_columns, target_level) -> pd.DataFrame: + collapsed_sources_df = old_sources_df\ + .T.groupby(target_level)\ + .sum().T + reordered_spectra_df = old_spectra_df.iloc[ + collapsed_sources_df[ + new_sources_columns + ].idxmax() + ].reset_index(drop=True) + + return reordered_spectra_df diff --git a/riid/visualize.py b/riid/visualize.py index 81dc3c3d..ef72aa93 100644 --- a/riid/visualize.py +++ b/riid/visualize.py @@ -15,7 +15,7 @@ from seaborn import heatmap from sklearn.metrics import confusion_matrix as confusion_matrix_sklearn -from riid.data.sampleset import SampleSet +from riid import SampleSet # DO NOT TOUCH what is set below nor override them inside a function. plt.style.use("default") diff --git a/tests/anomaly_tests.py b/tests/anomaly_tests.py index 35143a1d..daddb95d 100644 --- a/tests/anomaly_tests.py +++ b/tests/anomaly_tests.py @@ -6,10 +6,8 @@ import numpy as np +from riid import PassbySynthesizer, SeedMixer, get_dummy_seeds from riid.anomaly import PoissonNChannelEventDetector -from riid.data.synthetic.passby import PassbySynthesizer -from riid.data.synthetic.seed import SeedMixer -from riid.data.synthetic import get_dummy_seeds class TestAnomaly(unittest.TestCase): diff --git a/tests/data_tests.py b/tests/data_tests.py index 022bedb7..cc8a1c90 100644 --- a/tests/data_tests.py +++ b/tests/data_tests.py @@ -6,10 +6,10 @@ import tempfile import unittest -from riid import SAMPLESET_HDF_FILE_EXTENSION +from riid import (SAMPLESET_HDF_FILE_EXTENSION, SampleSet, get_dummy_seeds, + read_hdf) from riid.data.labeling import label_to_index_element -from riid.data.sampleset import SampleSet, _write_hdf, read_hdf -from riid.data.synthetic import get_dummy_seeds +from riid.data.sampleset import _write_hdf class TestData(unittest.TestCase): diff --git a/tests/gadras_tests.py b/tests/gadras_tests.py index 9210312d..680048eb 100644 --- a/tests/gadras_tests.py +++ b/tests/gadras_tests.py @@ -5,7 +5,8 @@ import unittest import pandas as pd -from riid.data.synthetic import get_dummy_seeds + +from riid import get_dummy_seeds from riid.gadras.pcf import (_pack_compressed_text_buffer, _unpack_compressed_text_buffer) diff --git a/tests/model_tests.py b/tests/model_tests.py index ea430abc..95ff6897 100644 --- a/tests/model_tests.py +++ b/tests/model_tests.py @@ -8,15 +8,12 @@ import numpy as np import pandas as pd -from riid.data.sampleset import SampleSet -from riid.data.synthetic import get_dummy_seeds -from riid.data.synthetic.seed import SeedMixer -from riid.data.synthetic.static import StaticSynthesizer -from riid.models import PyRIIDModel -from riid.models.bayes import (NegativeSpectrumError, PoissonBayesClassifier, - ZeroTotalCountsError) -from riid.models.neural_nets import (LabelProportionEstimator, MLPClassifier) -from riid.models.neural_nets.arad import ARADLatentPredictor, ARADv1, ARADv2 +from riid import SampleSet, SeedMixer, StaticSynthesizer, get_dummy_seeds +from riid.models import (ARADLatentPredictor, ARADv1, ARADv2, + LabelProportionEstimator, MLPClassifier, + PoissonBayesClassifier) +from riid.models.base import PyRIIDModel +from riid.models.bayes import NegativeSpectrumError, ZeroTotalCountsError class TestModels(unittest.TestCase): diff --git a/tests/sampleset_tests.py b/tests/sampleset_tests.py index 63a5a969..887f5308 100644 --- a/tests/sampleset_tests.py +++ b/tests/sampleset_tests.py @@ -8,13 +8,11 @@ import numpy as np import pandas as pd +from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds from riid.data.sampleset import (ChannelCountMismatchError, InvalidSampleCountError, SampleSet, SpectraState, SpectraStateMismatchError, SpectraType, _get_row_labels) -from riid.data.synthetic import get_dummy_seeds -from riid.data.synthetic.seed import SeedMixer -from riid.data.synthetic.static import StaticSynthesizer class TestSampleSet(unittest.TestCase): diff --git a/tests/seedmixer_tests.py b/tests/seedmixer_tests.py index e986f891..00ac6280 100644 --- a/tests/seedmixer_tests.py +++ b/tests/seedmixer_tests.py @@ -6,10 +6,8 @@ import numpy as np from scipy.spatial.distance import jensenshannon -from riid.data.sampleset import SampleSet -from riid.data.synthetic import get_dummy_seeds -from riid.data.synthetic.seed import SeedMixer +from riid import SampleSet, SeedMixer, get_dummy_seeds class TestSeedMixer(unittest.TestCase): diff --git a/tests/staticsynth_tests.py b/tests/staticsynth_tests.py index 91684708..afbd297a 100644 --- a/tests/staticsynth_tests.py +++ b/tests/staticsynth_tests.py @@ -7,12 +7,11 @@ import numpy as np import pandas as pd +from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds from riid.data import InvalidSeedError, get_expected_spectra -from riid.data.synthetic import (Synthesizer, get_dummy_seeds, - get_merged_sources_samplewise, - get_samples_per_seed) -from riid.data.synthetic.seed import SeedMixer -from riid.data.synthetic.static import StaticSynthesizer +from riid.data.synthetic.base import (Synthesizer, + get_merged_sources_samplewise, + get_samples_per_seed) class TestStaticSynthesis(unittest.TestCase): diff --git a/tests/visualize_tests.py b/tests/visualize_tests.py index 151e9f48..ec318227 100644 --- a/tests/visualize_tests.py +++ b/tests/visualize_tests.py @@ -6,11 +6,9 @@ import numpy as np -from riid.data.synthetic import get_dummy_seeds -from riid.data.synthetic.seed import SeedMixer -from riid.data.synthetic.static import StaticSynthesizer +from riid import SeedMixer, StaticSynthesizer, get_dummy_seeds from riid.metrics import precision_recall_curve -from riid.models.neural_nets import MLPClassifier +from riid.models import MLPClassifier from riid.visualize import (plot_correlation_between_all_labels, plot_count_rate_history, plot_label_and_prediction_distributions, From 7646e94991d1d15ce4708bacdcb1e77ad0d5fbcf Mon Sep 17 00:00:00 2001 From: Tyler Morrow Date: Mon, 26 Aug 2024 15:56:31 -0600 Subject: [PATCH 3/4] All source injector output requires verbosity. --- riid/gadras/api.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/riid/gadras/api.py b/riid/gadras/api.py index c7af0e94..a4a038bb 100644 --- a/riid/gadras/api.py +++ b/riid/gadras/api.py @@ -7,7 +7,6 @@ import json import os import sys -from typing import List import numpy as np import tqdm @@ -53,14 +52,14 @@ clr.AddReference("Sandia.Gadras.Utilities") clr.AddReference("System.Collections") - from Sandia.Gadras.API import GadrasAPIWrapper, LocationInfo # noqa + from Sandia.Gadras.API import GadrasAPIWrapper, LocationInfo # type: ignore try: - from Sandia.Gadras.API.Inject import InjectSetup # noqa + from Sandia.Gadras.API.Inject import InjectSetup # type: ignore IS_GADRAS19 = True except ModuleNotFoundError: - from Sandia.Gadras.API import InjectSetup # noqa - from Sandia.Gadras.Utilities import Configs # noqa - from System.Collections.Generic import List # noqa + from Sandia.Gadras.API import InjectSetup # type: ignore + from Sandia.Gadras.Utilities import Configs # type: ignore + from System.Collections.Generic import List # type: ignore INJECT_PARAMS = { @@ -236,9 +235,13 @@ def generate(self, config: dict, rel_output_path: str, verbose: bool = False) -> worker = self.gadras_api.GetBatchInjectWorker() injects_exist = False - pbar = tqdm.tqdm(config["sources"], desc="Running injects") + if verbose: + pbar = tqdm.tqdm(config["sources"], desc="Running injects") + else: + pbar = config["sources"] for fg in pbar: - pbar.set_description(f"Running inject for '{fg['isotope']}'") + if verbose: + pbar.set_description(f"Injecting '{fg['isotope']}'") inject_setups = self._get_inject_setups_for_sources( self.gadras_api, config["gamma_detector"]["parameters"], From c42155fedb9734a7669437f14d483b9cc5284c21 Mon Sep 17 00:00:00 2001 From: Tyler Morrow Date: Thu, 29 Aug 2024 13:06:31 -0600 Subject: [PATCH 4/4] Refactor synthesizers to better preserve seed info. --- riid/data/sampleset.py | 2 +- riid/data/synthetic/base.py | 55 +++++++++-------------------------- riid/data/synthetic/passby.py | 49 +++++++++++++++++++++++++++---- riid/data/synthetic/static.py | 35 +++++++++++++++++----- tests/staticsynth_tests.py | 5 ---- 5 files changed, 85 insertions(+), 61 deletions(-) diff --git a/riid/data/sampleset.py b/riid/data/sampleset.py index 47dd0332..c9e5547c 100644 --- a/riid/data/sampleset.py +++ b/riid/data/sampleset.py @@ -1800,13 +1800,13 @@ def _pcf_dict_to_ss(pcf_dict: dict, verbose=True): "total_counts": sum(spectrum["spectrum"]), "neutron_counts": spectrum["header"]["Total_Neutron_Counts"], "distance_cm": distance, - "areal_density": ad, "ecal_order_0": order_0, "ecal_order_1": order_1, "ecal_order_2": order_2, "ecal_order_3": order_3, "ecal_low_e": low_E, "atomic_number": an, + "areal_density": ad, "occupancy_flag": spectrum["header"]["Occupancy_Flag"], "tag": spectrum["header"]["Tag"], } diff --git a/riid/data/synthetic/base.py b/riid/data/synthetic/base.py index 9209c7f3..6b8f0376 100644 --- a/riid/data/synthetic/base.py +++ b/riid/data/synthetic/base.py @@ -81,8 +81,7 @@ def _verify_n_samples_synthesized(self, actual: int, expected: int): "Be sure to remove any columns from your seeds' sources DataFrame that " "contain all zeroes.") - def _get_batch(self, fg_seed, fg_sources, bg_seed, bg_sources, ecal, - lt_targets, snr_targets, rt_targets=None, distance_cm=None): + def _get_batch(self, fg_seed, fg_sources, bg_seed, bg_sources, lt_targets, snr_targets): if not (self.return_fg or self.return_gross): raise ValueError("Computing to return nothing.") @@ -127,10 +126,8 @@ def _get_batch(self, fg_seed, fg_sources, bg_seed, bg_sources, ecal, # Sample sets if self.return_fg: snrs = fg_counts / np.sqrt(long_bg_counts.clip(1)) - fg_ss = get_fg_sample_set(fg_spectra, fg_sources, ecal, lt_targets, - snrs=snrs, total_counts=fg_counts, - real_times=rt_targets, distance_cm=distance_cm, - timestamps=self._synthesis_start_dt) + fg_ss = get_fg_sample_set(fg_spectra, fg_sources, lt_targets, + snrs=snrs, total_counts=fg_counts) self._n_samples_synthesized += fg_ss.n_samples if self.return_gross: tiled_fg_sources = _tile_sources_and_scale( @@ -146,40 +143,28 @@ def _get_batch(self, fg_seed, fg_sources, bg_seed, bg_sources, ecal, gross_sources = get_merged_sources_samplewise(tiled_fg_sources, tiled_bg_sources) gross_counts = gross_spectra.sum(axis=1) snrs = fg_counts / np.sqrt(bg_counts.clip(1)) - gross_ss = get_gross_sample_set(gross_spectra, gross_sources, ecal, - lt_targets, snrs, gross_counts, - real_times=rt_targets, distance_cm=distance_cm, - timestamps=self._synthesis_start_dt) + gross_ss = get_gross_sample_set(gross_spectra, gross_sources, + lt_targets, snrs, gross_counts) self._n_samples_synthesized += gross_ss.n_samples return fg_ss, gross_ss -def get_sample_set(spectra, sources, ecal, live_times, snrs, total_counts=None, - real_times=None, distance_cm=None, timestamps=None, - descriptions=None) -> SampleSet: +def _get_minimal_ss(spectra, sources, live_times, snrs, total_counts=None) -> SampleSet: n_samples = spectra.shape[0] + if n_samples <= 0: + raise ValueError(f"Can't build SampleSet with {n_samples} samples.") ss = SampleSet() ss.spectra_state = SpectraState.Counts ss.spectra = pd.DataFrame(spectra) ss.sources = sources ss.info.description = np.full(n_samples, "") # Ensures the length of info equal n_samples - if descriptions: - ss.info.description = descriptions ss.info.snr = snrs - ss.info.timestamp = timestamps ss.info.total_counts = total_counts if total_counts is not None else spectra.sum(axis=1) - ss.info.ecal_order_0 = ecal[0] - ss.info.ecal_order_1 = ecal[1] - ss.info.ecal_order_2 = ecal[2] - ss.info.ecal_order_3 = ecal[3] - ss.info.ecal_low_e = ecal[4] ss.info.live_time = live_times - ss.info.real_time = real_times if real_times is not None else live_times - ss.info.distance_cm = distance_cm ss.info.occupancy_flag = 0 - ss.info.tag = " " # TODO: test if this can be empty string + ss.info.tag = " " # TODO: test if this can be an empty string return ss @@ -196,44 +181,30 @@ def _tile_sources_and_scale(sources, n_samples, scalars) -> pd.DataFrame: return tiled_sources -def get_fg_sample_set(spectra, sources, ecal, live_times, snrs, total_counts, - real_times=None, distance_cm=None, timestamps=None, - descriptions=None) -> SampleSet: +def get_fg_sample_set(spectra, sources, live_times, snrs, total_counts) -> SampleSet: tiled_sources = _tile_sources_and_scale( sources, spectra.shape[0], spectra.sum(axis=1) ) - ss = get_sample_set( + ss = _get_minimal_ss( spectra=spectra, sources=tiled_sources, - ecal=ecal, live_times=live_times, snrs=snrs, total_counts=total_counts, - real_times=real_times, - distance_cm=distance_cm, - timestamps=timestamps, - descriptions=descriptions ) ss.spectra_type = SpectraType.Foreground return ss -def get_gross_sample_set(spectra, sources, ecal, live_times, snrs, total_counts, - real_times=None, distance_cm=None, timestamps=None, - descriptions=None) -> SampleSet: - ss = get_sample_set( +def get_gross_sample_set(spectra, sources, live_times, snrs, total_counts) -> SampleSet: + ss = _get_minimal_ss( spectra=spectra, sources=sources, - ecal=ecal, live_times=live_times, snrs=snrs, total_counts=total_counts, - real_times=real_times, - distance_cm=distance_cm, - timestamps=timestamps, - descriptions=descriptions ) ss.spectra_type = SpectraType.Gross return ss diff --git a/riid/data/synthetic/passby.py b/riid/data/synthetic/passby.py index 84234618..ae637593 100644 --- a/riid/data/synthetic/passby.py +++ b/riid/data/synthetic/passby.py @@ -177,7 +177,7 @@ def _calculate_passby_shape(self, fwhm: float): return 1 / (np.power(samples, 2) + 1) def _generate_single_passby(self, fwhm: float, snr: float, dwell_time: float, - fg_seed: np.array, bg_seed: np.array, fg_ecal: np.array, + fg_seed: np.array, bg_seed: np.array, fg_sources: pd.Series, bg_sources: pd.Series): """Generate a `SampleSet` with a sequence of spectra representative of a single pass-by. @@ -212,7 +212,6 @@ def _generate_single_passby(self, fwhm: float, snr: float, dwell_time: float, fg_sources, bg_seed, bg_sources, - fg_ecal, live_times, snr_targets ) @@ -275,17 +274,55 @@ def generate(self, fg_seeds_ss: SampleSet, bg_seeds_ss: SampleSet, for fg_i in range(fg_seeds_ss.n_samples): fg_pmf = fg_seeds_ss.spectra.iloc[fg_i] fg_sources = fg_seeds_ss.sources.iloc[fg_i] - fg_ecal = fg_seeds_ss.ecal[fg_i] for t_i in range(self.events_per_seed): fwhm = fwhm_targets[t_i] snr = snr_targets[t_i] dwell_time = dwell_time_targets[t_i] - pb_args = (fwhm, snr, dwell_time, fg_pmf, bg_pmf, - fg_ecal, fg_sources, bg_sources) + pb_args = (fg_i, fwhm, snr, dwell_time, fg_pmf, bg_pmf, + fg_sources, bg_sources) args.append(pb_args) # TODO: follow prevents periodic progress reports - passbys = [self._generate_single_passby(*a) for a in args] + passbys = [] + for a in args: + f, fwhm, snr, dwell_time, fg_pmf, bg_pmf, fg_sources, bg_sources = a + fg_passby_ss, gross_passby_ss = self._generate_single_passby( + fwhm, snr, dwell_time, fg_pmf, bg_pmf, fg_sources, bg_sources + ) + live_times = None + if fg_passby_ss is not None: + live_times = fg_passby_ss.info.live_time + elif gross_passby_ss is not None: + live_times = gross_passby_ss.info.live_time + else: + live_times = 1.0 + + fg_seed_ecal = fg_seeds_ss.ecal[f] + fg_seed_info = fg_seeds_ss.info.iloc[f] + batch_rt_targets = live_times * (1 - fg_seed_info.dead_time_prop) + fg_seed_distance_cm = fg_seed_info.distance_cm + fg_seed_dead_time_prop = fg_seed_info.dead_time_prop + fg_seed_ad = fg_seed_info.areal_density + fg_seed_an = fg_seed_info.atomic_number + fg_seed_neutron_counts = fg_seed_info.neutron_counts + + def _set_remaining_info(ss): + if ss is None: + return + ss: SampleSet = ss + ss.ecal = fg_seed_ecal + ss.info.real_time = batch_rt_targets + ss.info.distance_cm = fg_seed_distance_cm + ss.info.dead_time_prop = fg_seed_dead_time_prop + ss.info.areal_density = fg_seed_ad + ss.info.atomic_number = fg_seed_an + ss.info.neutron_counts = fg_seed_neutron_counts + ss.info.timestamp = self._synthesis_start_dt + + _set_remaining_info(fg_passby_ss) + _set_remaining_info(gross_passby_ss) + + passbys.append((fg_passby_ss, gross_passby_ss)) if verbose: delay = time() - tstart diff --git a/riid/data/synthetic/static.py b/riid/data/synthetic/static.py index cdf0aed0..df06b0ba 100644 --- a/riid/data/synthetic/static.py +++ b/riid/data/synthetic/static.py @@ -180,20 +180,41 @@ def _get_synthetic_samples(self, fg_seeds_ss: SampleSet, bg_seeds_ss: SampleSet, bg_sources = bg_seeds_ss.sources.iloc[b] fg_seed = fg_seeds_ss.spectra.iloc[f] fg_sources = fg_seeds_ss.sources.iloc[f] - fg_seed_rt = fg_seeds_ss.info.real_time.iloc[f] - fg_seed_lt = fg_seeds_ss.info.live_time.iloc[f] batch_lt_targets = lt_targets[batch_begin_idx:batch_end_idx] - batch_rt_targets = lt_targets[batch_begin_idx:batch_end_idx] * (fg_seed_rt / fg_seed_lt) batch_snr_targets = snr_targets[batch_begin_idx:batch_end_idx] - distance_cm = fg_seeds_ss.info.distance_cm.iloc[f] - ecal = fg_seeds_ss.ecal[f] fg_batch_ss, gross_batch_ss = self._get_batch( fg_seed, fg_sources, bg_seed, bg_sources, - ecal, batch_lt_targets, batch_snr_targets, batch_rt_targets, - distance_cm + batch_lt_targets, + batch_snr_targets ) + + fg_seed_ecal = fg_seeds_ss.ecal[f] + fg_seed_info = fg_seeds_ss.info.iloc[f] + batch_rt_targets = batch_lt_targets * (1 - fg_seed_info.dead_time_prop) + fg_seed_distance_cm = fg_seed_info.distance_cm + fg_seed_dead_time_prop = fg_seed_info.dead_time_prop + fg_seed_ad = fg_seed_info.areal_density + fg_seed_an = fg_seed_info.atomic_number + fg_seed_neutron_counts = fg_seed_info.neutron_counts + + def _set_remaining_info(ss): + if ss is None: + return + ss: SampleSet = ss + ss.ecal = fg_seed_ecal + ss.info.real_time = batch_rt_targets + ss.info.distance_cm = fg_seed_distance_cm + ss.info.dead_time_prop = fg_seed_dead_time_prop + ss.info.areal_density = fg_seed_ad + ss.info.atomic_number = fg_seed_an + ss.info.neutron_counts = fg_seed_neutron_counts + ss.info.timestamp = self._synthesis_start_dt + + _set_remaining_info(fg_batch_ss) + _set_remaining_info(gross_batch_ss) + fg_ss_batches.append(fg_batch_ss) gross_ss_batches.append(gross_batch_ss) diff --git a/tests/staticsynth_tests.py b/tests/staticsynth_tests.py index afbd297a..8ec6ddf8 100644 --- a/tests/staticsynth_tests.py +++ b/tests/staticsynth_tests.py @@ -217,20 +217,16 @@ def test_get_batch(self): [0.3, 0.4, 0.3], index=["X", "Y", "Z"] ) - ecal = (0, 3000, 100, 0, 0) lts = np.array([4.2]).astype(float) snrs = np.array([63.2]).astype(float) - distance_cm = 50 fg_ss, gross_ss = synth._get_batch( fg_seed=fg_seed, fg_sources=fg_sources, bg_seed=bg_seed, bg_sources=bg_sources, - ecal=ecal, lt_targets=lts, snr_targets=snrs, - distance_cm=distance_cm ) self.assertTrue(np.allclose( @@ -242,7 +238,6 @@ def test_get_batch(self): gross_ss.sources.loc[:, bg_sources.index].sum(axis=1) / synth.bg_cps, lts, )) - self.assertTrue(np.allclose( gross_ss.sources.loc[:, fg_sources.index], fg_ss.sources.loc[:, fg_sources.index],