-
Notifications
You must be signed in to change notification settings - Fork 248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow skipping utf8 converison in Python3 #548
Changes from 7 commits
9acc61f
888dbde
3938f5f
afb36fe
03e7f3a
1b8a8a4
c68165a
66a590e
08cd5d7
5bc6155
c9ed747
be3edb2
7f831e7
09816fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -84,7 +84,8 @@ def rollback(self): | |
raise NotSupportedError | ||
|
||
def cursor(self, user=None, configuration=None, convert_types=True, | ||
dictify=False, fetch_error=True, close_finished_queries=True): | ||
dictify=False, fetch_error=True, close_finished_queries=True, | ||
convert_strings_to_unicode=True): | ||
"""Get a cursor from the HiveServer2 (HS2) connection. | ||
|
||
Parameters | ||
|
@@ -96,6 +97,12 @@ def cursor(self, user=None, configuration=None, convert_types=True, | |
When `False`, timestamps and decimal values will not be converted | ||
to Python `datetime` and `Decimal` values. (These conversions are | ||
expensive.) Only applies when using HS2 protocol versions > 6. | ||
convert_strings_to_unicode : bool, optional | ||
When `True`, the following types will be converted to unicode: | ||
STRING, LIST, MAP, STRUCT, UNIONTYPE, NULL, VARCHAR, CHAR, TIMESTAMP, | ||
DECIMAL, DATE. | ||
When `False`, conversion will occur only for types expected by | ||
convert_types in python3: TIMESTAMP, DECIMAL, DATE. | ||
dictify : bool, optional | ||
When `True` cursor will return key value pairs instead of rows. | ||
fetch_error : bool, optional | ||
|
@@ -151,7 +158,8 @@ def cursor(self, user=None, configuration=None, convert_types=True, | |
|
||
cursor = cursor_class(session, convert_types=convert_types, | ||
fetch_error=fetch_error, | ||
close_finished_queries=close_finished_queries) | ||
close_finished_queries=close_finished_queries, | ||
convert_strings_to_unicode=convert_strings_to_unicode) | ||
|
||
if self.default_db is not None: | ||
log.info('Using database %s as default', self.default_db) | ||
|
@@ -168,9 +176,11 @@ class HiveServer2Cursor(Cursor): | |
# HiveServer2Cursor objects are associated with a Session | ||
# they are instantiated with alive session_handles | ||
|
||
def __init__(self, session, convert_types=True, fetch_error=True, close_finished_queries=True): | ||
def __init__(self, session, convert_types=True, fetch_error=True, close_finished_queries=True, | ||
convert_strings_to_unicode=True): | ||
self.session = session | ||
self.convert_types = convert_types | ||
self.convert_strings_to_unicode = convert_strings_to_unicode | ||
self.fetch_error = fetch_error | ||
self.close_finished_queries = close_finished_queries | ||
|
||
|
@@ -570,7 +580,8 @@ def fetchcbatch(self): | |
batch = (self._last_operation.fetch( | ||
self.description, | ||
self.buffersize, | ||
convert_types=self.convert_types)) | ||
convert_types=self.convert_types, | ||
convert_strings_to_unicode=self.convert_strings_to_unicode)) | ||
if len(batch) == 0: | ||
return None | ||
return batch | ||
|
@@ -620,7 +631,8 @@ def fetchcolumnar(self): | |
batch = (self._last_operation.fetch( | ||
self.description, | ||
self.buffersize, | ||
convert_types=self.convert_types)) | ||
convert_types=self.convert_types, | ||
convert_strings_to_unicode=self.convert_strings_to_unicode)) | ||
if len(batch) == 0: | ||
break | ||
batches.append(batch) | ||
|
@@ -659,8 +671,9 @@ def _ensure_buffer_is_filled(self): | |
log.debug('_ensure_buffer_is_filled: buffer empty and op is active ' | ||
'=> fetching more data') | ||
self._buffer = self._last_operation.fetch(self.description, | ||
self.buffersize, | ||
convert_types=self.convert_types) | ||
self.buffersize, | ||
convert_types=self.convert_types, | ||
convert_strings_to_unicode=self.convert_strings_to_unicode) | ||
if len(self._buffer) > 0: | ||
return | ||
if not self._buffer.expect_more_rows: | ||
|
@@ -1012,7 +1025,8 @@ def pop_to_preallocated_list(self, output_list, count, offset=0, stride=1): | |
|
||
class CBatch(Batch): | ||
|
||
def __init__(self, trowset, expect_more_rows, schema, convert_types=True): | ||
def __init__(self, trowset, expect_more_rows, schema, convert_types=True, | ||
convert_strings_to_unicode=True): | ||
self.expect_more_rows = expect_more_rows | ||
self.schema = schema | ||
tcols = [_TTypeId_to_TColumnValue_getters[schema[i][1]](col) | ||
|
@@ -1040,14 +1054,33 @@ def __init__(self, trowset, expect_more_rows, schema, convert_types=True): | |
|
||
# STRING columns are read as binary and decoded here to be able to handle | ||
# non-valid utf-8 strings in Python 3. | ||
|
||
if six.PY3: | ||
self._convert_strings_to_unicode(type_, is_null, values) | ||
if convert_strings_to_unicode: | ||
self._convert_strings_to_unicode(type_, is_null, values, | ||
types=["STRING", "LIST", "MAP", "STRUCT", "UNIONTYPE", "NULL", "VARCHAR", "CHAR", "TIMESTAMP", "DECIMAL", "DATE"]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These type lists could be moved to a constant,e.g HS2_STRING_TYPES and CONVERTED_TYPES There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Absolutely, fixed |
||
elif convert_types: | ||
self._convert_strings_to_unicode(type_, is_null, values, types=["TIMESTAMP", "DECIMAL", "DATE"]) | ||
|
||
if convert_types: | ||
values = self._convert_values(type_, is_null, values) | ||
|
||
self.columns.append(Column(type_, values, is_null)) | ||
|
||
def _convert_strings_to_unicode(self, type_, is_null, values, types): | ||
if type_ in types: | ||
for i in range(len(values)): | ||
if is_null[i]: | ||
values[i] = None | ||
continue | ||
try: | ||
# Do similar handling of non-valid UTF-8 strings as Thriftpy2: | ||
# https://github.com/Thriftpy/thriftpy2/blob/8e218b3fd89c597c2e83d129efecfe4d280bdd89/thriftpy2/protocol/binary.py#L241 | ||
# If decoding fails then keep the original bytearray. | ||
values[i] = values[i].decode("UTF-8") | ||
except UnicodeDecodeError: | ||
pass | ||
|
||
def _convert_values(self, type_, is_null, values): | ||
# pylint: disable=consider-using-enumerate | ||
if type_ == 'TIMESTAMP': | ||
|
@@ -1062,20 +1095,6 @@ def _convert_values(self, type_, is_null, values): | |
values[i] = (None if is_null[i] else _parse_date(values[i])) | ||
return values | ||
|
||
def _convert_strings_to_unicode(self, type_, is_null, values): | ||
if type_ in ["STRING", "LIST", "MAP", "STRUCT", "UNIONTYPE", "DECIMAL", "DATE", "TIMESTAMP", "NULL", "VARCHAR", "CHAR"]: | ||
for i in range(len(values)): | ||
if is_null[i]: | ||
values[i] = None | ||
continue | ||
try: | ||
# Do similar handling of non-valid UTF-8 strings as Thriftpy2: | ||
# https://github.com/Thriftpy/thriftpy2/blob/8e218b3fd89c597c2e83d129efecfe4d280bdd89/thriftpy2/protocol/binary.py#L241 | ||
# If decoding fails then keep the original bytearray. | ||
values[i] = values[i].decode("UTF-8") | ||
except UnicodeDecodeError: | ||
pass | ||
|
||
def __len__(self): | ||
return self.remaining_rows | ||
|
||
|
@@ -1412,7 +1431,7 @@ def get_log(self, max_rows=1024, orientation=TFetchOrientation.FETCH_NEXT): | |
resp = self._rpc('FetchResults', req, False) | ||
schema = [('Log', 'STRING', None, None, None, None, None)] | ||
log = self._wrap_results(resp.results, resp.hasMoreRows, schema, | ||
convert_types=True) | ||
convert_types=True, convert_strings_to_unicode=True) | ||
log = '\n'.join(l[0] for l in log) | ||
return log | ||
|
||
|
@@ -1457,7 +1476,7 @@ def get_summary(self): | |
|
||
def fetch(self, schema=None, max_rows=1024, | ||
orientation=TFetchOrientation.FETCH_NEXT, | ||
convert_types=True): | ||
convert_types=True, convert_strings_to_unicode=True): | ||
if not self.has_result_set: | ||
log.debug('fetch_results: has_result_set=False') | ||
return None | ||
|
@@ -1473,15 +1492,18 @@ def fetch(self, schema=None, max_rows=1024, | |
# results are kept around for retry to be successful. | ||
resp = self._rpc('FetchResults', req, False) | ||
return self._wrap_results(resp.results, resp.hasMoreRows, schema, | ||
convert_types=convert_types) | ||
convert_types=convert_types, | ||
convert_strings_to_unicode=convert_strings_to_unicode) | ||
|
||
def _wrap_results(self, results, expect_more_rows, schema, convert_types=True): | ||
def _wrap_results(self, results, expect_more_rows, schema, convert_types=True, | ||
convert_strings_to_unicode=True): | ||
if self.is_columnar: | ||
log.debug('fetch_results: constructing CBatch') | ||
return CBatch(results, expect_more_rows, schema, convert_types=convert_types) | ||
return CBatch(results, expect_more_rows, schema, convert_types=convert_types, | ||
convert_strings_to_unicode=convert_strings_to_unicode) | ||
else: | ||
log.debug('fetch_results: constructing RBatch') | ||
# TODO: RBatch ignores 'convert_types' | ||
# TODO: RBatch ignores 'convert_types' and 'convert_strings_to_unicode' | ||
return RBatch(results, expect_more_rows, schema) | ||
|
||
@property | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -113,3 +113,9 @@ def cur(con): | |
cur = con.cursor() | ||
yield cur | ||
cur.close() | ||
|
||
@fixture(scope='session') | ||
def cur_noconv(con): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. convert_types is True here, so I think that cur_no_string_conversion or something similar would be better There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree. |
||
cur = con.cursor(convert_types=True, convert_strings_to_unicode=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here and in test_data_types.py: is there a reason for removing the final new lines? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's by mistake, fixed |
||
yield cur | ||
cur.close() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,21 +54,26 @@ def date_table(cur): | |
table_name = 'tmp_date_table' | ||
ddl = """CREATE TABLE {0} (d date)""".format(table_name) | ||
cur.execute(ddl) | ||
cur.execute('''insert into {0} | ||
values (date "0001-01-01"), (date "1999-9-9")'''.format(table_name)) | ||
try: | ||
yield table_name | ||
finally: | ||
cur.execute("DROP TABLE {0}".format(table_name)) | ||
|
||
|
||
@pytest.mark.connect | ||
def test_date_basic(cur, date_table): | ||
def setup_test_date_basic(cur, date_table): | ||
"""Insert and read back a couple of data values in a wide range.""" | ||
cur.execute('''insert into {0} | ||
values (date "0001-01-01"), (date "1999-9-9")'''.format(date_table)) | ||
cur.execute('select d from {0} order by d'.format(date_table)) | ||
results = cur.fetchall() | ||
assert results == [(datetime.date(1, 1, 1),), (datetime.date(1999, 9, 9),)] | ||
|
||
@pytest.mark.connect | ||
def test_date_basic(cur, date_table): | ||
setup_test_date_basic(cur, date_table) | ||
|
||
@pytest.mark.connect | ||
def test_date_basic_noconv(cur_noconv, date_table): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add tests more for types - date is a special case where string conversion is done regardless of convert_strings_to_unicode |
||
setup_test_date_basic(cur_noconv, date_table) | ||
|
||
@fixture(scope='module') | ||
def timestamp_table(cur): | ||
|
@@ -119,4 +124,4 @@ def test_utf8_strings(cur): | |
cur.execute('select unhex("AA")') | ||
result = cur.fetchone()[0] | ||
assert result == b"\xaa" | ||
assert result.decode("UTF-8", "replace") == u"�" | ||
assert result.decode("UTF-8", "replace") == u"�" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It could be mentioned that these are the types that are transmitted as strings in HS2 protocol
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added.