From 47818b8240074222ce771cce4eb2af097db3d5ba Mon Sep 17 00:00:00 2001 From: jabader97 <51213988+jabader97@users.noreply.github.com> Date: Tue, 10 May 2022 16:33:01 +0200 Subject: [PATCH 01/29] [ADD] Base class for `NLLLoss` & implement `MSELoss` (#252) The base class automates the Hessian square root approximation via MC-sampling. It requires specifying the likelihood distribution. The Hessian square root is then approximated by computing gradients with targets drawn from the likelihood with `autograd`. --- * Updated cross entropy and MSE to use new NLL base * Changed _post_process and _checks to raise NotImplementedError * Refactored NLL base for general log-prob derivative from torch.distribution, as well as overwriting for MSE and CE * Fixed some spacing errors * Updated doc strings * Added test for NLL version of compute_sampled_grads * Fixed spacing issue introduced in last commit * Implemented (some of) commented fixes * Added NotImplementedError to _veryify_support to fix coveralls problem * Fixed some commenting, removed mean from mse make distribution, moved sampling into loop for nll base * [REF] Use distribution for sampled gradients in manual approach * Changed use_dist to use_autograd, fixed return statement error in MSE, fixed device error in MSE * Black * Reverted changes to CEL * Reverted cross entropy changes * Removed unneeded changes to clean the diff * Some docstring updates * A few missed changes for diff * [REF] Move `use_autograd` inside `NLLLossDerivatives` * [REF] Change default of `use_autograd` to `False` for `MSELoss` * [FMT] Remove space * [DEL] Remove `use_autograd` from `CrossEntropyLoss` * [DOC] Clarify `use_autograd` in test function * [FIX] Syntax error * [CI] Add NLLLossDerivatives to fully documented * Added missing type annotations to nll_base.py, removed redundant autograd_res call in test_sqrt_hessian_sampled_squared_approximates_hessian_nll * Removed unnecessary ABC in MSE loss, fixed _compute_sampled_grad_manual name * Fixed documentation to match standards and reflect current version * Darglint fixes * Pydocstyle fix * Darglint formatting fix * Removed retain_graph=True * Created MSE_LOSS_PROBLEMS for test_sqrt_hessian_sampled_squared_approximates_hessian_nll to run on * Added autograd test to check that sample has same shape as subsampled_input * Reformated some too-long lines * [REF] Remove `enable_grad` and `Variable` * [REF] Shorten import * [REF] Rewrite NLL test with recursion * [FIX] Remove unused import * [FIX] darglint * [DOC] Polish MSELoss * [DOC] Polish NLL base * [DOC] Polish derivatives test * [FIX] Type annotation * [FIX] Darglint * [DOC] Polish NLLbase * [DOC] One more pass through docstrings Co-authored-by: Felix Dangel Co-authored-by: Felix Dangel --- backpack/core/derivatives/mseloss.py | 109 ++++++---- backpack/core/derivatives/nll_base.py | 188 ++++++++++++++++++ backpack/utils/module_classification.py | 13 ++ fully_documented.txt | 1 + test/core/derivatives/derivatives_test.py | 61 ++++++ .../derivatives/implementation/backpack.py | 12 +- test/core/derivatives/problem.py | 5 +- 7 files changed, 352 insertions(+), 37 deletions(-) create mode 100644 backpack/core/derivatives/nll_base.py diff --git a/backpack/core/derivatives/mseloss.py b/backpack/core/derivatives/mseloss.py index f09750052..b5c4caa97 100644 --- a/backpack/core/derivatives/mseloss.py +++ b/backpack/core/derivatives/mseloss.py @@ -3,21 +3,34 @@ from math import sqrt from typing import List, Tuple -from torch import Tensor, eye, normal, ones +from torch import Size, Tensor, eye, ones, tensor +from torch.distributions import Normal from torch.nn import MSELoss -from backpack.core.derivatives.basederivatives import BaseLossDerivatives +from backpack.core.derivatives.nll_base import NLLLossDerivatives -class MSELossDerivatives(BaseLossDerivatives): - """Derivatives of the MSE Loss. +class MSELossDerivatives(NLLLossDerivatives): + """Derivatives of ``MSELoss``. We only support 2D tensors. For `X : [n, d]` and `Y : [n, d]`, if `reduce=sum`, the MSE computes `∑ᵢ₌₁ⁿ ‖X[i,∶] − Y[i,∶]‖²`. If `reduce=mean`, the result is divided by `nd`. + + ``MSELoss`` is a negative log-likelihood of a Gaussian with mean corresponding + to the module input and constant standard deviation √0.5. """ + def __init__(self, use_autograd: bool = False): + """Initialization for MSE loss derivative. + + Args: + use_autograd: Compute gradients with autograd (rather than manual) + Defaults to ``False`` (manual computation). + """ + super().__init__(use_autograd=use_autograd) + def _sqrt_hessian( self, module: MSELoss, @@ -25,7 +38,7 @@ def _sqrt_hessian( g_out: Tuple[Tensor], subsampling: List[int] = None, ) -> Tensor: # noqa: D102 - self.check_input_dims(module) + self._check_input_dims(module) input0: Tensor = module.input0 N, D = input0.shape @@ -40,33 +53,6 @@ def _sqrt_hessian( return sqrt_H - def _sqrt_hessian_sampled( - self, - module: MSELoss, - g_inp: Tuple[Tensor], - g_out: Tuple[Tensor], - mc_samples: int = 1, - subsampling: List[int] = None, - ) -> Tensor: - self.check_input_dims(module) - - input0: Tensor = module.input0 - N, D = input0.shape - N_active = N if subsampling is None else len(subsampling) - samples = normal( - 0, - 1, - size=[mc_samples, N_active, D], - device=input0.device, - dtype=input0.dtype, - ) - samples *= sqrt(2) / sqrt(mc_samples) - - if module.reduction == "mean": - samples /= sqrt(input0.numel()) - - return samples - def _sum_hessian(self, module, g_inp, g_out): """The Hessian, summed across the batch dimension. @@ -78,7 +64,7 @@ def _sum_hessian(self, module, g_inp, g_out): Returns: a `[D, D]` tensor of the Hessian, summed across batch """ - self.check_input_dims(module) + self._check_input_dims(module) N, D = module.input0.shape H = 2 * eye(D, device=module.input0.device) @@ -101,10 +87,63 @@ def hessian_mat_prod(mat): return hessian_mat_prod - def check_input_dims(self, module): + def _verify_support(self, module: MSELoss): + """We only support 2D tensors.""" + self._check_input_dims(module) + + def _make_distribution(self, subsampled_input: Tensor) -> Normal: + """Create the likelihood distribution whose NLL is the MSE. + + The log probability of the Gaussian distribution is proportional to + ¹/₍₂𝜎²₎∑ᵢ₌₁ⁿ (xᵢ−𝜇)². Because MSE = ∑ᵢ₌₁ⁿ(Yᵢ−Ŷᵢ)², this is + equivalent for samples drawn from a Gaussian distribution with + mean of the subsampled input and standard deviation √0.5. + + Args: + subsampled_input: input after subsampling + + Returns: + Normal distribution for targets | inputs + """ + return Normal( + subsampled_input, tensor(sqrt(0.5), device=subsampled_input.device) + ) + + def _check_input_dims(self, module: MSELoss): """Raises an exception if the shapes of the input are not supported.""" if not len(module.input0.shape) == 2: raise ValueError("Only 2D inputs are currently supported for MSELoss.") - def hessian_is_psd(self): + def hessian_is_psd(self) -> bool: + """Return whether cross-entropy loss Hessian is positive semi-definite. + + Returns: + True + """ return True + + @staticmethod + def _get_mean_normalization(input: Tensor) -> int: + return input.numel() + + def _compute_sampled_grads_manual( + self, subsampled_input: Tensor, mc_samples: int + ) -> Tensor: + """Manually compute gradients from sampled targets. + + Because MSE = ∑ᵢ₌₁ⁿ(Yᵢ−Ŷᵢ)², the gradient is 2∑ᵢ₋₁ⁿ(Yᵢ−Ŷᵢ). + + Args: + subsampled_input: input after subsampling + mc_samples: number of samples + + Returns: + Gradient samples + """ + dist = self._make_distribution(subsampled_input) + samples = dist.sample(sample_shape=Size([mc_samples])) + subsampled_input_expanded = subsampled_input.unsqueeze(0).expand( + mc_samples, -1, -1 + ) + + return 2 * (samples - subsampled_input_expanded) diff --git a/backpack/core/derivatives/nll_base.py b/backpack/core/derivatives/nll_base.py new file mode 100644 index 000000000..1f8b14795 --- /dev/null +++ b/backpack/core/derivatives/nll_base.py @@ -0,0 +1,188 @@ +"""Partial derivative bases for NLL losses.""" +from math import sqrt +from typing import List, Tuple + +from torch import Tensor, stack +from torch.autograd import grad +from torch.distributions import Distribution +from torch.nn import Module + +from backpack.core.derivatives.basederivatives import BaseLossDerivatives +from backpack.utils.subsampling import subsample + + +class NLLLossDerivatives(BaseLossDerivatives): + """Base class for partial derivatives of negative log-likelihood losses. + + These loss functions can be expressed as a negative log-likelihood (NLL) + of targets given the input, 𝑙(fₙ,yₙ)= −log p(yₙ | fₙ) with a likelihood + distribution p(· | f). + """ + + def __init__(self, use_autograd: bool = True): + """Initialization. + + Args: + use_autograd: Compute gradient samples with autograd (rather than manually). + Default: ``True``. This argument is used to test the non-default + computation. + """ + self.use_autograd = use_autograd + + def _sqrt_hessian_sampled( + self, + module: Module, + g_inp: Tuple[Tensor], + g_out: Tuple[Tensor], + mc_samples: int = 1, + subsampling: List[int] = None, + ) -> Tensor: + """Approximate the Hessian square root through Monte-Carlo sampling. + + If use_autograd is True, _make_distribution must be implemented. + Otherwise, _compute_sampled_grads_manual must be implemented. + + In mean reduction mode, _get_mean_normalization must be implemented. + + Args: + module: loss module. + g_inp: Gradient of loss w.r.t. input + g_out: Gradient of loss w.r.t. output + mc_samples: number of Monte Carlo samples to take + subsampling: Indices of samples that are sliced along the dimension + + Returns: + Approximate Hessian square root. Has shape [mc_samples, + subsampled_input.shape]. + """ + self._verify_support(module) + subsampled_input = subsample(module.input0, subsampling=subsampling) + sqrt_hessian = self.compute_sampled_grads(subsampled_input, mc_samples) / sqrt( + mc_samples + ) + if module.reduction == "mean": + sqrt_hessian /= sqrt(self._get_mean_normalization(module.input0)) + return sqrt_hessian + + def _verify_support(self, module: Module): + """Verify that the module hyperparameters are supported. + + Args: + module: loss module + + Raises: + NotImplementedError: If the module has unsupported hyperparameters. + """ + raise NotImplementedError + + def compute_sampled_grads( + self, subsampled_input: Tensor, mc_samples: int + ) -> Tensor: + """Compute gradients with targets drawn from the likelihood p(· | f). + + If use_autograd is True, use _compute_sampled_grads_autograd. + Otherwise, use _compute_sampled_grads_manual. + + Args: + subsampled_input: input after subsampling + mc_samples: number of gradient samples + + Returns: + Sampled gradients of shape [mc_samples, *subsampled_input.shape] + """ + grad_func = ( + self._compute_sampled_grads_autograd + if self.use_autograd + else self._compute_sampled_grads_manual + ) + return grad_func(subsampled_input, mc_samples) + + def _compute_sampled_grads_autograd( + self, subsampled_input: Tensor, mc_samples: int + ) -> Tensor: + """Compute gradients for samples of the likelihood distribution with autograd. + + _make_distribution must be implemented for this function to work. + + Args: + subsampled_input: input after subsampling + mc_samples: number of samples + + Returns: + Sampled gradients of shape [mc_samples, *subsampled_input.shape] + """ + subsampled_input.requires_grad = True + gradients = [] + + dist = self._make_distribution(subsampled_input) + self._check_distribution_shape(dist, subsampled_input) + + for _ in range(mc_samples): + y_tilde = dist.sample() + loss_tilde = -dist.log_prob(y_tilde).sum() + gradients.append(grad(loss_tilde, subsampled_input)[0]) + + return stack(gradients) + + def _compute_sampled_grads_manual( + self, subsampled_input: Tensor, mc_samples: int + ) -> Tensor: + """Compute gradients for samples of the likelihood distribution manually. + + This function can be used instead of _compute_sampled_grads_autograd if + the gradient is known analytically. + + Args: + subsampled_input: input after subsampling + mc_samples: number of samples + + Raises: + NotImplementedError: if manual sampled gradients not implemented + """ + raise NotImplementedError("Manual sampled gradients not implemented.") + + def _make_distribution(self, subsampled_input: Tensor) -> Distribution: + """Create the likelihood distribution p(· | f). + + This should be in the form of a torch.Distributions object for p, such that + the desired loss 𝑙(f, y) α ∑ₙ − log p(yₙ | fₙ). + + Otherwise, the returned object must offer functions to draw samples and to + evaluate the log-probability. + + Args: + subsampled_input: input after subsampling + + Raises: + NotImplementedError: If not implemented. + """ + raise NotImplementedError + + @staticmethod + def _get_mean_normalization(input: Tensor) -> int: + """Return the normalization factor in mean mode. + + The number C in loss = 1 / C * ∑ᵢ lossᵢ. + + Args: + input: input to the layer + + Raises: + NotImplementedError: If not implemented + """ + raise NotImplementedError + + @staticmethod + def _check_distribution_shape(dist: Distribution, subsampled_input: Tensor): + """Verify shape of sampled targets y ∼ p(· | f). + + Args: + dist: Distribution of the likelihood p(y | f), e.g. created by + _make_distribution. + subsampled_input: Input after subsampling. + + Raises: + ValueError: If the target samples have incorrect shape. + """ + if dist.sample().shape != subsampled_input.shape: + raise ValueError("Sample does not have same shape as subsampled_input.") diff --git a/backpack/utils/module_classification.py b/backpack/utils/module_classification.py index b8d9b5b5f..ff3470990 100644 --- a/backpack/utils/module_classification.py +++ b/backpack/utils/module_classification.py @@ -3,6 +3,7 @@ from torch.nn import Module, Sequential from torch.nn.modules.loss import _Loss +from backpack.core.derivatives.mseloss import MSELoss from backpack.custom_module.branching import Parallel, _Branch from backpack.custom_module.reduce_tuple import ReduceTuple @@ -19,6 +20,18 @@ def is_loss(module: Module) -> bool: return isinstance(module, _Loss) +def is_mse(module: Module) -> bool: + """Return whether 'module' is a MSELoss function. + + Args: + module: A PyTorch module. + + Returns: + Whether 'module' is an MSE loss function + """ + return isinstance(module, MSELoss) + + def is_no_op(module: Module) -> bool: """Return whether the module does no operation in graph. diff --git a/fully_documented.txt b/fully_documented.txt index 275f0b92d..e56a28e68 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -5,6 +5,7 @@ backpack/context.py backpack/custom_module/ backpack/core/derivatives/basederivatives.py +backpack/core/derivatives/nll_base.py backpack/core/derivatives/rnn.py backpack/core/derivatives/shape_check.py backpack/core/derivatives/__init__.py diff --git a/test/core/derivatives/derivatives_test.py b/test/core/derivatives/derivatives_test.py index 5eae40c54..8c2c0c2d6 100644 --- a/test/core/derivatives/derivatives_test.py +++ b/test/core/derivatives/derivatives_test.py @@ -73,6 +73,9 @@ problem.make_id() for problem in CUSTOM_SLICING_MODULE_PROBLEMS ] +MSE_PROBLEMS = [problem for problem in PROBLEMS if problem.is_mse()] +MSE_IDS = [problem.make_id() for problem in MSE_PROBLEMS] + SUBSAMPLINGS = [None, [0, 0], [2, 0]] SUBSAMPLING_IDS = [f"subsampling={s}".replace(" ", "") for s in SUBSAMPLINGS] @@ -344,6 +347,64 @@ def test_sqrt_hessian_sampled_squared_approximates_hessian( problem.tear_down() +@mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) +@mark.parametrize("problem", MSE_PROBLEMS, ids=MSE_IDS) +def test_sqrt_hessian_sampled_squared_approximates_hessian_nll( + problem: DerivativesTestProblem, + subsampling: Union[List[int], None], + mc_samples: int = 50000, + chunks: int = 10, + rerun_on_crash: bool = True, +) -> None: + """Test the MC-sampled sqrt decomposition of the input Hessian for NLL loss base. + + Compares the Hessian to reconstruction from individual Hessian MC-sampled + sqrt. This test runs specifically on the autograd version of + compute_sampled_grads, rather than manual versions which are used by default + and tested elsewhere. + + Args: + problem: Test case. + subsampling: Indices of active samples. + mc_samples: number of samples. Defaults to 50000. + chunks: Number of passes the MC samples will be processed sequentially. + rerun_on_crash: Run the test again with more samples, then crash if it + still fails. Default: ``True``. + + Raises: + AssertionError: If the MC-sampled Hessian square root does not square to the + exact Hessian. + """ + problem.set_up() + skip_subsampling_conflict(problem, subsampling) + RTOL, ATOL = 1e-2, 8e-3 + + autograd_res = AutogradDerivatives(problem).input_hessian(subsampling=subsampling) + + try: + backpack_res = BackpackDerivatives(problem).input_hessian_via_sqrt_hessian( + mc_samples=mc_samples, + chunks=chunks, + subsampling=subsampling, + use_autograd=True, + ) + problem.tear_down() + check_sizes_and_values(autograd_res, backpack_res, rtol=RTOL, atol=ATOL) + + except AssertionError as e: + if rerun_on_crash: + more = 10 + test_sqrt_hessian_sampled_squared_approximates_hessian_nll( + problem, + subsampling, + mc_samples=mc_samples * more, + chunks=chunks * more, + rerun_on_crash=False, + ) + else: + raise e + + @mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) @mark.parametrize("problem", LOSS_FAIL_PROBLEMS, ids=LOSS_FAIL_IDS) def test_sqrt_hessian_sampled_should_fail( diff --git a/test/core/derivatives/implementation/backpack.py b/test/core/derivatives/implementation/backpack.py index 092d368c1..f6e1a1693 100644 --- a/test/core/derivatives/implementation/backpack.py +++ b/test/core/derivatives/implementation/backpack.py @@ -5,6 +5,7 @@ from torch import Tensor, einsum, zeros +from backpack.core.derivatives.nll_base import NLLLossDerivatives from backpack.utils.subsampling import subsample @@ -82,7 +83,11 @@ def sum_hessian(self): # noqa: D102 return self.problem.derivative.sum_hessian(self.problem.module, None, None) def input_hessian_via_sqrt_hessian( - self, mc_samples: int = None, chunks: int = 1, subsampling: List[int] = None + self, + mc_samples: int = None, + chunks: int = 1, + subsampling: List[int] = None, + use_autograd: bool = False, ) -> Tensor: """Computes the Hessian w.r.t. to the input from its matrix square root. @@ -92,6 +97,8 @@ def input_hessian_via_sqrt_hessian( chunks: Maximum sequential split of the computation. Default: ``1``. Only used if mc_samples is specified. subsampling: Indices of active samples. ``None`` uses all samples. + use_autograd: Compute sampled gradients with ``autograd``. Only relevant + for ``NLLLossDerivatives``. Default: ``False``. Returns: Hessian with respect to the input. Has shape @@ -105,6 +112,9 @@ def input_hessian_via_sqrt_hessian( chunk_samples = chunk_sizes(mc_samples, chunks) chunk_weights = [samples / mc_samples for samples in chunk_samples] + if isinstance(self.problem.derivative, NLLLossDerivatives): + self.problem.derivative.use_autograd = use_autograd + individual_hessians: Tensor = sum( weight * self._sample_hessians_from_sqrt( diff --git a/test/core/derivatives/problem.py b/test/core/derivatives/problem.py index af3524159..4d98ed533 100644 --- a/test/core/derivatives/problem.py +++ b/test/core/derivatives/problem.py @@ -8,7 +8,7 @@ from torch import Tensor, long from backpack import extend -from backpack.utils.module_classification import is_loss +from backpack.utils.module_classification import is_loss, is_mse from backpack.utils.subsampling import subsample @@ -141,6 +141,9 @@ def make_output_shape(self): def is_loss(self): return is_loss(self.make_module()) + def is_mse(self): + return is_mse(self.make_module()) + def forward_pass( self, input_requires_grad: bool = False, subsampling: List[int] = None ) -> Tuple[Tensor, Tensor, Dict[str, Tensor]]: From 0a488d7cd6658a7a778757f5ce3dafb8746a74cd Mon Sep 17 00:00:00 2001 From: jabader97 <51213988+jabader97@users.noreply.github.com> Date: Fri, 3 Jun 2022 14:14:49 +0200 Subject: [PATCH 02/29] [REF] Implement `CrossEntropyLoss` as `NLLLoss` (#256) - Replace shape check of samples in main library with test - Add `retain_grad=True` for autograd computation of sampled gradients (for MSELoss, it worked without `retain_graph`) --- * [REF] Changed cross entropy loss to NLL base * [REF] Removed arrange and rearrange, made CE work for autograd * [REF] Changed compute_grad_manual for CE to use _make_distribution * [REF] some cleaning * [REF] Moved nll distribution shape check * [FIX] darglint, isort * [FIX] removed some unused import statements * [REF] Remove redundant import, improve names * [REF] Improve readability by linebreaks * [REF] Import loss modules from `torch.nn` * [TEST] Apply sub-sampling to input and target for shape check * [FIX] Add tear_down call * [DEL] Remove clone+detach Co-authored-by: Felix Dangel --- backpack/core/derivatives/crossentropyloss.py | 100 ++++++++++++------ backpack/core/derivatives/nll_base.py | 18 +--- backpack/utils/module_classification.py | 13 +-- test/core/derivatives/derivatives_test.py | 36 ++++++- test/core/derivatives/problem.py | 6 +- 5 files changed, 109 insertions(+), 64 deletions(-) diff --git a/backpack/core/derivatives/crossentropyloss.py b/backpack/core/derivatives/crossentropyloss.py index 68690df9a..5c8bc2768 100644 --- a/backpack/core/derivatives/crossentropyloss.py +++ b/backpack/core/derivatives/crossentropyloss.py @@ -3,21 +3,31 @@ from typing import Callable, Dict, List, Tuple from einops import rearrange -from torch import Tensor, diag, diag_embed, einsum, eye, multinomial, ones_like, softmax +from torch import Size, Tensor, diag, diag_embed, einsum, eye, ones_like, softmax +from torch.distributions import Categorical from torch.nn import CrossEntropyLoss from torch.nn.functional import one_hot -from backpack.core.derivatives.basederivatives import BaseLossDerivatives +from backpack.core.derivatives.nll_base import NLLLossDerivatives from backpack.utils.subsampling import subsample -class CrossEntropyLossDerivatives(BaseLossDerivatives): +class CrossEntropyLossDerivatives(NLLLossDerivatives): """Partial derivatives for cross-entropy loss. The `torch.nn.CrossEntropyLoss` operation is a composition of softmax and negative log-likelihood. """ + def __init__(self, use_autograd: bool = False): + """Initialization for CE loss derivative. + + Args: + use_autograd: Compute gradients with autograd (rather than manual) + Defaults to ``False`` (manual computation). + """ + super().__init__(use_autograd=use_autograd) + def _sqrt_hessian( self, module: CrossEntropyLoss, @@ -43,37 +53,6 @@ def _sqrt_hessian( sqrt_H = self._expand_sqrt_h(sqrt_H) return sqrt_H - def _sqrt_hessian_sampled( - self, - module: CrossEntropyLoss, - g_inp: Tuple[Tensor], - g_out: Tuple[Tensor], - mc_samples: int = 1, - subsampling: List[int] = None, - ) -> Tensor: - self._check_2nd_order_parameters(module) - - M = mc_samples - C = module.input0.shape[1] - - probs = self._get_probs(module, subsampling=subsampling) - probs, *rearrange_info = self._merge_batch_and_additional(probs) - - V_dim = 0 - probs_unsqueezed = probs.unsqueeze(V_dim).repeat(M, 1, 1) - - multi = multinomial(probs, M, replacement=True) - classes = one_hot(multi, num_classes=C) - classes = einsum("nvc->vnc", classes).float() - - sqrt_mc_h = (probs_unsqueezed - classes) / sqrt(M) - - if module.reduction == "mean": - sqrt_mc_h /= sqrt(self._get_mean_normalization(module.input0)) - - sqrt_mc_h = self._ungroup_batch_and_additional(sqrt_mc_h, *rearrange_info) - return sqrt_mc_h - def _sum_hessian( self, module: CrossEntropyLoss, g_inp: Tuple[Tensor], g_out: Tuple[Tensor] ) -> Tensor: @@ -265,3 +244,56 @@ def _get_mean_normalization(input: Tensor) -> int: Divisor for mean reduction. """ return input.numel() // input.shape[1] + + def _verify_support(self, module: CrossEntropyLoss): + """We only support default weight and ignore_index. + + Args: + module: CrossEntropyLoss module + """ + self._check_2nd_order_parameters(module) + + def _make_distribution(self, subsampled_input: Tensor) -> Categorical: + """Create the likelihood distribution whose NLL is the CE. + + The log probability of the Categorical distribution for a single sample + with k classes is ∑ᵢ₌₁ᵏ Ŷᵢ log pᵢ, where Ŷ is one-hot encoded. If p is + chosen as the softmax, this is equivalent to CrossEntropyLoss + + Args: + subsampled_input: input after subsampling + + Returns: + Normal distribution for targets | inputs + """ + probs = softmax(subsampled_input, dim=1) + probs_rearranged = einsum("nc...->n...c", probs) + return Categorical(probs_rearranged) + + def _compute_sampled_grads_manual( + self, subsampled_input: Tensor, mc_samples: int + ) -> Tensor: + """Manually compute gradients from sampled targets. + + Cross Entropy loss is ∑ᵢ₌₁ᵏ Ŷᵢ log 𝜎(xᵢ), where 𝜎(xᵢ) is the softmax of + the input and Ŷᵢ is one-hot encoded. The gradient is 𝜎(xᵢ) - Ŷᵢ. + + Args: + subsampled_input: input after subsampling + mc_samples: number of samples + + Returns: + Gradient samples + """ + probs = softmax(subsampled_input, dim=1) + expand_dims = [mc_samples] + probs.dim() * [-1] + probs_unsqeezed = probs.unsqueeze(0).expand(*expand_dims) # [V N C D1 D2] + + distribution = self._make_distribution(subsampled_input) + samples = distribution.sample(Size([mc_samples])) # [V N D1 D2] + samples_onehot = one_hot(samples, num_classes=probs.shape[1]) # [V N D1 D2 C] + samples_onehot_rearranged = einsum("vn...c->vnc...", samples_onehot).to( + probs.dtype + ) # [V N C D1 D2] + + return probs_unsqeezed - samples_onehot_rearranged diff --git a/backpack/core/derivatives/nll_base.py b/backpack/core/derivatives/nll_base.py index 1f8b14795..f2c5dd607 100644 --- a/backpack/core/derivatives/nll_base.py +++ b/backpack/core/derivatives/nll_base.py @@ -115,12 +115,11 @@ def _compute_sampled_grads_autograd( gradients = [] dist = self._make_distribution(subsampled_input) - self._check_distribution_shape(dist, subsampled_input) for _ in range(mc_samples): y_tilde = dist.sample() loss_tilde = -dist.log_prob(y_tilde).sum() - gradients.append(grad(loss_tilde, subsampled_input)[0]) + gradients.append(grad(loss_tilde, subsampled_input, retain_graph=True)[0]) return stack(gradients) @@ -171,18 +170,3 @@ def _get_mean_normalization(input: Tensor) -> int: NotImplementedError: If not implemented """ raise NotImplementedError - - @staticmethod - def _check_distribution_shape(dist: Distribution, subsampled_input: Tensor): - """Verify shape of sampled targets y ∼ p(· | f). - - Args: - dist: Distribution of the likelihood p(y | f), e.g. created by - _make_distribution. - subsampled_input: Input after subsampling. - - Raises: - ValueError: If the target samples have incorrect shape. - """ - if dist.sample().shape != subsampled_input.shape: - raise ValueError("Sample does not have same shape as subsampled_input.") diff --git a/backpack/utils/module_classification.py b/backpack/utils/module_classification.py index ff3470990..e3321c9ff 100644 --- a/backpack/utils/module_classification.py +++ b/backpack/utils/module_classification.py @@ -1,9 +1,8 @@ """Contains util function for classification of modules.""" from torch.fx import GraphModule -from torch.nn import Module, Sequential +from torch.nn import CrossEntropyLoss, Module, MSELoss, Sequential from torch.nn.modules.loss import _Loss -from backpack.core.derivatives.mseloss import MSELoss from backpack.custom_module.branching import Parallel, _Branch from backpack.custom_module.reduce_tuple import ReduceTuple @@ -20,16 +19,18 @@ def is_loss(module: Module) -> bool: return isinstance(module, _Loss) -def is_mse(module: Module) -> bool: - """Return whether 'module' is a MSELoss function. +def is_nll(module: Module) -> bool: + """Return whether 'module' is an NLL loss function. + + Current NLL loss functions include MSE and CE. Args: module: A PyTorch module. Returns: - Whether 'module' is an MSE loss function + Whether 'module' is an NLL loss function """ - return isinstance(module, MSELoss) + return isinstance(module, (MSELoss, CrossEntropyLoss)) def is_no_op(module: Module) -> bool: diff --git a/test/core/derivatives/derivatives_test.py b/test/core/derivatives/derivatives_test.py index 8c2c0c2d6..fb39bc96c 100644 --- a/test/core/derivatives/derivatives_test.py +++ b/test/core/derivatives/derivatives_test.py @@ -7,7 +7,7 @@ - Transposed Jacobian-matrix products with respect to layer parameters """ from contextlib import nullcontext -from test.automated_test import check_sizes_and_values +from test.automated_test import check_sizes, check_sizes_and_values from test.core.derivatives.batch_norm_settings import BATCH_NORM_SETTINGS from test.core.derivatives.embedding_settings import EMBEDDING_SETTINGS from test.core.derivatives.implementation.autograd import AutogradDerivatives @@ -33,6 +33,7 @@ from torch import Tensor, rand from backpack.core.derivatives.convnd import weight_jac_t_save_memory +from backpack.utils.subsampling import subsample PROBLEMS = make_test_problems(SETTINGS) IDS = [problem.make_id() for problem in PROBLEMS] @@ -73,8 +74,8 @@ problem.make_id() for problem in CUSTOM_SLICING_MODULE_PROBLEMS ] -MSE_PROBLEMS = [problem for problem in PROBLEMS if problem.is_mse()] -MSE_IDS = [problem.make_id() for problem in MSE_PROBLEMS] +NLL_PROBLEMS = [problem for problem in PROBLEMS if problem.is_nll()] +NLL_IDS = [problem.make_id() for problem in NLL_PROBLEMS] SUBSAMPLINGS = [None, [0, 0], [2, 0]] SUBSAMPLING_IDS = [f"subsampling={s}".replace(" ", "") for s in SUBSAMPLINGS] @@ -348,7 +349,7 @@ def test_sqrt_hessian_sampled_squared_approximates_hessian( @mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) -@mark.parametrize("problem", MSE_PROBLEMS, ids=MSE_IDS) +@mark.parametrize("problem", NLL_PROBLEMS, ids=NLL_IDS) def test_sqrt_hessian_sampled_squared_approximates_hessian_nll( problem: DerivativesTestProblem, subsampling: Union[List[int], None], @@ -405,6 +406,33 @@ def test_sqrt_hessian_sampled_squared_approximates_hessian_nll( raise e +@mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) +@mark.parametrize("problem", NLL_PROBLEMS, ids=NLL_IDS) +def test_dist_sample_shape_nll( + problem: DerivativesTestProblem, + subsampling: Union[List[int], None], +) -> None: + """Test distribution sample shape for NLL derivatives. + + Compares the shape sampled from the distribution to the output to + verify the shapes match. + + Args: + problem: Test case. + subsampling: Indices of active samples. + """ + problem.set_up() + skip_subsampling_conflict(problem, subsampling) + BackpackDerivatives(problem).store_forward_io() + + subsampled_input = subsample(problem.module.input0, subsampling=subsampling) + subsampled_target = subsample(problem.module.input1, subsampling=subsampling) + samples = problem.derivative._make_distribution(subsampled_input).sample() + + check_sizes(samples, subsampled_target) + problem.tear_down() + + @mark.parametrize("subsampling", SUBSAMPLINGS, ids=SUBSAMPLING_IDS) @mark.parametrize("problem", LOSS_FAIL_PROBLEMS, ids=LOSS_FAIL_IDS) def test_sqrt_hessian_sampled_should_fail( diff --git a/test/core/derivatives/problem.py b/test/core/derivatives/problem.py index 4d98ed533..cd46704bc 100644 --- a/test/core/derivatives/problem.py +++ b/test/core/derivatives/problem.py @@ -8,7 +8,7 @@ from torch import Tensor, long from backpack import extend -from backpack.utils.module_classification import is_loss, is_mse +from backpack.utils.module_classification import is_loss, is_nll from backpack.utils.subsampling import subsample @@ -141,8 +141,8 @@ def make_output_shape(self): def is_loss(self): return is_loss(self.make_module()) - def is_mse(self): - return is_mse(self.make_module()) + def is_nll(self): + return is_nll(self.make_module()) def forward_pass( self, input_requires_grad: bool = False, subsampling: List[int] = None From cd864721c5ecb10f0a774397154eadadddbc0d81 Mon Sep 17 00:00:00 2001 From: jabader97 <51213988+jabader97@users.noreply.github.com> Date: Fri, 3 Jun 2022 15:22:54 +0200 Subject: [PATCH 03/29] [ADD] Implement sampled gradients for `BCEWithLogitsLoss` (#257) * [ADD] Added BCEWithLogits loss to NLL base * [TEST] Skip BCEWithLogitsLoss _sqrt_hessian and _compute_sampled_grads_manual (not implemented) * [DOC] Fix darglint * [DEL] Remove f-string * [TEST] Skip unimplemented methods for BCEWithLogitsLoss * [REF] Raise NotImplementedErrors, rename output -> target * [REF] Rename bceloss -> bcewithlogitsloss * [REF] Less imports, type annotation, docstring polish * [ADD] Support `reduction='sum'` * [DEL] Remove redundant constructor Co-authored-by: Felix Dangel --- .../core/derivatives/bcewithlogitsloss.py | 84 +++++++++++++++++++ backpack/utils/module_classification.py | 6 +- test/core/derivatives/__init__.py | 3 + test/core/derivatives/derivatives_test.py | 10 +++ test/core/derivatives/loss_settings.py | 10 +++ test/utils/skip_test.py | 12 ++- 6 files changed, 121 insertions(+), 4 deletions(-) create mode 100644 backpack/core/derivatives/bcewithlogitsloss.py diff --git a/backpack/core/derivatives/bcewithlogitsloss.py b/backpack/core/derivatives/bcewithlogitsloss.py new file mode 100644 index 000000000..15c4157d2 --- /dev/null +++ b/backpack/core/derivatives/bcewithlogitsloss.py @@ -0,0 +1,84 @@ +"""NLL extention for BCEWithLogits Loss.""" + +from torch import Tensor +from torch.distributions import Binomial +from torch.nn import BCEWithLogitsLoss + +from backpack.core.derivatives.nll_base import NLLLossDerivatives + + +class BCELossWithLogitsDerivatives(NLLLossDerivatives): + """Derivatives of the BCEWithLogits Loss.""" + + def _verify_support(self, module: BCEWithLogitsLoss): + """Verification of module support for BCEWithLogitsLoss. + + Currently BCEWithLogitsLoss only supports binary target tensors, + 2D inputs, and default parameters. + + Args: + module: BCEWithLogitsLoss module + """ + self._check_binary(module) + self._check_is_default(module) + self._check_input_dims(module) + + def _check_binary(self, module: BCEWithLogitsLoss): + """Raises exception if labels are not binary. + + Args: + module: BCEWithLogitsLoss module + + Raises: + NotImplementedError: if labels are non-binary. + """ + if any(x != 0 and x != 1 for x in module.input1.flatten()): + raise NotImplementedError( + "Only binary targets (0 and 1) are currently supported." + ) + + def _check_is_default(self, module: BCEWithLogitsLoss): + """Raises exception if module parameters are not default. + + Args: + module: BCEWithLogitsLoss module + + Raises: + NotImplementedError: if module parameters non-default. + """ + if module.weight is not None: + raise NotImplementedError("Only None weight is currently supported.") + if module.pos_weight is not None: + raise NotImplementedError("Only None pos_weight is currently supported.") + + def _check_input_dims(self, module: BCEWithLogitsLoss): + """Raises an exception if the shapes of the input are not supported. + + Args: + module: BCEWithLogitsLoss module + + Raises: + NotImplementedError: if input is not 2D. + """ + if module.input0.dim() != 2: + raise NotImplementedError("Only 2D inputs are currently supported.") + + def _make_distribution(self, subsampled_input: Tensor) -> Binomial: + """Make the sampling distribution for the NLL loss form of BCEWithLogits. + + The BCEWithLogitsLoss ∝ ∑ᵢ₌₁ⁿ Yᵢ log 𝜎(xᵢ) + (1 − Yᵢ) log(1− 𝜎(xᵢ)). + The log likelihood of the Binomial distribution is + Yᵢ log p(xᵢ) + (1 − Yᵢ) log(1 − p(xᵢ)), so these are equivalent if + p(xᵢ) = 𝜎(xᵢ). + + Args: + subsampled_input: input after subsampling + + Returns: + Binomial distribution with sigmoid probabilities from the subsampled_input. + """ + return Binomial(probs=subsampled_input.sigmoid()) + + @staticmethod + def _get_mean_normalization(input: Tensor) -> int: + return input.shape[0] diff --git a/backpack/utils/module_classification.py b/backpack/utils/module_classification.py index e3321c9ff..c70e5247d 100644 --- a/backpack/utils/module_classification.py +++ b/backpack/utils/module_classification.py @@ -1,6 +1,6 @@ """Contains util function for classification of modules.""" from torch.fx import GraphModule -from torch.nn import CrossEntropyLoss, Module, MSELoss, Sequential +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, Module, MSELoss, Sequential from torch.nn.modules.loss import _Loss from backpack.custom_module.branching import Parallel, _Branch @@ -22,7 +22,7 @@ def is_loss(module: Module) -> bool: def is_nll(module: Module) -> bool: """Return whether 'module' is an NLL loss function. - Current NLL loss functions include MSE and CE. + Current NLL loss functions include MSE, CE and BCEWithLogits. Args: module: A PyTorch module. @@ -30,7 +30,7 @@ def is_nll(module: Module) -> bool: Returns: Whether 'module' is an NLL loss function """ - return isinstance(module, (MSELoss, CrossEntropyLoss)) + return isinstance(module, (MSELoss, CrossEntropyLoss, BCEWithLogitsLoss)) def is_no_op(module: Module) -> bool: diff --git a/test/core/derivatives/__init__.py b/test/core/derivatives/__init__.py index b76a1e633..36bc03bfc 100644 --- a/test/core/derivatives/__init__.py +++ b/test/core/derivatives/__init__.py @@ -13,6 +13,7 @@ BatchNorm1d, BatchNorm2d, BatchNorm3d, + BCEWithLogitsLoss, Conv1d, Conv2d, Conv3d, @@ -45,6 +46,7 @@ from backpack.core.derivatives.avgpool2d import AvgPool2DDerivatives from backpack.core.derivatives.avgpool3d import AvgPool3DDerivatives from backpack.core.derivatives.batchnorm_nd import BatchNormNdDerivatives +from backpack.core.derivatives.bcewithlogitsloss import BCELossWithLogitsDerivatives from backpack.core.derivatives.conv1d import Conv1DDerivatives from backpack.core.derivatives.conv2d import Conv2DDerivatives from backpack.core.derivatives.conv3d import Conv3DDerivatives @@ -120,4 +122,5 @@ SumModule: SumModuleDerivatives, Pad: PadDerivatives, Slicing: SlicingDerivatives, + BCEWithLogitsLoss: BCELossWithLogitsDerivatives, } diff --git a/test/core/derivatives/derivatives_test.py b/test/core/derivatives/derivatives_test.py index fb39bc96c..79664b173 100644 --- a/test/core/derivatives/derivatives_test.py +++ b/test/core/derivatives/derivatives_test.py @@ -24,6 +24,7 @@ from test.utils.skip_test import ( skip_adaptive_avg_pool3d_cuda, skip_batch_norm_train_mode_with_subsampling, + skip_BCEWithLogitsLoss, skip_subsampling_conflict, ) from typing import List, Union @@ -292,6 +293,7 @@ def test_sqrt_hessian_squared_equals_hessian( """ problem.set_up() skip_subsampling_conflict(problem, subsampling) + skip_BCEWithLogitsLoss(problem) # TODO Implement sqrt_hessian for BCEWithLogitsLoss backpack_res = BackpackDerivatives(problem).input_hessian_via_sqrt_hessian( subsampling=subsampling @@ -337,6 +339,9 @@ def test_sqrt_hessian_sampled_squared_approximates_hessian( """ problem.set_up() skip_subsampling_conflict(problem, subsampling) + skip_BCEWithLogitsLoss( + problem + ) # TODO Implement _compute_sampled_grads_manual for BCEWithLogitsLoss backpack_res = BackpackDerivatives(problem).input_hessian_via_sqrt_hessian( mc_samples=mc_samples, chunks=chunks, subsampling=subsampling @@ -456,6 +461,7 @@ def test_sum_hessian(problem): problem (DerivativesProblem): Problem for derivative test. """ problem.set_up() + skip_BCEWithLogitsLoss(problem) # TODO Implement _sum_hessian for BCEWithLogitsLoss backpack_res = BackpackDerivatives(problem).sum_hessian() autograd_res = AutogradDerivatives(problem).sum_hessian() @@ -596,6 +602,10 @@ def test_make_hessian_mat_prod(problem: DerivativesTestProblem) -> None: problem: test problem """ problem.set_up() + skip_BCEWithLogitsLoss( + problem + ) # TODO Implement _make_hessian_mat_prod for BCEWithLogitsLoss + mat = rand(4, *problem.input_shape, device=problem.device) autograd_res = AutogradDerivatives(problem).hessian_mat_prod(mat) diff --git a/test/core/derivatives/loss_settings.py b/test/core/derivatives/loss_settings.py index 391420cae..ad292e48c 100644 --- a/test/core/derivatives/loss_settings.py +++ b/test/core/derivatives/loss_settings.py @@ -77,6 +77,16 @@ "input_fn": lambda: torch.rand(size=(1, 1)), "target_fn": lambda: regression_targets(size=(1, 1)), }, + { + "module_fn": lambda: torch.nn.BCEWithLogitsLoss(reduction="mean"), + "input_fn": lambda: torch.rand(size=(2, 1)), + "target_fn": lambda: classification_targets(size=(2, 1), num_classes=2).float(), + }, + { + "module_fn": lambda: torch.nn.BCEWithLogitsLoss(reduction="sum"), + "input_fn": lambda: torch.rand(size=(4, 1)), + "target_fn": lambda: classification_targets(size=(4, 1), num_classes=2).float(), + }, ] diff --git a/test/utils/skip_test.py b/test/utils/skip_test.py index 4f282662f..0a77a5c1a 100644 --- a/test/utils/skip_test.py +++ b/test/utils/skip_test.py @@ -5,7 +5,7 @@ from typing import List, Union from pytest import skip -from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d +from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d, BCEWithLogitsLoss from backpack.utils import ADAPTIVE_AVG_POOL_BUG @@ -69,3 +69,13 @@ def skip_large_parameters( num_params = sum(p.numel() for p in problem.trainable_parameters()) if num_params > max_num_params: skip(f"Model has too many parameters: {num_params} > {max_num_params}") + + +def skip_BCEWithLogitsLoss(problem: ExtensionsTestProblem) -> None: + """Skip if the test problem uses BCEWithLogitsLoss. + + Args: + problem: Test case. + """ + if isinstance(problem.module, BCEWithLogitsLoss): + skip("Skipping BCEWithLogitsLoss") From 94f2bdef899188900fe03faa05dbc8324c00230d Mon Sep 17 00:00:00 2001 From: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Date: Sat, 24 Sep 2022 17:52:31 +0200 Subject: [PATCH 04/29] [REF] Ignore warning to explicitly declare `abstractmethod`s (#264) * [REF] Declare `abstractmethod`s * [REF] Ignore warning: class inheriting from `ABC` has no abstract methods Co-authored-by: Felix Dangel --- backpack/core/derivatives/basederivatives.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/backpack/core/derivatives/basederivatives.py b/backpack/core/derivatives/basederivatives.py index 94c152884..2a55ac88d 100644 --- a/backpack/core/derivatives/basederivatives.py +++ b/backpack/core/derivatives/basederivatives.py @@ -9,7 +9,7 @@ from backpack.core.derivatives import shape_check -class BaseDerivatives(ABC): +class BaseDerivatives(ABC): # noqa: B024 """First- and second-order partial derivatives of unparameterized module. Note: @@ -193,6 +193,7 @@ def hessian_is_diagonal(self, module: Module) -> bool: # FIXME Currently returns `∂²output[i] / ∂input[i]² * g_out[0][i]`, # which s the residual matrix diagonal, rather than the Hessian diagonal + def hessian_diagonal( self, module: Module, g_in: Tuple[Tensor], g_out: Tuple[Tensor] ) -> Tensor: @@ -306,7 +307,7 @@ def reshape_like_output(cls, mat: Tensor, module: Module) -> Tensor: return cls._reshape_like(mat, module.output.shape) -class BaseParameterDerivatives(BaseDerivatives, ABC): +class BaseParameterDerivatives(BaseDerivatives, ABC): # noqa: B024 """First- and second order partial derivatives of a module with parameters. Assumptions (true for `nn.Linear`, `nn.Conv(Transpose)Nd`, `nn.BatchNormNd`): @@ -435,7 +436,7 @@ def _weight_jac_mat_prod( raise NotImplementedError -class BaseLossDerivatives(BaseDerivatives, ABC): +class BaseLossDerivatives(BaseDerivatives, ABC): # noqa: B024 """Second- order partial derivatives of loss functions.""" # TODO Add shape check From 53ddd861da47df30c664e3b58dbdf19a598a3f05 Mon Sep 17 00:00:00 2001 From: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Date: Sat, 24 Sep 2022 18:06:08 +0200 Subject: [PATCH 05/29] [FIX] Use correct imports from backpack in website examples (#262) Replaces `from backpack import ..., X` <-> `from backpack import ..., extensions` and `with backpack(X())` <-> `with backpack(extensions.X())` Also applies white space cleanup. Co-authored-by: Felix Dangel --- docs_src/splash/_includes/code-samples.html | 34 ++++++++++----------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/docs_src/splash/_includes/code-samples.html b/docs_src/splash/_includes/code-samples.html index e5a98ba50..7a701bce8 100644 --- a/docs_src/splash/_includes/code-samples.html +++ b/docs_src/splash/_includes/code-samples.html @@ -71,14 +71,14 @@ """ from torch.nn import CrossEntropyLoss, Linear from utils import load_mnist_data -from backpack import extend, backpack, Variance +from backpack import extend, backpack, extensions X, y = load_mnist_data() model = extend(Linear(784, 10)) lossfunc = extend(CrossEntropyLoss()) loss = lossfunc(model(X), y) -with backpack(Variance()): +with backpack(extensions.Variance()): loss.backward() for param in model.parameters(): @@ -118,14 +118,14 @@ """ from torch.nn import CrossEntropyLoss, Linear from utils import load_mnist_data -from backpack import extend, backpack, DiagGGNExact +from backpack import extend, backpack, extensions X, y = load_mnist_data() model = extend(Linear(784, 10)) lossfunc = extend(CrossEntropyLoss()) loss = lossfunc(model(X), y) -with backpack(DiagGGNExact()): +with backpack(extensions.DiagGGNExact()): loss.backward() for param in model.parameters(): @@ -141,14 +141,14 @@ """ from torch.nn import CrossEntropyLoss, Linear from utils import load_mnist_data -from backpack import extend, backpack, KFAC +from backpack import extend, backpack, extensions X, y = load_mnist_data() model = extend(Linear(784, 10)) lossfunc = extend(CrossEntropyLoss()) loss = lossfunc(model(X), y) -with backpack(KFAC()): +with backpack(extensions.KFAC()): loss.backward() for param in model.parameters(): @@ -159,20 +159,20 @@