From 60c789de0786380a2ac1884d5c67c19cdc41a517 Mon Sep 17 00:00:00 2001 From: SMIT MALKAN Date: Tue, 8 Aug 2023 02:49:26 +0530 Subject: [PATCH] `ReplaceLambdaWithMethodReference` should not replace ambiguous references (#148) * added ambiguous method reference validation check (#96) * Restore some original tests & functionality --------- Co-authored-by: Tim te Beek --- .../ReplaceLambdaWithMethodReference.java | 297 +++++++++--------- .../ReplaceLambdaWithMethodReferenceTest.java | 28 +- 2 files changed, 172 insertions(+), 153 deletions(-) diff --git a/src/main/java/org/openrewrite/staticanalysis/ReplaceLambdaWithMethodReference.java b/src/main/java/org/openrewrite/staticanalysis/ReplaceLambdaWithMethodReference.java index 7ce596b65..9637cf582 100644 --- a/src/main/java/org/openrewrite/staticanalysis/ReplaceLambdaWithMethodReference.java +++ b/src/main/java/org/openrewrite/staticanalysis/ReplaceLambdaWithMethodReference.java @@ -73,184 +73,191 @@ private static class ReplaceLambdaWithMethodReferenceKotlinVisitor extends Kotli private static class ReplaceLambdaWithMethodReferenceJavaVisitor extends JavaVisitor { @Override - public J visitLambda(J.Lambda lambda, ExecutionContext executionContext) { - J.Lambda l = (J.Lambda) super.visitLambda(lambda, executionContext); - updateCursor(l); + public J visitLambda(J.Lambda lambda, ExecutionContext executionContext) { + J.Lambda l = (J.Lambda) super.visitLambda(lambda, executionContext); + updateCursor(l); - String code = ""; - J body = l.getBody(); - if (body instanceof J.Block && ((J.Block) body).getStatements().size() == 1) { - Statement statement = ((J.Block) body).getStatements().get(0); - if (statement instanceof J.MethodInvocation) { - body = statement; - } else if (statement instanceof J.Return && - (((J.Return) statement).getExpression()) instanceof MethodCall) { - body = ((J.Return) statement).getExpression(); + String code = ""; + J body = l.getBody(); + if (body instanceof J.Block && ((J.Block) body).getStatements().size() == 1) { + Statement statement = ((J.Block) body).getStatements().get(0); + if (statement instanceof J.MethodInvocation) { + body = statement; + } else if (statement instanceof J.Return && + (((J.Return) statement).getExpression()) instanceof MethodCall) { + body = ((J.Return) statement).getExpression(); + } + } else if (body instanceof J.InstanceOf) { + J.InstanceOf instanceOf = (J.InstanceOf) body; + J j = instanceOf.getClazz(); + if ((j instanceof J.Identifier || j instanceof J.FieldAccess) && + instanceOf.getExpression() instanceof J.Identifier) { + J.FieldAccess classLiteral = newClassLiteral(((TypeTree) j).getType(), j instanceof J.FieldAccess); + if (classLiteral != null) { + //noinspection DataFlowIssue + JavaType.FullyQualified rawClassType = ((JavaType.Parameterized) classLiteral.getType()).getType(); + Optional isInstanceMethod = rawClassType.getMethods().stream().filter(m -> m.getName().equals("isInstance")).findFirst(); + if (isInstanceMethod.isPresent()) { + return newInstanceMethodReference(isInstanceMethod.get(), classLiteral, lambda.getType()).withPrefix(lambda.getPrefix()); + } } - } else if (body instanceof J.InstanceOf) { - J.InstanceOf instanceOf = (J.InstanceOf) body; - J j = instanceOf.getClazz(); - if ((j instanceof J.Identifier || j instanceof J.FieldAccess) && - instanceOf.getExpression() instanceof J.Identifier) { - J.FieldAccess classLiteral = newClassLiteral(((TypeTree) j).getType(), j instanceof J.FieldAccess); + } + } else if (body instanceof J.TypeCast) { + if (!(((J.TypeCast) body).getExpression() instanceof J.MethodInvocation)) { + J.ControlParentheses j = ((J.TypeCast) body).getClazz(); + J tree = j.getTree(); + if ((tree instanceof J.Identifier || tree instanceof J.FieldAccess) && + !(j.getType() instanceof JavaType.GenericTypeVariable)) { + J.FieldAccess classLiteral = newClassLiteral(((Expression) tree).getType(), tree instanceof J.FieldAccess); if (classLiteral != null) { //noinspection DataFlowIssue - JavaType.FullyQualified rawClassType = ((JavaType.Parameterized) classLiteral.getType()).getType(); - Optional isInstanceMethod = rawClassType.getMethods().stream().filter(m -> m.getName().equals("isInstance")).findFirst(); - if (isInstanceMethod.isPresent()) { - return newInstanceMethodReference(isInstanceMethod.get(), classLiteral, lambda.getType()).withPrefix(lambda.getPrefix()); + JavaType.FullyQualified classType = ((JavaType.Parameterized) classLiteral.getType()).getType(); + Optional castMethod = classType.getMethods().stream().filter(m -> m.getName().equals("cast")).findFirst(); + if (castMethod.isPresent()) { + return newInstanceMethodReference(castMethod.get(), classLiteral, lambda.getType()).withPrefix(lambda.getPrefix()); } } } - } else if (body instanceof J.TypeCast) { - if (!(((J.TypeCast) body).getExpression() instanceof J.MethodInvocation)) { - J.ControlParentheses j = ((J.TypeCast) body).getClazz(); - J tree = j.getTree(); - if ((tree instanceof J.Identifier || tree instanceof J.FieldAccess) && - !(j.getType() instanceof JavaType.GenericTypeVariable)) { - J.FieldAccess classLiteral = newClassLiteral(((Expression) tree).getType(), tree instanceof J.FieldAccess); - if (classLiteral != null) { - //noinspection DataFlowIssue - JavaType.FullyQualified classType = ((JavaType.Parameterized) classLiteral.getType()).getType(); - Optional castMethod = classType.getMethods().stream().filter(m -> m.getName().equals("cast")).findFirst(); - if (castMethod.isPresent()) { - return newInstanceMethodReference(castMethod.get(), classLiteral, lambda.getType()).withPrefix(lambda.getPrefix()); - } + } + } + + if (body instanceof J.Binary) { + J.Binary binary = (J.Binary) body; + if (isNullCheck(binary.getLeft(), binary.getRight()) || + isNullCheck(binary.getRight(), binary.getLeft())) { + doAfterVisit(new ShortenFullyQualifiedTypeReferences().getVisitor()); + code = J.Binary.Type.Equal.equals(binary.getOperator()) ? "java.util.Objects::isNull" : + "java.util.Objects::nonNull"; + return JavaTemplate.builder(code) + .contextSensitive() + .build() + .apply(getCursor(), l.getCoordinates().replace()); + } + } else if (body instanceof MethodCall) { + MethodCall method = (MethodCall) body; + if (method instanceof J.NewClass) { + J.NewClass nc = (J.NewClass) method; + if (nc.getBody() != null) { + return l; + } else { + if (isAMethodInvocationArgument(l, getCursor()) && nc.getType() instanceof JavaType.Class) { + JavaType.Class clazz = (JavaType.Class) nc.getType(); + boolean hasMultipleConstructors = clazz.getMethods().stream().filter(JavaType.Method::isConstructor).count() > 1; + if (hasMultipleConstructors) { + return l; } } } } - if (body instanceof J.Binary) { - J.Binary binary = (J.Binary) body; - if (isNullCheck(binary.getLeft(), binary.getRight()) || - isNullCheck(binary.getRight(), binary.getLeft())) { + if (multipleMethodInvocations(method) || + !methodArgumentsMatchLambdaParameters(method, lambda) || + method instanceof J.MemberReference) { + return l; + } + + Expression select = + method instanceof J.MethodInvocation ? ((J.MethodInvocation) method).getSelect() : null; + JavaType.Method methodType = method.getMethodType(); + if (methodType != null && !isMethodReferenceAmbiguous(methodType)) { + if (methodType.hasFlags(Flag.Static) || + methodSelectMatchesFirstLambdaParameter(method, lambda)) { doAfterVisit(new ShortenFullyQualifiedTypeReferences().getVisitor()); - code = J.Binary.Type.Equal.equals(binary.getOperator()) ? "java.util.Objects::isNull" : - "java.util.Objects::nonNull"; - return JavaTemplate.builder(code) + return newStaticMethodReference(methodType, true, lambda.getType()).withPrefix(lambda.getPrefix()); + } else if (method instanceof J.NewClass) { + return JavaTemplate.builder("#{}::new") .contextSensitive() .build() - .apply(getCursor(), l.getCoordinates().replace()); - } - } else if (body instanceof MethodCall) { - MethodCall method = (MethodCall) body; - if (method instanceof J.NewClass) { - J.NewClass nc = (J.NewClass) method; - if (nc.getBody() != null) { - return l; - } else { - if (isAMethodInvocationArgument(l, getCursor()) && nc.getType() instanceof JavaType.Class) { - JavaType.Class clazz = (JavaType.Class) nc.getType(); - boolean hasMultipleConstructors = clazz.getMethods().stream().filter(JavaType.Method::isConstructor).count() > 1; - if (hasMultipleConstructors) { - return l; - } - } - } + .apply(getCursor(), l.getCoordinates().replace(), className((J.NewClass) method)); + } else if (select != null) { + return newInstanceMethodReference(methodType, select, lambda.getType()).withPrefix(lambda.getPrefix()); + } else { + String templ = "#{}::#{}"; + return JavaTemplate.builder(templ) + .contextSensitive() + .build() + .apply(getCursor(), l.getCoordinates().replace(), "this", + method.getMethodType().getName()); } + } + } - if (multipleMethodInvocations(method) || - !methodArgumentsMatchLambdaParameters(method, lambda) || - method instanceof J.MemberReference) { - return l; - } + return l; + } - Expression select = - method instanceof J.MethodInvocation ? ((J.MethodInvocation) method).getSelect() : null; - JavaType.Method methodType = method.getMethodType(); - if (methodType != null) { - if (methodType.hasFlags(Flag.Static) || - methodSelectMatchesFirstLambdaParameter(method, lambda)) { - doAfterVisit(new ShortenFullyQualifiedTypeReferences().getVisitor()); + // returns the class name as given in the source code (qualified or unqualified) + private String className(J.NewClass method) { + TypeTree clazz = method.getClazz(); + return clazz instanceof J.ParameterizedType ? ((J.ParameterizedType) clazz).getClazz().toString() : + Objects.toString(clazz); + } - return newStaticMethodReference(methodType, true, lambda.getType()).withPrefix(lambda.getPrefix()); - } else if (method instanceof J.NewClass) { - return JavaTemplate.builder("#{}::new") - .contextSensitive() - .build() - .apply(getCursor(), l.getCoordinates().replace(), className((J.NewClass) method)); - } else if (select != null) { - return newInstanceMethodReference(methodType, select, lambda.getType()).withPrefix(lambda.getPrefix()); - } else { - String templ = "#{}::#{}"; - return JavaTemplate.builder(templ) - .contextSensitive() - .build() - .apply(getCursor(), l.getCoordinates().replace(), "this", - method.getMethodType().getName()); - } - } - } + private boolean multipleMethodInvocations(MethodCall method) { + return method instanceof J.MethodInvocation && + ((J.MethodInvocation) method).getSelect() instanceof J.MethodInvocation; + } - return l; + private boolean methodArgumentsMatchLambdaParameters(MethodCall method, J.Lambda lambda) { + JavaType.Method methodType = method.getMethodType(); + if (methodType == null) { + return false; } - - // returns the class name as given in the source code (qualified or unqualified) - private String className(J.NewClass method) { - TypeTree clazz = method.getClazz(); - return clazz instanceof J.ParameterizedType ? ((J.ParameterizedType) clazz).getClazz().toString() : - Objects.toString(clazz); + boolean static_ = methodType.hasFlags(Flag.Static); + List methodArgs = method.getArguments().stream().filter(a -> !(a instanceof J.Empty)) + .collect(Collectors.toList()); + List lambdaParameters = lambda.getParameters().getParameters() + .stream().filter(J.VariableDeclarations.class::isInstance) + .map(J.VariableDeclarations.class::cast).map(v -> v.getVariables().get(0)) + .collect(Collectors.toList()); + if (methodArgs.isEmpty() && lambdaParameters.isEmpty()) { + return true; } - - private boolean multipleMethodInvocations(MethodCall method) { - return method instanceof J.MethodInvocation && - ((J.MethodInvocation) method).getSelect() instanceof J.MethodInvocation; + if (!static_ && methodSelectMatchesFirstLambdaParameter(method, lambda)) { + methodArgs.add(0, ((J.MethodInvocation) method).getSelect()); } - - private boolean methodArgumentsMatchLambdaParameters(MethodCall method, J.Lambda lambda) { - JavaType.Method methodType = method.getMethodType(); - if (methodType == null) { + if (methodArgs.size() != lambdaParameters.size()) { + return false; + } + for (int i = 0; i < lambdaParameters.size(); i++) { + JavaType lambdaParam = lambdaParameters.get(i).getVariableType(); + if (!(methodArgs.get(i) instanceof J.Identifier)) { return false; } - boolean static_ = methodType.hasFlags(Flag.Static); - List methodArgs = method.getArguments().stream().filter(a -> !(a instanceof J.Empty)) - .collect(Collectors.toList()); - List lambdaParameters = lambda.getParameters().getParameters() - .stream().filter(J.VariableDeclarations.class::isInstance) - .map(J.VariableDeclarations.class::cast).map(v -> v.getVariables().get(0)) - .collect(Collectors.toList()); - if (methodArgs.isEmpty() && lambdaParameters.isEmpty()) { - return true; - } - if (!static_ && methodSelectMatchesFirstLambdaParameter(method, lambda)) { - methodArgs.add(0, ((J.MethodInvocation) method).getSelect()); - } - if (methodArgs.size() != lambdaParameters.size()) { + JavaType methodArgument = ((J.Identifier) methodArgs.get(i)).getFieldType(); + if (lambdaParam != methodArgument) { return false; } - for (int i = 0; i < lambdaParameters.size(); i++) { - JavaType lambdaParam = lambdaParameters.get(i).getVariableType(); - if (!(methodArgs.get(i) instanceof J.Identifier)) { - return false; - } - JavaType methodArgument = ((J.Identifier) methodArgs.get(i)).getFieldType(); - if (lambdaParam != methodArgument) { - return false; - } - } - return true; } + return true; + } - private boolean methodSelectMatchesFirstLambdaParameter(MethodCall method, J.Lambda lambda) { - if (!(method instanceof J.MethodInvocation) || - !(((J.MethodInvocation) method).getSelect() instanceof J.Identifier) || - lambda.getParameters().getParameters().isEmpty() || - !(lambda.getParameters().getParameters().get(0) instanceof J.VariableDeclarations)) { - return false; - } - J.VariableDeclarations firstLambdaParameter = (J.VariableDeclarations) lambda.getParameters() - .getParameters().get(0); - return ((J.Identifier) ((J.MethodInvocation) method).getSelect()).getFieldType() == - firstLambdaParameter.getVariables().get(0).getVariableType(); + private boolean methodSelectMatchesFirstLambdaParameter(MethodCall method, J.Lambda lambda) { + if (!(method instanceof J.MethodInvocation) || + !(((J.MethodInvocation) method).getSelect() instanceof J.Identifier) || + lambda.getParameters().getParameters().isEmpty() || + !(lambda.getParameters().getParameters().get(0) instanceof J.VariableDeclarations)) { + return false; } + J.VariableDeclarations firstLambdaParameter = (J.VariableDeclarations) lambda.getParameters() + .getParameters().get(0); + return ((J.Identifier) ((J.MethodInvocation) method).getSelect()).getFieldType() == + firstLambdaParameter.getVariables().get(0).getVariableType(); + } - private boolean isNullCheck(J j1, J j2) { - return j1 instanceof J.Identifier && j2 instanceof J.Literal && - "null".equals(((J.Literal) j2).getValueSource()); - } + private boolean isNullCheck(J j1, J j2) { + return j1 instanceof J.Identifier && j2 instanceof J.Literal && + "null".equals(((J.Literal) j2).getValueSource()); } + private boolean isMethodReferenceAmbiguous(JavaType.Method _method) { + return _method.getDeclaringType().getMethods().stream() + .filter(meth -> meth.getName().equals(_method.getName())) + .filter(meth -> !meth.getName().equals("println")) + .filter(meth -> !meth.isConstructor()) + .count() > 1; + } + } + private static boolean isAMethodInvocationArgument(J.Lambda lambda, Cursor cursor) { Cursor parent = cursor.dropParentUntil(p -> p instanceof J.MethodInvocation || p instanceof J.CompilationUnit); if (parent.getValue() instanceof J.MethodInvocation) { diff --git a/src/test/java/org/openrewrite/staticanalysis/ReplaceLambdaWithMethodReferenceTest.java b/src/test/java/org/openrewrite/staticanalysis/ReplaceLambdaWithMethodReferenceTest.java index 5cfd9e18d..5dbd5ec89 100644 --- a/src/test/java/org/openrewrite/staticanalysis/ReplaceLambdaWithMethodReferenceTest.java +++ b/src/test/java/org/openrewrite/staticanalysis/ReplaceLambdaWithMethodReferenceTest.java @@ -16,7 +16,6 @@ package org.openrewrite.staticanalysis; - import org.junit.jupiter.api.Test; import org.openrewrite.DocumentExample; import org.openrewrite.Issue; @@ -78,6 +77,25 @@ List method(List l) { ); } + @Issue("https://github.com/openrewrite/rewrite-static-analysis/issues/96") + @Test + void ignoreAmbiguousMethodReference() { + rewriteRun( + //language=java + java( + """ + import java.util.stream.Stream; + + class Test { + Stream method() { + return Stream.of(1, 32, 12, 15, 23).map(x -> Integer.toString(x)); + } + } + """ + ) + ); + } + @Test void containsMultipleStatements() { rewriteRun( @@ -426,7 +444,6 @@ void systemOutPrint() { java( """ import java.util.List; - class Test { void method(List input) { input.forEach(x -> System.out.println(x)); @@ -435,7 +452,6 @@ void method(List input) { """, """ import java.util.List; - class Test { void method(List input) { input.forEach(System.out::println); @@ -453,7 +469,6 @@ void systemOutPrintInBlock() { java( """ import java.util.List; - class Test { void method(List input) { input.forEach(x -> { System.out.println(x); }); @@ -462,7 +477,6 @@ void method(List input) { """, """ import java.util.List; - class Test { void method(List input) { input.forEach(System.out::println); @@ -484,8 +498,8 @@ public class CheckType { } """ ), + //language=java java( - //language=java """ import java.util.List; import java.util.stream.Collectors; @@ -917,7 +931,6 @@ void foo() { s = () -> new java.util.ArrayList(); s = () -> new ArrayList(); s = () -> new java.util.HashSet(); - Function f; f = i -> new ArrayList(i); } @@ -938,7 +951,6 @@ void foo() { s = java.util.ArrayList::new; s = ArrayList::new; s = java.util.HashSet::new; - Function f; f = ArrayList::new; }