Skip to content

Commit

Permalink
honour connect_timeout and PGCONNECT_TIMEOUT
Browse files Browse the repository at this point in the history
The wait_select callback previously installed to enable Ctrl+C during
long queries was breaking configurable connection timeout :
psycopg/psycopg2#944

This is replaced with a more visible async connection and a manual
call to a custom wait_select with support for timeout.

The timeout mimics default libpq behavior and reads the connect_timeout
connection parameter with a fallback on PGCONNECT_TIMEOUT environment
variable (and a default of 0: no timeout).

A secondary benefit is to allow importing PGMigrate inside another
project without PGMigrate altering the global set_wait_callback.
  • Loading branch information
kmichel-aiven committed Apr 20, 2021
1 parent 8e3a4db commit c9347e2
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 6 deletions.
24 changes: 18 additions & 6 deletions aiven_db_migrate/migrate/pgmigrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
PGDataDumpFailedError, PGDataNotFoundError, PGMigrateValidationFailedError, PGSchemaDumpFailedError, PGTooMuchDataError
)
from aiven_db_migrate.migrate.pgutils import (
create_connection_string, find_pgbin_dir, get_connection_info, validate_pg_identifier_length
create_connection_string, find_pgbin_dir, get_connection_info, validate_pg_identifier_length, wait_select
)
from aiven_db_migrate.migrate.version import __version__
from concurrent import futures
Expand All @@ -31,8 +31,6 @@
import threading
import time

# https://www.psycopg.org/docs/faq.html#faq-interrupt-query
psycopg2.extensions.set_wait_callback(psycopg2.extras.wait_select)
MAX_CLI_LEN = 2097152 # getconf ARG_MAX


Expand Down Expand Up @@ -136,6 +134,12 @@ def conn_str(self, *, dbname: str = None) -> str:
conn_info["application_name"] = conn_info["application_name"] + "/" + self.mangle_db_name(conn_info["dbname"])
return create_connection_string(conn_info)

def connect_timeout(self):
try:
return int(self.conn_info.get("connect_timeout", os.environ.get("PGCONNECT_TIMEOUT", "")), 10)
except ValueError:
return None

@contextmanager
def _cursor(self, *, dbname: str = None) -> RealDictCursor:
conn: psycopg2.extensions.connection = None
Expand All @@ -146,8 +150,8 @@ def _cursor(self, *, dbname: str = None) -> RealDictCursor:
# from multiple threads; allow only one connection at time
self.conn_lock.acquire()
try:
conn = psycopg2.connect(**conn_info)
conn.autocommit = True
conn = psycopg2.connect(**conn_info, async_=True)
wait_select(conn, self.connect_timeout())
yield conn.cursor(cursor_factory=RealDictCursor)
finally:
if conn is not None:
Expand All @@ -165,7 +169,15 @@ def c(
) -> List[Dict[str, Any]]:
results: List[Dict[str, Any]] = []
with self._cursor(dbname=dbname) as cur:
cur.execute(query, args)
try:
cur.execute(query, args)
wait_select(cur.connection)
except KeyboardInterrupt:
# We wrap the whole execute+wait block to make sure we cancel
# the query in all cases, which we couldn't if KeyboardInterupt
# was only handled inside wait_select.
cur.connection.cancel()
raise
if return_rows:
results = cur.fetchall()
if return_rows > 0 and len(results) != return_rows:
Expand Down
38 changes: 38 additions & 0 deletions aiven_db_migrate/migrate/pgutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from typing import Any, Dict
from urllib.parse import parse_qs, urlparse

import psycopg2
import select
import time


def find_pgbin_dir(pgversion: str) -> Path:
def _pgbin_paths():
Expand Down Expand Up @@ -105,3 +109,37 @@ def parse_connection_string_url(url: str) -> Dict[str, str]:
for k, v in parse_qs(p.query).items():
fields[k] = v[-1]
return fields


# This enables interruptible queries with an approach similar to
# https://www.psycopg.org/docs/faq.html#faq-interrupt-query
# However, to handle timeouts we can't use psycopg2.extensions.set_wait_callback :
# https://github.com/psycopg/psycopg2/issues/944
# Instead we rely on manually calling wait_select after connection and queries.
# Since it's not a wait callback, we do not capture and transform KeyboardInterupt here.
def wait_select(conn, timeout=None):
start_time = time.monotonic()
poll = select.poll()
while True:
if timeout is not None and timeout > 0:
time_left = start_time + timeout - time.monotonic()
if time_left <= 0:
raise TimeoutError("wait_select: timeout after {} seconds".format(timeout))
else:
time_left = 1
state = conn.poll()
if state == psycopg2.extensions.POLL_OK:
return
elif state == psycopg2.extensions.POLL_READ:
poll.register(conn.fileno(), select.POLLIN)
elif state == psycopg2.extensions.POLL_WRITE:
poll.register(conn.fileno(), select.POLLOUT)
else:
raise conn.OperationalError("wait_select: invalid poll state")
try:
# When the remote address does not exist at all, poll.poll() waits its full timeout without any event.
# However, in the same conditions, conn.poll() raises a psycopg2 exception almost immediately.
# It is better to fail quickly instead of waiting the full timeout, so we keep our poll.poll() below 1sec.
poll.poll(min(1.0, time_left) * 1000)
finally:
poll.unregister(conn.fileno())
16 changes: 16 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ def inject_pg_fixture(*, name: str, pgversion: str, scope="module"):


SUPPORTED_PG_VERSIONS = ["9.5", "9.6", "10", "11", "12"]
pg_cluster_for_tests: List[str] = list()
pg_source_and_target_for_tests: List[Tuple[str, str]] = list()
pg_source_and_target_for_replication_tests: List[Tuple[str, str]] = list()

Expand Down Expand Up @@ -437,6 +438,10 @@ def generate_fixtures():
pg_source_and_target_for_tests.append((source_name, target_name))
if LooseVersion(source) >= "10":
pg_source_and_target_for_replication_tests.append((source_name, target_name))
for version in set(pg_source_versions).union(pg_target_versions):
fixture_name = "pg{}".format(version.replace(".", ""))
inject_pg_fixture(name=fixture_name, pgversion=version)
pg_cluster_for_tests.append(fixture_name)


generate_fixtures()
Expand All @@ -450,6 +455,17 @@ def test_pg_source_and_target_for_replication_tests():
print(pg_source_and_target_for_replication_tests)


@pytest.fixture(name="pg_cluster", params=pg_cluster_for_tests, scope="function")
def fixture_pg_cluster(request):
"""Returns a fixture parametrized on the union of all source and target pg versions."""
cluster_runner = request.getfixturevalue(request.param)
yield cluster_runner
for cleanup in cluster_runner.cleanups:
cleanup()
cluster_runner.cleanups.clear()
cluster_runner.drop_dbs()


@pytest.fixture(name="pg_source_and_target", params=pg_source_and_target_for_tests, scope="function")
def fixture_pg_source_and_target(request):
source, target = request.param
Expand Down
26 changes: 26 additions & 0 deletions test/test_pg_cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) 2021 Aiven, Helsinki, Finland. https://aiven.io/
import signal

from aiven_db_migrate.migrate.pgmigrate import PGCluster
from multiprocessing import Process
from test.conftest import PGRunner
from typing import Tuple

import os
import pytest
import time


def test_interruptible_queries(pg_cluster: PGRunner):
def wait_and_interrupt():
time.sleep(1)
os.kill(os.getppid(), signal.SIGINT)

cluster = PGCluster(conn_info=pg_cluster.conn_info())
interuptor = Process(target=wait_and_interrupt)
interuptor.start()
start_time = time.monotonic()
with pytest.raises(KeyboardInterrupt):
cluster.c("select pg_sleep(100)")
assert time.monotonic() - start_time < 2
interuptor.join()
23 changes: 23 additions & 0 deletions test/test_pg_migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from test.utils import random_string, Timer
from typing import Any, Dict, Optional

import os
import psycopg2
import pytest
import time


class PGMigrateTest:
Expand Down Expand Up @@ -154,6 +156,27 @@ def test_migrate_invalid_conn_str(self):
PGMigrate(source_conn_info=source_conn_info, target_conn_info=target_conn_info).migrate()
assert str(err.value) == "Invalid source or target connection string"

def test_migrate_connect_timeout_parameter(self):
for source_conn_info in ("host=example.org connect_timeout=1", "postgresql://example.org?connect_timeout=1"):
start_time = time.monotonic()
with pytest.raises(TimeoutError):
PGMigrate(source_conn_info=source_conn_info, target_conn_info=self.target.conn_info()).migrate()
end_time = time.monotonic()
assert end_time - start_time < 2

def test_migrate_connect_timeout_environment(self):
start_time = time.monotonic()
original_timeout = os.environ.get("PGCONNECT_TIMEOUT")
try:
os.environ["PGCONNECT_TIMEOUT"] = "1"
with pytest.raises(TimeoutError):
PGMigrate(source_conn_info="host=example.org", target_conn_info=self.target.conn_info()).migrate()
end_time = time.monotonic()
assert end_time - start_time < 2
finally:
if original_timeout is not None:
os.environ["PGCONNECT_TIMEOUT"] = original_timeout

def test_migrate_same_server(self):
source_conn_info = target_conn_info = self.target.conn_info()
with pytest.raises(PGMigrateValidationFailedError) as err:
Expand Down

0 comments on commit c9347e2

Please sign in to comment.