Skip to content

Commit

Permalink
Add tp.pad
Browse files Browse the repository at this point in the history
  • Loading branch information
yizhuoz004 committed Sep 11, 2024
1 parent 85a7197 commit 973dd71
Show file tree
Hide file tree
Showing 6 changed files with 252 additions and 0 deletions.
31 changes: 31 additions & 0 deletions tripy/tests/frontend/trace/ops/test_pad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import tripy as tp
from tripy.frontend.trace.ops import Pad


class TestPad:
def test_op_func(self):
a = tp.Tensor([1, 2, 3, 4])
a = tp.pad(a, (1, 1))
assert isinstance(a, tp.Tensor)
assert isinstance(a.trace_tensor.producer, Pad)

def test_infer_rank(self):
a = tp.Tensor([1, 2, 3, 4])
a = tp.pad(a, (1, 1))
assert a.trace_tensor.rank == 1
45 changes: 45 additions & 0 deletions tripy/tests/integration/test_pad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
import numpy as np

import tripy as tp


class TestPad:
@pytest.mark.parametrize(
"padding_sizes, padding_value",
[
(((0, 1), (2, 0)), 0),
(((1, 2), (2, 3)), 1),
],
)
def test_pad_constant(self, padding_sizes, padding_value):
inp = np.arange(4, dtype=np.int32).reshape((2, 2))

out = tp.pad(tp.Tensor(inp), padding_sizes, padding_value=padding_value)
expected = np.pad(inp, padding_sizes, constant_values=padding_value)

assert np.array_equal(np.from_dlpack(tp.copy(out, device=tp.device("cpu"))), expected)

def test_pad_tensor(self):
inp = np.arange(6, dtype=np.float32).reshape((2, 3))

inp_tp = tp.Tensor(inp)
out = tp.pad(tp.Tensor(inp), ((0, inp_tp.shape[0]), (inp_tp.shape[1], 0)))
expected = np.pad(inp, ((0, 2), (3, 0)))

assert np.array_equal(np.from_dlpack(tp.copy(out, device=tp.device("cpu"))), expected)
1 change: 1 addition & 0 deletions tripy/tripy/flat_ir/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from tripy.flat_ir.ops.maximum import MaxOp
from tripy.flat_ir.ops.minimum import MinOp
from tripy.flat_ir.ops.mul import MulOp
from tripy.flat_ir.ops.pad import DynamicPadOp
from tripy.flat_ir.ops.plugin import PluginOp
from tripy.flat_ir.ops.pow import PowOp
from tripy.flat_ir.ops.reduce import ArgMinMaxOp, ReduceOp
Expand Down
29 changes: 29 additions & 0 deletions tripy/tripy/flat_ir/ops/pad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from dataclasses import dataclass

from mlir_tensorrt.compiler.dialects import stablehlo

from tripy.flat_ir.ops.base import BaseFlatIROp


@dataclass(repr=False)
class DynamicPadOp(BaseFlatIROp):
def to_mlir(self, operands):
output = stablehlo.dynamic_pad(self.outputs[0].to_mlir(), *operands)
return [output]
1 change: 1 addition & 0 deletions tripy/tripy/frontend/trace/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from tripy.frontend.trace.ops.iota import Iota
from tripy.frontend.trace.ops.matmul import MatrixMultiplication
from tripy.frontend.trace.ops.permute import Permute, Transpose
from tripy.frontend.trace.ops.pad import Pad
from tripy.frontend.trace.ops.plugin import Plugin
from tripy.frontend.trace.ops.quantize import Quantize
from tripy.frontend.trace.ops.reduce import ArgMinMax, Reduce
Expand Down
145 changes: 145 additions & 0 deletions tripy/tripy/frontend/trace/ops/pad.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
#
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from dataclasses import dataclass
from typing import Sequence, Union, Tuple
from tripy import export, constraints
from tripy.frontend.trace.ops import utils as op_utils
from tripy.frontend.trace.ops.base import BaseTraceOp
from tripy.common.exception import raise_error


@dataclass(repr=False)
class Pad(BaseTraceOp):

padding_value: Union[int, float]

infer_shape_output_idxs = op_utils.ShapeOutputIdxPolicies.never_return_shape

def infer_dtypes(self):
self.outputs[0].dtype = self.inputs[0].dtype

def infer_rank(self):
self.outputs[0].rank = self.inputs[0].rank

def to_flat_ir(self, inputs, outputs):
from tripy.common.datatype import int32
from tripy.flat_ir.ops import ConstantOp, DynamicPadOp
from tripy.flat_ir.tensor import FlatIRTensor

pad_val_tensor = FlatIRTensor.build(
shape=(),
rank=0,
dtype=outputs[0].dtype,
device=outputs[0].device,
reason_details=[f"create the constant value tensor (containing {self.padding_value}) for a pad operation"],
)
ConstantOp.build([], [pad_val_tensor], data=self.padding_value)

# interior_padding is not supported
# create the default value
pad_size_shape = (inputs[0].rank,)
interior_pad_tensor = FlatIRTensor.build(
shape=pad_size_shape,
rank=1,
dtype=int32,
device=outputs[0].device,
reason_details=[f"create the default value for interior_padding argument."],
)
ConstantOp.build([], [interior_pad_tensor], data=[0] * inputs[0].rank)

# [operand, pad_val, low, high, interior]
inputs.insert(1, pad_val_tensor)
inputs.append(interior_pad_tensor)
# set padding size tensors' shape
# because stablehlo requires static shapes
inputs[2].shape = pad_size_shape
inputs[3].shape = pad_size_shape
DynamicPadOp.build(inputs, outputs)


def _convert_pad_sizes(padding_sizes):
from tripy.common.datatype import int32
from tripy.frontend.tensor import Tensor
from tripy.frontend.trace.ops.concatenate import concatenate
from tripy.frontend.trace.ops.unsqueeze import unsqueeze

if not any(isinstance(e, Tensor) for e in padding_sizes):
return Tensor(padding_sizes, dtype=int32)

sizes_1d = []
for size in padding_sizes:
if isinstance(size, Tensor):
assert size.rank == 0
sizes_1d.append(unsqueeze(size, 0))
else:
sizes_1d.append(Tensor([size], dtype=int32))
return concatenate(sizes_1d, 0)


@export.public_api(document_under="operations/functions")
@constraints.dtype_info(
dtype_variables={
"T1": ["float32", "float16", "bfloat16", "float8", "int4", "int8", "int32", "int64", "bool"],
},
dtype_constraints={"input": "T1", constraints.RETURN_VALUE: "T1"},
)
def pad(input: "tripy.Tensor", padding_sizes: Sequence[Tuple], padding_value: Union[int, float] = 0) -> "tripy.Tensor":
r"""
Pads `input` with `padding_value` of given `padding_sizes`.
Args:
input: The input tensor.
padding_sizes: A sequence of padding sizes of each dimension. Its length must equal to the rank
of `input`. Each element of `padding_size` is a tuple of integers or scalars `(low, high)`,
which represents the padding size before the lowest index and after the highest index at
the corresponding dimension.
padding_value: The padding value.
Returns:
The padded tensor.
.. code-block:: python
:linenos:
:caption: Constant padding.
input = tp.reshape(tp.arange(6, dtype=tp.float32), (2, 3))
output = tp.pad(input, ((1, 0), (0, 1)))
input_np = np.arange(6, dtype=np.float32).reshape((2, 3)) # doc: omit
expected = np.pad(input_np, ((1, 0), (0, 1))) # doc: omit
assert np.array_equal(cp.from_dlpack(output).get(), expected)
"""
from tripy.frontend.tensor import Tensor

if len(padding_sizes) != input.rank:
raise_error(
"`padding_sizes` length must equal to the rank of `input`.",
[f"Got padding_sizes={padding_sizes}, ", f" input's rank={input.rank}"],
)

padding_low, padding_high = list(zip(*padding_sizes))
return Pad.build(
[
input,
_convert_pad_sizes(padding_low),
_convert_pad_sizes(padding_high),
],
padding_value,
)

0 comments on commit 973dd71

Please sign in to comment.