"""
TaskLauncher is the core module of the system,
which actually executes the tasks.
"""
import concurrent
import time
from concurrent.futures import Future
from datetime import datetime
from threading import Thread
from traceback import print_exc
from typing import Any, cast
import dill
from bson import ObjectId
from pydantic import BaseModel, model_validator
from pydantic.root_model import RootModel
from alab_management.device_view.device import BaseDevice
from alab_management.device_view.device_view import DeviceView
from alab_management.resource_manager.enums import _EXTRA_REQUEST, RequestStatus
from alab_management.sample_view.sample import SamplePosition
from alab_management.sample_view.sample_view import SamplePositionRequest
from alab_management.task_view import TaskPriority
from alab_management.utils.data_objects import DocumentNotUpdatedError, get_collection
_SampleRequestDict = dict[str, int]
_ResourceRequestDict = dict[
type[BaseDevice] | str | None, _SampleRequestDict
] # the raw request sent by task process
[docs]
class RequestCanceledError(Exception):
"""Request Canceled Error."""
# considering concurrent.futures.TimeoutError and TimeoutError becomes the same
# from Python 3.11. We should determine the base class of this exception.
if concurrent.futures.TimeoutError is TimeoutError:
CombinedTimeoutError = TimeoutError
else:
[docs]
class CombinedTimeoutError(
TimeoutError, concurrent.futures.TimeoutError
): # pylint: disable=duplicate-bases
"""
Combined TimeoutError.
If you catch either TimeoutError or concurrent.futures.TimeoutError, this will catch both.
"""
[docs]
class DeviceRequest(BaseModel):
"""Pydantic model for device request."""
identifier: str
content: str
[docs]
class ResourceRequestItem(BaseModel):
"""Pydantic model for resource request item."""
device: DeviceRequest
sample_positions: list[SamplePositionRequest]
[docs]
class ResourcesRequest(RootModel):
"""
This class is used to validate the resource request. Each request should have a format like this.
.. code-block::
[
{
"device":{
"identifier": "name" or "type" or "nodevice",
"content": string corresponding to identifier
},
"sample_positions": [
{
"prefix": prefix of sample position,
"number": integer number of such positions requested.
},
...
]
},
...
].
See Also
--------
:py:class:`SamplePositionRequest <alab_management.sample_view.sample_view.SamplePositionRequest>`
"""
root: list[ResourceRequestItem]
[docs]
@model_validator(mode="before")
def preprocess(cls, values):
"""Preprocess the request."""
new_values = []
for request_dict in values:
if request_dict["device"]["identifier"] not in [
"name",
"type",
_EXTRA_REQUEST,
]:
raise ValueError(
f"device identifier must be one of 'name', 'type', or {_EXTRA_REQUEST}"
)
new_values.append(
{
"device": {
"identifier": request_dict["device"]["identifier"],
"content": request_dict["device"]["content"],
},
"sample_positions": request_dict["sample_positions"],
}
)
return new_values
[docs]
class RequestMixin:
"""Simple wrapper for the request collection."""
def __init__(self):
self._request_collection = get_collection("requests")
[docs]
def update_request_status(
self,
request_id: ObjectId,
status: RequestStatus,
original_status: RequestStatus | list[RequestStatus] = None,
):
"""Update the status of a request by request_id."""
if original_status is not None:
if isinstance(original_status, list):
value_returned = self._request_collection.update_one(
{
"_id": request_id,
"status": {"$in": [status.name for status in original_status]},
},
{"$set": {"status": status.name}},
)
else:
value_returned = self._request_collection.update_one(
{"_id": request_id, "status": original_status.name},
{"$set": {"status": status.name}},
)
else:
value_returned = self._request_collection.update_one(
{"_id": request_id}, {"$set": {"status": status.name}}
)
if value_returned.modified_count == 0:
raise DocumentNotUpdatedError(
f"Request {request_id} was not updated to {status.name}, "
f"because it is not in {original_status.name} status."
)
return value_returned
[docs]
def get_request(self, request_id: ObjectId, **kwargs) -> dict[str, Any] | None:
"""Get a request by request_id."""
return self._request_collection.find_one(
{"_id": request_id}, **kwargs
) # DB_ACCESS_OUTSIDE_VIEW
[docs]
def get_requests_by_status(self, status: RequestStatus):
"""Get all requests by status."""
return self._request_collection.find(
{"status": status.name}
) # DB_ACCESS_OUTSIDE_VIEW
[docs]
def get_requests_by_task_id(self, task_id: ObjectId):
"""Get all requests by task_id."""
return self._request_collection.find({"task_id": task_id})
[docs]
class ResourceRequester(RequestMixin):
"""
Class for request lab resources easily. This class will insert a request into the database,
and then the task manager will read from the database and assign the resources.
It is used in :py:class:`~alab_management.lab_view.LabView`.
"""
def __init__(
self,
task_id: ObjectId,
):
self._request_collection = get_collection("requests")
self._waiting: dict[ObjectId, dict[str, Any]] = {}
self.task_id = task_id
self.device_view = DeviceView()
self.priority: int | TaskPriority = (
TaskPriority.NORMAL
) # will usually be overwritten by BaseTask instantiation.
super().__init__()
self._stop = False
self._thread = Thread(
target=self._check_request_status_loop, name="CheckRequestStatus"
)
self._thread.daemon = True
self._thread.start()
def __close__(self):
"""Close the thread."""
self._stop = True
self._thread.join()
__del__ = __close__
[docs]
def request_resources(
self,
resource_request: _ResourceRequestDict,
timeout: float | None = None,
priority: TaskPriority | int | None = None,
) -> dict[str, Any]:
"""
Request lab resources.
Write the request into the database, and then the task manager will read from the
database and assign the resources.
"""
f = Future()
if priority is None:
priority = self.priority
formatted_resource_request = []
device_str_to_request = {}
for device, position_dict in resource_request.items():
if device is None:
identifier = _EXTRA_REQUEST
content = _EXTRA_REQUEST
elif isinstance(device, str):
identifier = "name"
content = device
elif issubclass(device, BaseDevice):
identifier = "type"
content = device.__name__
else:
raise ValueError(
"device must be a name of a specific device, a class of type BaseDevice, or None"
)
device_str_to_request[content] = device
positions = [
dict(SamplePositionRequest(prefix=prefix, number=number))
for prefix, number in position_dict.items()
] # immediate dict conversion - SamplePositionRequest is only used to check request format.
formatted_resource_request.append(
{
"device": {
"identifier": identifier,
"content": content,
},
"sample_positions": positions,
}
)
if not isinstance(formatted_resource_request, ResourcesRequest):
formatted_resource_request = ResourcesRequest(root=formatted_resource_request) # type: ignore
formatted_resource_request = formatted_resource_request.model_dump( # pylint: disable=assignment-from-no-return
mode="json"
)
result = self._request_collection.insert_one(
{
"request": formatted_resource_request,
"status": RequestStatus.PENDING.name,
"task_id": self.task_id,
"priority": int(priority),
"submitted_at": datetime.now(),
}
) # DB_ACCESS_OUTSIDE_VIEW
_id: ObjectId = cast(ObjectId, result.inserted_id)
self._waiting[_id] = {"f": f, "device_str_to_request": device_str_to_request}
try:
result = self.get_concurrent_result(f, timeout=timeout)
except concurrent.futures.TimeoutError as e:
# if the request is not fulfilled, cancel it to make sure the resources are released
request = self._request_collection.find_one_and_update(
{"_id": _id, "status": {"$ne": RequestStatus.FULFILLED.name}},
{"$set": {"status": RequestStatus.CANCELED.name}},
)
if request is not None:
raise CombinedTimeoutError(
f"Request {result.inserted_id} timed out after {timeout} seconds."
) from e
else: # if the request is fulfilled, return the result normally, wrong timeout
result = self.get_concurrent_result(f)
return {
**self._post_process_requested_resource(
devices=result["devices"],
sample_positions=result["sample_positions"],
resource_request=resource_request,
),
"request_id": result["request_id"],
}
[docs]
@staticmethod
def get_concurrent_result(f: Future, timeout: float | None = None):
"""
Get the result of a future with a timeout.
If the request is canceled, we will catch a RequestCanceledError and hang the program.
The hanged program will be killed by the abort exception in the task actor, which will
be handled in the task actor to clean up the lab.
"""
try:
return f.result(timeout=timeout)
except RequestCanceledError:
# if there is an abort signal, we will just hang the program
while True:
# abort signal here. It should be handled in the task actor
time.sleep(1)
[docs]
def release_resources(self, request_id: ObjectId):
"""Release a request by request_id."""
# For the requests that were CANCELED or ERROR, but have assigned resources, release them
request = self.get_request(request_id)
if request["status"] in [RequestStatus.CANCELED.name, RequestStatus.ERROR.name]:
if ("assigned_devices" in request) or (
"assigned_sample_positions" in request
):
self.update_request_status(
request_id,
RequestStatus.NEED_RELEASE,
original_status=[RequestStatus.CANCELED, RequestStatus.ERROR],
)
else:
# If it doesn't have assigned resources, just leave it as CANCELED or ERROR
return
# For the requests that were fulfilled, definitely have assigned resources, release them
elif request["status"] == RequestStatus.FULFILLED.name:
self.update_request_status(
request_id,
RequestStatus.NEED_RELEASE,
original_status=RequestStatus.FULFILLED,
)
# wait for the request to be released or canceled or errored during the release
while self.get_request(request_id, projection=["status"])["status"] not in [
RequestStatus.RELEASED.name,
RequestStatus.CANCELED.name,
RequestStatus.ERROR.name,
]:
time.sleep(0.5)
[docs]
def release_all_resources(self):
"""
Release all requests by task_id, used for error recovery.
For the requests that are not fulfilled, they will be marked as CANCELED.
For the request that have been fulfilled, they will be marked as NEED_RELEASE.
For the request that have been errored, release assigned resources.
"""
# For the requests that were fulfilled, definitely have assigned resources, release them
self._request_collection.update_many(
{
"task_id": self.task_id,
"status": RequestStatus.FULFILLED.name,
},
{
"$set": {
"status": RequestStatus.NEED_RELEASE.name,
}
},
)
# For the requests that were CANCELED or ERROR, but have assigned resources, release them
assigned_cancel_error_requests_id = []
for request in self.get_requests_by_task_id(self.task_id):
if request["status"] in [
RequestStatus.CANCELED.name,
RequestStatus.ERROR.name,
] and (
("assigned_devices" in request)
or ("assigned_sample_positions" in request)
):
self.update_request_status(
request["_id"],
RequestStatus.NEED_RELEASE,
original_status=[RequestStatus.CANCELED, RequestStatus.ERROR],
)
assigned_cancel_error_requests_id.append(request["_id"])
# wait for all the requests to be released or canceled or errored during the release
while any(
(
request["status"]
not in [
RequestStatus.RELEASED.name,
RequestStatus.CANCELED.name,
RequestStatus.ERROR.name,
]
)
for request in self.get_requests_by_task_id(self.task_id)
):
time.sleep(0.5)
def _check_request_status_loop(self):
while not self._stop:
try:
for request_id in self._waiting.copy():
status = self.get_request(request_id=request_id, projection=["status"])["status"] # type: ignore
if status == RequestStatus.FULFILLED.name:
self._handle_fulfilled_request(request_id=request_id)
elif status == RequestStatus.ERROR.name:
self._handle_error_request(request_id=request_id)
elif status == RequestStatus.CANCELED.name:
self._handle_canceled_request(request_id=request_id)
except Exception:
print_exc() # for debugging in the test
raise
time.sleep(0.5)
def _handle_fulfilled_request(self, request_id: ObjectId):
entry = self.get_request(request_id)
if entry["status"] != RequestStatus.FULFILLED.name: # type: ignore
return
assigned_devices: dict[str, dict[str, str | bool]] = entry["assigned_devices"] # type: ignore
assigned_sample_positions: dict[str, list[dict[str, Any]]] = entry["assigned_sample_positions"] # type: ignore
request: dict[str, Any] = self._waiting.pop(request_id)
f: Future = request["f"]
device_str_to_request: dict[str, type[BaseDevice] | str | None] = request[
"device_str_to_request"
]
f.set_result(
{
"devices": {
device_str_to_request[device_str]: device_dict["name"]
for device_str, device_dict in assigned_devices.items()
},
"sample_positions": {
name: [
sample_position["name"]
for sample_position in sample_positions_list
]
for name, sample_positions_list in assigned_sample_positions.items()
},
"request_id": request_id,
}
)
def _handle_error_request(self, request_id: ObjectId):
entry = self.get_request(request_id)
if entry["status"] != RequestStatus.ERROR.name: # type: ignore
return
error: Exception = dill.loads(entry["error"]) # type: ignore
request: dict[str, Any] = self._waiting.pop(request_id)
f: Future = request["f"]
f.set_exception(error)
def _handle_canceled_request(self, request_id: ObjectId):
entry = self.get_request(request_id)
if entry["status"] != RequestStatus.CANCELED.name: # type: ignore
return
request: dict[str, Any] = self._waiting.pop(request_id)
f: Future = request["f"]
# for the canceled request, we will return an empty result
# and wait for the abort to be handled by the task actor
f.set_exception(RequestCanceledError("Abort signal received."))
@staticmethod
def _post_process_requested_resource(
devices: dict[type[BaseDevice] | str, str],
sample_positions: dict[str, list[str]],
resource_request: dict[str | type[BaseDevice] | None, dict[str, int]],
):
processed_sample_positions: dict[
type[BaseDevice] | str | None, dict[str, list[str]]
] = {}
for device_request, sample_position_dict in resource_request.items():
if len(sample_position_dict) == 0:
continue
processed_sample_positions[device_request] = {}
for prefix in sample_position_dict:
reply_prefix = prefix
if device_request is None: # no device name to prepend
pass
else:
device_prefix = (
f"{devices[device_request]}{SamplePosition.SEPARATOR}"
)
if not reply_prefix.startswith(
device_prefix
): # dont extra prepend for nested requests
reply_prefix = device_prefix + reply_prefix
processed_sample_positions[device_request][prefix] = sample_positions[
reply_prefix
]
return {
"devices": devices,
"sample_positions": processed_sample_positions,
}