Skip to content

Commit

Permalink
PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
erindru committed Sep 11, 2024
1 parent 1da1fd1 commit 6038a60
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 46 deletions.
50 changes: 5 additions & 45 deletions sqlglot/dialects/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,7 @@
from sqlglot import exp
from sqlglot.dialects.trino import Trino
from sqlglot.dialects.hive import Hive
from sqlglot.tokens import Token, TokenType


def _parse_as_hive(raw_tokens: t.List[Token]) -> bool:
if len(raw_tokens) > 0:
first_token = raw_tokens[0]
if first_token.token_type == TokenType.CREATE:
# CREATE is Hive (except for CREATE VIEW and CREATE TABLE... AS SELECT)
return not any(t.token_type in (TokenType.VIEW, TokenType.SELECT) for t in raw_tokens)

# ALTER and DROP are Hive
return first_token.token_type in (TokenType.ALTER, TokenType.DROP)
return False
from sqlglot.tokens import TokenType


def _generate_as_hive(expression: exp.Expression) -> bool:
Expand Down Expand Up @@ -81,13 +69,6 @@ class Tokenizer(Trino.Tokenizer):
"UNLOAD": TokenType.COMMAND,
}

class HiveParser(Hive.Parser):
"""
Parse queries for the Athena Hive execution engine
"""

pass

class Parser(Trino.Parser):
"""
Parse queries for the Athena Trino execution engine
Expand All @@ -98,30 +79,6 @@ class Parser(Trino.Parser):
TokenType.USING: lambda self: self._parse_as_command(self._prev),
}

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self._hive_parser = Athena.HiveParser(*args, **kwargs)

def parse(
self, raw_tokens: t.List[Token], sql: t.Optional[str] = None
) -> t.List[t.Optional[exp.Expression]]:
if _parse_as_hive(raw_tokens):
return self._hive_parser.parse(raw_tokens, sql)

return super().parse(raw_tokens, sql)

class HiveGenerator(Hive.Generator):
"""
Generating queries for the Athena Hive execution engine
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self._identifier_start = "`"
self._identifier_end = "`"

class Generator(Trino.Generator):
"""
Generate queries for the Athena Trino execution engine
Expand All @@ -139,7 +96,10 @@ class Generator(Trino.Generator):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._hive_generator = Athena.HiveGenerator(*args, **kwargs)

hive_kwargs = {**kwargs, "dialect": "hive"}

self._hive_generator = Hive.Generator(*args, **hive_kwargs)

def generate(self, expression: exp.Expression, copy: bool = True) -> str:
if _generate_as_hive(expression):
Expand Down
30 changes: 29 additions & 1 deletion tests/dialects/test_athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@ def test_athena(self):
check_command_warning=True,
)

self.validate_identity(
"/* leading comment */CREATE SCHEMA foo",
write_sql="/* leading comment */ CREATE SCHEMA `foo`",
identify=True,
)
self.validate_identity(
"/* leading comment */SELECT * FROM foo",
write_sql='/* leading comment */ SELECT * FROM "foo"',
identify=True,
)

def test_ddl(self):
# Hive-like, https://docs.aws.amazon.com/athena/latest/ug/create-table.html
self.validate_identity("CREATE EXTERNAL TABLE foo (id INT) COMMENT 'test comment'")
Expand Down Expand Up @@ -53,11 +64,13 @@ def test_ddl(self):
self.validate_identity(
"CREATE TABLE foo WITH (table_type='ICEBERG', external_location='s3://foo/') AS SELECT * FROM a"
)
self.validate_identity(
"CREATE TABLE foo AS WITH foo AS (SELECT a, b FROM bar) SELECT * FROM foo"
)

def test_ddl_quoting(self):
self.validate_identity("CREATE SCHEMA `foo`")
self.validate_identity("CREATE SCHEMA foo")
self.validate_identity("CREATE SCHEMA foo", write_sql="CREATE SCHEMA `foo`", identify=True)

self.validate_identity("CREATE EXTERNAL TABLE `foo` (`id` INT) LOCATION 's3://foo/'")
self.validate_identity("CREATE EXTERNAL TABLE foo (id INT) LOCATION 's3://foo/'")
Expand Down Expand Up @@ -107,6 +120,14 @@ def test_ddl_quoting(self):
'ALTER TABLE "foo" DROP COLUMN "id"', write_sql="ALTER TABLE `foo` DROP COLUMN `id`"
)

self.validate_identity(
'CREATE TABLE "foo" AS WITH "foo" AS (SELECT "a", "b" FROM "bar") SELECT * FROM "foo"'
)
self.validate_identity(
'CREATE TABLE `foo` AS WITH `foo` AS (SELECT "a", `b` FROM "bar") SELECT * FROM "foo"',
write_sql='CREATE TABLE "foo" AS WITH "foo" AS (SELECT "a", "b" FROM "bar") SELECT * FROM "foo"',
)

def test_dml_quoting(self):
self.validate_identity("SELECT a AS foo FROM tbl")
self.validate_identity('SELECT "a" AS "foo" FROM "tbl"')
Expand Down Expand Up @@ -139,3 +160,10 @@ def test_dml_quoting(self):
write_sql='DELETE FROM "foo" WHERE "id" > 10',
identify=True,
)

self.validate_identity("WITH foo AS (SELECT a, b FROM bar) SELECT * FROM foo")
self.validate_identity(
"WITH foo AS (SELECT a, b FROM bar) SELECT * FROM foo",
write_sql='WITH "foo" AS (SELECT "a", "b" FROM "bar") SELECT * FROM "foo"',
identify=True,
)

0 comments on commit 6038a60

Please sign in to comment.