From 466aa6903d7e12fe92261a197c42fa2201edb0c7 Mon Sep 17 00:00:00 2001
From: Timothe Jost <timothe.jost@wanadoo.fr>
Date: Fri, 31 May 2024 18:20:46 +0200
Subject: [PATCH] PypelineLoggerProtocol typehint on getLogger

---
 src/pypelines/__init__.py  |  2 +-
 src/pypelines/loggs.py     | 18 ++++++++++++++++++
 src/pypelines/pipelines.py |  7 +------
 3 files changed, 20 insertions(+), 7 deletions(-)

diff --git a/src/pypelines/__init__.py b/src/pypelines/__init__.py
index 5a803c9..3c6d080 100644
--- a/src/pypelines/__init__.py
+++ b/src/pypelines/__init__.py
@@ -1,4 +1,4 @@
-__version__ = "0.0.61"
+__version__ = "0.0.62"
 
 from . import loggs
 from .pipes import *
diff --git a/src/pypelines/loggs.py b/src/pypelines/loggs.py
index c7ad61e..5f36e65 100644
--- a/src/pypelines/loggs.py
+++ b/src/pypelines/loggs.py
@@ -13,10 +13,28 @@ from coloredlogs import (
 )
 from pathlib import Path
 
+from typing import Protocol, Callable, cast
+
 NAMELENGTH = 33  # global variable for formatting the length of the padding dedicated to name part in a logging record
 LEVELLENGTH = 9  # global variable for formatting the length of the padding dedicated to levelname part in a record
 
 
+class PypelineLoggerProtocol(Protocol):
+    def save(self, msg, *args, **kwargs) -> None: ...
+    def load(self, msg, *args, **kwargs) -> None: ...
+    def note(self, msg, *args, **kwargs) -> None: ...
+    def start(self, msg, *args, **kwargs) -> None: ...
+    def end(self, msg, *args, **kwargs) -> None: ...
+    def header(self, msg, *args, **kwargs) -> None: ...
+
+
+class PypelineLogger(logging.Logger, PypelineLoggerProtocol):
+    pass
+
+
+getLogger = cast(Callable[[str], PypelineLogger], logging.getLogger)
+
+
 def enable_logging(
     filename: str | None = None,
     terminal_level: str = "NOTE",
diff --git a/src/pypelines/pipelines.py b/src/pypelines/pipelines.py
index 06e4533..a140325 100644
--- a/src/pypelines/pipelines.py
+++ b/src/pypelines/pipelines.py
@@ -10,12 +10,7 @@ if TYPE_CHECKING:
     from .graphs import PipelineGraph
 
 
-class PipelineType(Protocol):
-
-    def __getattr__(self, name: str) -> "BasePipe": ...
-
-
-class Pipeline(PipelineType):
+class Pipeline:
     pipes: Dict[str, "BasePipe"]
     runner_backend_class = BaseTaskBackend
 
-- 
GitLab