From 3b549307ffc72a726354336bddf771a65c6682c4 Mon Sep 17 00:00:00 2001 From: Kevin Michel Date: Mon, 19 Apr 2021 16:32:47 +0200 Subject: [PATCH] honour connect_timeout and PGCONNECT_TIMEOUT The wait_select callback previously installed to enable Ctrl+C during long queries was breaking configurable connection timeout : https://github.com/psycopg/psycopg2/issues/944 This was replaced with a more visible async connection and a manual call to a custom wait_select with support for timeout. The timeout mimics default libpq behavior and reads the connect_timeout connection parameter with a fallback on PGCONNECT_TIMEOUT environment variable (and a default of 0: no timeout). A secondary benefit is to allow importing PGMigrate inside another project without PGMigrate altering the global set_wait_callback. --- aiven_db_migrate/migrate/pgmigrate.py | 15 +++++++---- aiven_db_migrate/migrate/pgutils.py | 37 +++++++++++++++++++++++++++ test/test_pg_migrate.py | 23 +++++++++++++++++ 3 files changed, 70 insertions(+), 5 deletions(-) diff --git a/aiven_db_migrate/migrate/pgmigrate.py b/aiven_db_migrate/migrate/pgmigrate.py index 49222d6..d0fae35 100644 --- a/aiven_db_migrate/migrate/pgmigrate.py +++ b/aiven_db_migrate/migrate/pgmigrate.py @@ -4,7 +4,7 @@ PGDataDumpFailedError, PGDataNotFoundError, PGMigrateValidationFailedError, PGSchemaDumpFailedError, PGTooMuchDataError ) from aiven_db_migrate.migrate.pgutils import ( - create_connection_string, find_pgbin_dir, get_connection_info, validate_pg_identifier_length + create_connection_string, find_pgbin_dir, get_connection_info, validate_pg_identifier_length, wait_select ) from aiven_db_migrate.migrate.version import __version__ from concurrent import futures @@ -31,8 +31,6 @@ import threading import time -# https://www.psycopg.org/docs/faq.html#faq-interrupt-query -psycopg2.extensions.set_wait_callback(psycopg2.extras.wait_select) MAX_CLI_LEN = 2097152 # getconf ARG_MAX @@ -136,6 +134,12 @@ def conn_str(self, *, dbname: str = None) -> str: conn_info["application_name"] = conn_info["application_name"] + "/" + self.mangle_db_name(conn_info["dbname"]) return create_connection_string(conn_info) + def connect_timeout(self): + try: + return int(self.conn_info.get("connect_timeout", os.environ.get("PGCONNECT_TIMEOUT", "")), 10) + except ValueError: + return None + @contextmanager def _cursor(self, *, dbname: str = None) -> RealDictCursor: conn: psycopg2.extensions.connection = None @@ -146,8 +150,8 @@ def _cursor(self, *, dbname: str = None) -> RealDictCursor: # from multiple threads; allow only one connection at time self.conn_lock.acquire() try: - conn = psycopg2.connect(**conn_info) - conn.autocommit = True + conn = psycopg2.connect(**conn_info, async_=True) + wait_select(conn, self.connect_timeout()) yield conn.cursor(cursor_factory=RealDictCursor) finally: if conn is not None: @@ -166,6 +170,7 @@ def c( results: List[Dict[str, Any]] = [] with self._cursor(dbname=dbname) as cur: cur.execute(query, args) + wait_select(cur.connection) if return_rows: results = cur.fetchall() if return_rows > 0 and len(results) != return_rows: diff --git a/aiven_db_migrate/migrate/pgutils.py b/aiven_db_migrate/migrate/pgutils.py index a49db50..4dd4af4 100644 --- a/aiven_db_migrate/migrate/pgutils.py +++ b/aiven_db_migrate/migrate/pgutils.py @@ -4,6 +4,10 @@ from typing import Any, Dict from urllib.parse import parse_qs, urlparse +import psycopg2 +import select +import time + def find_pgbin_dir(pgversion: str) -> Path: def _pgbin_paths(): @@ -105,3 +109,36 @@ def parse_connection_string_url(url: str) -> Dict[str, str]: for k, v in parse_qs(p.query).items(): fields[k] = v[-1] return fields + + +def wait_select(conn, timeout=None): + start_time = time.monotonic() + poll = select.poll() + while True: + if timeout is not None and timeout > 0: + time_left = start_time + timeout - time.monotonic() + if time_left <= 0: + raise TimeoutError("wait_select: timeout after {} seconds".format(timeout)) + else: + time_left = 1 + state = conn.poll() + try: + if state == psycopg2.extensions.POLL_OK: + return + elif state == psycopg2.extensions.POLL_READ: + poll.register(conn.fileno(), select.POLLIN) + elif state == psycopg2.extensions.POLL_WRITE: + poll.register(conn.fileno(), select.POLLOUT) + else: + raise conn.OperationalError("wait_select: invalid poll state") + try: + # When the remote address does not exist at all, poll.poll() waits its full timeout without any event. + # However, in the same conditions, conn.poll() raises a psycopg2 exception almost immediately. + # It is better to fail quickly instead of waiting the full timeout, so we keep our poll.poll() below 1sec. + poll.poll(min(1.0, time_left) * 1000) + finally: + poll.unregister(conn.fileno()) + except KeyboardInterrupt: + conn.cancel() + # the loop will be broken by a server error + continue diff --git a/test/test_pg_migrate.py b/test/test_pg_migrate.py index b1d3420..a243d13 100644 --- a/test/test_pg_migrate.py +++ b/test/test_pg_migrate.py @@ -6,8 +6,10 @@ from test.utils import random_string, Timer from typing import Any, Dict, Optional +import os import psycopg2 import pytest +import time class PGMigrateTest: @@ -154,6 +156,27 @@ def test_migrate_invalid_conn_str(self): PGMigrate(source_conn_info=source_conn_info, target_conn_info=target_conn_info).migrate() assert str(err.value) == "Invalid source or target connection string" + def test_migrate_connect_timeout_parameter(self): + for source_conn_info in ("host=example.org connect_timeout=1", "postgresql://example.org?connect_timeout=1"): + start_time = time.monotonic() + with pytest.raises(TimeoutError): + PGMigrate(source_conn_info=source_conn_info, target_conn_info=self.target.conn_info()).migrate() + end_time = time.monotonic() + assert end_time - start_time < 2 + + def test_migrate_connect_timeout_environment(self): + start_time = time.monotonic() + original_timeout = os.environ.get("PGCONNECT_TIMEOUT") + try: + os.environ["PGCONNECT_TIMEOUT"] = "1" + with pytest.raises(TimeoutError): + PGMigrate(source_conn_info="host=example.org", target_conn_info=self.target.conn_info()).migrate() + end_time = time.monotonic() + assert end_time - start_time < 2 + finally: + if original_timeout is not None: + os.environ["PGCONNECT_TIMEOUT"] = original_timeout + def test_migrate_same_server(self): source_conn_info = target_conn_info = self.target.conn_info() with pytest.raises(PGMigrateValidationFailedError) as err: