Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jamdagni86 fix for #17 (Support tuple args in lambda) #104

Merged
merged 4 commits into from
Oct 22, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions grumpy-tools-src/grumpy_tools/compiler/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,8 @@ def resolve_name(self, writer, name):
if var:
if var.type == Var.TYPE_GLOBAL:
return self._resolve_global(writer, name)
if var.type == Var.TYPE_TUPLE_PARAM:
return expr.GeneratedLocalVar(name)
writer.write_checked_call1('πg.CheckLocal(πF, {}, {})',
util.adjust_local_name(name),
util.go_str(name))
Expand All @@ -263,6 +265,7 @@ class Var(object):
TYPE_LOCAL = 0
TYPE_PARAM = 1
TYPE_GLOBAL = 2
TYPE_TUPLE_PARAM = 3

def __init__(self, name, var_type, arg_index=None):
self.name = name
Expand All @@ -273,6 +276,9 @@ def __init__(self, name, var_type, arg_index=None):
elif var_type == Var.TYPE_PARAM:
assert arg_index is not None
self.init_expr = 'πArgs[{}]'.format(arg_index)
elif var_type == Var.TYPE_TUPLE_PARAM:
assert arg_index is None
self.init_expr = 'nil'
else:
assert arg_index is None
self.init_expr = None
Expand Down Expand Up @@ -364,16 +370,40 @@ def __init__(self, node):
BlockVisitor.__init__(self)
self.is_generator = False
node_args = node.args
args = [a.arg for a in node_args.args]
args = []
for arg in node_args.args:
if isinstance(arg, ast.Tuple):
args.append(arg.elts)
else:
args.append(arg.arg)
# args = [a.arg for a in node_args.args]

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

args = [arg.elts if isinstance(arg, ast.Tuple) arg.arg for arg in node_args.args]

Copy link
Author

@alanjds alanjds Oct 8, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

args = [(arg.elts if isinstance(arg, ast.Tuple) else arg.arg) for arg in node_args.args]

I am wondering if this version is more readable than the original 5-line-long one

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just the pythonic way, but you'r right . the original is more readable.

if node_args.vararg:
args.append(node_args.vararg.arg)
if node_args.kwarg:
args.append(node_args.kwarg.arg)
for i, name in enumerate(args):
if name in self.vars:
msg = "duplicate argument '{}' in function definition".format(name)
raise util.ParseError(node, msg)
self.vars[name] = Var(name, Var.TYPE_PARAM, arg_index=i)
if isinstance(name, list):
arg_name = 'τ{}'.format(id(name))
for el in name:
self._parse_tuple(el, node)
self.vars[arg_name] = Var(arg_name, Var.TYPE_PARAM, i)
else:
self._check_duplicate_args(name, node)
self.vars[name] = Var(name, Var.TYPE_PARAM, arg_index=i)

def _parse_tuple(self, el, node):
if isinstance(el, ast.Tuple):
for x in el.elts:
self._parse_tuple(x, node)
else:
self._check_duplicate_args(el.arg, node)
self.vars[el.arg] = Var(el.arg, Var.TYPE_TUPLE_PARAM)

def _check_duplicate_args(self, name, node):
if name in self.vars:
msg = "duplicate argument '{}' in function definition".format(name)
raise util.ParseError(node, msg)

def visit_Yield(self, unused_node): # pylint: disable=unused-argument
self.is_generator = True
8 changes: 8 additions & 0 deletions grumpy-tools-src/grumpy_tools/compiler/block_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from grumpy_tools.compiler import util
from grumpy_tools.vendor import pythonparser


class PackageTest(unittest.TestCase):

def testCreate(self):
Expand Down Expand Up @@ -224,6 +225,13 @@ def testYieldExpr(self):
self.assertEqual(sorted(visitor.vars.keys()), ['foo'])
self.assertRegexpMatches(visitor.vars['foo'].init_expr, r'UnboundLocal')

def testTupleArgs(self):
func = _ParseStmt('def foo((bar, baz)): pass')
visitor = block.FunctionBlockVisitor(func)
self.assertEqual(len(visitor.vars), 3)
self.assertEqual(len([v for v in visitor.vars if visitor.vars[v].type == block.Var.TYPE_TUPLE_PARAM]), 2)
self.assertIn('bar', visitor.vars)
self.assertIn('baz', visitor.vars)

def _MakeModuleBlock():
importer = imputil.Importer(None, '__main__', '/tmp/foo.py', False)
Expand Down
15 changes: 14 additions & 1 deletion grumpy-tools-src/grumpy_tools/compiler/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,13 @@ def visit_function_inline(self, node):
func_block = block.FunctionBlock(self.block, node.name, func_visitor.vars,
func_visitor.is_generator)
visitor = StatementVisitor(func_block, self.future_node)

for arg in node.args.args:
if isinstance(arg, ast.Tuple):
arg_name = 'τ{}'.format(id(arg.elts))
with visitor.writer.indent_block():
visitor._tie_target(arg, util.adjust_local_name(arg_name)) # pylint: disable=protected-access

# Indent so that the function body is aligned with the goto labels.
with visitor.writer.indent_block():
visitor._visit_each(node.body) # pylint: disable=protected-access
Expand All @@ -519,9 +526,13 @@ def visit_function_inline(self, node):
defaults = [None] * (argc - len(args.defaults)) + args.defaults
for i, (a, d) in enumerate(zip(args.args, defaults)):
with self.visit_expr(d) if d else expr.nil_expr as default:
if isinstance(a, ast.Tuple):
name = util.go_str('τ{}'.format(id(a.elts)))
else:
name = util.go_str(a.arg)
tmpl = '$args[$i] = πg.Param{Name: $name, Def: $default}'
self.writer.write_tmpl(tmpl, args=func_args.expr, i=i,
name=util.go_str(a.arg), default=default.expr)
name=name, default=default.expr)
flags = []
if args.vararg:
flags.append('πg.CodeFlagVarArg')
Expand Down Expand Up @@ -583,6 +594,8 @@ def visit_function_inline(self, node):
def _assign_target(self, target, value):
if isinstance(target, ast.Name):
self.block.bind_var(self.writer, target.id, value)
elif isinstance(target, ast.arg):
self.block.bind_var(self.writer, target.arg, value)
elif isinstance(target, ast.Attribute):
with self.visit_expr(target.value) as obj:
self.writer.write_checked_call1(
Expand Down
24 changes: 24 additions & 0 deletions grumpy-tools-src/grumpy_tools/compiler/stmt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,30 @@ def foo(a, b):
print a, b
foo('bar', 'baz')""")))

def testFunctionDefWithTupleArgs(self):
self.assertEqual((0, "('bar', 'baz')\n"), _GrumpRun(textwrap.dedent("""\
def foo((a, b)):
print(a, b)
foo(('bar', 'baz'))""")))

def testFunctionDefWithNestedTupleArgs(self):
self.assertEqual((0, "('bar', 'baz', 'qux')\n"), _GrumpRun(textwrap.dedent("""\
def foo(((a, b), c)):
print(a, b, c)
foo((('bar', 'baz'), 'qux'))""")))

def testFunctionDefWithMultipleTupleArgs(self):
self.assertEqual((0, "('bar', 'baz')\n"), _GrumpRun(textwrap.dedent("""\
def foo(((a, ), (b, ))):
print(a, b)
foo((('bar',), ('baz', )))""")))

def testFunctionDefTupleArgsInLambda(self):
self.assertEqual((0, "[(3, 2), (4, 3), (12, 1)]\n"), _GrumpRun(textwrap.dedent("""\
c = {12: 1, 3: 2, 4: 3}
top = sorted(c.items(), key=lambda (k,v): v)
print (top)""")))

def testFunctionDefGenerator(self):
self.assertEqual((0, "['foo', 'bar']\n"), _GrumpRun(textwrap.dedent("""\
def gen():
Expand Down