Skip to content

Commit

Permalink
Make tp.Layernorm 1:1 with torch (#119)
Browse files Browse the repository at this point in the history
- Make the `normalized_shape` argument 1:1 with Torch. Previously, our
API only supported taking a single integer for the normalization
dimension. Now, a list of integers can be provided.
- Add integration tests for `tp.LayerNorm` and `tp.GroupNorm`
  • Loading branch information
markkraay committed Aug 23, 2024
1 parent f15fb83 commit 0a6a4dc
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 8 deletions.
61 changes: 61 additions & 0 deletions tripy/tests/integration/test_groupnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import numpy as np
import torch
import pytest

import tripy as tp
from tripy.common.exception import TripyException

DTYPES = [
(torch.float16, tp.float16),
(torch.float32, tp.float32)
]

class TestGroupNorm:
@pytest.mark.parametrize("torch_dtype, tp_dtype", DTYPES)
@pytest.mark.parametrize("input_shape", [(1, 10, 2)])
@pytest.mark.parametrize("num_groups", [2, 5])
@pytest.mark.parametrize("num_channels", [10])
@pytest.mark.parametrize("eps", [1e-5, 1e-3])
def test_groupnorm_accuracy(self, torch_dtype, tp_dtype, input_shape, num_groups, num_channels, eps):
groupnorm = torch.nn.GroupNorm(
num_groups=num_groups,
num_channels=num_channels,
eps=eps,
dtype=torch_dtype,
)
tp_groupnorm = tp.GroupNorm(
num_groups=num_groups,
num_channels=num_channels,
eps=eps,
dtype=tp_dtype,
)

tp_groupnorm.weight = tp.Parameter(groupnorm.weight)
tp_groupnorm.bias = tp.Parameter(groupnorm.bias)

input = torch.arange(torch.prod(torch.Tensor(input_shape))).reshape(input_shape).to(torch_dtype)
tp_input = tp.Tensor(input, dtype=tp_dtype)

output = tp_groupnorm(tp_input)
expected = tp.Tensor(groupnorm(input), device=tp.device("cpu"))

rtol_ = 2e-6 if tp_dtype == tp.float32 else 1e-3
assert output.shape == expected.shape
assert tp.allclose(output, expected, rtol=rtol_)
68 changes: 68 additions & 0 deletions tripy/tests/integration/test_layernorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import numpy as np
import re
import torch
import pytest

import tripy as tp
from tripy.common.exception import TripyException

DTYPES = [
(torch.float16, tp.float16),
(torch.float32, tp.float32)
]

class TestLayerNorm:
@pytest.mark.parametrize("torch_dtype, tp_dtype", DTYPES)
@pytest.mark.parametrize("input_shape", [(2, 2, 2)])
@pytest.mark.parametrize("normalized_shape", [(2, 2), (2,)])
@pytest.mark.parametrize("eps", [1e-5, 1e-3])
def test_layernorm_accuracy(self, torch_dtype, tp_dtype, input_shape, normalized_shape, eps):
layernorm = torch.nn.LayerNorm(
normalized_shape=normalized_shape,
eps=eps,
dtype=torch_dtype,
)
tp_layernorm = tp.LayerNorm(
normalized_shape=normalized_shape,
eps=eps,
dtype=tp_dtype,
)

# use Tripy's parameters
tp_layernorm.weight = tp.Parameter(layernorm.weight)
tp_layernorm.bias = tp.Parameter(layernorm.bias)

input = torch.arange(torch.prod(torch.Tensor(input_shape))).reshape(input_shape).to(torch_dtype)
tp_input = tp.Tensor(input, dtype=tp_dtype)

output = tp_layernorm(tp_input)
expected = tp.Tensor(layernorm(input), device=tp.device("cpu"))

rtol_ = 2e-7 if tp_dtype == tp.float32 else 1e-3
assert output.shape == expected.shape
assert tp.allclose(output, expected, rtol=rtol_)

def test_layernorm_improper_dimensions(self):
tp_layernorm = tp.LayerNorm(
normalized_shape=[2, 2],
)
x = tp.ones((5,5,5))
with pytest.raises(TripyException, match=re.escape("The input's last 2 dimensions must have a shape of [2, 2] and received [5, 5]")):
tp_layernorm(x)
4 changes: 2 additions & 2 deletions tripy/tripy/frontend/module/groupnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class GroupNorm(Module):
eps: float
"""A value added to the denominator to prevent division by zero. Defaults to 1e-5."""

def __init__(self, num_groups: int, num_channels: int, dtype: datatype.dtype = datatype.float32) -> None:
def __init__(self, num_groups: int, num_channels: int, dtype: datatype.dtype = datatype.float32, eps: float = 1e-5) -> None:
"""
Args:
num_groups: The number of groups to split the channels into.
Expand Down Expand Up @@ -97,7 +97,7 @@ def __init__(self, num_groups: int, num_channels: int, dtype: datatype.dtype = d
# Replace with random weights when #74 is completed.
self.weight = DefaultParameter((num_channels,), dtype=dtype)
self.bias = DefaultParameter((num_channels,), dtype=dtype)
self.eps = 1e-5
self.eps = eps

def __call__(self, x: "tripy.Tensor") -> "tripy.Tensor":
r"""
Expand Down
35 changes: 29 additions & 6 deletions tripy/tripy/frontend/module/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#

from dataclasses import dataclass
from typing import Union, Tuple

from tripy import export, utils
from tripy.common import datatype
Expand All @@ -25,19 +26,25 @@

@export.public_api(document_under="operations/modules")
@dataclass
@utils.constant_fields(["dtype"])
@utils.constant_fields(["dtype", "normalized_shape"])
class LayerNorm(Module):
r"""
Applies layer normalization over the input tensor:
:math:`\text{LayerNorm}(x) = \Large \frac{x - \bar{x}}{ \sqrt{\sigma^2 + \epsilon}} \normalsize * \gamma + \beta`
where :math:`\bar{x}` is the mean and :math:`\sigma^2` is the variance.
The mean and standard deviation are calculated over the last :math:`D`
dimensions, where :math:`D` is the dimension of `normalized_shape`.
"""

dtype: datatype.dtype
r"""The data type used to perform the operation."""

normalized_shape: Tuple[int]
r"""Defines the shape of the input tensor that is to be normalized over."""

weight: Parameter
r"""The :math:`\gamma` parameter of shape :math:`[\text{normalized_shape}]`."""

Expand All @@ -47,7 +54,7 @@ class LayerNorm(Module):
eps: float
"""A value added to the denominator to prevent division by zero."""

def __init__(self, normalized_shape: int, dtype: datatype.dtype = datatype.float32, eps: float = 1e-5) -> None:
def __init__(self, normalized_shape: Union[int, Tuple[int]], dtype: datatype.dtype = datatype.float32, eps: float = 1e-5) -> None:
"""
Args:
normalized_shape: The size of the feature dimension of the input over which normalization is performed.
Expand Down Expand Up @@ -77,9 +84,14 @@ def __init__(self, normalized_shape: int, dtype: datatype.dtype = datatype.float
self.dtype = dtype

# Replace with random weights when #74 is completed.
self.weight = DefaultParameter((normalized_shape,), dtype=dtype)
if isinstance(normalized_shape, int):
normalized_shape = (normalized_shape,)

self.normalized_shape = normalized_shape

self.bias = DefaultParameter((normalized_shape,), dtype=dtype)
self.weight = DefaultParameter(normalized_shape, dtype=dtype)

self.bias = DefaultParameter(normalized_shape, dtype=dtype)

self.eps = eps

Expand All @@ -92,9 +104,20 @@ def __call__(self, x: "tripy.Tensor") -> "tripy.Tensor":
A tensor of the same shape as the input.
"""
from tripy.frontend.trace.ops.reduce import mean, var
from tripy.frontend.shape import Shape
from tripy.frontend.trace.ops.unary_elementwise import rsqrt
from tripy.common.exception import raise_error

# The mean and the variance are computed over the last D dimensions
D = len(self.normalized_shape)

if x.shape[-D:] != self.normalized_shape:
raise_error("Unexpected input shape",
[f"The input's last {D} dimensions must have a shape of {self.normalized_shape} and received {x.shape[-D:].data()}"]
)

mean_val = mean(x, dim=-1, keepdim=True)
var_val = var(x, dim=-1, keepdim=True, correction=0) + self.eps
reduce_dims = tuple(-i for i in range(D, 0, -1))
mean_val = mean(x, dim=reduce_dims, keepdim=True)
var_val = var(x, dim=reduce_dims, keepdim=True, correction=0) + self.eps
x = (x - mean_val) * rsqrt(var_val)
return self.weight * x + self.bias

0 comments on commit 0a6a4dc

Please sign in to comment.