Skip to content

Commit

Permalink
Merge pull request #309 from f-dangel/development
Browse files Browse the repository at this point in the history
Update master with code for `1.6.0` release
  • Loading branch information
f-dangel committed Jul 12, 2023
2 parents ffa6068 + ee1dc4f commit 1ebfb40
Show file tree
Hide file tree
Showing 68 changed files with 2,569 additions and 382 deletions.
9 changes: 3 additions & 6 deletions .conda_env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@ channels:
- pytorch
- defaults
dependencies:
- pip=19.3.1
- python=3.7.6
- pip=21.2.4
- python=3.8.5
- pip:
- -e .
- -e .[lint]
- -e .[test]
- -e .[docs]
- -e .[lint,test,doc]
28 changes: 14 additions & 14 deletions .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- name: Set up Python 3.7
- name: Set up Python 3.8
uses: actions/setup-python@v1
with:
python-version: 3.7
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand All @@ -30,10 +30,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- name: Set up Python 3.7
- name: Set up Python 3.8
uses: actions/setup-python@v1
with:
python-version: 3.7
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand All @@ -45,10 +45,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- name: Set up Python 3.7
- name: Set up Python 3.8
uses: actions/setup-python@v1
with:
python-version: 3.7
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand All @@ -61,10 +61,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- name: Set up Python 3.7
- name: Set up Python 3.8
uses: actions/setup-python@v1
with:
python-version: 3.7
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand All @@ -77,10 +77,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- name: Set up Python 3.7
- name: Set up Python 3.8
uses: actions/setup-python@v1
with:
python-version: 3.7
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand All @@ -92,10 +92,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- name: Set up Python 3.7
- name: Set up Python 3.8
uses: actions/setup-python@v1
with:
python-version: 3.7
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand All @@ -107,10 +107,10 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- name: Set up Python 3.7
- name: Set up Python 3.8
uses: actions/setup-python@v1
with:
python-version: 3.7
python-version: 3.8
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
5 changes: 3 additions & 2 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,18 @@ jobs:
name: "py${{ matrix.python-version }} torch${{ matrix.pytorch-version}}"
runs-on: ubuntu-latest
env:
USING_COVERAGE: '3.7,3.9'
USING_COVERAGE: '3.9'

strategy:
matrix:
python-version: [3.7, 3.8, 3.9]
python-version: [3.8, 3.9]
pytorch-version:
- "==1.9.1"
- "==1.10.1"
- "==1.11.0"
- "==1.12.1"
- "==1.13.1"
- "==2.0.1"
- "" # latest
steps:
- uses: actions/checkout@v1
Expand Down
2 changes: 1 addition & 1 deletion .readthedocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ sphinx:
configuration: docs_src/rtd/conf.py

python:
version: 3.7
version: 3.8
install:
- method: pip
path: .
Expand Down
2 changes: 1 addition & 1 deletion README-dev.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# <img alt="BackPACK" src="./logo/backpack_logo_torch.svg" height="90"> BackPACK developer manual

## General standards
- Python version: support 3.7+, use 3.7 for development
- Python version: support 3.8+, use 3.8 for development
- `git` [branching model](https://nvie.com/posts/a-successful-git-branching-model/)
- Docstring style: [Google](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html)
- Test runner: [`pytest`](https://docs.pytest.org/en/latest/)
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

[![Travis](https://travis-ci.org/f-dangel/backpack.svg?branch=master)](https://travis-ci.org/f-dangel/backpack)
[![Coveralls](https://coveralls.io/repos/github/f-dangel/backpack/badge.svg?branch=master)](https://coveralls.io/github/f-dangel/backpack)
[![Python 3.7+](https://img.shields.io/badge/python-3.7+-blue.svg)](https://www.python.org/downloads/release/python-370/)
[![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/release/python-370/)

BackPACK is built on top of [PyTorch](https://github.com/pytorch/pytorch). It efficiently computes quantities other than the gradient.

Expand Down
5 changes: 3 additions & 2 deletions backpack/core/derivatives/basederivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def hessian_is_diagonal(self, module: Module) -> bool:

# FIXME Currently returns `∂²output[i] / ∂input[i]² * g_out[0][i]`,
# which s the residual matrix diagonal, rather than the Hessian diagonal

def hessian_diagonal(
self, module: Module, g_in: Tuple[Tensor], g_out: Tuple[Tensor]
) -> Tensor:
Expand Down Expand Up @@ -306,7 +307,7 @@ def reshape_like_output(cls, mat: Tensor, module: Module) -> Tensor:
return cls._reshape_like(mat, module.output.shape)


class BaseParameterDerivatives(BaseDerivatives, ABC):
class BaseParameterDerivatives(BaseDerivatives, ABC): # noqa: B024
"""First- and second order partial derivatives of a module with parameters.
Assumptions (true for `nn.Linear`, `nn.Conv(Transpose)Nd`, `nn.BatchNormNd`):
Expand Down Expand Up @@ -435,7 +436,7 @@ def _weight_jac_mat_prod(
raise NotImplementedError


class BaseLossDerivatives(BaseDerivatives, ABC):
class BaseLossDerivatives(BaseDerivatives, ABC): # noqa: B024
"""Second- order partial derivatives of loss functions."""

# TODO Add shape check
Expand Down
1 change: 0 additions & 1 deletion backpack/core/derivatives/batchnorm_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def _jac_t_mat_prod(
self._check_parameters(module)
N: int = self._get_n_axis(module)
if module.training:

if subsampling is not None:
raise NotImplementedError(
"BatchNorm VJP sub-sampling is not defined in train mode."
Expand Down
171 changes: 171 additions & 0 deletions backpack/core/derivatives/bcewithlogitsloss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
"""NLL extention for BCEWithLogits Loss."""

from math import sqrt
from typing import List, Tuple

from torch import Size, Tensor, sigmoid
from torch.distributions import Binomial
from torch.nn import BCEWithLogitsLoss

from backpack.core.derivatives.nll_base import NLLLossDerivatives
from backpack.utils.subsampling import subsample


class BCELossWithLogitsDerivatives(NLLLossDerivatives):
"""Derivatives of the BCEWithLogits Loss."""

def __init__(self, use_autograd: bool = False):
"""Initialization for BCEWithLogitsLoss derivative.
Args:
use_autograd: Compute gradients with autograd (rather than manual)
Defaults to ``False`` (manual computation).
"""
super().__init__(use_autograd=use_autograd)

def _verify_support(self, module: BCEWithLogitsLoss):
"""Verification of module support for BCEWithLogitsLoss.
Currently BCEWithLogitsLoss only supports binary target tensors,
2D inputs, and default parameters.
Args:
module: BCEWithLogitsLoss module
"""
self._check_binary(module)
self._check_is_default(module)
self._check_input_dims(module)

def _check_binary(self, module: BCEWithLogitsLoss):
"""Raises exception if labels are not binary.
Args:
module: BCEWithLogitsLoss module
Raises:
NotImplementedError: if labels are non-binary.
"""
if any(x not in [0, 1] for x in module.input1.flatten()):
raise NotImplementedError(
"Only binary targets (0 and 1) are currently supported."
)

def _check_is_default(self, module: BCEWithLogitsLoss):
"""Raises exception if module parameters are not default.
Args:
module: BCEWithLogitsLoss module
Raises:
NotImplementedError: if module parameters non-default.
"""
if module.weight is not None:
raise NotImplementedError("Only None weight is currently supported.")
if module.pos_weight is not None:
raise NotImplementedError("Only None pos_weight is currently supported.")

def _check_input_dims(self, module: BCEWithLogitsLoss):
"""Raises an exception if the shapes of the input are not supported.
Args:
module: BCEWithLogitsLoss module
Raises:
NotImplementedError: if input is not a batch of scalars.
"""
if module.input0.dim() != 2:
raise NotImplementedError("Only 2D inputs are currently supported.")
if module.input0.shape[1] != 1:
raise NotImplementedError(
"Only scalar-valued predictions are currently supported."
)

def _make_distribution(self, subsampled_input: Tensor) -> Binomial:
"""Make the sampling distribution for the NLL loss form of BCEWithLogits.
The BCEWithLogitsLoss ∝ ∑ᵢ₌₁ⁿ Yᵢ log 𝜎(xᵢ) + (1 − Yᵢ) log(1− 𝜎(xᵢ)).
The log likelihood of the Binomial distribution is
Yᵢ log p(xᵢ) + (1 − Yᵢ) log(1 − p(xᵢ)), so these are equivalent if
p(xᵢ) = 𝜎(xᵢ).
Args:
subsampled_input: input after subsampling
Returns:
Binomial distribution with sigmoid probabilities from the subsampled_input.
"""
return Binomial(probs=subsampled_input.sigmoid())

def _compute_sampled_grads_manual(
self, subsampled_input: Tensor, mc_samples: int
) -> Tensor:
"""Manually compute gradients from sampled targets.
Let fₙ ∈ ℝ and yₙ ∈ {0, 1} ∼ p(y | fₙ) and σ(fₙ) the softmax probability.
Then the gradient is ∇ℓ(fₙ, yₙ) = σ(fₙ) - fₙ.
Args:
subsampled_input: input after subsampling
mc_samples: number of samples
Returns:
Gradient samples
"""
probs = subsampled_input.sigmoid()
expand_dims = [mc_samples] + probs.dim() * [-1]
probs_unsqeezed = probs.unsqueeze(0).expand(*expand_dims) # [V N 1]

distribution = self._make_distribution(subsampled_input)
samples = distribution.sample(Size([mc_samples])) # [V N 1]

return probs_unsqeezed - samples

@staticmethod
def _get_mean_normalization(input: Tensor) -> int:
return input.shape[0]

def _sqrt_hessian(
self,
module: BCEWithLogitsLoss,
g_inp: Tuple[Tensor],
g_out: Tuple[Tensor],
subsampling: List[int],
) -> Tensor: # noqa: D102
"""Return a symmetric factorization of the loss Hessian.
# noqa: DAR101
Let fₙ ∈ ℝ be the input and yₙ ∈ [0; 1] be the label, and σ(fₙ) ∈ (0;
1) be the sigmoid probability. Then, the gradient ∇ℓ(fₙ, yₙ) w.r.t. fₙ
is ∇ℓ(fₙ, yₙ) = σ(fₙ) - yₙ, and the Hessian ∇²ℓ(fₙ, yₙ) w.r.t. fₙ is
∇²ℓ(fₙ, yₙ) = σ'(fₙ) = σ(fₙ) (1 - σ(fₙ)). Consequently, the (scalar)
Hessian square root is √(σ(fₙ) (1 - σ(fₙ))).
Returns:
Hessian square root factorization of shape ``[1, N, 1]`` where ``N``
corresponds to the (subsampled) batch size.
"""
self._check_is_default(module)
self._check_input_dims(module)

input0 = subsample(module.input0, subsampling=subsampling)
sigma = sigmoid(input0).unsqueeze(0)

sqrt_H = (sigma * (1 - sigma)).sqrt()

if module.reduction == "mean":
sqrt_H /= sqrt(self._get_mean_normalization(module.input0))

return sqrt_H

def hessian_is_psd(self) -> bool:
"""Return whether the Hessian is PSD.
Let fₙ ∈ ℝ be the input and yₙ ∈ [0; 1] be the label, and σ(fₙ) ∈ (0;
1) be the sigmoid probability. The Hessian ∇²ℓ(fₙ, yₙ) w.r.t. fₙ is
∇²ℓ(fₙ, yₙ) = σ'(fₙ) = σ(fₙ) (1 - σ(fₙ)) > 0. Hence, the Hessian is PSD.
Returns:
True
"""
return True
Loading

0 comments on commit 1ebfb40

Please sign in to comment.