Source code for alab_management.task_view.task

"""Define the base class of task, which will be used for defining more tasks."""

import inspect
import time
from abc import ABC, abstractmethod
from inspect import getfullargspec
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional

import gridfs
from bson.objectid import ObjectId
from pydantic import BaseModel, ConfigDict, Field, model_validator

from alab_management.builders.samplebuilder import SampleBuilder
from alab_management.config import AlabOSConfig
from alab_management.task_view.task_enums import TaskPriority
from alab_management.utils.data_objects import get_db

if TYPE_CHECKING:
    from alab_management.builders.experimentbuilder import ExperimentBuilder
    from alab_management.device_view.device import BaseDevice
    from alab_management.lab_view import LabView


_UNSET = object()


[docs] class LargeResult(BaseModel): """ A Pydantic model for a large result (file >16 MB). Stored in either gridFS or other filesystems (Cloud AWS S3, etc.). """ # to avoid import AlabOSConfig at the top level storage_type: str = Field( default_factory=lambda: AlabOSConfig()["large_result_storage"][ "default_storage_type" ] ) # The path to the local file, used for uploading local_path: str | Path | None = None # The identifier of the file in the storage system, can be a path or a key (e.g., gridfs id) # Obtained after storing the file, used for retrieving identifier: str | ObjectId | None = None # alternative to local path, used for uploading, local path has higher priority file_like_data: Any | None = None model_config = ConfigDict(arbitrary_types_allowed=True) # if file_like_data is provided check if it has .read() method
[docs] @model_validator(mode="before") def check_file_like_data(cls, values): """Check if file_like_data has a .read() method.""" file_like_data = values.get("file_like_data") if file_like_data is not None and not hasattr(file_like_data, "read"): raise ValueError("file_like_data must have a .read() method") return values
[docs] @classmethod def from_local_file(cls, local_path: str | Path, storage_type: str = _UNSET): """ Create a LargeResult object from a local file and store it. If file is failed to be stored, will raise a ValueError. Args: local_path: the path to the local file storage_type: the storage type, default to the default storage type in the config Returns ------- LargeResult: the LargeResult object """ if storage_type is _UNSET: storage_type = AlabOSConfig()["large_result_storage"][ "default_storage_type" ] large_file = cls(local_path=local_path, storage_type=storage_type) large_file.store() return large_file
[docs] @classmethod def from_file_like_data(cls, file_like_data: Any, storage_type: str = _UNSET): """ Create a LargeResult object from a file-like object. File-like object must have a .read() method. If file is failed to be stored, will raise a ValueError. Args: file_like_data: the file-like data storage_type: the storage type, default to the default storage type in the config Returns ------- LargeResult: the LargeResult object """ if storage_type is _UNSET: storage_type = AlabOSConfig()["large_result_storage"][ "default_storage_type" ] large_result = cls(file_like_data=file_like_data, storage_type=storage_type) large_result.store() return large_result
[docs] def store(self): """ Store the large result in the storage system. This method should block until the result is confirmed to be stored. This method should have a timeout regardless of the storage system to not block indefinitely. """ if self.storage_type == "gridfs": db = get_db() fs = gridfs.GridFS(db) if self.local_path: with open(self.local_path, "rb") as file: file_id = fs.put(file) elif self.file_like_data: file_id = fs.put(self.file_like_data) else: raise ValueError( "Either local_path or serializable_data must be provided for storing in gridfs." ) self.identifier = file_id # check if the file is stored, wait until it is stored for maximum 10 seconds for _ in range(10): if fs.exists(file_id): return time.sleep(1) raise ValueError(f"File with identifier {file_id} failed to be stored.") else: raise ValueError("Only gridfs storage is supported for now.")
[docs] def retrieve(self): """Retrieve the large result from the storage system.""" if self.storage_type == "gridfs": if self.identifier is None: raise ValueError( "Identifier is not provided for retrieving from gridfs." ) db = get_db() fs = gridfs.GridFS(db) if fs.get(self.identifier) is None: raise ValueError( f"File with identifier {self.identifier} does not exist." ) return fs.get(self.identifier).read() else: raise ValueError("Only gridfs storage is supported for now.")
[docs] def check_if_stored(self): """Check if the large result is stored in the storage system.""" if self.storage_type == "gridfs": if self.identifier is None: return False db = get_db() fs = gridfs.GridFS(db) return fs.exists(self.identifier) else: raise ValueError("Only gridfs storage is supported for now.")
[docs] class BaseTask(ABC): """ The abstract class of task. All the tasks should inherit from this class. """ def __init__( self, samples: list[str | ObjectId] | None = None, task_id: ObjectId | None = None, lab_view: Optional["LabView"] = None, priority: TaskPriority | int | None = TaskPriority.NORMAL, _offline_mode: bool = True, *args, **kwargs, ): """ Args: task_id: the identifier of task lab_view: a lab_view corresponding to the task_id samples: a list of sample_id's corresponding to samples involvend in the task. _offline_mode: whether the task is run in offline mode or not. It is in offline mode when you are trying to build an experiment out of it or get the task result. Here is an example about how to define a custom task .. code-block:: python def __init__(self, sample_1: ObjectId, sample_2: Optional[ObjectId], sample_3: Optional[ObjectId], sample_4: Optional[ObjectId], setpoints: List[Tuple[float, float]], *args, **kwargs): super(Heating, self).__init__(*args, **kwargs) self.setpoints = setpoints self.samples = [sample_1, sample_2, sample_3, sample_4] """ self.__offline = _offline_mode self._is_taskid_generated = ( False # whether the task_id is generated using ObjectId() here or not ) self.__samples = samples or [] if self.is_offline: if task_id is None: # if task_id is not provided, generate one self._is_taskid_generated = True task_id = ObjectId() self.task_id = task_id current_frame = inspect.currentframe() outer_frames = inspect.getouterframes(current_frame) subclass_init_frame = outer_frames[1].frame self.subclass_kwargs = { key: val for key, val in inspect.getargvalues(subclass_init_frame).locals.items() if key not in ["self", "args", "kwargs", "__class__"] } else: if (task_id is None) or (lab_view is None) or (samples is None): raise ValueError( "BaseTask was instantiated with offline mode off -- task_id, " "lab_view, and samples must all be provided!" ) self.task_id = task_id self.lab_view = lab_view self.logger = self.lab_view.logger self.priority = priority self.lab_view.priority = priority @property def is_offline(self) -> bool: """Returns True if this task is in offline, False if it is a live task.""" return self.__offline @property def samples(self) -> list[str]: """Returns the list of samples associated with this task.""" return self.__samples @property def priority(self) -> int: """Returns the priority of this task.""" if self.is_offline: return 0 return self.lab_view._resource_requester.priority @property def result_specification(self) -> type[BaseModel] | None: """ Returns a pydantic model describing the results to be generated by this task. If specified, this model will be used by task_actor to validate the results after the task is completed. If any error occurs, a warning will be printed. If there is a LargeResult in the result, it will ensured to be stored in the database. Raises ------ NotImplementedError: The subclass must implement this method. Returns ------- BaseModel: A Pydantic model type describing the results to be generated by this task. """ return None @priority.setter def priority(self, value: int | TaskPriority): if value < 0: raise ValueError("Priority should be a positive integer") if not self.__offline: self.lab_view._resource_requester.priority = int(value)
[docs] def set_message(self, message: str): """Sets the task message to be displayed on the dashboard.""" self._message = message if not self.__offline: self.lab_view._task_view.set_message(task_id=self.task_id, message=message)
[docs] def get_message(self): """Gets the task message to be displayed on the dashboard.""" return self._message
[docs] def validate(self) -> bool: """ Validate the task. This function will be called before the task is executed. Should return False if the task has values that make it impossible to execute. For example, a ``Heating`` subclass of `BaseTask` might return False if the set temperature is too high for the furnace. By default, this function returns True unless it is overridden by a subclass. """ return True
[docs] @abstractmethod def run(self): """ Run the task. In this function, you can request lab resources from lab manager and log data to database with logger. ``request_resources`` will not return until all the requested resources are available. So the task will pend until it gets the requested resources, which prevent the conflict in the resource allocation. When a device get the requested device and sample positions, it also takes over the ownership of these resources, i.e., other task cannot use the device or request the sample positions this task has requested. We use a context manager to manage the ownership of the resources. when a task is completed, all the devices and sample positions will be released automatically. Here is an example about how to define the task .. code-block:: python # request devices and sample positions from lab manager. The `$` represents # the name of assigned devices in the sample positions we try to request, # 4 is the number of sample positions. with self.lab_view.request_resources({Furnace: [("$.inside", 4)]}) as devices_and_positions: devices, sample_positions = devices_and_positions furnace = devices[Furnace] inside_furnace = sample_positions[Furnace]["$.inside"] for sample in self.samples: # in a task, we can call other tasks, which will share the same # task id, requested devices and sample positions. moving_task = Moving(sample=sample, task_id=self.task_id, dest=inside_furnace[0], lab_view=self.lab_view, logger=self.logger) moving_task.run() # send command to device furnace.run_program(self.setpoints) while furnace.is_running(): # log the device data, which is current temperature of the furnace self.logger.log_device_signal({ "device": furnace.name, "temperature": furnace.get_temperature(), }) """ raise NotImplementedError( "The .run method must be implemented by the subclass of BaseTask." )
[docs] def run_subtask( self, task: type["BaseTask"], samples: list[str] | str | None = None, **kwargs, ): """Run a subtask of this current task. Returns the result, if any, of the subtask.""" samples = samples or self.samples if isinstance(samples, str): samples = [samples] return self.lab_view.run_subtask(task=task, samples=samples, **kwargs)
[docs] @classmethod def from_kwargs( cls, samples: list[str | ObjectId], task_id: ObjectId, **subclass_kwargs ) -> "BaseTask": """ Used to create a new task object from the provided arguments. This is used in the `add_to` and `ExperimentBuilder.add_task` method to create a new task object and validate it before adding it to an experiment or sample builder. """ task_obj = cls( samples=samples, task_id=task_id, offline_mode=True, **subclass_kwargs, ) return task_obj
[docs] def add_to( self, samples: SampleBuilder | list[SampleBuilder], ): """Used to add basetask to a SampleBuilder's tasklist during Experiment construction. Args: samples (Union[SampleBuilder, List[SampleBuilder]]): One or more SampleBuilder's which will have this task appended to their tasklists. """ if not self.__offline: raise RuntimeError( "Cannot add a live BaseTask instance to a SampleBuilder. BaseTask must be instantiated with " "`offline_mode=True` to enable this method." ) if isinstance(samples, SampleBuilder): samples = [samples] experiment: ExperimentBuilder = samples[0].experiment task_id = self.task_id task_obj = self.__class__.from_kwargs( samples=[sample.name for sample in samples], task_id=ObjectId(task_id), **self.subclass_kwargs, ) if not task_obj.validate(): raise ValueError( "Task input validation failed!" + ( f"\nError message: {task_obj.get_message()}" if task_obj.get_message() else "" ) ) experiment.add_task( task_id=str(task_id), task_name=self.__class__.__name__, task_kwargs=self.subclass_kwargs, samples=samples, ) for sample in samples: sample.add_task(task_id=str(task_id))
_task_registry: dict[str, type[BaseTask]] = {} SUPPORTED_SAMPLE_POSITIONS_TYPE = dict[type["BaseDevice"] | str | None, str | list[str]] _reroute_task_registry: list[ dict[str, type[BaseTask] | SUPPORTED_SAMPLE_POSITIONS_TYPE] ] = []
[docs] def add_task(task: type[BaseTask]): """Register a task.""" if task.__name__ in _task_registry: raise KeyError(f"Duplicated operation name {task.__name__}") _task_registry[task.__name__] = task
[docs] def get_all_tasks() -> dict[str, type[BaseTask]]: """Get all the tasks in the registry.""" return _task_registry.copy()
[docs] def get_task_by_name(name: str) -> type[BaseTask]: """Get a task by name.""" return _task_registry[name]
[docs] def add_reroute_task( supported_sample_positions: SUPPORTED_SAMPLE_POSITIONS_TYPE, task: type[BaseTask], **kwargs, ): """Register a reroute task.""" if task.__name__ not in _task_registry: raise KeyError( f"Task {task.__name__} is not registered! Register with `add_task` before registering as a reroute task." ) if "sample" not in getfullargspec(task).args: raise ValueError( f"Task {task.__name__} does not have `sample` as a parameter! " "Reroute tasks must accept a `sample` parameter that specifies the name or sample_id of the sample to be " "rerouted" ) _reroute_task_registry.append( { "supported_sample_positions": supported_sample_positions, "task": task, "kwargs": kwargs, } )