"""
Repositories are arbitrary key-value stores. They are the data part of pydatatask.
You can store your data in any way you desire and as long as you can write a Repository class to describe it, it can be
used to drive a pipeline.
The notion of the "value" part of the key-value store abstraction is defined very, very loosely. The repository base
class doesn't have an interface to get or store values, only to query for and delete keys. Instead, you have to know
which repository subclass you're working with, and use its interfaces. For example, `MetadataRepository` assumes that
its values are structured objects and loads them fully into memory, and `BlobRepository` provides a streaming interface
to a flat address space.
"""
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
AsyncIterable,
Awaitable,
Callable,
Coroutine,
Dict,
List,
Literal,
Optional,
overload,
)
from abc import ABC, abstractmethod
from collections import Counter
from pathlib import Path
import base64
import hashlib
import inspect
import io
import logging
import os
import string
from kubernetes_asyncio.client import V1Pod
from types_aiobotocore_s3.client import S3Client
import aiofiles.os
import aiohttp.client_exceptions
import aioshutil
import botocore.exceptions
import docker_registry_client_async
import dxf
import motor.motor_asyncio
import yaml
from .utils import AReadStream, AReadText, AWriteStream, AWriteText, roundrobin
if TYPE_CHECKING:
from .task import ExecutorTask, KubeTask
l = logging.getLogger(__name__)
__all__ = (
"Repository",
"BlobRepository",
"MetadataRepository",
"FileRepositoryBase",
"FileRepository",
"DirectoryRepository",
"S3BucketRepository",
"S3BucketInfo",
"MongoMetadataRepository",
"InProcessMetadataRepository",
"InProcessBlobStream",
"InProcessBlobRepository",
"DockerRepository",
"LiveKubeRepository",
"ExecutorLiveRepo",
"AggregateOrRepository",
"AggregateAndRepository",
"BlockingRepository",
"YamlMetadataRepository",
"YamlMetadataFileRepository",
"YamlMetadataS3Repository",
"RelatedItemRepository",
)
def job_getter(f):
"""
Use this function to annotate non-abstract methods which take a job identifier as their first parameter. This is
used by RelatedItemRepository to automatically translate job identifiers to related ones.
"""
if not inspect.iscoroutinefunction(f):
raise TypeError("only async functions can be job_getters")
f.is_job_getter = True
return f
[docs]class Repository(ABC):
"""
A repository is a key-value store where the keys are names of jobs. Since the values have unspecified semantics, the
only operations you can do on a generic repository are query for keys.
A repository can be async-iterated to get a listing of its members.
"""
CHARSET = CHARSET_START_END = string.ascii_letters + string.digits
[docs] @classmethod
def is_valid_job_id(cls, job: str):
"""
Determine whether the given job identifier is valid, i.e. that it contains only valid characters
(numbers and letters by default).
"""
return (
0 < len(job) < 64
and all(c in cls.CHARSET for c in job)
and job[0] in cls.CHARSET_START_END
and job[-1] in cls.CHARSET_START_END
)
[docs] async def filter_jobs(self, iterator: AsyncIterable[str]) -> AsyncIterable[str]:
"""
Apply `is_valid_job_id` as a filter to an async iterator.
"""
async for job in iterator:
if self.is_valid_job_id(job):
yield job
else:
l.warning("Skipping %s %s - not a valid job id", self, repr(job))
[docs] async def contains(self, item):
"""
Determine whether the given job identifier is present in this repository.
The default implementation is quite inefficient; please override this if possible.
"""
async for x in self:
if x == item:
return True
return False
def __aiter__(self):
return self.filter_jobs(self.unfiltered_iter())
[docs] @abstractmethod
async def unfiltered_iter(self) -> AsyncGenerator[str, None]:
"""
The core method of Repository. Implement this to produce an iterable of every string which could potentially
be a job identifier present in this repository. When the repository is iterated directly, this will be filtered
by `filter_jobs`.
"""
raise NotImplementedError
# noinspection PyUnreachableCode
yield None # pylint: disable=unreachable
[docs] @abstractmethod
async def info(self, job) -> Any:
"""
Returns an arbitrary piece of data related to job. Notably, this is used during templating.
This should do something meaningful even if the repository does not contain the requested job.
"""
raise NotImplementedError
[docs] @abstractmethod
async def delete(self, job):
"""
Delete the given job from the repository. This should succeed even if the job is not present in this repository.
"""
raise NotImplementedError
[docs] async def info_all(self) -> Dict[str, Any]:
"""
Produce a mapping from every job present in the repository to its corresponding info. The default implementation
is somewhat inefficient; please override it if there is a more effective way to load all info.
"""
return {job: await self.info(job) async for job in self}
[docs] async def validate(self):
"""
Override this method to raise an exception if for any reason the repository is misconfigured. This will be
automatically called by the pipeline on opening.
"""
[docs] def map(
self, func: Callable, filt: Optional[Callable[[str], Awaitable[bool]]] = None, allow_deletes=False
) -> "MapRepository":
"""
Generate a :class:`MapRepository` based on this repository and the given parameters.
"""
return MapRepository(self, func, filt, allow_deletes=allow_deletes)
class MapRepository(Repository):
"""
A MapRepository is a repository which uses arbitrary functions to map and filter results from a base repository.
"""
def __init__(
self,
base: Repository,
func: Callable[[Any], Coroutine[None, None, Any]],
filt: Optional[Callable[[str], Awaitable[bool]]] = None,
allow_deletes=False,
):
"""
:param func: The function to use to translate the base repository's `info` results into the mapped `info`
results.
:param filt: Optional: An async function to use to determine whether a given key should be considered part of
the mapped repository.
:param allow_deletes: Whether the delete operation will do anything on the mapped repository.
"""
self.base = base
self.func = func
self.filter = filt
self.allow_deletes = allow_deletes
async def contains(self, item):
if self.filter is None or await self.filter(item):
return await self.base.contains(item)
return False
async def delete(self, job):
if self.allow_deletes:
await self.base.delete(job)
async def unfiltered_iter(self):
async for item in self.base.unfiltered_iter():
if self.filter is None or await self.filter(item):
yield item
async def info(self, job):
return await self.func(await self.base.info(job))
async def info_all(self) -> Dict[str, Any]:
result = await self.base.info_all()
to_remove = []
for k, v in result.items():
if self.filter is None or await self.filter(k):
result[k] = await self.func(v)
else:
to_remove.append(k)
for k in to_remove:
result.pop(k)
return result
[docs]class BlobRepository(Repository, ABC):
"""
A blob repository has values which are flat data blobs that can be streamed for reading or writing.
"""
@overload
async def open(self, job: str, mode: Literal["r"]) -> AReadText:
...
@overload
async def open(self, job: str, mode: Literal["rb"]) -> AReadStream:
...
@overload
async def open(self, job: str, mode: Literal["w"]) -> AWriteText:
...
@overload
async def open(self, job: str, mode: Literal["wb"]) -> AWriteStream:
...
[docs] @abstractmethod
async def open(self, job, mode="r"):
"""
Open the given job's value as a stream for reading or writing, in text or binary mode.
"""
raise NotImplementedError
[docs]class FileRepositoryBase(Repository, ABC):
"""
A file repository is a local directory where each job identifier is a filename, optionally suffixed with an
extension before hitting the filesystem. This is an abstract base class for other file repositories which have more
to say about what is found at these filepaths.
"""
def __init__(self, basedir, extension="", case_insensitive=False):
self.basedir = Path(basedir)
self.extension = extension
self.case_insensitive = case_insensitive
[docs] async def contains(self, item):
return await aiofiles.os.path.exists(self.basedir / (item + self.extension))
def __repr__(self):
return f'<{type(self).__name__} {self.basedir / ("*" + self.extension)}>'
[docs] async def unfiltered_iter(self):
for name in await aiofiles.os.listdir(self.basedir):
if self.case_insensitive:
cond = name.lower().endswith(self.extension.lower())
else:
cond = name.endswith(self.extension)
if cond:
yield name[: -len(self.extension) if self.extension else None]
[docs] async def validate(self):
self.basedir.mkdir(exist_ok=True, parents=True)
if not os.access(self.basedir, os.W_OK):
raise PermissionError(f"Cannot write to {self.basedir}")
[docs] def fullpath(self, job) -> Path:
"""
Construct the full local path of the file corresponding to ``job``.
"""
return self.basedir / (job + self.extension)
[docs] @job_getter
async def info(self, job):
"""
The templating info provided by a file repository is the full path to the corresponding file as a string.
"""
return str(self.fullpath(job))
[docs]class FileRepository(FileRepositoryBase, BlobRepository):
"""
A file repository whose members are files, treated as streamable blobs.
"""
[docs] @job_getter
async def open(self, job, mode="r"):
if not self.is_valid_job_id(job):
raise KeyError(job)
return aiofiles.open(self.fullpath(job), mode)
[docs] async def delete(self, job):
try:
await aiofiles.os.unlink(self.fullpath(job))
except FileNotFoundError:
pass
[docs]class DirectoryRepository(FileRepositoryBase):
"""
A file repository whose members are directories.
"""
def __init__(self, *args, discard_empty=False, **kwargs):
"""
:param discard_empty: Whether only directories containing at least one member should be considered as "present"
in the repository.
"""
super().__init__(*args, **kwargs)
self.discard_empty = discard_empty
[docs] @job_getter
async def mkdir(self, job):
"""
Create an empty directory corresponding to ``job``. Do nothing if the directory already exists.
"""
try:
await aiofiles.os.mkdir(self.fullpath(job))
except FileExistsError:
pass
[docs] async def delete(self, job):
if await self.contains(job):
await aioshutil.rmtree(self.fullpath(job))
[docs] async def contains(self, item):
result = await super().contains(item)
if not self.discard_empty:
return result
if not result:
return False
return bool(list(await aiofiles.os.listdir(self.fullpath(item))))
[docs] async def unfiltered_iter(self):
async for item in super().unfiltered_iter():
if self.discard_empty:
if bool(list(await aiofiles.os.listdir(self.fullpath(item)))):
yield item
else:
yield item
class S3BucketBinaryWriter:
"""
A class for streaming (or buffering) byte data to be written to an `S3BucketRepository`.
"""
def __init__(self, repo: "S3BucketRepository", job: str):
self.repo = repo
self.job = job
self.buffer = io.BytesIO()
super().__init__()
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
async def close(self):
"""
Close and flush the data to the bucket.
"""
self.buffer.seek(0, io.SEEK_END)
size = self.buffer.tell()
self.buffer.seek(0, io.SEEK_SET)
await self.repo.client.put_object(
Bucket=self.repo.bucket,
Key=self.repo.object_name(self.job),
Body=self.buffer,
ContentLength=size,
ContentType=self.repo.mimetype,
)
async def write(self, data: bytes):
"""
Write some data to the stream.
"""
self.buffer.write(data)
class S3BucketReader:
"""
A class for streaming byte data from an `S3BucketRepository`.
"""
def __init__(self, body):
self.body = body
async def close(self):
"""
Close and release the stream.
"""
self.body.close()
async def read(self, n=None): # pylint: disable=unused-argument :(
"""
Read the entire body of the blob. Due to API limitations, we can't read less than that at once...
"""
return await self.body.read()
async def __aenter__(self):
await self.body.__aenter__()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.body.__aexit__(exc_type, exc_val, exc_tb)
[docs]class S3BucketInfo:
"""
The data structure returned from :meth:`S3BucketRepository.info`.
:ivar uri: The s3 URI of the current job's resource, e.g. ``s3://bucket/prefix/job.ext``. ``str(info)`` will also
return this.
:ivar endpoint: The URL of the API server providing the S3 interface.
:ivar bucket: The name of the bucket objects are stored in.
:ivar prefix: How to prefix an object name such that it will fit into this repository.
:ivar suffix: How to suffix an object name such that it will fit into this repository.
"""
def __init__(self, endpoint: str, uri: str, bucket: str, prefix: str, suffix: str):
self.endpoint = endpoint
self.uri = uri
self.prefix = prefix
self.suffix = suffix
self.bucket = bucket
def __str__(self):
return self.uri
[docs]class S3BucketRepository(BlobRepository):
"""
A repository where keys are paths in a S3 bucket. Provides a streaming interface to the corresponding blobs.
"""
def __init__(
self,
client: Callable[[], S3Client],
bucket: str,
prefix: str = "",
suffix: str = "",
mimetype: str = "application/octet-stream",
incluster_endpoint: Optional[str] = None,
):
"""
:param client: A callable returning an aiobotocore S3 client connected and authenticated to the server you wish
to store things on.
:param bucket: The name of the bucket from which to load and store.
:param prefix: A prefix to put on the job name before translating it into a bucket path. If this is meant to be
a directory name it should end with a slash character.
:param suffix: A suffix to put on the job name before translating it into a bucket path. If this is meant to
be a file extension it should start with a dot.
:param mimetype: The MIME type to set the content when adding data.
:param incluster_endpoint: Optional: An endpoint URL to provide as the result of info() queries instead of
extracting the URL from ``client``.
"""
self._client = client
self.bucket = bucket
self.prefix = prefix
self.suffix = suffix
self.mimetype = mimetype
self.incluster_endpoint = incluster_endpoint
@property
def client(self):
"""
The aiobotocore S3 client. This will raise an error if the client comes from a session which is not opened.
"""
return self._client()
def __repr__(self):
return f"<{type(self).__name__} {self.bucket}/{self.prefix}*{self.suffix}>"
[docs] async def contains(self, item):
try:
await self.client.head_object(Bucket=self.bucket, Key=self.object_name(item))
except botocore.exceptions.ClientError:
return False
else:
return True
[docs] async def unfiltered_iter(self):
paginator = self.client.get_paginator("list_objects")
async for page in paginator.paginate(Bucket=self.bucket, Prefix=self.prefix):
for obj in page.get("Contents", []):
if obj["Key"].endswith(self.suffix):
yield obj["Key"][len(self.prefix) : -len(self.suffix) if self.suffix else None]
[docs] async def validate(self):
try:
await self.client.head_bucket(Bucket=self.bucket)
except botocore.exceptions.ClientError as e:
if "404" in str(e):
await self.client.create_bucket(Bucket=self.bucket)
else:
raise
[docs] def object_name(self, job):
"""
Return the object name for the given job.
"""
return f"{self.prefix}{job}{self.suffix}"
[docs] @job_getter
async def open(self, job, mode="r"):
if not self.is_valid_job_id(job):
raise KeyError(job)
if mode == "wb":
return S3BucketBinaryWriter(self, job)
elif mode == "w":
return AWriteText(S3BucketBinaryWriter(self, job))
elif mode == "rb":
return S3BucketReader((await self.client.get_object(Bucket=self.bucket, Key=self.object_name(job)))["Body"])
elif mode == "r":
return AReadText(
S3BucketReader((await self.client.get_object(Bucket=self.bucket, Key=self.object_name(job)))["Body"])
)
else:
raise ValueError(mode)
[docs] @job_getter
async def info(self, job):
"""
Return an `S3BucketInfo` corresponding to the given job.
"""
return S3BucketInfo(
self.incluster_endpoint or self.client._endpoint.host,
f"s3://{self.bucket}/{self.object_name(job)}",
self.bucket,
self.prefix,
self.suffix,
)
[docs] async def delete(self, job):
await self.client.delete_object(Bucket=self.bucket, Key=self.object_name(job))
[docs]class DockerRepository(Repository):
"""
A docker repository is, well, an actual docker repository hosted in some registry somewhere. Keys translate to tags
on this repository.
"""
def __init__(
self,
registry: Callable[[], docker_registry_client_async.DockerRegistryClientAsync],
domain: str,
repository: str,
):
"""
:param registry: A callable returning a
`docker_registry_client_async <https://pypi.org/project/docker-registry-client-async/>`_
client object with appropriate authentication information.
:param domain: The registry domain to connect to, e.g. ``index.docker.io``.
:param repository: The repository to store images in within the domain, e.g. ``myname/myrepo``.
"""
self._registry = registry
self.domain = domain
self.repository = repository
@property
def registry(self) -> docker_registry_client_async.DockerRegistryClientAsync:
"""
The ``docker_registry_client_async`` client object. If this is provided by an unopened session, raise an error.
"""
return self._registry()
[docs] async def unfiltered_iter(self):
try:
image = docker_registry_client_async.ImageName(self.repository, endpoint=self.domain)
tags = (await self.registry.get_tags(image)).tags["tags"]
if tags is None:
return
for tag in tags:
yield tag
except aiohttp.client_exceptions.ClientResponseError as e:
if e.status != 404:
raise
def __repr__(self):
return f"<DockerRepository {self.domain}/{self.repository}:*>"
[docs] @job_getter
async def info(self, job):
"""
The info provided by a docker repository is a dict with two keys, "withdomain" and "withoutdomain". e.g.:
.. code::
{ "withdomain": "docker.example.com/myname/myrepo:job", "withoutdomain": "myname/myrepo:job" }
"""
return {
"withdomain": f"{self.domain}/{self.repository}:{job}",
"withoutdomain": f"{self.repository}:{job}",
}
def _dxf_auth(self, dxf_obj, response):
# what a fucking hack
for pattern, credentials in self.registry.credentials.items():
if pattern.fullmatch(self.domain):
result = credentials
break
else:
raise PermissionError("Missing credentials for %s" % self.domain)
if self.registry.ssl:
username, password = base64.b64decode(result).decode().split(":")
dxf_obj.authenticate(username, password, response)
else:
dxf_obj._headers = {"Authorization": "Basic " + result}
[docs] async def delete(self, job):
# if not await self.contains(job):
# return
self._delete_inner(job) # blocking! epic fail
def _delete_inner(self, job):
random_data = os.urandom(16)
random_digest = "sha256:" + hashlib.sha256(random_data).hexdigest()
d = dxf.DXF(
host=self.domain,
repo=self.repository,
auth=self._dxf_auth,
insecure=not self.registry.ssl,
)
d.push_blob(data=random_data, digest=random_digest)
d.set_alias(job, random_digest)
d.del_alias(job)
[docs]class LiveKubeRepository(Repository):
"""
A repository where keys translate to ``job`` labels on running kube pods. This repository is constructed
automatically by a `KubeTask` or subclass and is linked as the ``live`` repository. Do not construct this class
manually.
"""
def __init__(self, task: "KubeTask"):
self.task = task
[docs] async def unfiltered_iter(self):
for pod in await self.pods():
yield pod.metadata.labels["job"]
[docs] async def contains(self, item):
return bool(await self.task.podman.query(task=self.task.name, job=item))
def __repr__(self):
return f"<LiveKubeRepository task={self.task.name}>"
[docs] @job_getter
async def info(self, job):
"""
Cannot template with live kube info. Implement this if you have something in mind.
"""
return None
[docs] async def pods(self) -> List[V1Pod]:
"""
A list of live pod objects corresponding to this repository.
"""
return await self.task.podman.query(task=self.task.name)
[docs] async def delete(self, job):
"""
Deleting a job from this repository will delete the pod.
"""
pods = await self.task.podman.query(job=job, task=self.task.name)
for pod in pods: # there... really should be only one
await self.task.delete(pod)
# while await self.task.podman.query(job=job, task=self.task.name):
# await asyncio.sleep(0.2)
[docs]class AggregateAndRepository(Repository):
"""
A repository which is said to contain a job if all its children also contain that job
"""
def __init__(self, **children: Repository):
assert children
self.children = children
[docs] async def unfiltered_iter(self):
counting = Counter()
async for item in roundrobin([child.unfiltered_iter() for child in self.children.values()]):
counting[item] += 1
if counting[item] == len(self.children):
yield item
[docs] async def contains(self, item):
for child in self.children.values():
if not await child.contains(item):
return False
return True
[docs] @job_getter
async def info(self, job):
"""
The info provided by an aggregate And repository is a dict mapping each child's name to that child's info.
"""
return {name: await child.info(job) for name, child in self.children.items()}
[docs] async def delete(self, job):
"""
Deleting a job from an aggregate And repository deletes the job from all of its children.
"""
for child in self.children.values():
await child.delete(job)
[docs]class AggregateOrRepository(Repository):
"""
A repository which is said to contain a job if any of its children also contain that job
"""
def __init__(self, **children: Repository):
assert children
self.children = children
[docs] async def unfiltered_iter(self):
seen = set()
for child in self.children.values():
async for item in child.unfiltered_iter():
if item in seen:
continue
seen.add(item)
yield item
[docs] async def contains(self, item):
for child in self.children.values():
if await child.contains(item):
return True
return False
[docs] @job_getter
async def info(self, job):
"""
The info provided by an aggregate Or repository is a dict mapping each child's name to that child's info.
"""
return {name: await child.info(job) for name, child in self.children.items()}
[docs] async def delete(self, job):
"""
Deleting a job from an aggregate Or repository deletes the job from all of its children.
"""
for child in self.children.values():
await child.delete(job)
[docs]class BlockingRepository(Repository):
"""
A class that is said to contain a job if ``source`` contains it and ``unless`` does not contain it
"""
def __init__(self, source: Repository, unless: Repository, enumerate_unless=True):
self.source = source
self.unless = unless
self.enumerate_unless = enumerate_unless
[docs] async def unfiltered_iter(self):
if self.enumerate_unless:
blocked = set()
async for x in self.unless.unfiltered_iter():
blocked.add(x)
else:
blocked = None
async for item in self.source.unfiltered_iter():
if self.enumerate_unless and item in blocked:
continue
if not self.enumerate_unless and self.unless.contains(item):
continue
yield item
[docs] async def contains(self, item):
return await self.source.contains(item) and not await self.unless.contains(item)
[docs] @job_getter
async def info(self, job):
return await self.source.info(job)
[docs] async def delete(self, job):
await self.source.delete(job)
[docs]class ExecutorLiveRepo(Repository):
"""
A repository where keys translate to running jobs in an ExecutorTask. This repository is constructed automatically
and is linked as the ``live`` repository. Do not construct this class manually.
"""
def __init__(self, task: "ExecutorTask"):
self.task = task
def __repr__(self):
return f"<{type(self).__name__} task={self.task.name}>"
[docs] async def unfiltered_iter(self):
for job in self.task.rev_jobs:
yield job
[docs] async def contains(self, item):
return item in self.task.rev_jobs
[docs] async def delete(self, job):
"""
Deleting a job from the repository will cancel the corresponding task.
"""
await self.task.cancel(job)
[docs] async def info(self, job):
"""
There is no templating info for an `ExecutorLiveRepo`.
"""
return None
[docs]class InProcessBlobStream:
"""
A stream returned from an `BlobRepository.open` call from `InProcessBlobRepository`. Do not construct this manually.
"""
def __init__(self, repo: "InProcessBlobRepository", job: str): # pylint: disable=missing-function-docstring
self.repo = repo
self.job = job
self.data = io.BytesIO(repo.data.get(job, b""))
[docs] async def read(self, n: Optional[int] = None) -> bytes:
"""
Read up to ``n`` bytes from the stream.
"""
return self.data.read(n)
[docs] async def write(self, data: bytes):
"""
Write ``data`` to the stream.
"""
self.data.write(data)
[docs] async def close(self):
"""
Close and release the stream, syncing the data back to the repository.
"""
self.repo.data[self.job] = self.data.getvalue()
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
[docs]class InProcessBlobRepository(BlobRepository):
"""
An incredibly simple blob repository which stores all its values in a dict, and will let them vanish when the
process terminates.
"""
def __init__(self, data: Optional[Dict[str, bytes]] = None):
self.data = data if data is not None else {}
def __repr__(self):
return f"<{type(self).__name__}>"
[docs] @job_getter
async def info(self, job):
"""
There is no templating info for an `InProcessBlobRepository`.
"""
return None
[docs] @job_getter
async def open(self, job, mode="r"):
if not self.is_valid_job_id(job):
raise KeyError(job)
stream = InProcessBlobStream(self, job)
if mode == "r":
return AReadText(stream)
elif mode == "w":
return AWriteText(stream)
else:
return stream
[docs] async def unfiltered_iter(self):
for item in self.data:
yield item
[docs] async def contains(self, item):
return item in self.data
[docs] async def delete(self, job):
del self.data[job]