Skip to content

Commit

Permalink
honour connect_timeout and PGCONNECT_TIMEOUT
Browse files Browse the repository at this point in the history
The wait_select callback previously installed to enable Ctrl+C during
long queries was breaking configurable connection timeout :
psycopg/psycopg2#944

This was replaced with a more visible async connection and a manual
call to a custom wait_select with support for timeout.

The timeout mimics default libpq behavior and reads the connect_timeout
connection parameter with a fallback on PGCONNECT_TIMEOUT environment
variable (and a default of 0: no timeout).

A secondary benefit is to allow importing PGMigrate inside another
project without PGMigrate altering the global set_wait_callback.
  • Loading branch information
kmichel-aiven committed Apr 19, 2021
1 parent 8e3a4db commit 3b54930
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 5 deletions.
15 changes: 10 additions & 5 deletions aiven_db_migrate/migrate/pgmigrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
PGDataDumpFailedError, PGDataNotFoundError, PGMigrateValidationFailedError, PGSchemaDumpFailedError, PGTooMuchDataError
)
from aiven_db_migrate.migrate.pgutils import (
create_connection_string, find_pgbin_dir, get_connection_info, validate_pg_identifier_length
create_connection_string, find_pgbin_dir, get_connection_info, validate_pg_identifier_length, wait_select
)
from aiven_db_migrate.migrate.version import __version__
from concurrent import futures
Expand All @@ -31,8 +31,6 @@
import threading
import time

# https://www.psycopg.org/docs/faq.html#faq-interrupt-query
psycopg2.extensions.set_wait_callback(psycopg2.extras.wait_select)
MAX_CLI_LEN = 2097152 # getconf ARG_MAX


Expand Down Expand Up @@ -136,6 +134,12 @@ def conn_str(self, *, dbname: str = None) -> str:
conn_info["application_name"] = conn_info["application_name"] + "/" + self.mangle_db_name(conn_info["dbname"])
return create_connection_string(conn_info)

def connect_timeout(self):
try:
return int(self.conn_info.get("connect_timeout", os.environ.get("PGCONNECT_TIMEOUT", "")), 10)
except ValueError:
return None

@contextmanager
def _cursor(self, *, dbname: str = None) -> RealDictCursor:
conn: psycopg2.extensions.connection = None
Expand All @@ -146,8 +150,8 @@ def _cursor(self, *, dbname: str = None) -> RealDictCursor:
# from multiple threads; allow only one connection at time
self.conn_lock.acquire()
try:
conn = psycopg2.connect(**conn_info)
conn.autocommit = True
conn = psycopg2.connect(**conn_info, async_=True)
wait_select(conn, self.connect_timeout())
yield conn.cursor(cursor_factory=RealDictCursor)
finally:
if conn is not None:
Expand All @@ -166,6 +170,7 @@ def c(
results: List[Dict[str, Any]] = []
with self._cursor(dbname=dbname) as cur:
cur.execute(query, args)
wait_select(cur.connection)
if return_rows:
results = cur.fetchall()
if return_rows > 0 and len(results) != return_rows:
Expand Down
37 changes: 37 additions & 0 deletions aiven_db_migrate/migrate/pgutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from typing import Any, Dict
from urllib.parse import parse_qs, urlparse

import psycopg2
import select
import time


def find_pgbin_dir(pgversion: str) -> Path:
def _pgbin_paths():
Expand Down Expand Up @@ -105,3 +109,36 @@ def parse_connection_string_url(url: str) -> Dict[str, str]:
for k, v in parse_qs(p.query).items():
fields[k] = v[-1]
return fields


def wait_select(conn, timeout=None):
start_time = time.monotonic()
poll = select.poll()
while True:
if timeout is not None and timeout > 0:
time_left = start_time + timeout - time.monotonic()
if time_left <= 0:
raise TimeoutError("wait_select: timeout after {} seconds".format(timeout))
else:
time_left = 1
state = conn.poll()
try:
if state == psycopg2.extensions.POLL_OK:
return
elif state == psycopg2.extensions.POLL_READ:
poll.register(conn.fileno(), select.POLLIN)
elif state == psycopg2.extensions.POLL_WRITE:
poll.register(conn.fileno(), select.POLLOUT)
else:
raise conn.OperationalError("wait_select: invalid poll state")
try:
# When the remote address does not exist at all, poll.poll() waits its full timeout without any event.
# However, in the same conditions, conn.poll() raises a psycopg2 exception almost immediately.
# It is better to fail quickly instead of waiting the full timeout, so we keep our poll.poll() below 1sec.
poll.poll(min(1.0, time_left) * 1000)
finally:
poll.unregister(conn.fileno())
except KeyboardInterrupt:
conn.cancel()
# the loop will be broken by a server error
continue
23 changes: 23 additions & 0 deletions test/test_pg_migrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
from test.utils import random_string, Timer
from typing import Any, Dict, Optional

import os
import psycopg2
import pytest
import time


class PGMigrateTest:
Expand Down Expand Up @@ -154,6 +156,27 @@ def test_migrate_invalid_conn_str(self):
PGMigrate(source_conn_info=source_conn_info, target_conn_info=target_conn_info).migrate()
assert str(err.value) == "Invalid source or target connection string"

def test_migrate_connect_timeout_parameter(self):
for source_conn_info in ("host=example.org connect_timeout=1", "postgresql://example.org?connect_timeout=1"):
start_time = time.monotonic()
with pytest.raises(TimeoutError):
PGMigrate(source_conn_info=source_conn_info, target_conn_info=self.target.conn_info()).migrate()
end_time = time.monotonic()
assert end_time - start_time < 2

def test_migrate_connect_timeout_environment(self):
start_time = time.monotonic()
original_timeout = os.environ.get("PGCONNECT_TIMEOUT")
try:
os.environ["PGCONNECT_TIMEOUT"] = "1"
with pytest.raises(TimeoutError):
PGMigrate(source_conn_info="host=example.org", target_conn_info=self.target.conn_info()).migrate()
end_time = time.monotonic()
assert end_time - start_time < 2
finally:
if original_timeout is not None:
os.environ["PGCONNECT_TIMEOUT"] = original_timeout

def test_migrate_same_server(self):
source_conn_info = target_conn_info = self.target.conn_info()
with pytest.raises(PGMigrateValidationFailedError) as err:
Expand Down

0 comments on commit 3b54930

Please sign in to comment.