diff --git a/requirements-celery.txt b/requirements-celery.txt index fe25e7b7a64ed761da189fa5fb2c3b48f930bb35..324358028bc1135cdfff9eb0a86648978770ac92 100644 --- a/requirements-celery.txt +++ b/requirements-celery.txt @@ -1,2 +1,2 @@ celery>=5.3.5 -alyx_connector>=0.0.10 \ No newline at end of file +alyx_connector>=2.1.5 \ No newline at end of file diff --git a/src/pypelines/__init__.py b/src/pypelines/__init__.py index c53d483d48294d5a4236327a108106d62a8ff82b..93e24e38b05bf69e957e2f5ec6229110d316af54 100644 --- a/src/pypelines/__init__.py +++ b/src/pypelines/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.0.44" +__version__ = "0.0.45" from . import loggs from .pipes import * diff --git a/src/pypelines/celery_tasks.py b/src/pypelines/celery_tasks.py index 4bcd1e55cfa47bd70dff3e840d836262be7aa422..bd5059ecb356e74fd9be6a2a9394b559ba9781bd 100644 --- a/src/pypelines/celery_tasks.py +++ b/src/pypelines/celery_tasks.py @@ -6,7 +6,6 @@ from traceback import format_exc as format_traceback_exc import logging import coloredlogs from logging import getLogger -from functools import wraps from platform import node from pandas import Series @@ -20,56 +19,6 @@ if TYPE_CHECKING: APPLICATIONS_STORE = {} -def get_runner(task_name: str): - from celery import Task - - class CeleryRunner(Task): - name = task_name - - def run(self, task_id, extra=None): - - task = CeleryTaskRecord(task_id) - - try: - session = task.get_session() - application = task.get_application() - arguments = task.arguments - - with LogTask(task) as log_object: - logger = log_object.logger - task["log"] = log_object.filename - task["status"] = "Started" - task.partial_update() - - try: - step: "BaseStep" = ( - application.pipelines[task.pipeline_name].pipes[task.pipe_name].steps[task.step_name] - ) - if arguments.get("refresh", False) or arguments.get("refresh_requirements", []): - skip = False - else: - skip = True - arguments.pop(skip) - - step.generate(session, extra=extra, skip=skip, check_requirements=True, **task.arguments) - task.status_from_logs(log_object) - except Exception as e: - traceback_msg = format_traceback_exc() - logger.critical(f"Fatal Error : {e}") - logger.critical("Traceback :\n" + traceback_msg) - task["status"] = "Failed" - - except Exception as e: - # if it fails outside of the nested try statement, we can't store logs files, - # and we mention the failure through alyx directly. - task["status"] = "Uncatched_Fail" - task["log"] = str(e) - - task.partial_update() - - return CeleryRunner - - class CeleryAlyxTaskManager(BaseStepTaskManager): backend: "CeleryTaskBackend" @@ -78,7 +27,7 @@ class CeleryAlyxTaskManager(BaseStepTaskManager): def register_step(self): if self.backend: # self.backend.app.task(CeleryRunner, name=self.step.complete_name) - self.backend.app.register_task(get_runner(self.step.complete_name)) + self.backend.app.register_task(self.get_runner()) def start(self, session, extra=None, **kwargs): @@ -89,6 +38,55 @@ class CeleryAlyxTaskManager(BaseStepTaskManager): return CeleryTaskRecord.create(self, session, extra, **kwargs) + def get_runner(superself): # type: ignore + from celery import Task + + class CeleryRunner(Task): + name = superself.step.complete_name + + def run(self, task_id, extra=None): + + task = CeleryTaskRecord(task_id) + + try: + session = task.get_session() + application = task.get_application() + arguments = task.arguments + + with LogTask(task) as log_object: + logger = log_object.logger + task["log"] = log_object.filename + task["status"] = "Started" + task.partial_update() + + try: + step: "BaseStep" = ( + application.pipelines[task.pipeline_name].pipes[task.pipe_name].steps[task.step_name] + ) + if arguments.get("refresh", False) or arguments.get("refresh_requirements", []): + skip = False + else: + skip = True + arguments.pop(skip) + + step.generate(session, extra=extra, skip=skip, check_requirements=True, **task.arguments) + task.status_from_logs(log_object) + except Exception as e: + traceback_msg = format_traceback_exc() + logger.critical(f"Fatal Error : {e}") + logger.critical("Traceback :\n" + traceback_msg) + task["status"] = "Failed" + + except Exception as e: + # if it fails outside of the nested try statement, we can't store logs files, + # and we mention the failure through alyx directly. + task["status"] = "Uncatched_Fail" + task["log"] = str(e) + + task.partial_update() + + return CeleryRunner + class CeleryTaskRecord(dict): session: Series @@ -196,6 +194,28 @@ class CeleryTaskRecord(dict): task_dict["id"], task_infos_dict=task_dict, response_handle=response_handle, session=session ) + @staticmethod + def create_from_task_name(app: "Celery", task_name: str, pipeline_name: str, session, extra=None, **kwargs): + from one import ONE + + connector = ONE(mode="remote", data_access_mode="remote") + + data = { + "session": session.name if isinstance(session, Series) else session, + "name": task_name, + "arguments": kwargs, + "status": "Waiting", + "executable": pipeline_name, + } + + task_dict = connector.alyx.rest("tasks", "create", data=data) + + response_handle = app.send_task(name=task_name, kwargs={"task_id": task_dict["id"], "extra": extra}) + + return CeleryTaskRecord( + task_dict["id"], task_infos_dict=task_dict, response_handle=response_handle, session=session + ) + class CeleryTaskBackend(BaseTaskBackend): app: "Celery" @@ -226,18 +246,6 @@ class CeleryPipeline(Pipeline): runner_backend_class = CeleryTaskBackend -def get_setting_files_path(conf_path, app_name) -> List[Path]: - conf_path = Path(conf_path) - if conf_path.is_file(): - conf_path = conf_path.parent - files = [] - for prefix, suffix in zip(["", "."], ["", "_secrets"]): - file_loc = conf_path / f"{prefix}celery_{app_name}{suffix}.toml" - if file_loc.is_file(): - files.append(file_loc) - return files - - class LogTask: def __init__(self, task_record: CeleryTaskRecord, username=None, level="LOAD"): self.path = Path(task_record.session_path) / "logs" @@ -288,6 +296,119 @@ class LogTask: def create_celery_app(conf_path, app_name="pypelines", v_host=None) -> "Celery | None": + from types import MethodType + from celery import Task + + def get_setting_files_path(conf_path) -> List[Path]: + conf_path = Path(conf_path) + if conf_path.is_file(): + conf_path = conf_path.parent + files = [] + for prefix, suffix in zip(["", "."], ["", "_secrets"]): + file_loc = conf_path / f"{prefix}celery_{app_name}{suffix}.toml" + if file_loc.is_file(): + files.append(file_loc) + return files + + def get_signature_string(signature): + params = [ + param_value for param_name, param_value in signature.parameters.items() if param_name not in ["session"] + ] + return str(signature.replace(parameters=params))[1:-1].replace(" *,", "") + + class Handshake(Task): + name = f"{app_name}.handshake" + + def run(self): + return f"{node()} is happy to shake your hand and says hello !" + + class TasksInfos(Task): + name = f"{app_name}.tasks_infos" + + def run(self, app_name, selfish=False): + app = APPLICATIONS_STORE[app_name] + tasks_dynamic_data = {} + pipelines = getattr(app, "pipelines", {}) + if len(pipelines) == 0: + logger.warning( + "No pipeline is registered on this app instance. " + "Are you trying to read tasks infos from a non worker app ? (web server side ?)" + ) + return {} + for pipeline in pipelines.values(): + pipeline.resolve() + for pipe in pipeline.pipes.values(): + for step in pipe.steps.values(): + if step.complete_name in app.tasks.keys(): + sig = get_signature_string(step.generate.__signature__) + doc = step.generate.__doc__ + task_data = { + "signature": sig, + "docstring": doc, + "step_name": step.step_name, + "pipe_name": step.pipe_name, + "pipeline_name": step.pipeline_name, + "requires": [item.complete_name for item in step.requires], + "step_level_in_pipe": step.get_level(selfish=selfish), + } + tasks_dynamic_data[step.complete_name] = task_data + return tasks_dynamic_data + + def get_remote_tasks(self): + registered_tasks = self.control.inspect().registered_tasks() + workers = [] + task_names = [] + for worker, tasks in registered_tasks.items(): + workers.append(worker) + for task in tasks: + task_names.append(task) + + def get_celery_app_tasks(self, refresh=False): + + from datetime import datetime, timedelta + + auto_refresh_time = timedelta(0, (20 * 1)) # a full day (24 hours of 3600 seconds) + failed_refresh_retry_time = timedelta(0, (30 * 1)) # try to refresh after an hour + + app_task_data = getattr(self, "task_data", None) + + if app_task_data is None: + try: + task_data = self.tasks[f"{app_name}.tasks_infos"].delay(app_name).get(timeout=2) + app_task_data = {"task_data": task_data, "refresh_time": datetime.now() + auto_refresh_time} + setattr(self, "task_data", app_task_data) + logger.warning("Got tasks data for the first time since django server relaunched") + except Exception as e: + logger.warning(f"Could not get tasks from app. {e}") + # logger.warning(f"Remote tasks are : {self.get_remote_tasks()}") + # logger.warning(f"Local tasks are : {self.tasks}") + + else: + if datetime.now() > app_task_data["refresh_time"]: # we refresh if refresh time is elapsed + refresh = True + + if refresh: + try: + task_data = self.tasks[f"{app_name}.tasks_infos"].delay(app_name).get(timeout=2) + app_task_data = {"task_data": task_data, "refresh_time": datetime.now() + auto_refresh_time} + logger.warning("Refreshed celery tasks data sucessfully") + except Exception as e: + logger.warning( + "Could not refresh tasks data from remote celery worker. All workers are is probably running. " + f"{e}" + ) + app_task_data["refresh_time"] = datetime.now() + failed_refresh_retry_time + setattr(self, "task_data", app_task_data) + else: + delta = (app_task_data["refresh_time"] - datetime.now()).total_seconds() + logger.warning(f"Returned cached task_data. Next refresh will happen in at least {delta} seconds") + return app_task_data["task_data"] if app_task_data is not None else None + + def launch_named_task_remotely(self, session_id, task_name, extra=None, kwargs={}): + task_record = CeleryTaskRecord.create_from_task_name( + self, task_name, app_name, session_id, extra=extra, **kwargs + ) + return task_record failure_message = ( f"Celery app : {app_name} failed to be created." @@ -301,7 +422,7 @@ def create_celery_app(conf_path, app_name="pypelines", v_host=None) -> "Celery | logger.warning(f"Tried to create a celery app named {app_name}, but it already exists. Returning it instead.") return APPLICATIONS_STORE[app_name] - settings_files = get_setting_files_path(conf_path, app_name) + settings_files = get_setting_files_path(conf_path) if len(settings_files) == 0: logger.warning(f"{failure_message} Could not find celery toml config files.") @@ -355,51 +476,15 @@ def create_celery_app(conf_path, app_name="pypelines", v_host=None) -> "Celery | logger.warning(f"{failure_message} Could assign extra attribute {key} to celery app. {e}") return None - APPLICATIONS_STORE[app_name] = app - - from celery import Task - - class handshake(Task): - name = f"{app_name}.handshake" - - def run(self): - return f"{node()} is happy to shake your hand and says hello !" - - def get_signature_string(signature): - params = [ - param_value for param_name, param_value in signature.parameters.items() if param_name not in ["session"] - ] - return str(signature.replace(parameters=params))[1:-1].replace(" *,", "") + app.register_task(Handshake) + app.register_task(TasksInfos) - class tasks_infos(Task): - name = f"{app_name}.tasks_infos" - - def run(self, app_name, selfish=False): - app = APPLICATIONS_STORE[app_name] - tasks_dynamic_data = {} - pipelines = getattr(app, "pipelines", {}) - for pipeline in pipelines.values(): - pipeline.resolve() - for pipe in pipeline.pipes.values(): - for step in pipe.steps.values(): - if step.complete_name in app.tasks.keys(): - sig = get_signature_string(step.generate.__signature__) - doc = step.generate.__doc__ - task_data = { - "signature": sig, - "docstring": doc, - "step_name": step.step_name, - "pipe_name": step.pipe_name, - "pipeline_name": step.pipeline_name, - "requires": [item.complete_name for item in step.requires], - "step_level_in_pipe": step.get_level(selfish=selfish), - } - tasks_dynamic_data[step.complete_name] = task_data - return tasks_dynamic_data - - app.register_task(handshake) - app.register_task(tasks_infos) + app.get_remote_tasks = MethodType(get_remote_tasks, app) # type: ignore + app.get_celery_app_tasks = MethodType(get_celery_app_tasks, app) # type: ignore + app.launch_named_task_remotely = MethodType(launch_named_task_remotely, app) # type: ignore logger.info(f"The celery app {app_name} was created successfully.") + APPLICATIONS_STORE[app_name] = app + return app