Skip to content

Commit

Permalink
Merge branch 'release-6.1.0'
Browse files Browse the repository at this point in the history
  • Loading branch information
mpenkov committed Aug 21, 2022
2 parents b64796e + 9d80436 commit 706b7cf
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 15 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Unreleased

# 6.1.0, 21 August 2022

- Add cert parameter to http transport params (PR [#703](https://github.com/RaRe-Technologies/smart_open/pull/703), [@stev-0](https://github.com/stev-0))
- Allow passing additional kwargs for Azure writes (PR [#702](https://github.com/RaRe-Technologies/smart_open/pull/702), [@ddelange](https://github.com/ddelange))

# 6.0.0, 24 April 2022

This release deprecates the old `ignore_ext` parameter.
Expand Down
9 changes: 4 additions & 5 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -348,15 +348,14 @@ Since going over all (or select) keys in an S3 bucket is a very common operation
.. code-block:: python
>>> from smart_open import s3
>>> # get data corresponding to 2010 and later under "silo-open-data/annual/monthly_rain"
>>> # we use workers=1 for reproducibility; you should use as many workers as you have cores
>>> bucket = 'silo-open-data'
>>> prefix = 'annual/monthly_rain/'
>>> prefix = 'Official/annual/monthly_rain/'
>>> for key, content in s3.iter_bucket(bucket, prefix=prefix, accept_key=lambda key: '/201' in key, workers=1, key_limit=3):
... print(key, round(len(content) / 2**20))
annual/monthly_rain/2010.monthly_rain.nc 13
annual/monthly_rain/2011.monthly_rain.nc 13
annual/monthly_rain/2012.monthly_rain.nc 13
Official/annual/monthly_rain/2010.monthly_rain.nc 13
Official/annual/monthly_rain/2011.monthly_rain.nc 13
Official/annual/monthly_rain/2012.monthly_rain.nc 13
GCS Credentials
---------------
Expand Down
2 changes: 2 additions & 0 deletions help.txt
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ FUNCTIONS
The username for authenticating over HTTP
password: str, optional
The password for authenticating over HTTP
cert: str/tuple, optional
If String, path to ssl client cert file (.pem). If Tuple, (‘cert’, ‘key’)
headers: dict, optional
Any headers to send in the request. If ``None``, the default headers are sent:
``{'Accept-Encoding': 'identity'}``. To use no headers at all,
Expand Down
11 changes: 9 additions & 2 deletions smart_open/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def open(
blob_id,
mode,
client=None, # type: Union[azure.storage.blob.BlobServiceClient, azure.storage.blob.ContainerClient, azure.storage.blob.BlobClient] # noqa
blob_kwargs=None,
buffer_size=DEFAULT_BUFFER_SIZE,
min_part_size=_DEFAULT_MIN_PART_SIZE,
max_concurrency=DEFAULT_MAX_CONCURRENCY,
Expand All @@ -86,10 +87,13 @@ def open(
The mode for opening the object. Must be either "rb" or "wb".
client: azure.storage.blob.BlobServiceClient, ContainerClient, or BlobClient
The Azure Blob Storage client to use when working with azure-storage-blob.
blob_kwargs: dict, optional
Additional parameters to pass to `BlobClient.commit_block_list`.
For writing only.
buffer_size: int, optional
The buffer size to use when performing I/O. For reading only.
min_part_size: int, optional
The minimum part size for multipart uploads. For writing only.
The minimum part size for multipart uploads. For writing only.
max_concurrency: int, optional
The number of parallel connections with which to download. For reading only.
Expand All @@ -111,6 +115,7 @@ def open(
container_id,
blob_id,
client,
blob_kwargs=blob_kwargs,
min_part_size=min_part_size
)
else:
Expand Down Expand Up @@ -387,12 +392,14 @@ def __init__(
container,
blob,
client, # type: Union[azure.storage.blob.BlobServiceClient, azure.storage.blob.ContainerClient, azure.storage.blob.BlobClient] # noqa
blob_kwargs=None,
min_part_size=_DEFAULT_MIN_PART_SIZE,
):
self._is_closed = False
self._container_name = container

self._blob = _get_blob_client(client, container, blob)
self._blob_kwargs = blob_kwargs or {}
# type: azure.storage.blob.BlobClient

self._min_part_size = min_part_size
Expand All @@ -419,7 +426,7 @@ def close(self):
if not self.closed:
if self._current_part.tell() > 0:
self._upload_part()
self._blob.commit_block_list(self._block_list)
self._blob.commit_block_list(self._block_list, **self._blob_kwargs)
self._block_list = []
self._is_closed = True
logger.debug("successfully closed")
Expand Down
20 changes: 15 additions & 5 deletions smart_open/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def open_uri(uri, mode, transport_params):
return open(uri, mode, **kwargs)


def open(uri, mode, kerberos=False, user=None, password=None, headers=None, timeout=None):
def open(uri, mode, kerberos=False, user=None, password=None, cert=None,
headers=None, timeout=None):
"""Implement streamed reader from a web site.
Supports Kerberos and Basic HTTP authentication.
Expand All @@ -66,6 +67,8 @@ def open(uri, mode, kerberos=False, user=None, password=None, headers=None, time
The username for authenticating over HTTP
password: str, optional
The password for authenticating over HTTP
cert: str/tuple, optional
if String, path to ssl client cert file (.pem). If Tuple, (‘cert’, ‘key’)
headers: dict, optional
Any headers to send in the request. If ``None``, the default headers are sent:
``{'Accept-Encoding': 'identity'}``. To use no headers at all,
Expand All @@ -80,7 +83,8 @@ def open(uri, mode, kerberos=False, user=None, password=None, headers=None, time
if mode == constants.READ_BINARY:
fobj = SeekableBufferedInputBase(
uri, mode, kerberos=kerberos,
user=user, password=password, headers=headers, timeout=timeout,
user=user, password=password, cert=cert,
headers=headers, timeout=timeout,
)
fobj.name = os.path.basename(urllib.parse.urlparse(uri).path)
return fobj
Expand All @@ -90,7 +94,8 @@ def open(uri, mode, kerberos=False, user=None, password=None, headers=None, time

class BufferedInputBase(io.BufferedIOBase):
def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE,
kerberos=False, user=None, password=None, headers=None, timeout=None):
kerberos=False, user=None, password=None, cert=None,
headers=None, timeout=None):
if kerberos:
import requests_kerberos
auth = requests_kerberos.HTTPKerberosAuth()
Expand All @@ -112,6 +117,7 @@ def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE,
self.response = requests.get(
url,
auth=auth,
cert=cert,
stream=True,
headers=self.headers,
timeout=self.timeout,
Expand Down Expand Up @@ -204,13 +210,15 @@ def readinto(self, b):
class SeekableBufferedInputBase(BufferedInputBase):
"""
Implement seekable streamed reader from a web site.
Supports Kerberos and Basic HTTP authentication.
Supports Kerberos, client certificate and Basic HTTP authentication.
"""

def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE,
kerberos=False, user=None, password=None, headers=None, timeout=None):
kerberos=False, user=None, password=None, cert=None,
headers=None, timeout=None):
"""
If Kerberos is True, will attempt to use the local Kerberos credentials.
If cert is set, will try to use a client certificate
Otherwise, will try to use "basic" HTTP authentication via username/password.
If none of those are set, will connect unauthenticated.
Expand All @@ -230,6 +238,7 @@ def __init__(self, url, mode='r', buffer_size=DEFAULT_BUFFER_SIZE,
else:
self.headers = headers

self.cert = cert
self.timeout = timeout

self.buffer_size = buffer_size
Expand Down Expand Up @@ -325,6 +334,7 @@ def _partial_request(self, start_pos=None):
self.url,
auth=self.auth,
stream=True,
cert=self.cert,
headers=self.headers,
timeout=self.timeout,
)
Expand Down
32 changes: 30 additions & 2 deletions smart_open/tests/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,12 @@ def __init__(self, container_client, name):
self.__contents = io.BytesIO()
self._staged_contents = {}

def commit_block_list(self, block_list):
def commit_block_list(self, block_list, metadata=None):
data = b''.join([self._staged_contents[block_blob['id']] for block_blob in block_list])
self.__contents = io.BytesIO(data)
self.set_blob_metadata(dict(size=len(data)))
metadata = metadata or {}
metadata.update({"size": len(data)})
self.set_blob_metadata(metadata)
self._container_client.register_blob_client(self)

def delete_blob(self):
Expand Down Expand Up @@ -590,6 +592,32 @@ def test_write_container_client(self):
))
assert output == [test_string]

def test_write_blob_client(self):
"""Does writing into Azure Blob Storage work correctly?"""
test_string = u"žluťoučký koníček".encode('utf8')
blob_name = "test_write_blob_client_%s" % BLOB_NAME

container_client = CLIENT.get_container_client(CONTAINER_NAME)
blob_client = container_client.get_blob_client(blob_name)

with smart_open.open(
"azure://%s/%s" % (CONTAINER_NAME, blob_name),
"wb",
transport_params={
"client": blob_client, "blob_kwargs": {"metadata": {"name": blob_name}}
},
) as fout:
fout.write(test_string)

self.assertEqual(blob_client.get_blob_properties()["name"], blob_name)

output = list(smart_open.open(
"azure://%s/%s" % (CONTAINER_NAME, blob_name),
"rb",
transport_params=dict(client=CLIENT),
))
self.assertEqual(output, [test_string])

def test_incorrect_input(self):
"""Does azure write fail on incorrect input?"""
blob_name = "test_incorrect_input_%s" % BLOB_NAME
Expand Down
12 changes: 12 additions & 0 deletions smart_open/tests/test_smart_open.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,18 @@ def test_http_pass(self):
self.assertTrue('Authorization' in actual_request.headers)
self.assertTrue(actual_request.headers['Authorization'].startswith('Basic '))

@responses.activate
def test_http_cert(self):
"""Does cert parameter get passed to requests"""
responses.add(responses.GET, "http://127.0.0.1/index.html",
body='line1\nline2', stream=True)
cert_path = '/path/to/my/cert.pem'
tp = dict(cert=cert_path)
smart_open.open("http://127.0.0.1/index.html", transport_params=tp)
self.assertEqual(len(responses.calls), 1)
actual_request = responses.calls[0].request
self.assertEqual(cert_path, actual_request.req_kwargs['cert'])

@responses.activate
def _test_compressed_http(self, suffix, query):
"""Can open <suffix> via http?"""
Expand Down
2 changes: 1 addition & 1 deletion smart_open/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = '6.0.0'
__version__ = '6.1.0'


if __name__ == '__main__':
Expand Down

0 comments on commit 706b7cf

Please sign in to comment.