Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
Fix SIR cache when call func many times (#21)
Browse files Browse the repository at this point in the history
* Fix SIR cache when call func many times

* Rename ut

* Polish code

* Remove bind

* Rename var and fun
  • Loading branch information
0x45f committed Apr 18, 2023
1 parent 038d464 commit 315b190
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 16 deletions.
44 changes: 33 additions & 11 deletions symbolic_trace/trace_cache_entrance.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,42 @@


def trace_cache(func):
@no_eval_frame
def call_with_cache(*args, **kwargs):
args, kwargs = convert_arguments(args), convert_arguments(kwargs)
args, kwargs = convert_arguments(args), convert_arguments(kwargs)
args, kwargs, outter_names = construct_inner_proxy_tensor(func.__name__, *args, **kwargs)

if frame_enter(func.__name__, args):
return cache_and_return(func.__name__, args)
return cache_and_return(func.__name__, outter_names)
ret = func(*args)
frame_leave(func.__name__, ret)
frame_leave(func.__name__, outter_names, ret)
return ret
return call_with_cache


def construct_inner_proxy_tensor(func_name, *args, **kwargs):
flat_args = paddle.utils.flatten(args)
flat_kwargs = paddle.utils.flatten(kwargs)
outter_names = []
name_i = 0
for i, v in enumerate(flat_args):
if isinstance(v, ProxyTensor):
name = '{}_input_{}'.format(func_name, name_i)
outter_names.append(v.name)
flat_args[i] = ProxyTensor(name, v.meta)
name_i = name_i + 1
for i, v in enumerate(flat_kwargs):
if isinstance(v, ProxyTensor):
name = '{}_input_{}'.format(func_name, name_i)
outter_names.append(v.name)
flat_kwargs[i] = ProxyTensor(name, v.meta)
name_i = name_i + 1

args = paddle.utils.pack_sequence_as(args, flat_args)
kwargs = paddle.utils.pack_sequence_as(kwargs, flat_kwargs)

return args, kwargs, outter_names

@no_eval_frame
# should generate a unique name for every function
def frame_enter(name, inputs):
Expand Down Expand Up @@ -52,7 +78,7 @@ def frame_enter(name, inputs):


@no_eval_frame
def frame_leave(name, outputs):
def frame_leave(name, outter_names, outputs):
key_name = SymbolicTraceContext().sir_key_stack[-1]
SymbolicTraceContext().sir_key_stack.pop()

Expand Down Expand Up @@ -88,24 +114,20 @@ def frame_leave(name, outputs):
return

# at the first time, the inputs and outputs need not change
SymbolicTraceContext().call_SIR(cur_sir.name, cur_sir.inputs, cur_sir.outputs)
SymbolicTraceContext().call_SIR(cur_sir.name, [Symbol(name) for name in outter_names], cur_sir.outputs)
log(1, cur_sir, "\n")
return


@no_eval_frame
def cache_and_return(name, inputs):
def cache_and_return(name, outter_names):
key_name = SymbolicTraceContext().sir_key_stack[-1]
SymbolicTraceContext().sir_key_stack.pop()

# find sir and it's origin_outputs
cached_sir = SymbolicTraceContext().statement_factory[key_name]
origin_outputs = SIRRuntimeCache().get_origin_outputs(key_name)

# gen call_SIR inputs
flat_inputs = paddle.utils.flatten(inputs)
symbol_inputs = [Symbol(x.name) for x in flat_inputs if isinstance(x, ProxyTensor)]

# create return value
outputs = gen_new_proxy_tensor_output(origin_outputs)

Expand All @@ -114,7 +136,7 @@ def cache_and_return(name, inputs):
symbol_outputs = [Symbol(x.name) for x in flat_outputs if isinstance(x, ProxyTensor)]

# add call_SIR
SymbolicTraceContext().call_SIR(cached_sir.name, symbol_inputs, symbol_outputs)
SymbolicTraceContext().call_SIR(cached_sir.name, [Symbol(name) for name in outter_names], symbol_outputs)
return outputs


Expand Down
8 changes: 3 additions & 5 deletions tests/error_test_sir_call.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import unittest
import paddle
from symbolic_trace import symbolic_trace
from symbolic_trace.trace_cache_entrance import frame_enter, frame_leave, cache_and_return
from symbolic_trace.trace_cache_entrance import trace_cache


@trace_cache
def sum(x, y):
if frame_enter("sum", (x, y)):
return cache_and_return("sum", (x, y))
ret = x + y
frame_leave("sum", (ret))
return ret

def main(x, y):
Expand All @@ -21,7 +19,7 @@ def test_return_callable(self):
x = paddle.to_tensor([1.0])
y = paddle.to_tensor([2.0])
ret = symbolic_trace(main)(x, y)
assert (ret.item() == 3.0), "Should be 4.0"
assert (ret.item() == 3.0), "Should be 3.0"

if __name__ == "__main__":
unittest.main()
File renamed without changes.

0 comments on commit 315b190

Please sign in to comment.