Skip to content

Commit

Permalink
feat(snowflake): Fix exp.Pivot FOR IN clause (#4109)
Browse files Browse the repository at this point in the history
  • Loading branch information
VaggelisD committed Sep 11, 2024
1 parent b10255e commit 3cb0041
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 8 deletions.
2 changes: 0 additions & 2 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1917,8 +1917,6 @@ def pivot_sql(self, expression: exp.Pivot) -> str:
direction = self.seg("UNPIVOT" if expression.unpivot else "PIVOT")

field = self.sql(expression, "field")
if field and isinstance(expression.args.get("field"), exp.PivotAny):
field = f"IN ({field})"

include_nulls = expression.args.get("include_nulls")
if include_nulls is not None:
Expand Down
7 changes: 3 additions & 4 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3907,13 +3907,12 @@ def _parse_aliased_expression() -> t.Optional[exp.Expression]:
self.raise_error("Expecting IN (")

if self._match(TokenType.ANY):
expr: exp.PivotAny | exp.In = self.expression(exp.PivotAny, this=self._parse_order())
exprs: t.List[exp.Expression] = ensure_list(exp.PivotAny(this=self._parse_order()))
else:
aliased_expressions = self._parse_csv(_parse_aliased_expression)
expr = self.expression(exp.In, this=value, expressions=aliased_expressions)
exprs = self._parse_csv(_parse_aliased_expression)

self._match_r_paren()
return expr
return self.expression(exp.In, this=value, expressions=exprs)

def _parse_pivot(self) -> t.Optional[exp.Pivot]:
index = self._index
Expand Down
4 changes: 2 additions & 2 deletions tests/dialects/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@ def test_snowflake(self):
"SELECT * FROM quarterly_sales PIVOT(SUM(amount) FOR quarter IN (SELECT DISTINCT quarter FROM ad_campaign_types_by_quarter WHERE television = TRUE ORDER BY quarter)) ORDER BY empid"
)
self.validate_identity(
"SELECT * FROM quarterly_sales PIVOT(SUM(amount) FOR IN (ANY ORDER BY quarter)) ORDER BY empid"
"SELECT * FROM quarterly_sales PIVOT(SUM(amount) FOR quarter IN (ANY ORDER BY quarter)) ORDER BY empid"
)
self.validate_identity(
"SELECT * FROM quarterly_sales PIVOT(SUM(amount) FOR IN (ANY)) ORDER BY empid"
"SELECT * FROM quarterly_sales PIVOT(SUM(amount) FOR quarter IN (ANY)) ORDER BY empid"
)
self.validate_identity(
"MERGE INTO my_db AS ids USING (SELECT new_id FROM my_model WHERE NOT col IS NULL) AS new_ids ON ids.type = new_ids.type AND ids.source = new_ids.source WHEN NOT MATCHED THEN INSERT VALUES (new_ids.new_id)"
Expand Down

0 comments on commit 3cb0041

Please sign in to comment.