Source code for trx.fetcher
# -*- coding: utf-8 -*-
"""Test data management for downloading and verifying test assets."""
import hashlib
import logging
import os
import shutil
import urllib.request
[docs]
TEST_DATA_REPO = "tee-ar-ex/trx-test-data"
[docs]
TEST_DATA_TAG = "v0.1.0"
# GitHub release API entrypoint for metadata (asset list, sizes, etc.).
[docs]
TEST_DATA_API_URL = (
f"https://api.github.com/repos/{TEST_DATA_REPO}/releases/tags/{TEST_DATA_TAG}"
)
# Direct download base for release assets.
[docs]
TEST_DATA_BASE_URL = (
f"https://github.com/{TEST_DATA_REPO}/releases/download/{TEST_DATA_TAG}"
)
[docs]
def get_home():
"""Return a user-writeable file-system location to put files.
Returns
-------
str
Path to the TRX home directory.
"""
if "TRX_HOME" in os.environ:
trx_home = os.environ["TRX_HOME"]
else:
trx_home = os.path.join(os.path.expanduser("~"), ".tee_ar_ex")
return trx_home
[docs]
def get_testing_files_dict():
"""Return dictionary linking zip file to their GitHub release URL and checksums.
Assets are hosted under the v0.1.0 release of tee-ar-ex/trx-test-data.
If URLs change, check TEST_DATA_API_URL to discover the latest asset
locations.
Returns
-------
dict
Mapping of filenames to (url, md5, sha256) tuples.
"""
return {
"DSI.zip": (
f"{TEST_DATA_BASE_URL}/DSI.zip",
"b847f053fc694d55d935c0be0e5268f7", # md5
"1b09ce8b4b47b2600336c558fdba7051218296e8440e737364f2c4b8ebae666c",
),
"memmap_test_data.zip": (
f"{TEST_DATA_BASE_URL}/memmap_test_data.zip",
"03f7651a0f9e3eeabee9aed0ad5f69e1", # md5
"98ba89d7a9a7baa2d37956a0a591dce9bb4581bd01296ad5a596706ee90a52ef",
),
"trx_from_scratch.zip": (
f"{TEST_DATA_BASE_URL}/trx_from_scratch.zip",
"d9f220a095ce7f027772fcd9451a2ee5", # md5
"f98ab6da6a6065527fde4b0b6aa40f07583e925d952182e9bbd0febd55c0f6b2",
),
"gold_standard.zip": (
f"{TEST_DATA_BASE_URL}/gold_standard.zip",
"57e3f9951fe77245684ede8688af3ae8", # md5
"35a0b633560cc2b0d8ecda885aa72d06385499e0cd1ca11a956b0904c3358f01",
),
}
[docs]
def md5sum(filename):
"""Compute the MD5 checksum of a file.
Parameters
----------
filename : str
Path to file to hash.
Returns
-------
str
Hexadecimal MD5 digest.
"""
h = hashlib.md5()
with open(filename, "rb") as f:
for chunk in iter(lambda: f.read(128 * h.block_size), b""):
h.update(chunk)
return h.hexdigest()
[docs]
def sha256sum(filename):
"""Compute the SHA256 checksum of a file.
Parameters
----------
filename : str
Path to file to hash.
Returns
-------
str
Hexadecimal SHA256 digest.
"""
h = hashlib.sha256()
with open(filename, "rb") as f:
for chunk in iter(lambda: f.read(128 * h.block_size), b""):
h.update(chunk)
return h.hexdigest()
[docs]
def fetch_data(files_dict, keys=None): # noqa: C901
"""Download files to folder and check their md5 checksums.
Parameters
----------
files_dict : dict
For each file in `files_dict` the value should be (url, md5).
The file will be downloaded from url, if the file does not already
exist or if the file exists but the md5 checksum does not match.
Zip files are automatically unzipped and its contents are md5 checked.
keys : list of str or str or None, optional
Subset of keys from ``files_dict`` to download. When None, all
keys are downloaded.
Raises
------
ValueError
Raises if the md5 checksum of the file does not match the expected
value. The downloaded file is not deleted when this error is raised.
"""
trx_home = get_home()
if not os.path.exists(trx_home):
os.makedirs(trx_home)
if keys is None:
keys = files_dict.keys()
elif isinstance(keys, str):
keys = [keys]
for f in keys:
file_entry = files_dict[f]
if len(file_entry) == 2:
url, expected_md5 = file_entry
expected_sha = None
else:
url, expected_md5, expected_sha = file_entry
full_path = os.path.join(trx_home, f)
logging.info("Downloading {} to {}".format(f, trx_home))
if not os.path.exists(full_path):
urllib.request.urlretrieve(url, full_path)
actual_md5 = md5sum(full_path)
if expected_md5 != actual_md5:
raise ValueError(
f"Md5sum for {f} does not match. "
"Please remove the file to download it again: " + full_path
)
if expected_sha is not None:
actual_sha = sha256sum(full_path)
if expected_sha != actual_sha:
raise ValueError(
f"SHA256 for {f} does not match. "
"Please remove the file to download it again: " + full_path
)
if f.endswith(".zip"):
dst_dir = os.path.join(trx_home, f[:-4])
shutil.unpack_archive(full_path, extract_dir=dst_dir, format="zip")