From a6c8c424594b0ea03169212e1f205171997249ae Mon Sep 17 00:00:00 2001 From: Timothe Jost <timothe.jost@wanadoo.fr> Date: Thu, 4 Apr 2024 19:52:55 +0200 Subject: [PATCH] fixing transmission of programatically useable signatures (with typehints etc) --- src/pypelines/__init__.py | 2 +- src/pypelines/celery_tasks.py | 57 +++++++++++++++++++++++++++-------- 2 files changed, 46 insertions(+), 13 deletions(-) diff --git a/src/pypelines/__init__.py b/src/pypelines/__init__.py index 06f0afa..2f181b6 100644 --- a/src/pypelines/__init__.py +++ b/src/pypelines/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.0.47" +__version__ = "0.0.48" from . import loggs from .pipes import * diff --git a/src/pypelines/celery_tasks.py b/src/pypelines/celery_tasks.py index 8734f4f..4c2301b 100644 --- a/src/pypelines/celery_tasks.py +++ b/src/pypelines/celery_tasks.py @@ -336,17 +336,50 @@ def create_celery_app(conf_path, app_name="pypelines", v_host=None) -> "Celery | ] return str(signature.replace(parameters=params))[1:-1].replace(" *,", "") - def get_signature_as_dict(signature_string): - from re import compile as re_compile - from ast import literal_eval - - signature_pattern = re_compile(r" *(?P<key> *\w+) *= *(?P<value>.*?) *(?=(?:(?:, *\w+ *=)|(?:$)))") - patt = signature_pattern.findall(signature_string) - data = {} - for key, value in patt: - t_val = literal_eval(value) - data[key] = t_val - return data + def get_type_name(annotation): + from inspect import Parameter + from typing import get_args, get_origin + from types import UnionType + + if isinstance(annotation, str): + annotation = string_to_typehint(annotation, globals(), locals()) + + if isinstance(annotation, UnionType): + typ = get_args(annotation)[0] + elif hasattr(annotation, "__origin__"): # For types from 'typing' like List, Dict, etc. + typ = get_origin(annotation) + else: + typ = annotation + + if isinstance(typ, type): + if typ is Parameter.empty: + return "__unknown__" + else: + return typ.__name__ + return "__unknown__" + + def string_to_typehint(string_hint, globalns=None, localns=None): + from typing import ForwardRef, _eval_type + + try: + return _eval_type(ForwardRef(string_hint), globalns, localns) + except NameError: + return "__unknown__" + + def get_signature_as_dict(signature): + from inspect import Parameter + + parameters = signature.parameters + parsed_args = {} + for name, param in parameters.items(): + + parsed_args[name] = { + "typehint": get_type_name(param.annotation), + "default_value": param.default if param.default is not Parameter.empty else "__empty__", + "kind": param.kind.name, + } + + return parsed_args class Handshake(Task): name = f"{app_name}.handshake" @@ -373,7 +406,7 @@ def create_celery_app(conf_path, app_name="pypelines", v_host=None) -> "Celery | for step in pipe.steps.values(): if step.complete_name in app.tasks.keys(): str_sig = get_signature_as_string(step.generate.__signature__) - dict_sig = get_signature_as_dict(str_sig) + dict_sig = get_signature_as_dict(step.generate.__signature__) doc = step.generate.__doc__ task_data = { "signature": str_sig, -- GitLab