diff --git a/src/pypelines/__init__.py b/src/pypelines/__init__.py index 3dd76d2f2f90f157e98a91273ba5866490f477aa..8adb3b921eddbd734839ce675008bc0d40be3217 100644 --- a/src/pypelines/__init__.py +++ b/src/pypelines/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.0.53" +__version__ = "0.0.54" from . import loggs from .pipes import * diff --git a/src/pypelines/celery_tasks.py b/src/pypelines/celery_tasks.py index 8f7c50d0a421fd9088dee33a2e704d68a70b1ec3..3f7b41e038cdc385f2ff01b9aa3671e365308905 100644 --- a/src/pypelines/celery_tasks.py +++ b/src/pypelines/celery_tasks.py @@ -236,6 +236,25 @@ class CeleryTaskRecord(dict): task_dict["id"], task_infos_dict=task_dict, response_handle=response_handle, session=session ) + @staticmethod + def create_from_model( + app: "Celery", task_model: type, task_name: str, pipeline_name: str, session: object, extra=None, **kwargs + ): + + new_task = task_model( + name=task_name, session=session, arguments=kwargs, status="Waiting", executable=pipeline_name + ) + new_task.save() + + task_dict = new_task.__dict__.copy() + task_dict.pop("_state", None) + + response_handle = app.send_task(name=task_name, kwargs={"task_id": task_dict["id"], "extra": extra}) + + return CeleryTaskRecord( + task_dict["id"], task_infos_dict=task_dict, response_handle=response_handle, session=session + ) + class CeleryTaskBackend(BaseTaskBackend): app: "Celery" @@ -502,10 +521,17 @@ def create_celery_app(conf_path, app_name="pypelines", v_host=None) -> "Celery | logger.warning(f"Returned cached task_data. Next refresh will happen in at least {delta} seconds") return app_task_data["task_data"] if app_task_data is not None else None - def launch_named_task_remotely(self, session_id, task_name, extra=None, kwargs={}): - task_record = CeleryTaskRecord.create_from_task_name( - self, task_name, app_name, session_id, extra=extra, **kwargs - ) + def launch_named_task_remotely(self, session_id, task_name, task_model=None, extra=None, kwargs={}): + + if task_model is None: + task_record = CeleryTaskRecord.create_from_task_name( + self, task_name, app_name, session_id, extra=extra, **kwargs + ) + else: + task_record = CeleryTaskRecord.create_from_model( + self, task_model, task_name, app_name, session_id, extra=extra, **kwargs + ) + return task_record def is_hand_shaken(self):