diff --git a/src/pypelines/__init__.py b/src/pypelines/__init__.py index d343c0706aa6e83f05396bdbfb387583f29b71f2..409b92b40c0d36fd53e54b5c2fe90b1358807fec 100644 --- a/src/pypelines/__init__.py +++ b/src/pypelines/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.0.81" +__version__ = "0.0.82" from . import loggs from .pipes import * diff --git a/src/pypelines/disk.py b/src/pypelines/disk.py index 310483b919b54a4db0c5eafbbc1822bfb0554ca9..644df4a6fef1af52320caa860a262bb6eccfe7ea 100644 --- a/src/pypelines/disk.py +++ b/src/pypelines/disk.py @@ -2,7 +2,7 @@ import os, re from .sessions import Session import pickle -from typing import Callable, Type, Iterable, Literal, Protocol, TYPE_CHECKING +from typing import Callable, Type, Iterable, Literal, Protocol, TYPE_CHECKING, List, cast from abc import ABCMeta, abstractmethod from functools import wraps @@ -118,7 +118,7 @@ class BaseDiskObject(metaclass=ABCMeta): @staticmethod def multisession_unpacker(sessions, datas): - """Unpacks data from multiple sessions. + """Unpacks data from multiple sessions, to store them to "disk". Args: sessions (list): A list of session identifiers. @@ -131,8 +131,14 @@ class BaseDiskObject(metaclass=ABCMeta): def disk_step_instance(self) -> "BaseStep | None": """Returns an instance of the step that corresponds to the file on disk.""" + from .steps import BaseStep + if self.disk_step is not None: - return self.step.pipe.steps[self.disk_step] + if isinstance(self.disk_step, BaseStep): + return self.disk_step + elif isinstance(self.disk_step, str): + return self.step.pipe.steps[self.disk_step] + raise TypeError(f"Type must be BaseStep or str, found {type(self.disk_step)}") return None def is_matching(self): @@ -200,6 +206,11 @@ class BaseDiskObject(metaclass=ABCMeta): class NullDiskObject(BaseDiskObject): + """Class representing a Null Disk Object, which simulates a disk object with methods that always indicate + version deprecation and False as a check disk status, but does not perform actual disk operations to check that. + It will allways trigger a run, when calling generate on the step that use it. + """ + def version_deprecated(self) -> bool: """Indicates that the version of the function is deprecated. @@ -276,23 +287,24 @@ class CachedDiskObject(BaseDiskObject): Returns: dict: A dictionary containing the cached storage for the current step, session, and extra data. """ - if self.step.pipe not in self.storage: - self.storage[self.step.pipe] = {} - - if self.session.name not in self.storage[self.step.pipe].keys(): - self.storage[self.step.pipe][self.session.name] = {} - - if str(self.extra) not in self.storage[self.step.pipe][self.session.name].keys(): - stored_dict = self.save(None) - else: - stored_dict = self.storage[self.step.pipe][self.session.name][str(self.extra)] - - return stored_dict + step_dedicated_storage = self.storage.setdefault(self.step.complete_name, {}).setdefault(self.session.name, {}) + dedicated_key = f"extra#{self.extra}" + if dedicated_key not in step_dedicated_storage.keys(): + return self.wrap_up_data(None) + return step_dedicated_storage[dedicated_key] def load(self): """Load the content from the cached storage.""" return self.get_cached_storage()["content"] + def wrap_up_data(self, data): + stored_dict = { + "version": self.step.version, + "content": data, + "step": self.step.step_name, + } + return stored_dict + def save(self, data): """Save the data into the storage dictionary. @@ -302,12 +314,10 @@ class CachedDiskObject(BaseDiskObject): Returns: dict: A dictionary containing the version, content, and step name of the saved data. """ - stored_dict = { - "version": self.step.version, - "content": data, - "step": self.step.step_name, - } - self.storage[self.step.pipe][self.session.name][str(self.extra)] = stored_dict + step_dedicated_storage = self.storage.setdefault(self.step.complete_name, {}).setdefault(self.session.name, {}) + dedicated_key = f"extra#{self.extra}" + stored_dict = self.wrap_up_data(data) + step_dedicated_storage[dedicated_key] = stored_dict return stored_dict def check_disk(self): @@ -342,7 +352,7 @@ class CachedDiskObject(BaseDiskObject): # we compare levels with the currently called step # if disk step level < current called step level, we return True, else we return False. - if disk_step.get_level(selfish=True) < self.step.get_level(selfish=True): + if disk_step < self.step: return True return False @@ -350,3 +360,147 @@ class CachedDiskObject(BaseDiskObject): """Clears the cache by removing all items stored in the cache.""" for pipe in list(self.storage.keys()): self.storage.pop(pipe) + + +class FlaggedDiskObject(BaseDiskObject): + """A disk object that doesn't serve to actually load thinks, but as an indicator that they are avilable elsewhere, + by indicating with a flag, when the generate / save methods will have been executed. + If the flag file exists, the load method will be allowed to trigger (returning None) and keep the flow of + the pipeline running, without executing the runner for the flagged step and the ones below. + """ + + collection = ["preprocessing_saves"] + file_prefix: str + extension = "flag" + flaggable_steps: str | List[str] = "##highest" + supports_version = False + + def parse_extra(self): + return f".{self.extra}" if self.extra else "" + + def get_file_name(self, step: "BaseStep"): + file_prefix = self.file_prefix if hasattr(self, "file_prefix") else "runned" + version_str = f".{self.step.version}" if self.supports_version else "" + return ( + f"{file_prefix}.{self.step.pipe_name}." + f"{step.step_name}{self.parse_extra()}{version_str}" + f".{self.extension}" + ) + + def get_flag_path(self, step: "BaseStep"): + return os.path.join(self.session.path, os.path.sep.join(self.collection), self.get_file_name(step)) + + def get_flaggable_steps(self) -> "List[BaseStep]": + + def internal_getter(): + if isinstance(self.flaggable_steps, str): + if self.flaggable_steps.startswith("##"): + which = self.flaggable_steps.lstrip("#") + if which not in ["highest", "lowest"]: + raise ValueError("Must be 'highest' or 'lowest'") + return [self.step.pipe.ordered_steps(first=cast(Literal["highest", "lowest"], which))[0]] + return [self.step.pipe.steps[self.flaggable_steps]] + return sorted( + [self.step.pipe.steps[step_name] for step_name in self.flaggable_steps], + key=lambda step: step.get_level(selfish=True), + reverse=True, + ) + + if not hasattr(self, "_flaggable_steps"): + self._flaggable_steps = internal_getter() + return self._flaggable_steps + + def save(self, data): + if self.step_supports_flagging(): + flagpath = self.get_flag_path(self.step) + with open(flagpath, "w"): + return + + def check_disk(self): + for flagged_step in self.get_flaggable_steps(): + if flagged_step >= self.step: + if os.path.isfile(self.get_flag_path(flagged_step)): + self.disk_step = flagged_step + self.disk_version = flagged_step.version + return True + return False + + def step_supports_flagging(self): + return self.step in self.get_flaggable_steps() + + def load(self): + return f"FLAG FOR : {self.get_file_name(self.step)}" + + def step_level_too_low(self) -> bool: + """Check if the level of the disk step is lower than the current step. + + Returns: + bool: True if the level of the disk step is lower than the current step, False otherwise. + """ + # we get the step instance that corresponds to the one on the disk + disk_step = self.disk_step_instance() + + # we compare levels with the currently called step + # if disk step level < current called step level, we return True, else we return False. + if disk_step < self.step: + return True + return False + + def version_deprecated(self): + """Doesn't support versionning yet. Returning always False indicating non deprecation.""" + return False + + +class CachedFlaggedDiskObject(CachedDiskObject, FlaggedDiskObject): + """ + Behaves like a CachedDiskObject, but also supports flagging. + - If cache is available, loads from cache (priority). + - If not, but a flag is found for a step >= current, loads flag (skips running). + - If neither, triggers computation. + """ + + def check_disk(self): + self.cache_found = False + # 1. Check cache first (priority) + exists = CachedDiskObject.check_disk(self) + self.cache_found = exists + + # 2. If not in cache, check for flag + if not exists: + exists = FlaggedDiskObject.check_disk(self) + + return exists + + def load(self): + # If cache is available, load from cache + if self.cache_found: + return self.get_cached_storage()["content"] + + # If flag is found, return flag info (or None, or raise, as desired) + return FlaggedDiskObject.load(self) + + def save(self, data=None): + CachedDiskObject.save(self, data) + FlaggedDiskObject.save(self, data) + + def version_deprecated(self): + # If loaded from cache, check version + if self.cache_found: + return CachedDiskObject.version_deprecated(self) + # If loaded from flag, treat as not deprecated (or implement logic as needed) + return FlaggedDiskObject.version_deprecated(self) + + def step_level_too_low(self) -> bool: + # If loaded from cache, use cache logic + if self.cache_found: + return CachedDiskObject.step_level_too_low(self) + # If loaded from flag, use flag logic + return FlaggedDiskObject.step_level_too_low(self) + + def get_found_disk_object_description(self) -> str: + if self.cache_found: + return f"Cache for with step name {self.disk_step}" + # if flag is still found + if self.loadable: + return f"Flag found with step name {self.disk_step.step_name}" + return f"Cache nor Flag found for step {self.step.step_name}" diff --git a/src/pypelines/pipes.py b/src/pypelines/pipes.py index 9f922fc29101145d1fffacbbbae8ed67e5e9e5b5..1e396159ad74c9b4d6d7806c67b8dbe74817b040 100644 --- a/src/pypelines/pipes.py +++ b/src/pypelines/pipes.py @@ -287,3 +287,11 @@ class BasePipe(BasePipeType, metaclass=ABCMeta): return highest_step.load(session, extra) raise ValueError(f"Could not find a {self} object to load for the session {session.alias} with extra {extra}") + + def __eq__(self, other_pipe: "BasePipe"): + if hash(self) == hash(other_pipe): + return True + return False + + def __hash__(self): + return hash(f"{self.pipeline.pipeline_name}.{self.pipe_name}") diff --git a/src/pypelines/steps.py b/src/pypelines/steps.py index cb15087d63d1c2625180e1fe4a43861060bc608d..3899b1c8795945ec4d263b545c7f6fef06d6d2b1 100644 --- a/src/pypelines/steps.py +++ b/src/pypelines/steps.py @@ -775,6 +775,34 @@ class BaseStep: def logger(self) -> PypelineLogger: return getLogger(self.step_name[:NAMELENGTH]) + def __eq__(self, other_step: "BaseStep"): + if self.complete_name == other_step.complete_name: + return True + return False + + def __lt__(self, other_step: "BaseStep"): # Less than (<) + if self.pipe != other_step.pipe: + raise ArithmeticError("Cannot compare two steps of different pipes with <") + return self.get_level(selfish=True) < other_step.get_level(selfish=True) + + def __le__(self, other_step: "BaseStep"): # Less than or equal (<=) + if self.pipe != other_step.pipe: + raise ArithmeticError("Cannot compare two steps of different pipes with <=") + return self.get_level(selfish=True) <= other_step.get_level(selfish=True) + + def __gt__(self, other_step: "BaseStep"): # Greater than (>) + if self.pipe != other_step.pipe: + raise ArithmeticError("Cannot compare two steps of different pipes with >") + return self.get_level(selfish=True) > other_step.get_level(selfish=True) + + def __ge__(self, other_step: "BaseStep"): # Greater than or equal (>=) + if self.pipe != other_step.pipe: + raise ArithmeticError("Cannot compare two steps of different pipes with >=") + return self.get_level(selfish=True) >= other_step.get_level(selfish=True) + + def __hash__(self) -> int: + return hash(self.complete_name) + @dataclass class StepLevel: