From d5c478c6c738967f5d97b7780bd18720802c2fa5 Mon Sep 17 00:00:00 2001 From: Raghavendra Date: Tue, 29 Aug 2017 22:30:26 +0530 Subject: [PATCH 1/3] Fix for Support tuple args in lambda #17 --- compiler/block.py | 40 +++++++++++++++++++++++++++++++++++----- compiler/block_test.py | 10 ++++++++++ compiler/stmt.py | 15 ++++++++++++++- compiler/stmt_test.py | 24 ++++++++++++++++++++++++ 4 files changed, 83 insertions(+), 6 deletions(-) diff --git a/compiler/block.py b/compiler/block.py index 423ded2d..7a5461ab 100644 --- a/compiler/block.py +++ b/compiler/block.py @@ -249,6 +249,8 @@ def resolve_name(self, writer, name): if var: if var.type == Var.TYPE_GLOBAL: return self._resolve_global(writer, name) + if var.type == Var.TYPE_TUPLE_PARAM: + return expr.GeneratedLocalVar(name) writer.write_checked_call1('πg.CheckLocal(πF, {}, {})', util.adjust_local_name(name), util.go_str(name)) @@ -263,6 +265,7 @@ class Var(object): TYPE_LOCAL = 0 TYPE_PARAM = 1 TYPE_GLOBAL = 2 + TYPE_TUPLE_PARAM = 3 def __init__(self, name, var_type, arg_index=None): self.name = name @@ -273,6 +276,9 @@ def __init__(self, name, var_type, arg_index=None): elif var_type == Var.TYPE_PARAM: assert arg_index is not None self.init_expr = 'πArgs[{}]'.format(arg_index) + elif var_type == Var.TYPE_TUPLE_PARAM: + assert arg_index is None + self.init_expr = 'nil' else: assert arg_index is None self.init_expr = None @@ -364,16 +370,40 @@ def __init__(self, node): BlockVisitor.__init__(self) self.is_generator = False node_args = node.args - args = [a.arg for a in node_args.args] + args = [] + for arg in node_args.args: + if isinstance(arg, ast.Tuple): + args.append(arg.elts) + else: + args.append(arg.arg) + # args = [a.arg for a in node_args.args] + if node_args.vararg: args.append(node_args.vararg.arg) if node_args.kwarg: args.append(node_args.kwarg.arg) for i, name in enumerate(args): - if name in self.vars: - msg = "duplicate argument '{}' in function definition".format(name) - raise util.ParseError(node, msg) - self.vars[name] = Var(name, Var.TYPE_PARAM, arg_index=i) + if isinstance(name, list): + arg_name = 'τ{}'.format(id(name)) + for el in name: + self._parse_tuple(el, node) + self.vars[arg_name] = Var(arg_name, Var.TYPE_PARAM, i) + else: + self._check_duplicate_args(name, node) + self.vars[name] = Var(name, Var.TYPE_PARAM, arg_index=i) + + def _parse_tuple(self, el, node): + if isinstance(el, ast.Tuple): + for x in el.elts: + self._parse_tuple(x, node) + else: + self._check_duplicate_args(el.arg, node) + self.vars[el.arg] = Var(el.arg, Var.TYPE_TUPLE_PARAM) + + def _check_duplicate_args(self, name, node): + if name in self.vars: + msg = "duplicate argument '{}' in function definition".format(name) + raise util.ParseError(node, msg) def visit_Yield(self, unused_node): # pylint: disable=unused-argument self.is_generator = True diff --git a/compiler/block_test.py b/compiler/block_test.py index 63376997..8c455f73 100644 --- a/compiler/block_test.py +++ b/compiler/block_test.py @@ -26,6 +26,9 @@ from grumpy.compiler import util from grumpy import pythonparser +from compiler.block import Var + + class PackageTest(unittest.TestCase): def testCreate(self): @@ -224,6 +227,13 @@ def testYieldExpr(self): self.assertEqual(sorted(visitor.vars.keys()), ['foo']) self.assertRegexpMatches(visitor.vars['foo'].init_expr, r'UnboundLocal') + def testTupleArgs(self): + func = _ParseStmt('def foo((bar, baz)): pass') + visitor = block.FunctionBlockVisitor(func) + self.assertEqual(len(visitor.vars), 3) + self.assertTrue(len([v for v in visitor.vars if visitor.vars[v].type == Var.TYPE_TUPLE_PARAM]), 1) + self.assertIn('bar', visitor.vars) + self.assertIn('baz', visitor.vars) def _MakeModuleBlock(): importer = imputil.Importer(None, '__main__', '/tmp/foo.py', False) diff --git a/compiler/stmt.py b/compiler/stmt.py index 8daac1db..bf472c0b 100644 --- a/compiler/stmt.py +++ b/compiler/stmt.py @@ -504,6 +504,13 @@ def visit_function_inline(self, node): func_block = block.FunctionBlock(self.block, node.name, func_visitor.vars, func_visitor.is_generator) visitor = StatementVisitor(func_block, self.future_node) + + for arg in node.args.args: + if isinstance(arg, ast.Tuple): + arg_name = 'τ{}'.format(id(arg.elts)) + with visitor.writer.indent_block(): + visitor._tie_target(arg, util.adjust_local_name(arg_name)) + # Indent so that the function body is aligned with the goto labels. with visitor.writer.indent_block(): visitor._visit_each(node.body) # pylint: disable=protected-access @@ -519,9 +526,13 @@ def visit_function_inline(self, node): defaults = [None] * (argc - len(args.defaults)) + args.defaults for i, (a, d) in enumerate(zip(args.args, defaults)): with self.visit_expr(d) if d else expr.nil_expr as default: + if isinstance(a, ast.Tuple): + name = util.go_str('τ{}'.format(id(a.elts))) + else: + name = util.go_str(a.arg) tmpl = '$args[$i] = πg.Param{Name: $name, Def: $default}' self.writer.write_tmpl(tmpl, args=func_args.expr, i=i, - name=util.go_str(a.arg), default=default.expr) + name=name, default=default.expr) flags = [] if args.vararg: flags.append('πg.CodeFlagVarArg') @@ -583,6 +594,8 @@ def visit_function_inline(self, node): def _assign_target(self, target, value): if isinstance(target, ast.Name): self.block.bind_var(self.writer, target.id, value) + elif isinstance(target, ast.arg): + self.block.bind_var(self.writer, target.arg, value) elif isinstance(target, ast.Attribute): with self.visit_expr(target.value) as obj: self.writer.write_checked_call1( diff --git a/compiler/stmt_test.py b/compiler/stmt_test.py index 5f4cb0f5..91ff3cba 100644 --- a/compiler/stmt_test.py +++ b/compiler/stmt_test.py @@ -243,6 +243,30 @@ def foo(a, b): print a, b foo('bar', 'baz')"""))) + def testFunctionDefWithTupleArgs(self): + self.assertEqual((0, "('bar', 'baz')\n"), _GrumpRun(textwrap.dedent("""\ + def foo((a, b)): + print(a, b) + foo(('bar', 'baz'))"""))) + + def testFunctionDefWithNestedTupleArgs(self): + self.assertEqual((0, "('bar', 'baz', 'qux')\n"), _GrumpRun(textwrap.dedent("""\ + def foo(((a, b), c)): + print(a, b, c) + foo((('bar', 'baz'), 'qux'))"""))) + + def testFunctionDefWithMultipleTupleArgs(self): + self.assertEqual((0, "('bar', 'baz')\n"), _GrumpRun(textwrap.dedent("""\ + def foo(((a, ), (b, ))): + print(a, b) + foo((('bar',), ('baz', )))"""))) + + def testFunctionDefTupleArgsInLambda(self): + self.assertEqual((0, "[(3, 2), (4, 3), (12, 1)]\n"), _GrumpRun(textwrap.dedent("""\ + c = {12: 1, 3: 2, 4: 3} + top = sorted(c.items(), key=lambda (k,v): v) + print (top)"""))) + def testFunctionDefGenerator(self): self.assertEqual((0, "['foo', 'bar']\n"), _GrumpRun(textwrap.dedent("""\ def gen(): From 78b8542a950313f192393de6c2a0d03f6ea6e2b0 Mon Sep 17 00:00:00 2001 From: Raghavendra Date: Tue, 29 Aug 2017 22:50:40 +0530 Subject: [PATCH 2/3] Ignoring protected access in pylint --- compiler/stmt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/stmt.py b/compiler/stmt.py index bf472c0b..5e639eb3 100644 --- a/compiler/stmt.py +++ b/compiler/stmt.py @@ -509,7 +509,7 @@ def visit_function_inline(self, node): if isinstance(arg, ast.Tuple): arg_name = 'τ{}'.format(id(arg.elts)) with visitor.writer.indent_block(): - visitor._tie_target(arg, util.adjust_local_name(arg_name)) + visitor._tie_target(arg, util.adjust_local_name(arg_name)) # pylint: disable=protected-access # Indent so that the function body is aligned with the goto labels. with visitor.writer.indent_block(): From fb5fae4a1db1b1246a157895c75c29ff5dd4bc7a Mon Sep 17 00:00:00 2001 From: Raghavendra Date: Tue, 29 Aug 2017 23:00:10 +0530 Subject: [PATCH 3/3] Fixing unit test failure; removing import --- compiler/block_test.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/compiler/block_test.py b/compiler/block_test.py index 8c455f73..5a79d205 100644 --- a/compiler/block_test.py +++ b/compiler/block_test.py @@ -26,8 +26,6 @@ from grumpy.compiler import util from grumpy import pythonparser -from compiler.block import Var - class PackageTest(unittest.TestCase): @@ -231,7 +229,7 @@ def testTupleArgs(self): func = _ParseStmt('def foo((bar, baz)): pass') visitor = block.FunctionBlockVisitor(func) self.assertEqual(len(visitor.vars), 3) - self.assertTrue(len([v for v in visitor.vars if visitor.vars[v].type == Var.TYPE_TUPLE_PARAM]), 1) + self.assertEqual(len([v for v in visitor.vars if visitor.vars[v].type == block.Var.TYPE_TUPLE_PARAM]), 2) self.assertIn('bar', visitor.vars) self.assertIn('baz', visitor.vars)