Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CNN utility #1

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# Maximal Update Parametrization (μP) and Hyperparameter Transfer (μTransfer)
# NOTE:
- muP is incompatible with `weight_norm`; changes made in 1.0.1 only prevents certain errors.

# Maximal Update Parametrization (μP) and Hyperparameter Transfer (μTransfer)

[Paper link](https://arxiv.org/abs/2203.03466)
|
Expand All @@ -11,7 +14,7 @@ This can be used to tune extremely large neural networks such as large pretraine
More generally, μP reduces the fragility and uncertainty when transitioning from exploration to scaling up, which are not often talked about explicitly in the deep learning literature.

![](figures/sp_vs_mup_dashed.png)
<font size="1"> *Figure above: Training loss against learning rate on Transformers of varying `d_model` trained with Adam.*</font>
<font size="1"> *Figure above: Training loss against learning rate on Transformers of varying `d_model` trained with Adam.*</font>


μP turns out to be the *unique* "natural" parametrization that has this hyperparameter stability property across width, as empirically verified in the gif below on MLPs trained with SGD. Here, across time, we interpolate between PyTorch default and μP's learning rate and initialization scalings (right), and we scale up the width-256 model (log2(width)=8) to width 2^13 = 8192 using this interpolated scaling rule (left).
Expand Down Expand Up @@ -88,7 +91,7 @@ base_model = MyModel(width=1)
delta_model = MyModel(width=2) # Optionally use `torchdistx` to avoid instantiating

### Instantiate the target model (the model you actually want to train).
### This should be the same as the base model except
### This should be the same as the base model except
### the widths could be potentially different.
### In particular, base_model and model should have the same depth.
model = MyModel(width=100)
Expand All @@ -107,7 +110,7 @@ set_base_shapes(model, base_model, delta=delta_model)
# make_base_shapes(base_model, delta_model, filename)
### and later set base shapes directly from the filename
# set_base_shapes(model, filename)
### This is useful when one cannot fit both
### This is useful when one cannot fit both
### base_model and model in memory at the same time

### Replace your custom init, if any
Expand Down Expand Up @@ -222,12 +225,12 @@ the following can shrink to 0 at initialization in μP (at a 1/sqrt(width) rate)

These are transient, and after a few steps their curves should be roughly flat.
Nevertheless, to remove the discrepancy at init, we recommend
- initializing the output layer
- initializing the output layer
(should be a `MuReadout` instance) weights to be 0 via
the `readout_zero_init=True` option and
- initializing the query matrix in a Transformer to 0
(this has to be done manually). If symmetry-breaking is desired in the attention logits at init, initialize the (relative) position biases with nonzero variance.

#### Tips for Coord Check

- Use a large learning rate (larger than you'd use for actual training). This would emphasize any potential exploding coordinates issue, which could be hidden by the initialization if the learning rate is too small.
Expand Down Expand Up @@ -298,8 +301,8 @@ contact [[email protected]](mailto:[email protected]) with any additio

## Trademarks

This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
trademarks or logos is subject to and must follow
This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft
trademarks or logos is subject to and must follow
[Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general).
Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship.
Any use of third-party trademarks or logos are subject to those third-party's policies.
10 changes: 5 additions & 5 deletions examples/MLP/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -107,7 +107,7 @@
" data, target = data.to(device), target.to(device)\n",
" optimizer.zero_grad()\n",
" output = model(data.view(data.size(0), -1))\n",
" \n",
"\n",
" loss = criterion(output, target)\n",
" loss.backward()\n",
" train_loss += loss.item() * data.shape[0] # sum up batch loss\n",
Expand Down Expand Up @@ -154,7 +154,7 @@
" self.fc_2 = nn.Linear(width, width, bias=False)\n",
" self.fc_3 = nn.Linear(width, num_classes, bias=False)\n",
" self.reset_parameters()\n",
" \n",
"\n",
" def reset_parameters(self):\n",
" nn.init.kaiming_normal_(self.fc_1.weight, a=1, mode='fan_in')\n",
" self.fc_1.weight.data /= self.input_mult**0.5\n",
Expand All @@ -176,7 +176,7 @@
" self.fc_2 = nn.Linear(width, width, bias=False)\n",
" self.fc_3 = MuReadout(width, num_classes, bias=False, output_mult=self.output_mult)\n",
" self.reset_parameters()\n",
" \n",
"\n",
" def reset_parameters(self):\n",
" nn.init.kaiming_normal_(self.fc_1.weight, a=1, mode='fan_in')\n",
" self.fc_1.weight.data /= self.input_mult**0.5\n",
Expand Down Expand Up @@ -233,7 +233,7 @@
"source": [
"# optimal HPs\n",
"output_mult = 32\n",
"input_mult = 0.00390625"
"input_mult = 0.00390625 # 1/256"
]
},
{
Expand Down
147 changes: 142 additions & 5 deletions mup/layer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright 2022 Microsoft Corporation.
from torch.nn import Linear
# Copyright 2022 Resemble AI.
import torch
from torch.nn import Linear, Conv1d, Conv2d


class MuReadout(Linear):
Expand All @@ -15,7 +17,7 @@ def __init__(self, *args, readout_zero_init=False, output_mult=1.0, **kwargs):
self.output_mult = output_mult
self.readout_zero_init = readout_zero_init
super().__init__(*args, **kwargs)

def reset_parameters(self) -> None:
if self.readout_zero_init:
self.weight.data[:] = 0
Expand Down Expand Up @@ -50,15 +52,15 @@ def _rescale_parameters(self):
self.bias.data *= self.width_mult()**0.5
self.weight.data *= self.width_mult()**0.5
self._has_rescaled_params = True

def forward(self, x):
return super().forward(
self.output_mult * x / self.width_mult())


class MuSharedReadout(MuReadout):
'''`MuReadout` with weights shared with an `nn.Embedding` layer.

Inputs:
weight: should be weight of an `nn.Embedding` layer
other inputs are fed to `MuReadout`
Expand All @@ -67,6 +69,136 @@ def __init__(self, weight, bias=True, **kwargs):
super().__init__(*weight.shape, bias=bias, **kwargs)
self.weight = weight


class MuConv1dOut(Conv1d):
'''Drop-in replacement for all output linear layers.

An "output" linear layer is one that maps from a width dimension (e.g.,
`d_model` in a Transformer) to a non-width dimension (e.g., vocab size).

This layer implements the version of μP with a 1/width multiplier and a
constant variance initialization for both weights and biases.
'''
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size,
*args,
readout_zero_init=False,
output_mult=1.0,
**kwargs
):
self.output_mult = output_mult
self.readout_zero_init = readout_zero_init
super().__init__(in_channels, out_channels, kernel_size, *args, **kwargs)

def reset_parameters(self) -> None:
if self.readout_zero_init:
self.weight.data[:] = 0
if self.bias is not None:
self.bias.data[:] = 0
else:
super().reset_parameters()

def width_mult(self):
assert hasattr(self.weight, 'infshape'), (
'Please call set_base_shapes(...). If using torch.nn.DataParallel, '
'switch to distributed training with '
'torch.nn.parallel.DistributedDataParallel instead'
)
return self.weight.infshape.width_mult()


def _rescale_parameters(self):
'''Rescale parameters to convert SP initialization to μP initialization.

Warning: This method is NOT idempotent and should be called only once
unless you know what you are doing.
'''
if hasattr(self, '_has_rescaled_params') and self._has_rescaled_params:
raise RuntimeError(
"`_rescale_parameters` has been called once before already. "
"Unless you know what you are doing, usually you should not be calling `_rescale_parameters` more than once.\n"
"If you called `set_base_shapes` on a model loaded from a checkpoint, "
"or just want to re-set the base shapes of an existing model, "
"make sure to set the flag `rescale_params=False`.\n"
"To bypass this error and *still rescale parameters*, set `self._has_rescaled_params=False` before this call.")
if self.bias is not None:
self.bias.data *= self.width_mult()**0.5
self.weight.data *= self.width_mult()**0.5
self._has_rescaled_params = True

def forward(self, x):
return super().forward(
self.output_mult * x / self.width_mult())



class MuConv2dOut(Conv2d):
'''Drop-in replacement for all output linear layers.

An "output" linear layer is one that maps from a width dimension (e.g.,
`d_model` in a Transformer) to a non-width dimension (e.g., vocab size).

This layer implements the version of μP with a 1/width multiplier and a
constant variance initialization for both weights and biases.
'''
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size,
*args,
readout_zero_init=False,
output_mult=1.0,
**kwargs
):
self.output_mult = output_mult
self.readout_zero_init = readout_zero_init
super().__init__(in_channels, out_channels, kernel_size, *args, **kwargs)

def reset_parameters(self) -> None:
if self.readout_zero_init:
self.weight.data[:] = 0
if self.bias is not None:
self.bias.data[:] = 0
else:
super().reset_parameters()

def width_mult(self):
assert hasattr(self.weight, 'infshape'), (
'Please call set_base_shapes(...). If using torch.nn.DataParallel, '
'switch to distributed training with '
'torch.nn.parallel.DistributedDataParallel instead'
)
return self.weight.infshape.width_mult()

def _rescale_parameters(self):
'''Rescale parameters to convert SP initialization to μP initialization.

Warning: This method is NOT idempotent and should be called only once
unless you know what you are doing.
'''
if hasattr(self, '_has_rescaled_params') and self._has_rescaled_params:
raise RuntimeError(
"`_rescale_parameters` has been called once before already. "
"Unless you know what you are doing, usually you should not be calling `_rescale_parameters` more than once.\n"
"If you called `set_base_shapes` on a model loaded from a checkpoint, "
"or just want to re-set the base shapes of an existing model, "
"make sure to set the flag `rescale_params=False`.\n"
"To bypass this error and *still rescale parameters*, set `self._has_rescaled_params=False` before this call.")
if self.bias is not None:
self.bias.data *= self.width_mult()**0.5
self.weight.data *= self.width_mult()**0.5
self._has_rescaled_params = True

def forward(self, x):
return super().forward(
self.output_mult * x / self.width_mult())



def rescale_linear_bias(linear):
'''Rescale bias in nn.Linear layers to convert SP initialization to μP initialization.

Expand All @@ -79,6 +211,11 @@ def rescale_linear_bias(linear):
"To bypass this error and *still rescale biases*, set `linear._has_rescaled_params=False` before this call.")
if linear.bias is None:
return
fanin_mult = linear.weight.infshape[1].width_mult()
if isinstance(linear.weight, torch.nn.parameter.Parameter):
fanin_mult = linear.weight.infshape[1].width_mult()
elif hasattr(linear, "weight_v"):
fanin_mult = linear.weight_v.infshape[1].width_mult()
else:
raise ValueError("[jrm] Unexpected type of parameter")
linear.bias.data *= fanin_mult**0.5
linear._has_rescaled_params = True
29 changes: 25 additions & 4 deletions mup/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@

import yaml
from torch import nn
from torch.nn import Linear
from torch.nn import Linear, Conv1d, Conv2d
from torch.nn.modules.conv import _ConvNd

from mup.infshape import InfShape, zip_infshape
from mup.layer import MuReadout, rescale_linear_bias
from mup.layer import (
MuReadout,
MuConv1dOut,
MuConv2dOut,
rescale_linear_bias
)


__BSH_COMMENT__ = '''\
# This is a base shape file encoded in yaml
Expand Down Expand Up @@ -36,7 +42,7 @@ def save_base_shapes(model_or_shapes, file):
s = __BSH_COMMENT__ + s
with open(file, 'w') as f:
f.write(s)

def load_base_shapes(filename):
'''Get a dict of `InfShape` from a filename.'''
with open(filename, 'r') as f:
Expand Down Expand Up @@ -156,6 +162,7 @@ def apply_infshapes(model, infshapes):
for name, p in model.named_parameters():
p.infshape = infshapes[name]


def set_base_shapes(model, base, rescale_params=True, delta=None, savefile=None, do_assert=True):
'''Sets the `p.infshape` attribute for each parameter `p` of `model`.

Expand All @@ -169,7 +176,7 @@ def set_base_shapes(model, base, rescale_params=True, delta=None, savefile=None,
assuming the model is initialized using the default pytorch init (or
He initialization etc that scale the same way with fanin): If True
(default), rescales parameters to have the correct (μP) variances.
do_assert:
do_assert:
Output:
same object as `model`, after setting the `infshape` attribute of each parameter.
'''
Expand All @@ -194,6 +201,7 @@ def set_base_shapes(model, base, rescale_params=True, delta=None, savefile=None,
rescale_linear_bias(module)
return model


def assert_hidden_size_inf(model):
'''
This tests for any `nn.Linear` whose output dimension is finite but input
Expand All @@ -207,3 +215,16 @@ def assert_hidden_size_inf(model):
f'{name} has infinite fan-in and finite fan-out dimensions but is not type `MuReadout`. '
'To resolve this, either change the module to `MuReadout` or change the fan-out to an infinite dimension.'
)
elif isinstance(module, Conv1d) and not isinstance(module, MuConv1dOut):
if not module.weight.infshape[0].isinf() and module.weight.infshape[1].isinf():
assert False, (
f'{name} has infinite fan-in and finite fan-out dimensions but is not type `MuConv1dOut`. '
'To resolve this, either change the module to `MuConv1dOut` or change the fan-out to an infinite dimension.'
)
elif isinstance(module, Conv2d) and not isinstance(module, MuConv2dOut):
if not module.weight.infshape[0].isinf() and module.weight.infshape[1].isinf():
assert False, (
f'{name} has infinite fan-in and finite fan-out dimensions but is not type `MuConv1dOut`. '
'To resolve this, either change the module to `MuConv1dOut` or change the fan-out to an infinite dimension.'
)

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="mup",
version="1.0.0",
version="1.0.3",
author="Edward J Hu, Greg Yang",
author_email="[email protected], [email protected]",
description="Maximal Update Parametrization",
Expand Down