From b123a7576e87b4d3e825ba91894f1201c2398448 Mon Sep 17 00:00:00 2001 From: Naoki Takezoe Date: Tue, 2 Jul 2024 23:51:58 +0900 Subject: [PATCH] sql (fix): Support TIMESTAMP AT TIME ZONE literal (#3576) --- .../scala/wvlet/airframe/sql/model/Expression.scala | 8 ++++++++ .../wvlet/airframe/sql/parser/SQLInterpreter.scala | 12 ++++++++++-- .../wvlet/airframe/sql/parser/SQLGeneratorTest.scala | 8 ++++++++ 3 files changed, 26 insertions(+), 2 deletions(-) diff --git a/airframe-sql/src/main/scala/wvlet/airframe/sql/model/Expression.scala b/airframe-sql/src/main/scala/wvlet/airframe/sql/model/Expression.scala index 32fe96be27..eb4d229e65 100644 --- a/airframe-sql/src/main/scala/wvlet/airframe/sql/model/Expression.scala +++ b/airframe-sql/src/main/scala/wvlet/airframe/sql/model/Expression.scala @@ -1157,6 +1157,14 @@ object Expression { override def sqlExpr = s"TIMESTAMP '${value}'" override def toString = s"Literal(TIMESTAMP '${value}')" } + case class TimestampWithTimeZoneLiteral(value: String, timezone: String, nodeLocation: Option[NodeLocation]) + extends Literal + with LeafExpression { + override def dataType: DataType = DataType.TimestampType(TimestampField.TIMESTAMP, false) + override def stringValue: String = value + override def sqlExpr = s"TIMESTAMP '${value}' AT TIME ZONE '${timezone}'" + override def toString = s"Literal(TIMESTAMP '${value}' AT '${timezone}')" + } case class DecimalLiteral(value: String, nodeLocation: Option[NodeLocation]) extends Literal with LeafExpression { override def dataType: DataType = DataType.DecimalType(TypeVariable("precision"), TypeVariable("scale")) override def stringValue: String = value diff --git a/airframe-sql/src/main/scala/wvlet/airframe/sql/parser/SQLInterpreter.scala b/airframe-sql/src/main/scala/wvlet/airframe/sql/parser/SQLInterpreter.scala index 65d0df424e..925443edff 100644 --- a/airframe-sql/src/main/scala/wvlet/airframe/sql/parser/SQLInterpreter.scala +++ b/airframe-sql/src/main/scala/wvlet/airframe/sql/parser/SQLInterpreter.scala @@ -471,12 +471,20 @@ class SQLInterpreter(withNodeLocation: Boolean = true) extends SqlBaseBaseVisito // TODO Parse decimal-type precision properly case "decimal" => DecimalLiteral(v, getLocation(ctx)) case "char" => CharLiteral(v, getLocation(ctx)) - case other => - GenericLiteral(tpe, v, getLocation(ctx)) + case other => GenericLiteral(tpe, v, getLocation(ctx)) } } } + override def visitAtTimeZone(ctx: AtTimeZoneContext): Expression = { + val v = expression(ctx.timeZoneSpecifier()).asInstanceOf[StringLiteral].value + + expression(ctx.valueExpression()) match { + case t: TimestampLiteral => TimestampWithTimeZoneLiteral(t.value, v, t.nodeLocation) + case other => other + } + } + override def visitBasicStringLiteral(ctx: BasicStringLiteralContext): StringLiteral = { StringLiteral(unquote(ctx.STRING().getText), getLocation(ctx)) } diff --git a/airframe-sql/src/test/scala/wvlet/airframe/sql/parser/SQLGeneratorTest.scala b/airframe-sql/src/test/scala/wvlet/airframe/sql/parser/SQLGeneratorTest.scala index 68d92e2805..1377625742 100644 --- a/airframe-sql/src/test/scala/wvlet/airframe/sql/parser/SQLGeneratorTest.scala +++ b/airframe-sql/src/test/scala/wvlet/airframe/sql/parser/SQLGeneratorTest.scala @@ -186,4 +186,12 @@ class SQLGeneratorTest extends AirSpec { val sql = SQLGenerator.print(resolvedPlan).toLowerCase sql shouldBe "select * from (select * from a) join a using (id)" } + + test("print TIMESTAMP AT TIME ZONE") { + val resolvedPlan = + SQLAnalyzer.analyze("SELECT TIMESTAMP '1992-02-01 00:00 UTC' AT TIME ZONE 'Asia/Tokyo'", "default", demoCatalog) + + val sql = SQLGenerator.print(resolvedPlan) + sql shouldBe "SELECT TIMESTAMP '1992-02-01 00:00 UTC' AT TIME ZONE 'Asia/Tokyo'" + } }