diff --git a/src/pypelines/__init__.py b/src/pypelines/__init__.py
index c488c86fa35291fea2c19cf2eaafec35cc08c773..bec4fca0bd04549887208bd10b4502e93fe84525 100644
--- a/src/pypelines/__init__.py
+++ b/src/pypelines/__init__.py
@@ -1,4 +1,4 @@
-__version__ = "0.0.71"
+__version__ = "0.0.72"
from . import loggs
from .pipes import *
diff --git a/src/pypelines/pipes.py b/src/pypelines/pipes.py
index 74ed279c5bf4fa9ac5ae73f2e6ebacc2d0de4ed6..86f8731a21cae0c4f913dee21604dc53ca0191b4 100644
--- a/src/pypelines/pipes.py
+++ b/src/pypelines/pipes.py
@@ -75,20 +75,20 @@ class BasePipe(BasePipeType, metaclass=ABCMeta):
requires_is_step_attr = True
for class_name, class_object in inspect.getmembers(self, predicate=inspect.isclass):
- print(class_name, class_object)
if class_name == "Steps":
- print("FOUND")
steps_members_scanner = inspect.getmembers(class_object(), predicate=inspect.ismethod)
requires_is_step_attr = False
break
# this loop populates self.steps dictionnary from the instanciated (bound) step methods.
for step_name, step in steps_members_scanner:
- print("step:", step_name)
if not requires_is_step_attr or getattr(step, "is_step", False):
step_name = to_snake_case(step_name)
_steps[step_name] = step
+ for step_name, step in inspect.getmembers(self, predicate=inspect.isclass):
+ if
+
if len(_steps) < 1:
raise ValueError(
f"You should register at least one step class with @stepmethod in {self.pipe_name} class. {_steps=}"
diff --git a/src/pypelines/steps.py b/src/pypelines/steps.py
index 5d6f0b69700033473cd8ae416a7cd40842e6e15c..1093e8f40b5d55b8f19ad98f142c5cd5db64ecf9 100644
--- a/src/pypelines/steps.py
+++ b/src/pypelines/steps.py
@@ -8,7 +8,7 @@ from pandas import DataFrame
from dataclasses import dataclass
from types import MethodType
-from typing import Callable, Type, Iterable, Protocol, List, TYPE_CHECKING, Any
+from typing import Callable, Type, Iterable, Protocol, List, TYPE_CHECKING, Any, Optional
if TYPE_CHECKING:
from .pipelines import Pipeline
@@ -78,7 +78,9 @@ class BaseStep:
pipe: "BasePipe"
pipeline: "Pipeline"
- def __init__(self, pipeline: "Pipeline", pipe: "BasePipe", worker: MethodType, step_name: str = ""):
+ def __init__(
+ self, pipeline: "Pipeline", pipe: "BasePipe", worker: Optional[MethodType] = None, step_name: str = ""
+ ):
"""Initialize a BaseStep object.
Args:
@@ -102,22 +104,37 @@ class BaseStep:
self.pipeline = pipeline
# save an instanciated access to the pipe parent
self.pipe = pipe
+
+ self.step_name = to_snake_case(self.get_attribute_or_default("step_name", step_name))
+
+ if not self.step_name:
+ raise ValueError(f"Step name in {self.pipe.pipe_name} cannot be blank nor None")
+
# save an instanciated access to the step function (undecorated)
- self.worker = worker
+ if not hasattr(self, "worker"):
+ if worker is None:
+ raise AttributeError(
+ f"For the step : {self.pipe.pipe_name}.{self.step_name}, a worker method must "
+ "be defined if created from a class"
+ )
+ needs_attachment = True
+ self.worker = worker
+ else:
+ needs_attachment = False
# we attach the values of the worker elements to BaseStep
# as they are get only (no setter) on worker (bound method)
- self.do_dispatch = getattr(self.worker, "do_dispatch", False)
- self.version = getattr(self.worker, "version", 0)
- self.requires = getattr(self.worker, "requires", [])
- self.step_name = to_snake_case(getattr(self.worker, "step_name", step_name))
- if not self.step_name:
- raise ValueError("Step name cannot be blank nor None")
+ getattr(self, "do_dispatch", getattr(self.worker, "do_dispatch", False))
- self.callbacks = getattr(self.worker, "callbacks", [])
+ self.do_dispatch = self.get_attribute_or_default("do_dispatch", False)
+ self.version = self.get_attribute_or_default("version", 0)
+ self.requires = self.get_attribute_or_default("requires", [])
- self.worker = MethodType(worker.__func__, self)
+ self.callbacks = self.get_attribute_or_default("callbacks", [])
+
+ if needs_attachment:
+ self.worker = MethodType(worker.__func__, self)
# self.make_wrapped_functions()
@@ -128,6 +145,10 @@ class BaseStep:
self.task = self.pipeline.runner_backend.create_task_manager(self)
+ def get_attribute_or_default(self, attribute_name: str, default: Any) -> Any:
+ # TODO : fix here , when calling get_attribute_or_default before worker s set, cannot work
+ return getattr(self, attribute_name, getattr(self.worker, attribute_name, default))
+
@property
def requirement_stack(self) -> Callable:
"""Return a partial function that calls the get_requirement_stack method of the pipeline