Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow skipping utf8 converison in Python3 #548

Merged
merged 14 commits into from
Aug 29, 2024
82 changes: 53 additions & 29 deletions impala/hiveserver2.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ def rollback(self):
raise NotSupportedError

def cursor(self, user=None, configuration=None, convert_types=True,
dictify=False, fetch_error=True, close_finished_queries=True):
dictify=False, fetch_error=True, close_finished_queries=True,
convert_strings_to_unicode=True):
"""Get a cursor from the HiveServer2 (HS2) connection.

Parameters
Expand All @@ -96,6 +97,12 @@ def cursor(self, user=None, configuration=None, convert_types=True,
When `False`, timestamps and decimal values will not be converted
to Python `datetime` and `Decimal` values. (These conversions are
expensive.) Only applies when using HS2 protocol versions > 6.
convert_strings_to_unicode : bool, optional
When `True`, the following types, which are transmitted as strings
in HS2 protocol, will be converted to unicode: STRING, LIST, MAP,
STRUCT, UNIONTYPE, NULL, VARCHAR, CHAR, TIMESTAMP, DECIMAL, DATE.
When `False`, conversion will occur only for types expected by
convert_types in python3: TIMESTAMP, DECIMAL, DATE.
dictify : bool, optional
When `True` cursor will return key value pairs instead of rows.
fetch_error : bool, optional
Expand Down Expand Up @@ -151,7 +158,8 @@ def cursor(self, user=None, configuration=None, convert_types=True,

cursor = cursor_class(session, convert_types=convert_types,
fetch_error=fetch_error,
close_finished_queries=close_finished_queries)
close_finished_queries=close_finished_queries,
convert_strings_to_unicode=convert_strings_to_unicode)

if self.default_db is not None:
log.info('Using database %s as default', self.default_db)
Expand All @@ -168,9 +176,11 @@ class HiveServer2Cursor(Cursor):
# HiveServer2Cursor objects are associated with a Session
# they are instantiated with alive session_handles

def __init__(self, session, convert_types=True, fetch_error=True, close_finished_queries=True):
def __init__(self, session, convert_types=True, fetch_error=True, close_finished_queries=True,
convert_strings_to_unicode=True):
self.session = session
self.convert_types = convert_types
self.convert_strings_to_unicode = convert_strings_to_unicode
self.fetch_error = fetch_error
self.close_finished_queries = close_finished_queries

Expand Down Expand Up @@ -570,7 +580,8 @@ def fetchcbatch(self):
batch = (self._last_operation.fetch(
self.description,
self.buffersize,
convert_types=self.convert_types))
convert_types=self.convert_types,
convert_strings_to_unicode=self.convert_strings_to_unicode))
if len(batch) == 0:
return None
return batch
Expand Down Expand Up @@ -620,7 +631,8 @@ def fetchcolumnar(self):
batch = (self._last_operation.fetch(
self.description,
self.buffersize,
convert_types=self.convert_types))
convert_types=self.convert_types,
convert_strings_to_unicode=self.convert_strings_to_unicode))
if len(batch) == 0:
break
batches.append(batch)
Expand Down Expand Up @@ -659,8 +671,9 @@ def _ensure_buffer_is_filled(self):
log.debug('_ensure_buffer_is_filled: buffer empty and op is active '
'=> fetching more data')
self._buffer = self._last_operation.fetch(self.description,
self.buffersize,
convert_types=self.convert_types)
self.buffersize,
convert_types=self.convert_types,
convert_strings_to_unicode=self.convert_strings_to_unicode)
if len(self._buffer) > 0:
return
if not self._buffer.expect_more_rows:
Expand Down Expand Up @@ -1012,7 +1025,8 @@ def pop_to_preallocated_list(self, output_list, count, offset=0, stride=1):

class CBatch(Batch):

def __init__(self, trowset, expect_more_rows, schema, convert_types=True):
def __init__(self, trowset, expect_more_rows, schema, convert_types=True,
convert_strings_to_unicode=True):
self.expect_more_rows = expect_more_rows
self.schema = schema
tcols = [_TTypeId_to_TColumnValue_getters[schema[i][1]](col)
Expand All @@ -1023,6 +1037,9 @@ def __init__(self, trowset, expect_more_rows, schema, convert_types=True):

log.debug('CBatch: input TRowSet num_cols=%s num_rows=%s tcols=%s',
num_cols, num_rows, tcols)

HS2_STRING_TYPES = ["STRING", "LIST", "MAP", "STRUCT", "UNIONTYPE", "NULL", "VARCHAR", "CHAR", "TIMESTAMP", "DECIMAL", "DATE"]
CONVERTED_TYPES=["TIMESTAMP", "DECIMAL", "DATE"]

self.columns = []
for j in range(num_cols):
Expand All @@ -1040,14 +1057,32 @@ def __init__(self, trowset, expect_more_rows, schema, convert_types=True):

# STRING columns are read as binary and decoded here to be able to handle
# non-valid utf-8 strings in Python 3.

if six.PY3:
self._convert_strings_to_unicode(type_, is_null, values)
if convert_strings_to_unicode:
self._convert_strings_to_unicode(type_, is_null, values, types=HS2_STRING_TYPES)
elif convert_types:
self._convert_strings_to_unicode(type_, is_null, values, types=CONVERTED_TYPES)

if convert_types:
values = self._convert_values(type_, is_null, values)

self.columns.append(Column(type_, values, is_null))

def _convert_strings_to_unicode(self, type_, is_null, values, types):
if type_ in types:
for i in range(len(values)):
if is_null[i]:
values[i] = None
continue
try:
# Do similar handling of non-valid UTF-8 strings as Thriftpy2:
# https://github.com/Thriftpy/thriftpy2/blob/8e218b3fd89c597c2e83d129efecfe4d280bdd89/thriftpy2/protocol/binary.py#L241
# If decoding fails then keep the original bytearray.
values[i] = values[i].decode("UTF-8")
except UnicodeDecodeError:
pass

def _convert_values(self, type_, is_null, values):
# pylint: disable=consider-using-enumerate
if type_ == 'TIMESTAMP':
Expand All @@ -1062,20 +1097,6 @@ def _convert_values(self, type_, is_null, values):
values[i] = (None if is_null[i] else _parse_date(values[i]))
return values

def _convert_strings_to_unicode(self, type_, is_null, values):
if type_ in ["STRING", "LIST", "MAP", "STRUCT", "UNIONTYPE", "DECIMAL", "DATE", "TIMESTAMP", "NULL", "VARCHAR", "CHAR"]:
for i in range(len(values)):
if is_null[i]:
values[i] = None
continue
try:
# Do similar handling of non-valid UTF-8 strings as Thriftpy2:
# https://github.com/Thriftpy/thriftpy2/blob/8e218b3fd89c597c2e83d129efecfe4d280bdd89/thriftpy2/protocol/binary.py#L241
# If decoding fails then keep the original bytearray.
values[i] = values[i].decode("UTF-8")
except UnicodeDecodeError:
pass

def __len__(self):
return self.remaining_rows

Expand Down Expand Up @@ -1412,7 +1433,7 @@ def get_log(self, max_rows=1024, orientation=TFetchOrientation.FETCH_NEXT):
resp = self._rpc('FetchResults', req, False)
schema = [('Log', 'STRING', None, None, None, None, None)]
log = self._wrap_results(resp.results, resp.hasMoreRows, schema,
convert_types=True)
convert_types=True, convert_strings_to_unicode=True)
log = '\n'.join(l[0] for l in log)
return log

Expand Down Expand Up @@ -1457,7 +1478,7 @@ def get_summary(self):

def fetch(self, schema=None, max_rows=1024,
orientation=TFetchOrientation.FETCH_NEXT,
convert_types=True):
convert_types=True, convert_strings_to_unicode=True):
if not self.has_result_set:
log.debug('fetch_results: has_result_set=False')
return None
Expand All @@ -1473,15 +1494,18 @@ def fetch(self, schema=None, max_rows=1024,
# results are kept around for retry to be successful.
resp = self._rpc('FetchResults', req, False)
return self._wrap_results(resp.results, resp.hasMoreRows, schema,
convert_types=convert_types)
convert_types=convert_types,
convert_strings_to_unicode=convert_strings_to_unicode)

def _wrap_results(self, results, expect_more_rows, schema, convert_types=True):
def _wrap_results(self, results, expect_more_rows, schema, convert_types=True,
convert_strings_to_unicode=True):
if self.is_columnar:
log.debug('fetch_results: constructing CBatch')
return CBatch(results, expect_more_rows, schema, convert_types=convert_types)
return CBatch(results, expect_more_rows, schema, convert_types=convert_types,
convert_strings_to_unicode=convert_strings_to_unicode)
else:
log.debug('fetch_results: constructing RBatch')
# TODO: RBatch ignores 'convert_types'
# TODO: RBatch ignores 'convert_types' and 'convert_strings_to_unicode'
return RBatch(results, expect_more_rows, schema)

@property
Expand Down
6 changes: 6 additions & 0 deletions impala/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,9 @@ def cur(con):
cur = con.cursor()
yield cur
cur.close()

@fixture(scope='session')
def cur_no_string_conv(con):
cur = con.cursor(convert_types=True, convert_strings_to_unicode=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here and in test_data_types.py: is there a reason for removing the final new lines?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's by mistake, fixed

yield cur
cur.close()
111 changes: 96 additions & 15 deletions impala/tests/test_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

import datetime
import pytest
import sys
from pytest import fixture
from decimal import Decimal

@fixture(scope='module')
def decimal_table(cur):
Expand Down Expand Up @@ -49,48 +51,97 @@ def test_cursor_description_precision_scale(cur, decimal_table):
for (exp, obs) in zip(expected, observed):
assert exp == obs


@fixture(scope='module')
def decimal_table2(cur):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was the existing decimal_table not enough for the tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found it more consistent with other tests this way, but it's really minor to change, if you insist.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ack

table_name = 'tmp_decimal_table2'
ddl = """CREATE TABLE {0} (val decimal(18, 9))""".format(table_name)
cur.execute(ddl)
cur.execute('''insert into {0}
values (cast(123456789.123456789 as decimal(18, 9))),
(cast(-123456789.123456789 as decimal(18, 9))),
(cast(0.000000001 as decimal(18, 9))),
(cast(-0.000000001 as decimal(18, 9))),
(cast(999999999.999999999 as decimal(18, 9))),
(cast(-999999999.999999999 as decimal(18, 9)))'''.format(table_name))
try:
yield table_name
finally:
cur.execute("DROP TABLE {0}".format(table_name))


def common_test_decimal(cur, decimal_table):
"""Read back a few decimal values in a wide range."""
cur.execute('select val from {0} order by val'.format(decimal_table))
results = cur.fetchall()
assert results == [(Decimal('-999999999.999999999'),),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice to have: it would increase the coverage of the tests if there would be also NULL values in these test tables

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, added

(Decimal('-123456789.123456789'),),
(Decimal('-0.000000001'),),
(Decimal('0.000000001'),),
(Decimal('123456789.123456789'),),
(Decimal('999999999.999999999'),)]


@pytest.mark.connect
def test_decimal_basic(cur, decimal_table2):
common_test_decimal(cur, decimal_table2)


@pytest.mark.connect
def test_decimal_no_string_conv(cur_no_string_conv, decimal_table2):
common_test_decimal(cur_no_string_conv, decimal_table2)


@fixture(scope='module')
def date_table(cur):
table_name = 'tmp_date_table'
ddl = """CREATE TABLE {0} (d date)""".format(table_name)
cur.execute(ddl)
cur.execute('''insert into {0}
values (date "0001-01-01"), (date "1999-9-9")'''.format(table_name))
try:
yield table_name
finally:
cur.execute("DROP TABLE {0}".format(table_name))


@pytest.mark.connect
def test_date_basic(cur, date_table):
"""Insert and read back a couple of data values in a wide range."""
cur.execute('''insert into {0}
values (date "0001-01-01"), (date "1999-9-9")'''.format(date_table))
def common_test_date(cur, date_table):
"""Read back a couple of data values in a wide range."""
cur.execute('select d from {0} order by d'.format(date_table))
results = cur.fetchall()
assert results == [(datetime.date(1, 1, 1),), (datetime.date(1999, 9, 9),)]


@pytest.mark.connect
def test_date_basic(cur, date_table):
common_test_date(cur, date_table)


@pytest.mark.connect
def test_date_no_string_conv(cur_no_string_conv, date_table):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add similar test for decimals?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure

common_test_date(cur_no_string_conv, date_table)


@fixture(scope='module')
def timestamp_table(cur):
table_name = 'tmp_timestamp_table'
ddl = """CREATE TABLE {0} (ts timestamp)""".format(table_name)
cur.execute(ddl)
try:
yield table_name
finally:
cur.execute("DROP TABLE {0}".format(table_name))


@pytest.mark.connect
def test_timestamp_basic(cur, timestamp_table):
"""Insert and read back a few timestamp values in a wide range."""
cur.execute('''insert into {0}
values (cast("1400-01-01 00:00:00" as timestamp)),
(cast("2014-06-23 13:30:51" as timestamp)),
(cast("2014-06-23 13:30:51.123" as timestamp)),
(cast("2014-06-23 13:30:51.123456" as timestamp)),
(cast("2014-06-23 13:30:51.123456789" as timestamp)),
(cast("9999-12-31 23:59:59" as timestamp))'''.format(timestamp_table))
(cast("9999-12-31 23:59:59" as timestamp))'''.format(table_name))
try:
yield table_name
finally:
cur.execute("DROP TABLE {0}".format(table_name))


def common_test_timestamp(cur, timestamp_table):
"""Read back a few timestamp values in a wide range."""
cur.execute('select ts from {0} order by ts'.format(timestamp_table))
results = cur.fetchall()
assert results == [(datetime.datetime(1400, 1, 1, 0, 0),),
Expand All @@ -101,6 +152,16 @@ def test_timestamp_basic(cur, timestamp_table):
(datetime.datetime(9999, 12, 31, 23, 59, 59),)]


@pytest.mark.connect
def test_timestamp_basic(cur, timestamp_table):
common_test_timestamp(cur, timestamp_table)


@pytest.mark.connect
def test_timestamp_no_string_conv(cur_no_string_conv, timestamp_table):
common_test_timestamp(cur_no_string_conv, timestamp_table)


@pytest.mark.connect
def test_utf8_strings(cur):
"""Use STRING/VARCHAR/CHAR values with multi byte unicode code points in a query."""
Expand All @@ -120,3 +181,23 @@ def test_utf8_strings(cur):
result = cur.fetchone()[0]
assert result == b"\xaa"
assert result.decode("UTF-8", "replace") == u"�"


@pytest.mark.connect
def test_string_conv(cur):
cur.execute('select "Test string"')
result = cur.fetchone()
assert result[0] == u"Test string"


@pytest.mark.connect
def test_string_no_string_conv(cur_no_string_conv):
cur = cur_no_string_conv
cur.execute('select "Test string"')
result = cur.fetchone()


if sys.version_info[0] < 3:
assert result[0] == u"Test string"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this fail if 'u' was not added? Python does automatic conversion of strings in many cases. We should check the exact type with isinstance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

else:
assert result[0] == b"Test string"