From 0e7940837b63218363523db4ebe744e1c6bde5b7 Mon Sep 17 00:00:00 2001 From: tlento Date: Wed, 12 Jul 2023 01:31:54 -0400 Subject: [PATCH] Add support for validate_sql method to BigQuery In CLI contexts MetricFlow will issue dry run queries as part of its warehouse validation operations, and so we are adding a validate_sql method to all adapters. This commit adds support for the validate_sql method to BigQuery. It does so by creating a BigQuery-specific `dry_run` method on the BigQueryConnectionManager. This simply passes through the input SQL with the `dry_run` QueryJobParameter flag set True. This will result in BigQuery computing and returning a cost estimate for the query, or raising an exception in the event the query is not valid. Note: constructing the response object involves some repetitive value extraction from the QueryResult returned by BigQuery. While I would ordinariy prefer to tidy this up first we are pressed for time, and so we postpone that cleanup in order to keep this change as isolated as possible. --- .../unreleased/Features-20230712-014350.yaml | 6 +++ dbt/adapters/bigquery/connections.py | 45 ++++++++++++++++++- dbt/adapters/bigquery/impl.py | 10 +++++ tests/functional/adapter/utils/test_utils.py | 27 +++++++++++ 4 files changed, 86 insertions(+), 2 deletions(-) create mode 100644 .changes/unreleased/Features-20230712-014350.yaml diff --git a/.changes/unreleased/Features-20230712-014350.yaml b/.changes/unreleased/Features-20230712-014350.yaml new file mode 100644 index 000000000..9bd47f49b --- /dev/null +++ b/.changes/unreleased/Features-20230712-014350.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Add validate_sql to BigQuery adapter and dry_run to BigQueryConnectionManager +time: 2023-07-12T01:43:50.36167-04:00 +custom: + Author: tlento + Issue: "805" diff --git a/dbt/adapters/bigquery/connections.py b/dbt/adapters/bigquery/connections.py index 06da1ff90..8662da1de 100644 --- a/dbt/adapters/bigquery/connections.py +++ b/dbt/adapters/bigquery/connections.py @@ -428,7 +428,13 @@ def get_table_from_response(cls, resp): column_names = [field.name for field in resp.schema] return agate_helper.table_from_data_flat(resp, column_names) - def raw_execute(self, sql, use_legacy_sql=False, limit: Optional[int] = None): + def raw_execute( + self, + sql, + use_legacy_sql=False, + limit: Optional[int] = None, + dry_run: bool = False, + ): conn = self.get_thread_connection() client = conn.handle @@ -446,7 +452,11 @@ def raw_execute(self, sql, use_legacy_sql=False, limit: Optional[int] = None): if active_user: labels["dbt_invocation_id"] = active_user.invocation_id - job_params = {"use_legacy_sql": use_legacy_sql, "labels": labels} + job_params = { + "use_legacy_sql": use_legacy_sql, + "labels": labels, + "dry_run": dry_run, + } priority = conn.credentials.priority if priority == Priority.Batch: @@ -554,6 +564,37 @@ def execute( return response, table + def dry_run(self, sql: str) -> BigQueryAdapterResponse: + """Run the given sql statement with the `dry_run` job parameter set. + + This will allow BigQuery to validate the SQL and immediately return job cost + estimates, which we capture in the BigQueryAdapterResponse. Invalid SQL + will result in an exception. + """ + sql = self._add_query_comment(sql) + query_job, _ = self.raw_execute(sql, dry_run=True) + + # TODO: Factor this repetitive block out into a factory method on + # BigQueryAdapterResponse + message = f"Ran dry run query for statement of type {query_job.statement_type}" + bytes_billed = query_job.total_bytes_billed + processed_bytes = self.format_bytes(query_job.total_bytes_processed) + location = query_job.location + project_id = query_job.project + job_id = query_job.job_id + slot_ms = query_job.slot_millis + + return BigQueryAdapterResponse( + _message=message, + code="DRY RUN", + bytes_billed=bytes_billed, + bytes_processed=processed_bytes, + location=location, + project_id=project_id, + job_id=job_id, + slot_ms=slot_ms, + ) + @staticmethod def _bq_job_link(location, project_id, job_id) -> str: return f"https://console.cloud.google.com/bigquery?project={project_id}&j=bq:{location}:{job_id}&page=queryresults" diff --git a/dbt/adapters/bigquery/impl.py b/dbt/adapters/bigquery/impl.py index 353be08d8..f53cd4084 100644 --- a/dbt/adapters/bigquery/impl.py +++ b/dbt/adapters/bigquery/impl.py @@ -2,6 +2,7 @@ import threading from typing import Dict, List, Optional, Any, Set, Union, Type +from dbt.contracts.connection import AdapterResponse from dbt.contracts.graph.nodes import ColumnLevelConstraint, ModelLevelConstraint, ConstraintType # type: ignore from dbt.dataclass_schema import dbtClassMixin, ValidationError @@ -1024,3 +1025,12 @@ def render_model_constraint(cls, constraint: ModelLevelConstraint) -> Optional[s def debug_query(self): """Override for DebugTask method""" self.execute("select 1 as id") + + def validate_sql(self, sql: str) -> AdapterResponse: + """Submit the given SQL to the engine for validation, but not execution. + + This submits the query with the `dry_run` flag set True. + + :param str sql: The sql to validate + """ + return self.connections.dry_run(sql) diff --git a/tests/functional/adapter/utils/test_utils.py b/tests/functional/adapter/utils/test_utils.py index 6fb2d05d2..dc42c4db0 100644 --- a/tests/functional/adapter/utils/test_utils.py +++ b/tests/functional/adapter/utils/test_utils.py @@ -1,4 +1,7 @@ +import random + import pytest +from google.api_core.exceptions import NotFound from dbt.tests.adapter.utils.test_array_append import BaseArrayAppend from dbt.tests.adapter.utils.test_array_concat import BaseArrayConcat @@ -24,6 +27,7 @@ from dbt.tests.adapter.utils.test_safe_cast import BaseSafeCast from dbt.tests.adapter.utils.test_split_part import BaseSplitPart from dbt.tests.adapter.utils.test_string_literal import BaseStringLiteral +from dbt.tests.adapter.utils.test_validate_sql import BaseValidateSqlMethod from tests.functional.adapter.utils.fixture_array_append import ( models__array_append_actual_sql, models__array_append_expected_sql, @@ -167,3 +171,26 @@ class TestSplitPart(BaseSplitPart): class TestStringLiteral(BaseStringLiteral): pass + + +class TestValidateSqlMethod(BaseValidateSqlMethod): + pass + + +class TestDryRunMethod: + """Test connection manager dry run method operation.""" + + def test_dry_run_method(self, project) -> None: + """Test dry run method on a DDL statement. + + This allows us to demonstrate that no SQL is executed. + """ + with project.adapter.connection_named("_test"): + client = project.adapter.connections.get_thread_connection().handle + random_suffix = "".join(random.choices([str(i) for i in range(10)], k=10)) + table_name = f"test_dry_run_{random_suffix}" + table_id = "{}.{}.{}".format(project.database, project.test_schema, table_name) + res = project.adapter.connections.dry_run(f"CREATE TABLE {table_id} (x INT64)") + assert res.code == "DRY RUN" + with pytest.raises(expected_exception=NotFound): + client.get_table(table_id)