diff --git a/src/pypelines/__init__.py b/src/pypelines/__init__.py index 6af25681f0e4f1c5a5ceaba7006006010548d7cf..8886a4603c9ea9816fe148481e12f4906c4afbf4 100644 --- a/src/pypelines/__init__.py +++ b/src/pypelines/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.0.28" +__version__ = "0.0.29" from . import loggs from .pipes import * diff --git a/src/pypelines/arguments.py b/src/pypelines/arguments.py index ae76f1ec187a58ab105fb3839ca05d128c2af9c4..d6c2fa8f77b3fdc6f1a654e3d1363fbe64d6676c 100644 --- a/src/pypelines/arguments.py +++ b/src/pypelines/arguments.py @@ -51,7 +51,9 @@ def autoload_arguments(wrapped_function, step): config_kwargs = get_step_arguments(session, step) if config_kwargs: # new_kwargs is not empty - local_log.note(f"Using the arguments for the function {step.full_name} found in pipelines_arguments.json.") + local_log.note( + f"Using the arguments for the function {step.relative_name} found in pipelines_arguments.json." + ) # this loop is just to show to log wich arguments have been overriden # from the json config by some arguments in the code overrides_names = [] @@ -62,7 +64,7 @@ def autoload_arguments(wrapped_function, step): if overrides_names: local_log.note( f"Values of pipelines_arguments.json arguments : {', '.join(overrides_names)}, are overrided by the" - f" current call arguments to {step.full_name}" + f" current call arguments to {step.relative_name}" ) config_kwargs.update(kwargs) @@ -75,13 +77,13 @@ def get_step_arguments(session, step): local_log = getLogger("autoload_arguments") try: - config_args = read_session_arguments_file(session, step)["functions"][step.full_name] + config_args = read_session_arguments_file(session, step)["functions"][step.relative_name] except FileNotFoundError as e: local_log.debug(f"{type(e).__name__} : {e}. Skipping") return {} except KeyError: local_log.debug( - f"Could not find the `functions` key or the key `{step.full_name}` in pipelines_arguments.json file at" + f"Could not find the `functions` key or the key `{step.relative_name}` in pipelines_arguments.json file at" f" {session.path}. Skipping" ) return {} diff --git a/src/pypelines/celery_tasks.py b/src/pypelines/celery_tasks.py new file mode 100644 index 0000000000000000000000000000000000000000..783d8a2e080cb874e772e0ca42b73a9b5f0641ed --- /dev/null +++ b/src/pypelines/celery_tasks.py @@ -0,0 +1,192 @@ +from .tasks import BaseTaskBackend +from .pipelines import Pipeline +from pathlib import Path +from traceback import format_exc as format_traceback_exc +from logging import getLogger +from functools import wraps +from .loggs import LogTask + +from typing import TYPE_CHECKING, List + +if TYPE_CHECKING: + from celery import Celery + from .steps import BaseStep + + +class CeleryTaskBackend(BaseTaskBackend): + app: "Celery" + + def __init__(self, parent: Pipeline, app: "Celery | None" = None): + super().__init__(parent) + self.parent = parent + + if app is not None: + self.success = True + self.app = app + + def start(self): + self.app.start() + + def register_step(self, step: "BaseStep"): + wrapped_step = getattr(step, "queue", None) + if wrapped_step is None: + return + if self: + self.app.task(wrapped_step, name=step.complete_name) + + def wrap_step(self, step): + + @wraps(step.generate) + def wrapper(task_id, extra=None): # session, *args, extra=None, **kwargs): + from one import ONE + + connector = ONE(mode="remote", data_access_mode="remote") + task = TaskRecord(connector.alyx.rest("tasks", "read", id=task_id)) + kwargs = task.arguments if task.arguments else {} + + try: + session = connector.search(id=task.session, details=True, no_cache=True) + + with LogTask(task) as log_object: + logger = log_object.logger + task.log = log_object.filename + task.status = "Started" + task = TaskRecord(connector.alyx.rest("tasks", "partial_update", **task.export())) + + try: + step.generate(session, extra=extra, skip=True, check_requirements=True, **kwargs) + task.status = CeleryTaskBackend.status_from_logs(log_object) + except Exception as e: + traceback_msg = format_traceback_exc() + logger.critical(f"Fatal Error : {e}") + logger.critical("Traceback :\n" + traceback_msg) + task.status = "Failed" + + except Exception as e: + # if it fails outside of the nested try statement, we can't store logs files, + # and we mention the failure through alyx directly. + task.status = "Uncatched_Fail" + task.log = str(e) + + connector.alyx.rest("tasks", "partial_update", **task.export()) + + return wrapper + + @staticmethod + def status_from_logs(log_object): + with open(log_object.fullpath, "r") as f: + content = f.read() + + if len(content) == 0: + return "No_Info" + if "CRITICAL" in content: + return "Failed" + elif "ERROR" in content: + return "Errors" + elif "WARNING" in content: + return "Warnings" + else: + return "Complete" + + +class CeleryPipeline(Pipeline): + runner_backend_class = CeleryTaskBackend + + +def get_setting_files_path(conf_path, app_name) -> List[Path]: + conf_path = Path(conf_path) + if conf_path.is_file(): + conf_path = conf_path.parent + files = [] + for prefix, suffix in zip(["", "."], ["", "_secrets"]): + file_loc = conf_path / f"{prefix}celery_{app_name}{suffix}.toml" + if file_loc.is_file(): + files.append(file_loc) + return files + + +def create_celery_app(conf_path, app_name="pypelines"): + + failure_message = ( + f"Celery app : {app_name} failed to be created." + "Don't worry, about this alert, " + "this is not be an issue if you didn't explicitely planned on using celery. Issue was : " + ) + + logger = getLogger("pypelines.create_celery_app") + settings_files = get_setting_files_path(conf_path, app_name) + + if len(settings_files) == 0: + logger.warning(f"{failure_message} Could not find celery toml config files.") + return None + + try: + from dynaconf import Dynaconf + except ImportError: + logger.warning(f"{failure_message} Could not import dynaconf. Maybe it is not istalled ?") + return None + + try: + settings = Dynaconf(settings_files=settings_files) + except Exception as e: + logger.warning(f"{failure_message} Could not create dynaconf object. {e}") + return None + + try: + app_display_name = settings.get("app_display_name", app_name) + broker_type = settings.connexion.broker_type + account = settings.account + password = settings.password + address = settings.address + backend = settings.connexion.backend + conf_data = settings.conf + except (AttributeError, KeyError) as e: + logger.warning(f"{failure_message} {e}") + return None + + try: + from celery import Celery + except ImportError: + logger.warning(f"{failure_message} Could not import celery. Maybe is is not installed ?") + return None + + try: + app = Celery( + app_display_name, + broker=(f"{broker_type}://" f"{account}:{password}@{address}//"), + backend=f"{backend}://", + ) + except Exception as e: + logger.warning(f"{failure_message} Could not create app. Maybe rabbitmq server @{address} is not running ? {e}") + return None + + for key, value in conf_data.items(): + try: + setattr(app.conf, key, value) + except Exception as e: + logger.warning(f"{failure_message} Could assign extra attribute {key} to celery app. {e}") + return None + + return app + + +class TaskRecord(dict): + # a class to make dictionnary keys accessible with attribute syntax + def __init__(self, *args, **kwargs): + from one import ONE + + self.connector = ONE(mode="remote", data_access_mode="remote") + super().__init__(*args, **kwargs) + self.__dict__ = self + + def export(self): + return {"id": self["id"], "data": {k: v for k, v in self.items() if k not in ["id", "session_path"]}} + + @staticmethod + def create(step: "BaseStep", session): + from one import ONE + + connector = ONE(mode="remote", data_access_mode="remote") + + data = {"name": step.complete_name} + connector.alyx.rest("tasks", "create", data=data) diff --git a/src/pypelines/disk.py b/src/pypelines/disk.py index 0b0e176ffdccb209dda3826b4b1bd58ed3b27830..cd534f16539a50c668c7f0cfa15637f62cf0d63e 100644 --- a/src/pypelines/disk.py +++ b/src/pypelines/disk.py @@ -41,7 +41,7 @@ class BaseDiskObject(metaclass=ABCMeta): @property def object_name(self): - return f"{self.step.full_name}{'.'+self.extra if self.extra else ''}" + return f"{self.step.relative_name}{'.' + self.extra if self.extra else ''}" @abstractmethod def version_deprecated(self) -> bool: @@ -88,7 +88,7 @@ class BaseDiskObject(metaclass=ABCMeta): def multisession_unpacker(sessions, datas): raise NotImplementedError - def disk_step_instance(self) -> "BaseStep": + def disk_step_instance(self) -> "BaseStep | None": """Returns an instance of the step that corresponds to the file on disk.""" if self.disk_step is not None: return self.step.pipe.steps[self.disk_step] @@ -108,7 +108,7 @@ class BaseDiskObject(metaclass=ABCMeta): def get_status_message(self): loadable_disk_message = "A disk object is loadable. " if self.is_loadable() else "" deprecated_disk_message = ( - f"This object's version is { 'deprecated' if self.version_deprecated() else 'the current one' }. " + f"This object's version is {'deprecated' if self.version_deprecated() else 'the current one'}. " ) step_level_disk_message = ( "This object's step level is" @@ -127,7 +127,7 @@ class BaseDiskObject(metaclass=ABCMeta): else "" ) return ( - f"{self.object_name} object has{ ' a' if self.is_matching() else ' no' } valid disk object found." + f"{self.object_name} object has {'a' if self.is_matching() else 'no'} valid disk object found." f" {found_disk_object_description}{loadable_disk_message}" ) diff --git a/src/pypelines/graphs.py b/src/pypelines/graphs.py index e083d02d95922502712addc3dc19e20294ac3cb5..779b57d2fe657b227b2b18542a0f942acf24e3d4 100644 --- a/src/pypelines/graphs.py +++ b/src/pypelines/graphs.py @@ -1,10 +1,15 @@ import numpy as np import matplotlib.pyplot as plt +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from networkx import DiGraph + class PipelineGraph: - callable_graph = None - name_graph = None + callable_graph: "DiGraph" + name_graph: "DiGraph" def __init__(self, pipeline): from networkx import DiGraph, draw, spring_layout, draw_networkx_labels @@ -20,17 +25,16 @@ class PipelineGraph: self.make_graphs() def make_graphs(self): - from networkx import DiGraph - callable_graph = DiGraph() - display_graph = DiGraph() + callable_graph = self.DiGraph() + display_graph = self.DiGraph() for pipe in self.pipeline.pipes.values(): for step in pipe.steps.values(): callable_graph.add_node(step) - display_graph.add_node(step.full_name) + display_graph.add_node(step.relative_name) for req in step.requires: callable_graph.add_edge(req, step) - display_graph.add_edge(req.full_name, step.full_name) + display_graph.add_edge(req.relative_name, step.relative_name) self.callable_graph = callable_graph self.name_graph = display_graph @@ -101,7 +105,7 @@ class PipelineGraph: # if len([]) # TODO : add distinctions of fractions of y if multiple nodes of the same pipe have same level x = pipe_x_indices[node.pipe] y = node.get_level() - pos[node.full_name] = (x, -y) + pos[node.relative_name] = (x, -y) return pos def separate_crowded_levels(self, pos, max_spacing=0.35): diff --git a/src/pypelines/pickle_backend.py b/src/pypelines/pickle_backend.py index 39a5f9024ca53bcf9a32786d26a6826f052b55b5..76d40e878c03eb655762a59bfc41237c3a3e5855 100644 --- a/src/pypelines/pickle_backend.py +++ b/src/pypelines/pickle_backend.py @@ -52,10 +52,14 @@ class PickleDiskObject(BaseDiskObject): # we compare levels with the currently called step # if disk step level < current called step level, we return True, else we return False. if disk_step.get_level(selfish=True) < self.step.get_level(selfish=True): - logger.debug(f"Disk step {disk_step.full_name} was lower than {self.step.full_name}. Returning True") + logger.debug( + f"Disk step {disk_step.relative_name} was lower than {self.step.relative_name}. Returning True" + ) return True - logger.debug(f"Disk step {disk_step.full_name} was higher or equal than {self.step.full_name}. Returning False") + logger.debug( + f"Disk step {disk_step.relative_name} was higher or equal than {self.step.relative_name}. Returning False" + ) return False @property @@ -157,7 +161,7 @@ class PickleDiskObject(BaseDiskObject): return True else: logger.load( - f"More than one partial match was found for {self.step.full_name}. Cannot auto select. Expected :" + f"More than one partial match was found for {self.step.relative_name}. Cannot auto select. Expected :" f" {expected_values}, Found : {match_datas}" ) return False @@ -189,7 +193,7 @@ class PickleDiskObject(BaseDiskObject): def load(self): logger = logging.getLogger("PickleDiskObject.load") - logger.debug(f"Current disk file status : {self.current_disk_file = }") + logger.debug(f"Current disk file status : {self.current_disk_file=}") if self.current_disk_file is None: raise IOError( "Could not find a file to load. Either no file was found on disk, or you forgot to run 'check_disk()'" diff --git a/src/pypelines/pipelines.py b/src/pypelines/pipelines.py index 84281485573d6bc8b314354f7ea5e8e30a7d50e7..fc1e1ff655fe2628379b5512bec3b14d6db83ab3 100644 --- a/src/pypelines/pipelines.py +++ b/src/pypelines/pipelines.py @@ -1,6 +1,8 @@ -from typing import Callable, Type, Dict, Iterable, Protocol, TYPE_CHECKING from logging import getLogger import os +from .tasks import BaseTaskBackend + +from typing import Callable, Type, Dict, List, Iterable, Protocol, TYPE_CHECKING if TYPE_CHECKING: from .pipes import BasePipe @@ -9,17 +11,17 @@ if TYPE_CHECKING: class Pipeline: - use_celery = False pipes: Dict[str, "BasePipe"] + runner_backend_class = BaseTaskBackend - def __init__(self, name: str, conf_path=None, use_celery=False): + def __init__(self, name: str, **runner_args): self.pipeline_name = name self.pipes = {} self.resolved = False - self.conf_path = os.path.dirname(conf_path) if conf_path is not None else None - if use_celery: - self.configure_celery() + # create a runner backend, if fails, the runner_backend object evaluates to False as a boolean + # (to be checked and used througout the pipeline wrappers creation) + self.runner_backend = self.runner_backend_class(self, **runner_args) def register_pipe(self, pipe_class: Type["BasePipe"]) -> Type["BasePipe"]: """Wrapper to instanciate and attache a a class inheriting from BasePipe it to the Pipeline instance. @@ -76,7 +78,9 @@ class Pipeline: self.resolved = True - def get_requirement_stack(self, instance: "BaseStep", names: bool = False, max_recursion: int = 100) -> list: + def get_requirement_stack( + self, instance: "BaseStep", names: bool = False, max_recursion: int = 100 + ) -> List["BaseStep"]: """Returns a list containing the ordered Steps that the "instance" Step object requires for being ran. Args: @@ -93,11 +97,11 @@ class Pipeline: """ self.resolve() # ensure requires lists are containing instances and not strings - parents = [] + parents: List["BaseStep"] = [] required_steps = [] def recurse_requirement_stack( - instance, + instance: "BaseStep", ): """ _summary_ @@ -130,7 +134,7 @@ class Pipeline: recurse_requirement_stack(instance) if names: - required_steps = [req.full_name for req in required_steps] + required_steps = [req.relative_name for req in required_steps] return required_steps @property @@ -138,29 +142,3 @@ class Pipeline: from .graphs import PipelineGraph return PipelineGraph(self) - - def configure_celery(self) -> None: - try : - from .tasks import CeleryHandler - except ImportError: - getLogger().warning( - f"Celery is not installed. Cannot set it up for the pipeline {self.pipeline_name}" - "Don't worry, about this alert, " - "this is not be an issue if you didn't explicitely planned on using celery." - ) - return - - celery = CeleryHandler(self.conf_path, self.pipeline_name) - if celery.success: - self.celery = celery - self.use_celery = True - else: - getLogger().warning( - f"Could not initialize celery for the pipeline {self.pipeline_name}." - "Don't worry, about this alert, " - "this is not be an issue if you didn't explicitely planned on using celery." - ) - - def finalize(self): - if self.use_celery: - self.celery.app.start() # pyright: ignore diff --git a/src/pypelines/pipes.py b/src/pypelines/pipes.py index 5abfa81b4c159a338130a606f558237b1d62d3e6..ebaeb7eb0c5b2c86ea8ebb3b39622acfd8107f21 100644 --- a/src/pypelines/pipes.py +++ b/src/pypelines/pipes.py @@ -40,8 +40,7 @@ class BasePipe(metaclass=ABCMeta): if len(_steps) < 1: raise ValueError( - f"You should register at least one step class with @stepmethod in {self.pipe_name} class." - f" { _steps = }" + f"You should register at least one step class with @stepmethod in {self.pipe_name} class. {_steps=}" ) # if len(_steps) > 1 and self.single_step: @@ -70,10 +69,6 @@ class BasePipe(metaclass=ABCMeta): # so that we attach the necessary components to it. setattr(self, step_name, step) - # if the pipeline has been created with celery settings, we attach the step generation to celery. - if self.pipeline.use_celery: - self.pipeline.celery.register_step(step) - # below is just a syntaxic sugar to help in case the pipe is "single_step" # so that we can access any pipe instance in pipeline with simple iteration on # pipeline.pipes.pipe, whatever if the object in pipelines.pipes is a step or a pipe diff --git a/src/pypelines/steps.py b/src/pypelines/steps.py index a221241f5aa42d289d1d39161a986952b6400ef3..412b48effa1aaeb30f18090fbb5572d752e7d2e2 100644 --- a/src/pypelines/steps.py +++ b/src/pypelines/steps.py @@ -6,7 +6,7 @@ import logging, inspect from dataclasses import dataclass from types import MethodType -from typing import Callable, Type, Iterable, Protocol, TYPE_CHECKING +from typing import Callable, Type, Iterable, Protocol, List, TYPE_CHECKING if TYPE_CHECKING: from .pipelines import Pipeline @@ -54,6 +54,17 @@ def stepmethod(requires=[], version=None, do_dispatch=True, on_save_callbacks=[] class BaseStep: + step_name: str + + requires: List["BaseStep"] + version: str | int + do_dispatch: bool + callbacks: List[Callable] + + worker: Callable + pipe: "BasePipe" + pipeline: "Pipeline" + def __init__( self, pipeline: "Pipeline", @@ -85,21 +96,29 @@ class BaseStep: self.multisession = self.pipe.multisession_class(self) + if self.pipeline.runner_backend: + queued_runner = self.pipeline.runner_backend.wrap_step(self) + setattr(self, "queue", queued_runner) + @property - def requirement_stack(self): + def requirement_stack(self) -> Callable: return partial(self.pipeline.get_requirement_stack, instance=self) @property - def pipe_name(self): + def pipe_name(self) -> str: return self.pipe.pipe_name @property - def full_name(self): + def relative_name(self) -> str: return f"{self.pipe_name}.{self.step_name}" - # @property - # def single_step(self): - # return self.pipe.single_step + @property + def pipeline_name(self) -> str: + return self.pipe.pipeline.pipeline_name + + @property + def complete_name(self) -> str: + return f"{self.pipeline_name}.{self.relative_name}" def disk_step(self, session, extra=""): disk_object = self.get_disk_object(session, extra) @@ -223,9 +242,9 @@ class BaseStep: ) # a flag to know if we are in requirement run or toplevel if in_requirement: - logger = logging.getLogger(f"╰─>req.{self.full_name}"[:NAMELENGTH]) + logger = logging.getLogger(f"╰─>req.{self.relative_name}"[:NAMELENGTH]) else: - logger = logging.getLogger(f"gen.{self.full_name}"[:NAMELENGTH]) + logger = logging.getLogger(f"gen.{self.relative_name}"[:NAMELENGTH]) if refresh and skip: raise ValueError( @@ -263,8 +282,8 @@ class BaseStep: elif skip: logger.load( - f"File exists for {self.full_name}{'.' + extra if extra else ''}. Loading and processing" - " will be skipped" + f"File exists for {self.relative_name}{'.' + extra if extra else ''}." + " Loading and processing will be skipped" ) if not check_requirements or refresh_requirements is not False: return None @@ -287,15 +306,18 @@ class BaseStep: result = disk_object.load() except IOError as e: raise IOError( - f"The DiskObject responsible for loading {self.full_name} has `is_loadable() == True`" + f"The DiskObject responsible for loading {self.relative_name}" + " has `is_loadable() == True`" " but the loading procedure failed. Double check and test your DiskObject check_disk" " and load implementation. Check the original error above." ) from e - logger.load(f"Loaded {self.full_name}{'.' + extra if extra else ''} sucessfully.") + logger.load(f"Loaded {self.relative_name}{'.' + extra if extra else ''} sucessfully.") return result else: - logger.load(f"Could not find or load {self.full_name}{'.' + extra if extra else ''} saved file.") + logger.load( + f"Could not find or load {self.relative_name}{'.' + extra if extra else ''} saved file." + ) else: logger.load("`refresh` was set to True, ignoring the state of disk files and running the function.") @@ -322,7 +344,7 @@ class BaseStep: if isinstance(refresh_requirements, list): _refresh = ( True - if step.pipe_name in refresh_requirements or step.full_name in refresh_requirements + if step.pipe_name in refresh_requirements or step.relative_name in refresh_requirements else False ) @@ -343,16 +365,18 @@ class BaseStep: return None if in_requirement: - logger.header(f"Performing the requirement {self.full_name}{'.' + extra if extra else ''}") + logger.header(f"Performing the requirement {self.relative_name}{'.' + extra if extra else ''}") else: - logger.header(f"Performing the computation to generate {self.full_name}{'.' + extra if extra else ''}") + logger.header( + f"Performing the computation to generate {self.relative_name}{'.' + extra if extra else ''}" + ) kwargs.update({"extra": extra}) if self.is_refresh_in_kwargs(): kwargs.update({"refresh": refresh}) result = self.pipe.pre_run_wrapper(self.worker(session, *args, **kwargs)) if save_output: - logger.save(f"Saving the generated {self.full_name}{'.' + extra if extra else ''} output.") + logger.save(f"Saving the generated {self.relative_name}{'.' + extra if extra else ''} output.") disk_object.save(result) # AFTER the saving has been done, if there is some callback function that should be run, we execute them @@ -414,7 +438,7 @@ class BaseStep: sig = inspect.signature(self.worker) param = sig.parameters.get("extra") if param is None: - raise ValueError(f"Parameter extra not found in function {self.full_name}") + raise ValueError(f"Parameter extra not found in function {self.relative_name}") if param.default is param.empty: raise ValueError("Parameter extra does not have a default value") return param.default @@ -431,8 +455,8 @@ class BaseStep: req_step = [step for step in self.requirement_stack() if step.pipe_name == pipe_name][-1] except IndexError as e: raise IndexError( - f"Could not find a required step with the pipe_name {pipe_name} for the step {self.full_name}. Are you" - " sure it figures in the requirement stack ?" + f"Could not find a required step with the pipe_name {pipe_name} for the step {self.relative_name}. " + "Are you sure it figures in the requirement stack ?" ) from e return req_step.load(session, extra=extra) @@ -443,21 +467,25 @@ class BaseStep: raise NotImplementedError def start_remotely(self, session, extra=None, **kwargs): - if not self.pipeline.use_celery: + + queued_runner = getattr(self, "queue", None) + + if queued_runner is None: raise NotImplementedError( - "Cannot use this feature with a pipeline that doesn't have a celery cluster access" + "Cannot use this feature with a pipeline that doesn't have an implemented and working runner backend" ) + from one import ONE connector = ONE(mode="remote", data_access_mode="remote") - worker = self.pipeline.celery.app.tasks[self.full_name] + worker = self.pipeline.celery.app.tasks[self.relative_name] task_dict = connector.alyx.rest( "tasks", "create", data={ "session": session.name, - "name": self.full_name, + "name": self.relative_name, "arguments": kwargs, "status": "Waiting", "executable": self.pipeline.celery.app_name, diff --git a/src/pypelines/tasks.py b/src/pypelines/tasks.py index 6e651e4e4abfe68e00eded180f66c456a3bd5fe9..23ff6358590e4b1f5fd86f1347521b234e791935 100644 --- a/src/pypelines/tasks.py +++ b/src/pypelines/tasks.py @@ -1,133 +1,31 @@ from functools import wraps -from logging import getLogger -import os, traceback +from typing import TYPE_CHECKING -from celery import Celery -from dynaconf import Dynaconf -from one import ONE -from .loggs import LogTask +if TYPE_CHECKING: + from .pipelines import Pipeline -class CeleryHandler: - settings = None - app = None - app_name = None - success: bool = False - - def __init__(self, conf_path, pipeline_name): - logger = getLogger() - settings_files = self.get_setting_files_path(conf_path, pipeline_name) - - if any([not os.path.isfile(file) for file in settings_files]): - logger.warning(f"Some celery configuration files were missing for pipeline {pipeline_name}") - return - - try: - self.settings = Dynaconf(settings_files=settings_files) - - self.app_name = self.settings.get("app_name", pipeline_name) - broker_type = self.settings.connexion.broker_type - account = self.settings.account - password = self.settings.password - address = self.settings.address - backend = self.settings.connexion.backend - conf_data = self.settings.conf - - except Exception as e: - logger.warning( - "Could not get all necessary information to configure celery when reading config files." - "Check their content." - ) - return +class BaseTaskBackend: - try: - self.app = Celery( - self.app_name, - broker=(f"{broker_type}://" f"{account}:{password}@{address}//"), - backend=f"{backend}://", - ) - except Exception as e: - logger.warning("Instanciating celery app failed. Maybe rabbitmq is not running ?") - - for key, value in conf_data.items(): - setattr(self.app.conf, key, value) - - try: - self.connector = ONE(data_access_mode="remote") - except Exception as e: - logger.warning("Instanciating One connector during celery configuration failed.") - return + success: bool = False - self.success = True + def __init__(self, parent: "Pipeline", **kwargs): + self.parent = parent - def get_setting_files_path(self, conf_path, pipeline_name): - files = [] - files.append(os.path.join(conf_path, f"celery_{pipeline_name}.toml")) - files.append(os.path.join(conf_path, f".celery_{pipeline_name}_secrets.toml")) - return files + def __bool__(self): + return self.success def register_step(self, step): - self.app.task(self.wrap_step(step), name=step.full_name) + wrapped_step = getattr(step, "queue", None) + if wrapped_step is None: + # do not register + pass + # registration code here def wrap_step(self, step): - @wraps(step.generate) - def wrapper(task_id, extra=None): # session, *args, extra=None, **kwargs): - from one import ONE - - connector = ONE(mode="remote", data_access_mode="remote") - task = TaskRecord(connector.alyx.rest("tasks", "read", id=task_id)) - kwargs = task.arguments if task.arguments else {} - try: - session = connector.search(id=task.session, details=True) - - with LogTask(task) as log_object: - logger = log_object.logger - task.log = log_object.filename - task.status = "Started" - task = TaskRecord(connector.alyx.rest("tasks", "partial_update", **task.export())) - - try: - step.generate(session, extra=extra, skip=True, check_requirements=True, **kwargs) - task.status = CeleryHandler.status_from_logs(log_object) - except Exception as e: - traceback_msg = traceback.format_exc() - logger.critical(f"Fatal Error : {e}") - logger.critical("Traceback :\n" + traceback_msg) - task.status = "Failed" - - except Exception as e: - # if it fails outside of the nested try statement, we can't store logs files, - # and we mention the failure through alyx directly. - task.status = "Uncatched_Fail" - task.log = str(e) - - connector.alyx.rest("tasks", "partial_update", **task.export()) + @wraps(step.generate) + def wrapper(*args, **kwargs): + return step.generate(*args, **kwargs) return wrapper - - @staticmethod - def status_from_logs(log_object): - with open(log_object.fullpath, "r") as f: - content = f.read() - - if len(content) == 0: - return "No_Info" - if "CRITICAL" in content: - return "Failed" - elif "ERROR" in content: - return "Errors" - elif "WARNING" in content: - return "Warnings" - else: - return "Complete" - - -class TaskRecord(dict): - # a class to make dictionnary keys accessible with attribute syntax - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.__dict__ = self - - def export(self): - return {"id": self.id, "data": {k: v for k, v in self.items() if k not in ["id", "session_path"]}}