Source code for pydatatask.proc_manager

"""
A process manager can be used with a ProcessTask to specify where and how the processes should run.
All a process manager needs to specify is how to launch a process and manipulate the filesystem, and the ProcessTask
will set up an appropriate environment for running the task and retrieve the results using this interface.
"""

from typing import (
    TYPE_CHECKING,
    AsyncContextManager,
    Callable,
    Dict,
    List,
    Literal,
    Optional,
    Set,
    Union,
)
from abc import abstractmethod
from pathlib import Path
import asyncio
import os
import shlex
import signal

import aiofiles
import aioshutil
import asyncssh
import psutil

from .task import STDOUT, _StderrIsStdout

if TYPE_CHECKING:
    from .repository import AReadStream, AReadText, AWriteStream, AWriteText

__all__ = ("AbstractProcessManager", "LocalLinuxManager", "SSHLinuxManager")


[docs]class AbstractProcessManager: """ The base class for process managers. Processes are managed through arbitrary "process identifier" handle strings. """
[docs] @abstractmethod async def get_live_pids(self, hint: Set[str]) -> Set[str]: """ Get a set of the live process identifiers. This may include processes which are not part of the current app/task, but MUST include all the processes in ``hint`` which are still alive. """ raise NotImplementedError
[docs] @abstractmethod async def spawn( self, args: List[str], environ: Dict[str, str], cwd: str, return_code: str, stdin: Optional[str], stdout: Optional[str], stderr: Optional[Union[str, _StderrIsStdout]], ) -> str: """ Launch a process on the target system. This function MUST NOT wait until the process has terminated before returning. :param args: The command line of the process to launch. ``args[0]`` is the executable to run. :param environ: A set of environment variables to add to the target process' environment. :param cwd: The directory to launch the process in. :param return_code: A filepath in which to store the process exit code, as an ascii integer. :param stdin: A filepath from which to read the process' stdin, or None for a null file. :param stdout: A filepath to which to write the process' stdout, or None for a null file. :param stderr: A filepath to which to write the process' stderr, None for a null file, or the constant `pydatatask.task.STDOUT` to interleave it into the stdout stream. :return: The process identifier of the process spawned. """ raise NotImplementedError
[docs] @abstractmethod async def kill(self, pid: str): """ Terminate the process with the given identifier. This should do nothing if the process is not currently running. This should not prevent e.g. the ``return_code`` file from being populated. """ raise NotImplementedError
@property @abstractmethod def basedir(self) -> Path: """ A path on the target system to a directory which can be freely manipulated by the app. """ raise NotImplementedError
[docs] @abstractmethod async def open( self, path: Path, mode: Literal["r", "rb", "w", "wb"] ) -> Union["AReadText", "AWriteText", "AReadStream", "AWriteStream"]: """ Open the file on the target system for reading or writing according to the given mode. """ raise NotImplementedError
[docs] @abstractmethod async def mkdir(self, path: Path): # exist_ok=True, parents=True """ Create a directory with the given path on the target system. This should have the semantics of ``mkdir -p``, i.e. create any necessary parent directories and also succeed if the directory already exists. """ raise NotImplementedError
[docs] @abstractmethod async def rmtree(self, path: Path): """ Remove a directory and all its children with the given path on the target system. """ raise NotImplementedError
[docs]class LocalLinuxManager(AbstractProcessManager): """ A process manager to run tasks on the local linux machine. By default, it will create a directory ``/tmp/pydatatask`` in which to store data. """ def __init__(self, app: str, local_path: Union[Path, str] = "/tmp/pydatatask"): self.local_path = Path(local_path) / app @property def basedir(self) -> Path: return self.local_path
[docs] async def get_live_pids(self, hint: Set[str]) -> Set[str]: return {str(x) for x in psutil.pids()}
[docs] async def spawn(self, args, environ, cwd, return_code, stdin, stdout, stderr): if stdin is None: stdin = "/dev/null" else: stdin = shlex.quote(stdin) if stdout is None: stdout = "/dev/null" else: stdout = shlex.quote(stdout) if stderr is None: stderr = "/dev/null" elif stderr is STDOUT: stderr = "&1" else: stderr = shlex.quote(stderr) if "/" in args[0]: os.chmod(args[0], 0o755) return_code = shlex.quote(return_code) p = await asyncio.create_subprocess_shell( f""" {shlex.join(args)} <{stdin} >{stdout} 2>{stderr} & PID=$! echo $PID wait $PID >/dev/null 2>/dev/null echo $? >{return_code} 2>/dev/null """, stdout=asyncio.subprocess.PIPE, close_fds=True, cwd=cwd, env=environ, ) pid = await p.stdout.readline() return pid.strip()
[docs] async def kill(self, pid: str): os.kill(int(pid), signal.SIGKILL)
[docs] async def mkdir(self, path: Path): path.mkdir(parents=True, exist_ok=True)
[docs] async def rmtree(self, path: Path): await aioshutil.rmtree(path)
[docs] async def open(self, path, mode): return aiofiles.open(path, mode)
class SSHLinuxFile: """ A file returned by `SSHLinuxManager.open`. Probably don't instantiate this directly. """ def __init__(self, path: Union[Path, str], mode: Literal["r", "w", "rb", "wb"], ssh: asyncssh.SSHClientConnection): self.path = Path(path) self.mode = mode self.ssh = ssh self.sftp_mgr: Optional[AsyncContextManager] = None self.sftp: Optional[asyncssh.SFTPClient] = None self.fp_mgr: Optional[AsyncContextManager] = None self.fp: Optional[asyncssh.SFTPClientFile] = None async def __aenter__(self) -> asyncssh.SFTPClientFile: self.sftp_mgr = self.ssh.start_sftp_client() self.sftp = await self.sftp_mgr.__aenter__() self.fp_mgr = self.sftp.open(self.path, self.mode) self.fp = await self.fp_mgr.__aenter__() return self.fp async def __aexit__(self, exc_type, exc_val, exc_tb): await self.fp_mgr.__aexit__(exc_type, exc_val, exc_tb) await self.sftp_mgr.__aexit__(exc_type, exc_val, exc_tb)
[docs]class SSHLinuxManager(AbstractProcessManager): """ A process manager that runs its processes on a remote linux machine, accessed over SSH. The SSH connection is parameterized by an :external:class:`asyncssh.connection.SSHClientConnection` instance. Sample usage: .. code:: python session = pydatatask.Session() @session.resource async def ssh(): async with asyncssh.connect( "localhost", port=self.port, username="weh", password="weh", known_hosts=None ) as s: yield s @session.resource async def procman(): yield pydatatask.SSHLinuxManager(self.test_id, ssh) task = pydatatask.ProcessTask("mytask", procman, ...) """ def __init__( self, app: str, ssh: Callable[[], asyncssh.SSHClientConnection], remote_path: Union[Path, str] = "/tmp/pydatatask", ): self.remote_path = Path(remote_path) / app self._ssh = ssh @property def ssh(self) -> asyncssh.SSHClientConnection: """ The `asyncssh.SSHClientConnection` instance associated. Will fail if the connection is provided by an unopened Session. """ return self._ssh() @property def basedir(self) -> Path: return self.remote_path
[docs] async def open(self, path, mode): return SSHLinuxFile(path, mode, self.ssh)
[docs] async def rmtree(self, path: Path): async with self.ssh.start_sftp_client() as sftp: await sftp.rmtree(path, ignore_errors=True)
[docs] async def mkdir(self, path: Path): async with self.ssh.start_sftp_client() as sftp: await sftp.makedirs(path, exist_ok=True)
[docs] async def kill(self, pid: str): await self.ssh.run(f"kill -9 {pid}")
[docs] async def get_live_pids(self, hint): p = await self.ssh.run("ls /proc") return {x for x in p.stdout.split() if x.isdigit()}
[docs] async def spawn(self, args, environ, cwd, return_code, stdin, stdout, stderr): if stdin is None: stdin = "/dev/null" else: stdin = shlex.quote(stdin) if stdout is None: stdout = "/dev/null" else: stdout = shlex.quote(stdout) if stderr is None: stderr = "/dev/null" elif stderr is STDOUT: stderr = "&1" else: stderr = shlex.quote(stderr) if "/" in args[0]: async with self.ssh.start_sftp_client() as sftp: await sftp.chmod(args[0], 0o755) p = await self.ssh.create_process( f""" cd {shlex.quote(cwd)} {shlex.join(args)} <{stdin} >{stdout} 2>{stderr} & PID=$! echo $PID wait $PID >/dev/null 2>/dev/null echo $? >{shlex.quote(return_code)} 2>/dev/null """, env=environ, ) pid = await p.stdout.readline() return pid.strip()