Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow skipping utf8 converison in Python3 #548

Merged
merged 14 commits into from
Aug 29, 2024
80 changes: 51 additions & 29 deletions impala/hiveserver2.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ def rollback(self):
raise NotSupportedError

def cursor(self, user=None, configuration=None, convert_types=True,
dictify=False, fetch_error=True, close_finished_queries=True):
dictify=False, fetch_error=True, close_finished_queries=True,
convert_strings_to_unicode=True):
"""Get a cursor from the HiveServer2 (HS2) connection.

Parameters
Expand All @@ -96,6 +97,12 @@ def cursor(self, user=None, configuration=None, convert_types=True,
When `False`, timestamps and decimal values will not be converted
to Python `datetime` and `Decimal` values. (These conversions are
expensive.) Only applies when using HS2 protocol versions > 6.
convert_strings_to_unicode : bool, optional
When `True`, the following types will be converted to unicode:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be mentioned that these are the types that are transmitted as strings in HS2 protocol

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

STRING, LIST, MAP, STRUCT, UNIONTYPE, NULL, VARCHAR, CHAR, TIMESTAMP,
DECIMAL, DATE.
When `False`, conversion will occur only for types expected by
convert_types in python3: TIMESTAMP, DECIMAL, DATE.
dictify : bool, optional
When `True` cursor will return key value pairs instead of rows.
fetch_error : bool, optional
Expand Down Expand Up @@ -151,7 +158,8 @@ def cursor(self, user=None, configuration=None, convert_types=True,

cursor = cursor_class(session, convert_types=convert_types,
fetch_error=fetch_error,
close_finished_queries=close_finished_queries)
close_finished_queries=close_finished_queries,
convert_strings_to_unicode=convert_strings_to_unicode)

if self.default_db is not None:
log.info('Using database %s as default', self.default_db)
Expand All @@ -168,9 +176,11 @@ class HiveServer2Cursor(Cursor):
# HiveServer2Cursor objects are associated with a Session
# they are instantiated with alive session_handles

def __init__(self, session, convert_types=True, fetch_error=True, close_finished_queries=True):
def __init__(self, session, convert_types=True, fetch_error=True, close_finished_queries=True,
convert_strings_to_unicode=True):
self.session = session
self.convert_types = convert_types
self.convert_strings_to_unicode = convert_strings_to_unicode
self.fetch_error = fetch_error
self.close_finished_queries = close_finished_queries

Expand Down Expand Up @@ -570,7 +580,8 @@ def fetchcbatch(self):
batch = (self._last_operation.fetch(
self.description,
self.buffersize,
convert_types=self.convert_types))
convert_types=self.convert_types,
convert_strings_to_unicode=self.convert_strings_to_unicode))
if len(batch) == 0:
return None
return batch
Expand Down Expand Up @@ -620,7 +631,8 @@ def fetchcolumnar(self):
batch = (self._last_operation.fetch(
self.description,
self.buffersize,
convert_types=self.convert_types))
convert_types=self.convert_types,
convert_strings_to_unicode=self.convert_strings_to_unicode))
if len(batch) == 0:
break
batches.append(batch)
Expand Down Expand Up @@ -659,8 +671,9 @@ def _ensure_buffer_is_filled(self):
log.debug('_ensure_buffer_is_filled: buffer empty and op is active '
'=> fetching more data')
self._buffer = self._last_operation.fetch(self.description,
self.buffersize,
convert_types=self.convert_types)
self.buffersize,
convert_types=self.convert_types,
convert_strings_to_unicode=self.convert_strings_to_unicode)
if len(self._buffer) > 0:
return
if not self._buffer.expect_more_rows:
Expand Down Expand Up @@ -1012,7 +1025,8 @@ def pop_to_preallocated_list(self, output_list, count, offset=0, stride=1):

class CBatch(Batch):

def __init__(self, trowset, expect_more_rows, schema, convert_types=True):
def __init__(self, trowset, expect_more_rows, schema, convert_types=True,
convert_strings_to_unicode=True):
self.expect_more_rows = expect_more_rows
self.schema = schema
tcols = [_TTypeId_to_TColumnValue_getters[schema[i][1]](col)
Expand Down Expand Up @@ -1040,14 +1054,33 @@ def __init__(self, trowset, expect_more_rows, schema, convert_types=True):

# STRING columns are read as binary and decoded here to be able to handle
# non-valid utf-8 strings in Python 3.

if six.PY3:
self._convert_strings_to_unicode(type_, is_null, values)
if convert_strings_to_unicode:
self._convert_strings_to_unicode(type_, is_null, values,
types=["STRING", "LIST", "MAP", "STRUCT", "UNIONTYPE", "NULL", "VARCHAR", "CHAR", "TIMESTAMP", "DECIMAL", "DATE"])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These type lists could be moved to a constant,e.g HS2_STRING_TYPES and CONVERTED_TYPES

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely, fixed

elif convert_types:
self._convert_strings_to_unicode(type_, is_null, values, types=["TIMESTAMP", "DECIMAL", "DATE"])

if convert_types:
values = self._convert_values(type_, is_null, values)

self.columns.append(Column(type_, values, is_null))

def _convert_strings_to_unicode(self, type_, is_null, values, types):
if type_ in types:
for i in range(len(values)):
if is_null[i]:
values[i] = None
continue
try:
# Do similar handling of non-valid UTF-8 strings as Thriftpy2:
# https://github.com/Thriftpy/thriftpy2/blob/8e218b3fd89c597c2e83d129efecfe4d280bdd89/thriftpy2/protocol/binary.py#L241
# If decoding fails then keep the original bytearray.
values[i] = values[i].decode("UTF-8")
except UnicodeDecodeError:
pass

def _convert_values(self, type_, is_null, values):
# pylint: disable=consider-using-enumerate
if type_ == 'TIMESTAMP':
Expand All @@ -1062,20 +1095,6 @@ def _convert_values(self, type_, is_null, values):
values[i] = (None if is_null[i] else _parse_date(values[i]))
return values

def _convert_strings_to_unicode(self, type_, is_null, values):
if type_ in ["STRING", "LIST", "MAP", "STRUCT", "UNIONTYPE", "DECIMAL", "DATE", "TIMESTAMP", "NULL", "VARCHAR", "CHAR"]:
for i in range(len(values)):
if is_null[i]:
values[i] = None
continue
try:
# Do similar handling of non-valid UTF-8 strings as Thriftpy2:
# https://github.com/Thriftpy/thriftpy2/blob/8e218b3fd89c597c2e83d129efecfe4d280bdd89/thriftpy2/protocol/binary.py#L241
# If decoding fails then keep the original bytearray.
values[i] = values[i].decode("UTF-8")
except UnicodeDecodeError:
pass

def __len__(self):
return self.remaining_rows

Expand Down Expand Up @@ -1412,7 +1431,7 @@ def get_log(self, max_rows=1024, orientation=TFetchOrientation.FETCH_NEXT):
resp = self._rpc('FetchResults', req, False)
schema = [('Log', 'STRING', None, None, None, None, None)]
log = self._wrap_results(resp.results, resp.hasMoreRows, schema,
convert_types=True)
convert_types=True, convert_strings_to_unicode=True)
log = '\n'.join(l[0] for l in log)
return log

Expand Down Expand Up @@ -1457,7 +1476,7 @@ def get_summary(self):

def fetch(self, schema=None, max_rows=1024,
orientation=TFetchOrientation.FETCH_NEXT,
convert_types=True):
convert_types=True, convert_strings_to_unicode=True):
if not self.has_result_set:
log.debug('fetch_results: has_result_set=False')
return None
Expand All @@ -1473,15 +1492,18 @@ def fetch(self, schema=None, max_rows=1024,
# results are kept around for retry to be successful.
resp = self._rpc('FetchResults', req, False)
return self._wrap_results(resp.results, resp.hasMoreRows, schema,
convert_types=convert_types)
convert_types=convert_types,
convert_strings_to_unicode=convert_strings_to_unicode)

def _wrap_results(self, results, expect_more_rows, schema, convert_types=True):
def _wrap_results(self, results, expect_more_rows, schema, convert_types=True,
convert_strings_to_unicode=True):
if self.is_columnar:
log.debug('fetch_results: constructing CBatch')
return CBatch(results, expect_more_rows, schema, convert_types=convert_types)
return CBatch(results, expect_more_rows, schema, convert_types=convert_types,
convert_strings_to_unicode=convert_strings_to_unicode)
else:
log.debug('fetch_results: constructing RBatch')
# TODO: RBatch ignores 'convert_types'
# TODO: RBatch ignores 'convert_types' and 'convert_strings_to_unicode'
return RBatch(results, expect_more_rows, schema)

@property
Expand Down
6 changes: 6 additions & 0 deletions impala/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,9 @@ def cur(con):
cur = con.cursor()
yield cur
cur.close()

@fixture(scope='session')
def cur_noconv(con):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

convert_types is True here, so I think that cur_no_string_conversion or something similar would be better

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree.

cur = con.cursor(convert_types=True, convert_strings_to_unicode=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here and in test_data_types.py: is there a reason for removing the final new lines?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's by mistake, fixed

yield cur
cur.close()
17 changes: 11 additions & 6 deletions impala/tests/test_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,21 +54,26 @@ def date_table(cur):
table_name = 'tmp_date_table'
ddl = """CREATE TABLE {0} (d date)""".format(table_name)
cur.execute(ddl)
cur.execute('''insert into {0}
values (date "0001-01-01"), (date "1999-9-9")'''.format(table_name))
try:
yield table_name
finally:
cur.execute("DROP TABLE {0}".format(table_name))


@pytest.mark.connect
def test_date_basic(cur, date_table):
def setup_test_date_basic(cur, date_table):
"""Insert and read back a couple of data values in a wide range."""
cur.execute('''insert into {0}
values (date "0001-01-01"), (date "1999-9-9")'''.format(date_table))
cur.execute('select d from {0} order by d'.format(date_table))
results = cur.fetchall()
assert results == [(datetime.date(1, 1, 1),), (datetime.date(1999, 9, 9),)]

@pytest.mark.connect
def test_date_basic(cur, date_table):
setup_test_date_basic(cur, date_table)

@pytest.mark.connect
def test_date_basic_noconv(cur_noconv, date_table):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add tests more for types - date is a special case where string conversion is done regardless of convert_strings_to_unicode

setup_test_date_basic(cur_noconv, date_table)

@fixture(scope='module')
def timestamp_table(cur):
Expand Down Expand Up @@ -119,4 +124,4 @@ def test_utf8_strings(cur):
cur.execute('select unhex("AA")')
result = cur.fetchone()[0]
assert result == b"\xaa"
assert result.decode("UTF-8", "replace") == u"�"
assert result.decode("UTF-8", "replace") == u"�"