Skip to content

Commit

Permalink
Use epath.Path in downloader.py
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 676837944
  • Loading branch information
fineguy authored and The TensorFlow Datasets Authors committed Sep 20, 2024
1 parent 3b0dab2 commit fc31737
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 87 deletions.
121 changes: 58 additions & 63 deletions tensorflow_datasets/core/download/download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,21 +297,21 @@ def __getstate__(self):
return state

@property
def _downloader(self):
def _downloader(self) -> downloader._Downloader:
if not self.__downloader:
self.__downloader = get_downloader(
max_simultaneous_downloads=self._max_simultaneous_downloads
)
return self.__downloader

@property
def _extractor(self):
def _extractor(self) -> extractor._Extractor:
if not self.__extractor:
self.__extractor = extractor.get_extractor()
return self.__extractor

@property
def downloaded_size(self):
def downloaded_size(self) -> int:
"""Returns the total size of downloaded files."""
return sum(url_info.size for url_info in self._recorded_url_infos.values())

Expand All @@ -331,6 +331,22 @@ def _record_url_infos(self):
self._recorded_url_infos,
)

def _get_manually_downloaded_path(
self, expected_url_info: checksums.UrlInfo | None
) -> epath.Path | None:
"""Checks if file is already downloaded in manual_dir."""
if not self._manual_dir: # Manual dir not passed
return None

if not expected_url_info or not expected_url_info.filename:
return None # Filename unknown.

manual_path = self._manual_dir / expected_url_info.filename
if not manual_path.exists(): # File not manually downloaded
return None

return manual_path

# Synchronize and memoize decorators ensure same resource will only be
# processed once, even if passed twice to download_manager.
@utils.build_synchronize_decorator()
Expand Down Expand Up @@ -363,9 +379,8 @@ def _download(self, resource: Url) -> promise.Promise[epath.Path]:
# * In `manual_dir` (manually downloaded data)
# * In `downloads/url_path` (checksum unknown)
# * In `downloads/checksum_path` (checksum registered)
manually_downloaded_path = _get_manually_downloaded_path(
manual_dir=self._manual_dir,
expected_url_info=expected_url_info,
manually_downloaded_path = self._get_manually_downloaded_path(
expected_url_info=expected_url_info
)
url_path = self._get_dl_path(url)
checksum_path = (
Expand Down Expand Up @@ -459,12 +474,11 @@ def _register_or_validate_checksums(
# the download isn't cached (re-running build will retrigger a new
# download). This is expected as it might mean the downloaded file
# was corrupted. Note: The tmp file isn't deleted to allow inspection.
_validate_checksums(
self._validate_checksums(
url=url,
path=path,
expected_url_info=expected_url_info,
computed_url_info=computed_url_info,
force_checksums_validation=self._force_checksums_validation,
)

return self._rename_and_get_final_dl_path(
Expand All @@ -476,6 +490,42 @@ def _register_or_validate_checksums(
url_path=url_path,
)

def _validate_checksums(
self,
url: str,
path: epath.Path,
computed_url_info: checksums.UrlInfo | None,
expected_url_info: checksums.UrlInfo | None,
) -> None:
"""Validate computed_url_info match expected_url_info."""
# If force-checksums validations, both expected and computed url_info
# should exists
if self._force_checksums_validation:
# Checksum of the downloaded file unknown (for manually downloaded file)
if not computed_url_info:
computed_url_info = checksums.compute_url_info(path)
# Checksums have not been registered
if not expected_url_info:
raise ValueError(
f'Missing checksums url: {url}, yet '
'`force_checksums_validation=True`. '
'Did you forget to register checksums?'
)

if (
expected_url_info
and computed_url_info
and expected_url_info != computed_url_info
):
msg = (
f'Artifact {url}, downloaded to {path}, has wrong checksum:\n'
f'* Expected: {expected_url_info}\n'
f'* Got: {computed_url_info}\n'
'To debug, see: '
'https://www.tensorflow.org/datasets/overview#fixing_nonmatchingchecksumerror'
)
raise NonMatchingChecksumError(msg)

def _rename_and_get_final_dl_path(
self,
url: str,
Expand Down Expand Up @@ -707,61 +757,6 @@ def manual_dir(self) -> epath.Path:
return self._manual_dir


def _get_manually_downloaded_path(
manual_dir: epath.Path | None,
expected_url_info: checksums.UrlInfo | None,
) -> epath.Path | None:
"""Checks if file is already downloaded in manual_dir."""
if not manual_dir: # Manual dir not passed
return None

if not expected_url_info or not expected_url_info.filename:
return None # Filename unknown.

manual_path = manual_dir / expected_url_info.filename
if not manual_path.exists(): # File not manually downloaded
return None

return manual_path


def _validate_checksums(
url: str,
path: epath.Path,
computed_url_info: checksums.UrlInfo | None,
expected_url_info: checksums.UrlInfo | None,
force_checksums_validation: bool,
) -> None:
"""Validate computed_url_info match expected_url_info."""
# If force-checksums validations, both expected and computed url_info
# should exists
if force_checksums_validation:
# Checksum of the downloaded file unknown (for manually downloaded file)
if not computed_url_info:
computed_url_info = checksums.compute_url_info(path)
# Checksums have not been registered
if not expected_url_info:
raise ValueError(
f'Missing checksums url: {url}, yet '
'`force_checksums_validation=True`. '
'Did you forget to register checksums?'
)

if (
expected_url_info
and computed_url_info
and expected_url_info != computed_url_info
):
msg = (
f'Artifact {url}, downloaded to {path}, has wrong checksum:\n'
f'* Expected: {expected_url_info}\n'
f'* Got: {computed_url_info}\n'
'To debug, see: '
'https://www.tensorflow.org/datasets/overview#fixing_nonmatchingchecksumerror'
)
raise NonMatchingChecksumError(msg)


def _map_promise(map_fn, all_inputs):
"""Map the function into each element and resolve the promise."""
all_promises = tree.map_structure(map_fn, all_inputs) # Apply the function
Expand Down
6 changes: 2 additions & 4 deletions tensorflow_datasets/core/download/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def increase_tqdm(self, dl_result: DownloadResult) -> None:
self._pbar_dl_size.update(dl_result.url_info.size)

def download(
self, url: str, destination_path: str, verify: bool = True
self, url: str, destination_path: epath.Path, verify: bool = True
) -> promise.Promise[concurrent.futures.Future[DownloadResult]]:
"""Download url to given path.
Expand All @@ -239,7 +239,6 @@ def download(
Returns:
Promise obj -> Download result.
"""
destination_path = os.fspath(destination_path)
self._pbar_url.update_total(1)
future = self._executor.submit(
self._sync_download, url, destination_path, verify
Expand All @@ -264,7 +263,7 @@ def _sync_file_copy(
return DownloadResult(path=out_path, url_info=url_info)

def _sync_download(
self, url: str, destination_path: str, verify: bool = True
self, url: str, destination_path: epath.Path, verify: bool = True
) -> DownloadResult:
"""Synchronous version of `download` method.
Expand All @@ -284,7 +283,6 @@ def _sync_download(
Raises:
DownloadError: when download fails.
"""
destination_path = epath.Path(destination_path)
try:
# If url is on a filesystem that gfile understands, use copy. Otherwise,
# use requests (http) or urllib (ftp).
Expand Down
35 changes: 15 additions & 20 deletions tensorflow_datasets/core/download/downloader_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for downloader."""

import hashlib
import io
import os
import tempfile
from typing import Optional
from unittest import mock

from etils import epath
import pytest
import tensorflow as tf
from tensorflow_datasets import testing
from tensorflow_datasets.core.download import downloader
from tensorflow_datasets.core.download import resource as resource_lib
Expand Down Expand Up @@ -59,11 +55,13 @@ def setUp(self):
super(DownloaderTest, self).setUp()
self.addCleanup(mock.patch.stopall)
self.downloader = downloader.get_downloader(10, hashlib.sha256)
self.tmp_dir = tempfile.mkdtemp(dir=tf.compat.v1.test.get_temp_dir())
self.tmp_dir = epath.Path(self.tmp_dir)
self.url = 'http://example.com/foo.tar.gz'
self.resource = resource_lib.Resource(url=self.url)
self.path = os.path.join(self.tmp_dir, 'foo.tar.gz')
self.incomplete_path = '%s.incomplete' % self.path
self.path = self.tmp_dir / 'foo.tar.gz'
self.incomplete_path = self.path.with_suffix(
self.path.suffix + '.incomplete'
)
self.response = b'This \nis an \nawesome\n response!'
self.resp_checksum = hashlib.sha256(self.response).hexdigest()
self.cookies = {}
Expand All @@ -84,22 +82,20 @@ def test_ok(self):
promise = self.downloader.download(self.url, self.tmp_dir)
future = promise.get()
url_info = future.url_info
self.assertEqual(self.path, os.fspath(future.path))
self.assertEqual(self.path, future.path)
self.assertEqual(url_info.checksum, self.resp_checksum)
with tf.io.gfile.GFile(self.path, 'rb') as result:
self.assertEqual(result.read(), self.response)
self.assertFalse(tf.io.gfile.exists(self.incomplete_path))
self.assertEqual(self.path.read_bytes(), self.response)
self.assertFalse(self.incomplete_path.exists())

def test_drive_no_cookies(self):
url = 'https://drive.google.com/uc?export=download&id=a1b2bc3'
promise = self.downloader.download(url, self.tmp_dir)
future = promise.get()
url_info = future.url_info
self.assertEqual(self.path, os.fspath(future.path))
self.assertEqual(self.path, future.path)
self.assertEqual(url_info.checksum, self.resp_checksum)
with tf.io.gfile.GFile(self.path, 'rb') as result:
self.assertEqual(result.read(), self.response)
self.assertFalse(tf.io.gfile.exists(self.incomplete_path))
self.assertEqual(self.path.read_bytes(), self.response)
self.assertFalse(self.incomplete_path.exists())

def test_drive(self):
self.cookies = {'foo': 'bar', 'download_warning_a': 'token', 'a': 'b'}
Expand Down Expand Up @@ -129,11 +125,10 @@ def test_ftp(self):
promise = self.downloader.download(url, self.tmp_dir)
future = promise.get()
url_info = future.url_info
self.assertEqual(self.path, os.fspath(future.path))
self.assertEqual(self.path, future.path)
self.assertEqual(url_info.checksum, self.resp_checksum)
with tf.io.gfile.GFile(self.path, 'rb') as result:
self.assertEqual(result.read(), self.response)
self.assertFalse(tf.io.gfile.exists(self.incomplete_path))
self.assertEqual(self.path.read_bytes(), self.response)
self.assertFalse(self.incomplete_path.exists())

def test_ftp_error(self):
error = downloader.urllib.error.URLError('Problem serving file.')
Expand Down

0 comments on commit fc31737

Please sign in to comment.