Skip to content

Commit

Permalink
Enables data type checking in the API
Browse files Browse the repository at this point in the history
- We already document supported data types via our dtype constraints decorator.
    This change also adds data type checks in the frontend to preempt confusing
    error messages emitted from lower in the stack.

- This also adds a switch in the `config` so that data type checking can be disabled
    if desired (e.g. this is useful in our own data type constraint negative testing).

- Fixes bugs in the `int4` implementation and corrects type constraints.

- Removes dtype checking from trace operations and associated helpers since it's now
    all done in the API.
  • Loading branch information
pranavm-nvidia committed Sep 11, 2024
1 parent 21a4eda commit 2d143c2
Show file tree
Hide file tree
Showing 48 changed files with 343 additions and 270 deletions.
20 changes: 9 additions & 11 deletions tripy/docs/conf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -28,7 +28,7 @@

import tripy as tp
from tripy.common.datatype import DATA_TYPES
from tripy.constraints import FUNC_W_DOC_VERIF, TYPE_VERIFICATION
from tripy.constraints import TYPE_VERIFICATION

PARAM_PAT = re.compile(":param .*?:")

Expand Down Expand Up @@ -145,6 +145,8 @@
def process_docstring(app, what, name, obj, options, lines):
doc = "\n".join(lines).strip()
blocks = helper.consolidate_code_blocks(doc)
unqual_name = name.split(".")[-1]

# Check signature for functions/methods and class constructors.
if what in {"function", "method"} or (what == "class" and name in seen_classes):
signature = inspect.signature(obj)
Expand All @@ -162,7 +164,8 @@ def process_docstring(app, what, name, obj, options, lines):
elif param.kind == inspect.Parameter.VAR_POSITIONAL:
pname = "*" + pname

if pname != "self" or obj.__qualname__ in FUNC_W_DOC_VERIF:
# Type annotations are optional for the `self` parameter unless the API has to be type-verified.
if pname != "self" or unqual_name in TYPE_VERIFICATION:
assert (
pname in documented_args
), f"Missing documentation for parameter: '{pname}' in: '{obj}'. Please ensure you've included this in the `Args:` section. Note: Documented parameters were: {documented_args} {doc}"
Expand All @@ -178,10 +181,6 @@ def process_docstring(app, what, name, obj, options, lines):
assert not inspect.ismodule(
param.annotation
), f"Type annotation cannot be a module, but got: '{param.annotation}' for parameter: '{pname}' in: '{obj}'. Please specify a type!"
else:
assert (
param.annotation == signature.empty
), f"Avoid using type annotations for the `self` parameter since this will corrupt the rendered documentation! Note: Documented parameters were: {documented_args} {doc}"

assert signature.return_annotation != signature.empty, (
f"Missing return type annotation for: '{obj}'. "
Expand All @@ -193,10 +192,7 @@ def process_docstring(app, what, name, obj, options, lines):
":returns:" in doc
), f"For: {obj}, return value is not documented. Please ensure you've included a `Returns:` section"

# New docstring logic:
# First figure out if object is using the @constraints.dtype_info decorator.
unqual_name = name.split(".")[-1]
if unqual_name in TYPE_VERIFICATION.keys():
if unqual_name in TYPE_VERIFICATION:
add_text_index = -1
for index, block in enumerate(blocks):
if re.search(r".. code-block::", block):
Expand Down Expand Up @@ -234,6 +230,7 @@ def process_docstring(app, what, name, obj, options, lines):
index += 1
blocks.insert(index, dtype_exception_text)
break

if re.search(r":param \w+: ", block):
param_name = re.match(r":param (\w+): ", block).group(1)
# Add dtype constraint to start of each parameter description.
Expand All @@ -242,6 +239,7 @@ def process_docstring(app, what, name, obj, options, lines):
blocks[index] = (
f"{block[0:add_text_index]}[dtype=\ **{TYPE_VERIFICATION[unqual_name].dtype_constraints[param_name]}**\ ] {block[add_text_index:]}"
)

if TYPE_VERIFICATION[unqual_name].return_dtype is not None and re.search(r":returns:", block):
add_text_index = re.search(r":returns:", block).span()[1] + 1
# Add dtype constraint to start of returns description.
Expand Down
2 changes: 1 addition & 1 deletion tripy/tests/flat_ir/test_flat_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def test_tensor_connectivity(self):
# When we build up a FlatIR with multiple layers, the tensors/ops
# should be connected to each other - i.e. the producer/inputs fields
# should let you walk through the entire FlatIR.
inp = tp.Tensor([0])
inp = tp.Tensor([0], dtype=tp.float32)

b = tp.tanh(inp)
out = tp.tanh(b)
Expand Down
2 changes: 1 addition & 1 deletion tripy/tests/frontend/module/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_incorrect_input_dtype(self):

with helper.raises(
tp.TripyException,
match="Index tensor for gather operation should be of int32 type.",
match="Unsupported data type for 'gather'.",
has_stack_info_for=[a],
):
out = linear(a)
8 changes: 0 additions & 8 deletions tripy/tests/frontend/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,14 +459,6 @@ def test_binary_elementwise_broadcast_rejected(self, values):
):
tp.Shape(values).multiply(tp.Tensor([values, values]))

def test_unary_elementwise_fails_at_run_time(self, values):
v = tp.exp(tp.Shape(values))
with raises(
tp.TripyException,
match=("'stablehlo.exponential' op operand #0 must be ranked tensor of"),
):
v.eval()

def test_shape_equality(self, other_values):
a = tp.Shape([4, 5])
if isinstance(other_values, np.ndarray):
Expand Down
2 changes: 1 addition & 1 deletion tripy/tests/frontend/trace/ops/test_binary_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def test_mismatched_dtypes_fails(self):
with helper.raises(
tp.TripyException,
# Keep the entire error message here so we'll know if the display becomes horribly corrupted.
match=r"For operation: '\+', data types for all inputs must match, but got: \[float32, float16\].",
match=r"Mismatched data types for '__add__'.",
has_stack_info_for=[a, b],
):
c = a + b
Expand Down
2 changes: 1 addition & 1 deletion tripy/tests/frontend/trace/ops/test_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_mismatched_dtypes_fails(self, conv_func):

with helper.raises(
tp.TripyException,
match=r"For operation: 'convolution', data types for all inputs must match, but got: \[float32, float16\].",
match=r"Mismatched data types for 'convolution'.",
has_stack_info_for=[input],
):
output = conv_layer(input)
Expand Down
4 changes: 2 additions & 2 deletions tripy/tests/frontend/trace/ops/test_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ def test_invalid_input_dtype(self):
a = tp.Tensor([1.0, 2.0])
with helper.raises(
tp.TripyException,
match="Input does not have a valid dtype in dequantize op",
match="Unsupported data type for 'dequantize'.",
):
a = tp.dequantize(a, 0.9, tp.float32)

def test_invalid_dequant_dtype(self):
a = tp.Tensor([2, 4], dtype=tp.int8)
with helper.raises(
tp.TripyException,
match="Unsupported dtype in dequantize op.",
match="Unsupported data type for 'dequantize'.",
):
a = tp.dequantize(a, 0.9, tp.int32)

Expand Down
9 changes: 4 additions & 5 deletions tripy/tests/frontend/trace/ops/test_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,15 @@ def test_op_func_all_partial(self, axis):
assert isinstance(out, tp.Tensor)
assert isinstance(out.trace_tensor.producer, Gather)

@pytest.mark.parametrize("axis", [0, 1, 2])
def test_incorrect_dtype(self, axis):
def test_incorrect_dtype(self):
a = tp.Tensor([[[1, 2, 3, 4], [1, 2, 3, 4]], [[1, 2, 3, 4], [1, 2, 3, 4]]])
index = tp.Tensor(np.zeros(1, dtype=np.float32))
with helper.raises(
tp.TripyException,
match="Index tensor for gather operation should be of int32 type.",
has_stack_info_for=[a, index],
match="Unsupported data type for 'gather'.",
has_stack_info_for=[index],
):
b = tp.gather(a, axis, index)
b = tp.gather(a, 0, index)

@pytest.mark.parametrize("index_shape", [(1,), (2, 2)])
@pytest.mark.parametrize("axis", [0, 1, 2])
Expand Down
4 changes: 3 additions & 1 deletion tripy/tests/frontend/trace/ops/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def test_mismatched_dtypes_fails(self):
a = tp.ones((2, 3), dtype=tp.float32)
b = tp.ones((3, 2), dtype=tp.float16)

with helper.raises(tp.TripyException, match="Incompatible input data types.", has_stack_info_for=[a, b]):
with helper.raises(
tp.TripyException, match="Mismatched data types for '__matmul__'.", has_stack_info_for=[a, b]
):
c = a @ b

def test_incompatible_1d_shapes_fails(self):
Expand Down
4 changes: 2 additions & 2 deletions tripy/tests/frontend/trace/ops/test_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ def test_invalid_input_dtype(self):
a = tp.Tensor([1, 2], dtype=tp.int32)
with helper.raises(
tp.TripyException,
match="Input does not have a valid dtype in quantize op.",
match="Unsupported data type for 'quantize'.",
):
a = tp.quantize(a, 0.9, tp.int8)

def test_unsupported_quant_dtype(self):
a = tp.Tensor([1.0, 2.0])
with helper.raises(
tp.TripyException,
match="Unsupported dtype in quantize op.",
match="Unsupported data type for 'quantize'.",
):
a = tp.quantize(a, 0.9, tp.float16)

Expand Down
4 changes: 2 additions & 2 deletions tripy/tests/frontend/trace/ops/test_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ def test_max(self):
assert isinstance(a.trace_tensor.producer, Reduce)

def test_all(self):
a = tp.ones((2, 3))
a = tp.ones((2, 3), dtype=tp.bool)
a = tp.all(a)
assert isinstance(a, tp.Tensor)
assert isinstance(a.trace_tensor.producer, Reduce)

def test_any(self):
a = tp.ones((2, 3))
a = tp.ones((2, 3), dtype=tp.bool)
a = tp.any(a)
assert isinstance(a, tp.Tensor)
assert isinstance(a.trace_tensor.producer, Reduce)
Expand Down
8 changes: 3 additions & 5 deletions tripy/tests/frontend/trace/ops/test_where.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,15 @@ def test_mismatched_input_dtypes(self):
a = tp.ones((2,), dtype=tp.float32)
b = tp.ones((2,), dtype=tp.float16)

with helper.raises(tp.TripyException, match="Incompatible input data types.", has_stack_info_for=[a, b, cond]):
with helper.raises(tp.TripyException, match="Mismatched data types for 'where'.", has_stack_info_for=[a, b]):
c = tp.where(cond, a, b)

def test_condition_is_not_bool(self):
cond = tp.ones((2,), dtype=tp.float32)
a = tp.ones((2,), dtype=tp.float32)
b = tp.ones((2,), dtype=tp.float32)

with helper.raises(
tp.TripyException, match="Condition input must have boolean type.", has_stack_info_for=[a, b, cond]
):
with helper.raises(tp.TripyException, match="Unsupported data type for 'where'.", has_stack_info_for=[cond]):
c = tp.where(cond, a, b)

def test_infer_rank(self):
Expand All @@ -91,7 +89,7 @@ def test_condition_is_not_bool(self):
mask = tp.Tensor([1.0, 2.0, 3.0, 4.0])

with helper.raises(
tp.TripyException, match="Condition input must have boolean type.", has_stack_info_for=[a, mask]
tp.TripyException, match="Unsupported data type for 'masked_fill'.", has_stack_info_for=[mask]
):
b = tp.masked_fill(a, mask, -1)

Expand Down
13 changes: 13 additions & 0 deletions tripy/tests/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,19 @@ def raises(ExcType: type, match: Optional[str] = None, has_stack_info_for: Seque
assert expected_stack_info in error_msg, f"Missing stack information for tensor:\n{expected_stack_info}"


@contextlib.contextmanager
def config(name: str, value: Any):
"""
Temporarily changes a configuration option.
"""
old_value = getattr(tp.config, name)
try:
setattr(tp.config, name, value)
yield
finally:
setattr(tp.config, name, old_value)


def check_mlir(mlir, expected):
# Checks a given MLIR module against a string of the expected program.
# MLIR indents with 2 spaces; we'll replace it with 4 spaces so that it's
Expand Down
28 changes: 22 additions & 6 deletions tripy/tests/spec_verification/object_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@

def tensor_builder(init, dtype, namespace):
if init is None:
return tp.ones(dtype=namespace[dtype], shape=(3, 2))
out = tp.ones(dtype=namespace[dtype], shape=(3, 2))
out.eval()
return out
elif not isinstance(init, tp.Tensor):
assert dtype == None
return init
return tp.cast(init, dtype=namespace[dtype])
out = tp.cast(init, dtype=namespace[dtype])
out.eval()
return out


def dtype_builder(init, dtype, namespace):
Expand All @@ -37,9 +40,12 @@ def dtype_builder(init, dtype, namespace):

def tensor_list_builder(init, dtype, namespace):
if init is None:
return [tp.ones(shape=(3, 2), dtype=namespace[dtype]) for _ in range(2)]
out = [tp.ones(shape=(3, 2), dtype=namespace[dtype]) for _ in range(2)]
else:
return [tp.cast(tens, dtype=namespace[dtype]) for tens in init]
out = [tp.cast(tens, dtype=namespace[dtype]) for tens in init]
for t in out:
t.eval()
return out


def device_builder(init, dtype, namespace):
Expand Down Expand Up @@ -87,11 +93,12 @@ def default_builder(init, dtype, namespace):
"full_like": {"value": 1},
"flip": {"dim": 1},
"gather": {"dim": 0, "index": tp.Tensor([1])},
"iota": {"shape": tp.Tensor([3])},
"iota": {"shape": tp.Tensor([4])},
"__matmul__": {"self": tp.ones((2, 3))},
"transpose": {"dim0": 0, "dim1": 1},
"permute": {"perm": [1, 0]},
"quantize": {"scale": tp.Tensor([1, 1, 1]), "dim": 0},
"dequantize": {"scale": tp.Tensor([1, 1, 1]), "dim": 0},
"sum": {"dim": 0},
"all": {"dim": 0},
"any": {"dim": 0},
Expand All @@ -111,6 +118,15 @@ def default_builder(init, dtype, namespace):
"zeros": {"shape": tp.Tensor([3, 2])},
"arange": {"start": 0, "stop": 5},
"repeat": {"repeats": 2, "dim": 0},
"convolution": {
"input": tp.ones((1, 3, 5, 5)),
"weight": tp.ones((1, 3, 3, 3)),
"padding": ((0, 0), (0, 0)),
"stride": [1, 1],
"groups": 1,
"lhs_dilation": [1, 1],
"rhs_dilation": [1, 1],
},
}


Expand Down
Loading

0 comments on commit 2d143c2

Please sign in to comment.