From cbce7365db29609f1727404543f6caad64c7f23f Mon Sep 17 00:00:00 2001
From: Timothe Jost <timothe.jost@wanadoo.fr>
Date: Tue, 13 May 2025 10:29:38 +0200
Subject: [PATCH] typing

---
 src/pypelines/__init__.py                |  4 ++-
 src/pypelines/extend_pandas/__init__.py  | 46 ++++++++++++++++++++++++
 src/pypelines/extend_pandas/__init__.pyi |  4 +++
 src/pypelines/extend_pandas/py.typed     |  0
 src/pypelines/extend_pandas/typing.py    |  8 +++++
 src/pypelines/pipes.py                   | 14 ++++----
 6 files changed, 67 insertions(+), 9 deletions(-)
 create mode 100644 src/pypelines/extend_pandas/__init__.py
 create mode 100644 src/pypelines/extend_pandas/__init__.pyi
 create mode 100644 src/pypelines/extend_pandas/py.typed
 create mode 100644 src/pypelines/extend_pandas/typing.py

diff --git a/src/pypelines/__init__.py b/src/pypelines/__init__.py
index b9cdce9..d343c07 100644
--- a/src/pypelines/__init__.py
+++ b/src/pypelines/__init__.py
@@ -1,4 +1,4 @@
-__version__ = "0.0.80"
+__version__ = "0.0.81"
 
 from . import loggs
 from .pipes import *
@@ -7,6 +7,8 @@ from .steps import *
 from .disk import *
 from .sessions import *
 
+from .extend_pandas import extend_pandas
+
 # NOTE:
 # pypelines is enabling the logging system by default when importing it
 # (it comprises colored logging, session prefix-logging, and logging to a file located in downloads folder)
diff --git a/src/pypelines/extend_pandas/__init__.py b/src/pypelines/extend_pandas/__init__.py
new file mode 100644
index 0000000..2f64c71
--- /dev/null
+++ b/src/pypelines/extend_pandas/__init__.py
@@ -0,0 +1,46 @@
+import pandas as pd
+from ..pipelines import Pipeline
+
+from .typing import SessionPipelineAccessorProto
+
+# This is only for type checkers, has no runtime effect
+pd.DataFrame.pypeline: SessionPipelineAccessorProto
+
+
+def extend_pandas():
+    if not hasattr(pd.DataFrame, "_pypelines_accessor_registered"):
+
+        @pd.api.extensions.register_dataframe_accessor("pypeline")
+        class SessionPipelineAccessor:
+            def __init__(self, pandas_obj: pd.DataFrame):
+                self._obj = pandas_obj
+
+            def __call__(self, pipeline: Pipeline):
+                self.pipeline = pipeline
+                return self
+
+            def output_exists(self, pipe_step_name: str):
+                names = pipe_step_name.split(".")
+                if len(names) == 1:
+                    pipe_name = names[0]
+                    step_name = self.pipeline.pipes[pipe_name].ordered_steps("highest")[0].step_name
+                elif len(names) == 2:
+                    pipe_name = names[0]
+                    step_name = names[1]
+                else:
+                    raise ValueError("pipe_step_name should be either a pipe_name.step_name or pipe_name")
+                complete_name = f"{pipe_name}.{step_name}"
+                return self._obj.apply(
+                    lambda session: self.pipeline.pipes[pipe_name]
+                    .steps[step_name]
+                    .get_disk_object(session)
+                    .is_loadable(),
+                    axis=1,
+                ).rename(complete_name)
+
+            def add_ouput(self, pipe_step_name: str):
+                return self._obj.assign(**{pipe_step_name: self.output_exists(pipe_step_name)})
+
+            def where_output(self, pipe_step_name: str, exists: bool):
+                new_obj = SessionPipelineAccessor(self._obj)(self.pipeline).add_ouput(pipe_step_name)
+                return new_obj[new_obj[pipe_step_name] == exists]
diff --git a/src/pypelines/extend_pandas/__init__.pyi b/src/pypelines/extend_pandas/__init__.pyi
new file mode 100644
index 0000000..378bf19
--- /dev/null
+++ b/src/pypelines/extend_pandas/__init__.pyi
@@ -0,0 +1,4 @@
+import pandas as pd
+from .typing import SessionPipelineAccessorProto
+
+pd.DataFrame.pypeline: SessionPipelineAccessorProto  # type: ignore
diff --git a/src/pypelines/extend_pandas/py.typed b/src/pypelines/extend_pandas/py.typed
new file mode 100644
index 0000000..e69de29
diff --git a/src/pypelines/extend_pandas/typing.py b/src/pypelines/extend_pandas/typing.py
new file mode 100644
index 0000000..b96cf3f
--- /dev/null
+++ b/src/pypelines/extend_pandas/typing.py
@@ -0,0 +1,8 @@
+from typing import Protocol
+import pandas as pd
+from ..pipelines import Pipeline
+
+
+class SessionPipelineAccessorProto(Protocol):
+    def __call__(self, pipeline: Pipeline) -> "SessionPipelineAccessorProto": ...
+    def output_exists(self, pipe_step_name: str) -> pd.Series: ...
diff --git a/src/pypelines/pipes.py b/src/pypelines/pipes.py
index 6633b0f..9f922fc 100644
--- a/src/pypelines/pipes.py
+++ b/src/pypelines/pipes.py
@@ -245,6 +245,10 @@ class BasePipe(BasePipeType, metaclass=ABCMeta):
         # the dispatcher must be return a wrapped function
         return function
 
+    def ordered_steps(self, first: Literal["lowest", "highest"] = "lowest"):
+        reverse = False if first == "lowest" else True
+        return sorted(list(self.steps.values()), key=lambda item: item.get_level(selfish=True), reverse=reverse)
+
     def load(self, session, extra="", which: Literal["lowest", "highest"] = "highest"):
         """Load a step object for a session with optional extra data.
 
@@ -260,20 +264,14 @@ class BasePipe(BasePipeType, metaclass=ABCMeta):
         Raises:
             ValueError: If no matching step object is found for the session.
         """
-        if which == "lowest":
-            reverse = False
-        else:
-            reverse = True
 
-        ordered_steps = sorted(
-            list(self.steps.values()), key=lambda item: item.get_level(selfish=True), reverse=reverse
-        )
+        ordered_steps = self.ordered_steps(first=which)
 
         highest_step = None
 
         if isinstance(session, DataFrame):
             # if multisession, we assume we are trying to just load sessions
-            # that all have reached the same level of requirements. (otherwise, use generate)
+            # that all have reached the same level of requirements. (otherwise, use generate to make them match levels)
             # because of that, we use only the first session in the lot to search the highest loadable step
             search_on_session = session.iloc[0]
         else:
-- 
GitLab