diff --git a/src/pypelines/__init__.py b/src/pypelines/__init__.py index 3cfaf2f6b4ff80f0061f57adb85c7fc4cdd26883..35f12e94d144997df7feda4a2a5296684806d7bf 100644 --- a/src/pypelines/__init__.py +++ b/src/pypelines/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.0.30" +__version__ = "0.0.32" from . import loggs from .pipes import * diff --git a/src/pypelines/celery_tasks.py b/src/pypelines/celery_tasks.py index 6e034cf2cfd72b99435ffca526bce0b801f22a08..2459ef3fa899d07b12f7c8af07beeba0d0b12e0f 100644 --- a/src/pypelines/celery_tasks.py +++ b/src/pypelines/celery_tasks.py @@ -1,10 +1,14 @@ from .tasks import BaseTaskBackend, BaseStepTaskManager from .pipelines import Pipeline +from .loggs import FileFormatter from pathlib import Path from traceback import format_exc as format_traceback_exc +import logging +import coloredlogs from logging import getLogger from functools import wraps -from .loggs import LogTask +from platform import node +from pandas import Series from typing import TYPE_CHECKING, List @@ -13,74 +17,75 @@ if TYPE_CHECKING: from .steps import BaseStep -class CeleryAlyxTaskManager(BaseStepTaskManager): +APPLICATIONS_STORE = {} - backend: "CeleryTaskBackend" - step: "BaseStep" - def register_step(self): - if self.backend: - self.backend.app.task(self.runner, name=self.step.complete_name) +def CeleryRunner(task_id, extra=None): - def runner(self, task_id, extra=None): + task = CeleryTaskRecord(task_id) - from one import ONE + try: + session = task.get_session() + application = task.get_application() + + with LogTask(task) as log_object: + logger = log_object.logger + task["log"] = log_object.filename + task["status"] = "Started" + task.partial_update() + + try: + step: "BaseStep" = application.pipelines[task.pipeline_name].pipes[task.pipe_name].steps[task.step_name] + step.generate(session, extra=extra, skip=True, check_requirements=True, **task.arguments) + task.status_from_logs(log_object) + except Exception as e: + traceback_msg = format_traceback_exc() + logger.critical(f"Fatal Error : {e}") + logger.critical("Traceback :\n" + traceback_msg) + task["status"] = "Failed" - connector = ONE(mode="remote", data_access_mode="remote") - task = CeleryTaskRecord(task_id) + except Exception as e: + # if it fails outside of the nested try statement, we can't store logs files, + # and we mention the failure through alyx directly. + task["status"] = "Uncatched_Fail" + task["log"] = str(e) - try: - session = connector.search( - id=task["session"], details=True, no_cache=True) - - with LogTask(task) as log_object: - logger = log_object.logger - task["log"] = log_object.filename - task["status"] = "Started" - task.partial_update() - - try: - self.step.generate( - session, extra=extra, skip=True, check_requirements=True, **task.arguments) - task.status_from_logs(log_object) - except Exception as e: - traceback_msg = format_traceback_exc() - logger.critical(f"Fatal Error : {e}") - logger.critical("Traceback :\n" + traceback_msg) - task["status"] = "Failed" + task.partial_update() - except Exception as e: - # if it fails outside of the nested try statement, we can't store logs files, - # and we mention the failure through alyx directly. - task["status"] = "Uncatched_Fail" - task["log"] = str(e) - task.partial_update() +class CeleryAlyxTaskManager(BaseStepTaskManager): + + backend: "CeleryTaskBackend" + step: "BaseStep" + + def register_step(self): + if self.backend: + self.backend.app.task(CeleryRunner, name=self.step.complete_name) def start(self, session, extra=None, **kwargs): if not self.backend: raise NotImplementedError( - "Cannot start a task on a celery cluster as this pipeline " - "doesn't have a working celery backend" + "Cannot start a task on a celery cluster as this pipeline " "doesn't have a working celery backend" ) return CeleryTaskRecord.create(self, session, extra, **kwargs) class CeleryTaskRecord(dict): + session: Series + # a class to make dictionnary keys accessible with attribute syntax - def __init__(self, task_id, task_infos_dict={}, response_handle=None): - if task_infos_dict: - super().__init__(task_infos_dict) - else: + def __init__(self, task_id, task_infos_dict={}, response_handle=None, session=None): + + if not task_infos_dict: from one import ONE connector = ONE(mode="remote", data_access_mode="remote") task_infos_dict = connector.alyx.rest("tasks", "read", id=task_id) - super().__init__(task_infos_dict) - self.session = + super().__init__(task_infos_dict) + self.session = session # type: ignore self.response = response_handle def status_from_logs(self, log_object): @@ -102,19 +107,50 @@ class CeleryTaskRecord(dict): def partial_update(self): from one import ONE + connector = ONE(mode="remote", data_access_mode="remote") connector.alyx.rest("tasks", "partial_update", **self.export()) + def get_session(self): + if self.session is None: + from one import ONE + + connector = ONE(mode="remote", data_access_mode="remote") + session = connector.search(id=self["session"], no_cache=True, details=True) + self.session = session # type: ignore + + return self.session + + def get_application(self): + try: + return APPLICATIONS_STORE[self["executable"]] + except KeyError: + raise KeyError(f"Unable to retrieve the application {self['executable']}") + + @property + def pipeline_name(self): + return self["name"].split(".")[0] + + @property + def pipe_name(self): + return self["name"].split(".")[1] + + @property + def step_name(self): + return self["name"].split(".")[2] + @property def arguments(self): args = self.get("arguments", {}) return args if args else {} @property - def session_path(self): - from one import ONE - connector = ONE(mode="remote", data_access_mode="remote") - connector.alyx.rest(sess) + def session_path(self) -> str: + return self.session["path"] + + @property + def task_id(self): + return self["id"] def export(self): return {"id": self["id"], "data": {k: v for k, v in self.items() if k not in ["id", "session_path"]}} @@ -130,7 +166,7 @@ class CeleryTaskRecord(dict): "name": task_manager.step.complete_name, "arguments": kwargs, "status": "Waiting", - "executable": task_manager.backend.app_name, + "executable": str(task_manager.backend.app.main), } task_dict = connector.alyx.rest("tasks", "create", data=data) @@ -138,7 +174,9 @@ class CeleryTaskRecord(dict): worker = task_manager.backend.app.tasks[task_manager.step.complete_name] response_handle = worker.delay(task_dict["id"], extra=extra) - return TaskRecord(task_dict["id"], task_dict, response_handle) + return CeleryTaskRecord( + task_dict["id"], task_infos_dict=task_dict, response_handle=response_handle, session=session + ) class CeleryTaskBackend(BaseTaskBackend): @@ -153,6 +191,10 @@ class CeleryTaskBackend(BaseTaskBackend): self.success = True self.app = app + pipelines = getattr(self.app, "pipelines", {}) + pipelines[parent.pipeline_name] = parent + self.app.pipelines = pipelines + def start(self): self.app.start() @@ -179,16 +221,16 @@ def get_setting_files_path(conf_path, app_name) -> List[Path]: class LogTask: - def __init__(self, task_record, username="", level="LOAD"): - self.path = os.path.normpath(task_record.session_path) - self.username = username - os.makedirs(self.path, exist_ok=True) - self.worker_pk = task_record.id - self.task_name = task_record.name + def __init__(self, task_record: CeleryTaskRecord, username=None, level="LOAD"): + self.path = Path(task_record.session_path) / "logs" + self.username = username if username is not None else (node() if node() else "unknown") + self.worker_pk = task_record.task_id + self.task_name = task_record["name"] self.level = getattr(logging, level.upper()) def __enter__(self): - self.logger = logging.getLogger() + self.path.mkdir(exist_ok=True) + self.logger = getLogger() self.set_handler() return self @@ -196,9 +238,8 @@ class LogTask: self.remove_handler() def set_handler(self): - self.filename = os.path.join( - "logs", f"task_log.{self.task_name}.{self.worker_pk}.log") - self.fullpath = os.path.join(self.path, self.filename) + self.filename = f"task_log.{self.task_name}.{self.worker_pk}.log" + self.fullpath = self.path / self.filename fh = logging.FileHandler(self.fullpath) f_formater = FileFormatter() coloredlogs.HostNameFilter.install( @@ -221,14 +262,14 @@ class LogTask: ) fh.setLevel(self.level) - fh.setFormatter() + fh.setFormatter(f_formater) self.logger.addHandler(fh) def remove_handler(self): self.logger.removeHandler(self.logger.handlers[-1]) -def create_celery_app(conf_path, app_name="pypelines"): +def create_celery_app(conf_path, app_name="pypelines") -> "Celery": failure_message = ( f"Celery app : {app_name} failed to be created." @@ -237,25 +278,27 @@ def create_celery_app(conf_path, app_name="pypelines"): ) logger = getLogger("pypelines.create_celery_app") + + if app_name in APPLICATIONS_STORE.keys(): + logger.warning(f"Tried to create a celery app named {app_name}, but it already exists. Returning it instead.") + return APPLICATIONS_STORE[app_name] + settings_files = get_setting_files_path(conf_path, app_name) if len(settings_files) == 0: - logger.warning( - f"{failure_message} Could not find celery toml config files.") + logger.warning(f"{failure_message} Could not find celery toml config files.") return None try: from dynaconf import Dynaconf except ImportError: - logger.warning( - f"{failure_message} Could not import dynaconf. Maybe it is not istalled ?") + logger.warning(f"{failure_message} Could not import dynaconf. Maybe it is not istalled ?") return None try: settings = Dynaconf(settings_files=settings_files) except Exception as e: - logger.warning( - f"{failure_message} Could not create dynaconf object. {e}") + logger.warning(f"{failure_message} Could not create dynaconf object. {e}") return None try: @@ -273,27 +316,38 @@ def create_celery_app(conf_path, app_name="pypelines"): try: from celery import Celery except ImportError: - logger.warning( - f"{failure_message} Could not import celery. Maybe is is not installed ?") + logger.warning(f"{failure_message} Could not import celery. Maybe is is not installed ?") return None try: app = Celery( app_display_name, - broker=(f"{broker_type}://" f"{account}:{password}@{address}//"), + broker=f"{broker_type}://{account}:{password}@{address}/{app_name}", backend=f"{backend}://", ) except Exception as e: - logger.warning( - f"{failure_message} Could not create app. Maybe rabbitmq server @{address} is not running ? {e}") + logger.warning(f"{failure_message} Could not create app. Maybe rabbitmq server @{address} is not running ? {e}") return None for key, value in conf_data.items(): try: setattr(app.conf, key, value) except Exception as e: - logger.warning( - f"{failure_message} Could assign extra attribute {key} to celery app. {e}") + logger.warning(f"{failure_message} Could assign extra attribute {key} to celery app. {e}") return None + APPLICATIONS_STORE[app_name] = app + return app + + +def create_worker_for_app(app_name): + from celery.bin.worker import worker as celery_worker + + def start_worker(app): + worker = celery_worker(app=app) + options = { + "loglevel": "INFO", + "traceback": True, + } + worker.run(**options) diff --git a/src/pypelines/loggs.py b/src/pypelines/loggs.py index bfe3fadb937cd7f23ca643fce42ba3f79a5dabcc..b93c187f97bfd6858353af7554e7ab573b214ba1 100644 --- a/src/pypelines/loggs.py +++ b/src/pypelines/loggs.py @@ -129,12 +129,11 @@ class DynamicColoredFormatter(coloredlogs.ColoredFormatter): """ pattern = r"%\((?P<part_name>\w+)\)-?(?P<length>\d+)?[sd]?" result = re.findall(pattern, fmt) - padding_dict = { - name: int(padding) if padding else 0 for name, padding in result} + padding_dict = {name: int(padding) if padding else 0 for name, padding in result} return padding_dict - def format(self, record): + def format(self, record: logging.LogRecord): """_summary_ Args: @@ -159,13 +158,11 @@ class DynamicColoredFormatter(coloredlogs.ColoredFormatter): missing_length = 0 if missing_length < 0 else missing_length if part_name in self.dynamic_levels.keys(): dyn_keys = self.dynamic_levels[part_name] - dynamic_style = {k: v for k, v in style.items( - ) if k in dyn_keys or dyn_keys == "all"} - part = coloredlogs.ansi_wrap( - coloredlogs.coerce_string(part), **dynamic_style) + dynamic_style = {k: v for k, v in style.items() if k in dyn_keys or dyn_keys == "all"} + part = coloredlogs.ansi_wrap(coloredlogs.coerce_string(part), **dynamic_style) part = part + (" " * missing_length) setattr(copy, part_name, part) - record = copy + record = copy # type: ignore s = self.formatMessage(record) if record.exc_info: @@ -304,8 +301,7 @@ class LogContext: for handler in self.root_logger.handlers: for filter in handler.filters: if getattr(filter, "context_msg", "") == self.context_msg: - self.root_logger.debug( - f"Filter already added to handler {handler}") + self.root_logger.debug(f"Filter already added to handler {handler}") found = True break @@ -317,8 +313,7 @@ class LogContext: context_filter = ContextFilter(self.context_msg) handler.addFilter(context_filter) self.context_filters[handler] = context_filter - self.root_logger.debug( - f"Added filter {context_filter} to handler {handler}") + self.root_logger.debug(f"Added filter {context_filter} to handler {handler}") break def __exit__(self, exc_type, exc_val, exc_tb): @@ -334,8 +329,7 @@ class LogContext: if filer_to_remove is None: continue else: - self.root_logger.debug( - f"Removing filter {filer_to_remove} from handler {handler} in this context") + self.root_logger.debug(f"Removing filter {filer_to_remove} from handler {handler} in this context") handler.removeFilter(filer_to_remove) @@ -412,14 +406,11 @@ def addLoggingLevel(levelName, levelNum, methodName=None, if_exists="raise"): if if_exists == "keep": return if hasattr(logging, levelName): - raise AttributeError( - "{} already defined in logging module".format(levelName)) + raise AttributeError("{} already defined in logging module".format(levelName)) if hasattr(logging, methodName): - raise AttributeError( - "{} already defined in logging module".format(methodName)) + raise AttributeError("{} already defined in logging module".format(methodName)) if hasattr(logging.getLoggerClass(), methodName): - raise AttributeError( - "{} already defined in logger class".format(methodName)) + raise AttributeError("{} already defined in logger class".format(methodName)) # This method was inspired by the answers to Stack Overflow post # http://stackoverflow.com/q/2183233/2988730, especially diff --git a/src/pypelines/tasks.py b/src/pypelines/tasks.py index 90c5d3e9a84f54d87f60f520f2d9cc72fcef9a1f..c9fe587ebe122ea1273f83139ad53c6239da7a57 100644 --- a/src/pypelines/tasks.py +++ b/src/pypelines/tasks.py @@ -32,17 +32,3 @@ class BaseTaskBackend: def create_task_manager(self, step) -> "BaseStepTaskManager": return self.task_manager_class(step, self) - - # wrapped_step = getattr(step, "queue", None) - # if wrapped_step is None: - # # do not register - # pass - # # registration code here - - # def wrap_step(self, step): - - # @wraps(step.generate) - # def wrapper(*args, **kwargs): - # return step.generate(*args, **kwargs) - - # return wrapper