diff --git a/src/one/__init__.py b/src/one/__init__.py
index fe8951b7b397daa70ec7738af4df8fc7e23fd661..62bffa1d3ad0ba5cae25570e3103f652ef73fba5 100644
--- a/src/one/__init__.py
+++ b/src/one/__init__.py
@@ -1,5 +1,5 @@
"""The Open Neurophysiology Environment (ONE) API"""
-__version__ = "2.1.13"
+__version__ = "2.1.14"
from . import api
from .api import ONE
diff --git a/src/one/api.py b/src/one/api.py
index 8899e3b641d5aea3d5634eddaeaa11988ebe813f..e77d0d9022b425f33070d9f7c5765a4b525f8e13 100644
--- a/src/one/api.py
+++ b/src/one/api.py
@@ -37,6 +37,8 @@ from .registration import RegistrationClient
from one.converters import ConversionMixin
import one.util as util
+from .files import FileTransferManager
+
# _logger = logging.getLogger(__name__)
"""int: The number of download threads"""
@@ -1007,6 +1009,10 @@ class One(ConversionMixin):
make_parquet_db(cache_dir, **kwargs)
return One(cache_dir, mode="local")
+ @property
+ def files_manager(self):
+ return FileTransferManager
+
@lru_cache(maxsize=1)
def ONE(*, mode="auto", data_access_mode="remote", wildcards=True, **kwargs):
@@ -1048,9 +1054,13 @@ def ONE(*, mode="auto", data_access_mode="remote", wildcards=True, **kwargs):
"""
_logger = logging.getLogger("ONE")
- if any(x in kwargs for x in ("base_url", "username", "password")) or not kwargs.get("cache_dir", False):
+ if (
+ any(x in kwargs for x in ("base_url", "username", "password"))
+ or not kwargs.get("cache_dir", False)
+ or mode == "local"
+ ):
return OneAlyx(mode=mode, data_access_mode=data_access_mode, wildcards=wildcards, **kwargs)
-
+ print("ENFOIRAX")
# If cache dir was provided and corresponds to one configured with an Alyx client, use OneAlyx
try:
one.params.check_cache_conflict(kwargs.get("cache_dir"))
@@ -2619,4 +2629,4 @@ class Session(pd.Series):
if "alias" not in series.index:
series["alias"] = series.pipeline.alias(separator=separator, zfill=zfill, date_format=date_format)
- return series
\ No newline at end of file
+ return series
diff --git a/src/one/files.py b/src/one/files.py
new file mode 100644
index 0000000000000000000000000000000000000000..e42dba09dc654ab750971f31f49b6bc542930e35
--- /dev/null
+++ b/src/one/files.py
@@ -0,0 +1,479 @@
+# built-ins imports
+import shutil
+from pathlib import Path
+from concurrent.futures import ThreadPoolExecutor, as_completed
+
+# custom imports
+from pint import Quantity
+from rich import print
+from rich.text import Text
+from rich.console import Console, Group
+from rich.progress import Progress, BarColumn, TextColumn
+from rich.panel import Panel, Padding
+from rich.live import Live
+from rich.prompt import Prompt
+
+from pandas import DataFrame, Series
+
+
+class FileTransferManager:
+
+ results: DataFrame
+ direction: str
+ sessions: DataFrame
+
+ def __init__(self, sessions, results=None, direction=None):
+
+ if isinstance(sessions, Series):
+ sessions = sessions.to_frame().transpose()
+
+ self.sessions = sessions
+ self.results = results # type: ignore
+ self.direction = direction # type: ignore
+
+ @staticmethod
+ def from_transfer(transfer):
+ return FileTransferManager(transfer.sessions, transfer.results, transfer.direction)
+
+ def _assert_file_checked(self, function_name):
+ if self.results is None or self.direction is None:
+ raise ValueError(
+ f"Cannot show {function_name} if no result has been "
+ "obtained through a fetch or a push_request command"
+ )
+
+ def resolve(self):
+
+ results = self.results.copy()
+ conflicts = results[results["decision"] == "conflict"]
+ accepted_decisions = ["transfer", "overwrite", "ignore", "conflict"]
+
+ def conflict_pannel(conflict=None):
+
+ if conflict is None:
+ return Panel("", title="❗ Handling Conflict :", border_style="dark_blue")
+
+ line = Text.assemble(
+ ("📄"),
+ (f"{conflict.source_filepath}\n", "steel_blue1 reverse"),
+ ("⏬\n"),
+ (f"{conflict.destination_filepath}\n", "steel_blue1 reverse"),
+ ("Informations : 🚩:", "dark_blue underline"),
+ (f"{conflict.warnings}\n\n", "dark_blue"),
+ (f"Type in your decision. (One of {accepted_decisions})"),
+ )
+
+ return Panel(line, title="❗ Handling Conflict :", border_style="dark_blue")
+
+ if len(conflicts):
+ console = Console()
+ with Live(conflict_pannel(), console=console, refresh_per_second=5) as live:
+ for index, conflited_file in results[results["decision"] == "conflict"].iterrows():
+ live.update(conflict_pannel(conflited_file))
+
+ decision = Prompt.ask(choices=accepted_decisions)
+ if decision not in accepted_decisions:
+ raise ValueError(f"Must be one of {accepted_decisions}, got {decision}")
+ else:
+ results.loc[index, "decision"] = decision
+
+ conflicts = results[results["decision"] == "conflict"]
+ if len(conflicts):
+ print(
+ Text(
+ "🚫 There seem to still exists conflicts to resolve. "
+ "Please run resolve again to fix the last ones 🚫",
+ style="bright_red bold",
+ )
+ )
+ else:
+ print(Text("All conflicts solved ! 🎉", style="spring_green3 bold"))
+
+ else:
+ print(Text("No conflicts to resolve. ✅", style="spring_green3 bold"))
+
+ return FileTransferManager(self.sessions, results=results, direction=self.direction)
+
+ def status(self, show_message=True, return_messages=False):
+ self._assert_file_checked("status")
+
+ action = self.direction
+
+ is_status_ok = True
+ messages = []
+
+ n_ignored_files = len(self.results[self.results["decision"] == "ignore"])
+ n_transfered_files = len(self.results[self.results["decision"] == "transfer"])
+ n_overritten_files = len(self.results[self.results["decision"] == "overwrite"])
+ n_conflicts = len(self.results[self.results["decision"] == "conflict"])
+
+ messages.append(Text("❗ Conflicting files:", style="dark_blue underline bold"))
+ if n_conflicts:
+ is_status_ok = False
+ messages.append(
+ Text(
+ f"\t🚫 {n_conflicts} conflicting files were found.\n"
+ "\tPlease sort out the conflits with the `.resolve()` command.\n",
+ style="bold bright_red",
+ )
+ )
+ else:
+ messages.append(Text("\tNo conflicts were found ! 🎉\n", style="bold spring_green3"))
+
+ metrics = self.transfer_metrics()
+ messages.append(Text("📈 Transfer metrics:", style="dark_blue underline bold"))
+
+ messages.append(Text(f"\t• {n_transfered_files} files will be newly transfered.", style="spring_green3"))
+ messages.append(
+ Text(
+ f"\t• {n_overritten_files} files will overwrite the equivalent file in destination.",
+ style="dark_orange",
+ )
+ )
+ messages.append(
+ Text(
+ f"\t• {n_ignored_files} files will be ignored " "(not transfered from source to destination).\n",
+ style="dark_orange",
+ )
+ )
+
+ transfer_destinations = metrics.groupby("destination")
+
+ messages.append(
+ Text(
+ f"💽 {action.capitalize()}ing to {len(transfer_destinations)} locations:",
+ style="dark_blue underline bold",
+ )
+ )
+
+ for destination, metric in transfer_destinations:
+
+ destination = str(destination)
+ sources = ", ".join(metric.source.astype(str)) # type: ignore
+ transfer_space = metric.transfer_space.sum()
+ free_space = metric.free_space.iloc[0]
+
+ enough_space = free_space > transfer_space
+ color = "spring_green3" if enough_space else "bright_red"
+ summary = "✅ Enough free space ✅" if enough_space else "❌ Not enough space ! ❌"
+
+ transfer_space_text = f"{transfer_space.to_compact():.2f~P} "
+ free_space_text = f"{free_space.to_compact():.2f~P}"
+
+ min_length = 11
+
+ sources = sources.rjust(min_length)
+ fill_size_source = "." * (len(sources) - (len(transfer_space_text)))
+ fill_description_source = "." * (len(sources) - len("Source "))
+
+ line = Text.assemble(
+ (f"• {summary}\n", color),
+ (" "),
+ ("Source", f"{color} underline"),
+ (" "),
+ (fill_description_source, "grey62"),
+ (" 🆚 ", color),
+ ("Destination\n", f"{color} underline"),
+ (f" {transfer_space_text}", f"{color} bold"),
+ (fill_size_source, "grey62"),
+ (" ⏩ ", color),
+ (f"{free_space_text}\n ", f"{color} bold"),
+ (sources, f"{color} reverse"),
+ (" ⏩ ", color),
+ (f"{destination}\n", f"{color} reverse"),
+ )
+
+ messages.append(Padding(Panel(line, border_style=color), (0, 0, 0, 6)))
+
+ if not enough_space:
+ is_status_ok = False
+
+ messages.append(Text("🌠 Conclusion:", style="dark_blue underline bold"))
+
+ if is_status_ok:
+ messages.append(Text(f"\t• ✅ Able to {action} the data.", style="spring_green3"))
+ messages.append(
+ Text(
+ f"\t• 📊 In total {metrics.transfer_space.sum().to_compact():.2f~P} of data will be " f"{action}ed",
+ style="spring_green3",
+ )
+ )
+ else:
+ messages.append(
+ Text(
+ f"\t• ❌ Cannot {action} the data. Please sort out the issues mentionned above", style="bright_red"
+ )
+ )
+
+ if show_message:
+ group = Group(*messages)
+ panel = Panel(
+ group,
+ title=f"📟 Pre-{action.capitalize()}ing Status Report",
+ width=100,
+ border_style="light_steel_blue3",
+ )
+ print(panel)
+
+ if return_messages:
+ return messages
+
+ return is_status_ok
+
+ def transfer_metrics(self):
+ self._assert_file_checked("status")
+
+ transfers_infos = []
+ for (source, destination), transfers in self.results.groupby(["source_volume", "destination_volume"]):
+ free_space = shutil.disk_usage(destination).free * Quantity("bytes")
+ transfer_space = transfers["source_filesize"].sum()
+ session_nb = len(transfers.session.unique())
+ files_nb = len(transfers)
+
+ transfers_infos.append(
+ dict(
+ source=source,
+ destination=destination,
+ free_space=free_space,
+ transfer_space=transfer_space,
+ session_nb=session_nb,
+ files_nb=files_nb,
+ )
+ )
+
+ return DataFrame(transfers_infos)
+
+ def _source_repositories(self):
+ return self._repositories("source")
+
+ def _destination_repositories(self):
+ return self._repositories("destination")
+
+ def _repositories(self, localisation="source"):
+ self._assert_file_checked("sources_directories")
+ if localisation not in ["source", "destination"]:
+ raise ValueError("_repositories localisation must be 'source' or 'destination'")
+
+ def replace_element(root_path, rel_path):
+ return str(root_path).replace(str(rel_path), " ")
+
+ if self.direction == "push":
+ if localisation == "source":
+ repo_key = "local_path"
+ else:
+ repo_key = "remote_path"
+ else:
+ if localisation == "source":
+ repo_key = "local_path"
+ else:
+ repo_key = "local_path"
+
+ root_paths = self.sessions.apply(lambda row: replace_element(row[repo_key], row["rel_path"]), axis=1)
+ return root_paths.unique().tolist()
+
+ def push(self):
+ self._assert_file_checked("push")
+ if self.direction != "push":
+ raise ValueError("Cannot push files without doing first a prepush check")
+ return self.transfer()
+
+ def pull(self):
+ self._assert_file_checked("pull")
+ if self.direction != "pull":
+ raise ValueError("Cannot pull files without doing first a fetch check")
+ return self.transfer()
+
+ def transfer(self):
+ status = self.status(show_message=False, return_messages=False)
+ if not status:
+ raise ValueError(f"Cannot {self.direction} the data. Please check issues with `.status()`")
+
+ transfers = self.results[(self.results["decision"] == "transfer") | (self.results["decision"] == "overwrite")]
+ transfer_results = self.copy_files_with_progress(
+ transfers.source_filepath, transfers.destination_filepath, transfers.decision
+ )
+ # TODO : write a message here if transfer errors, using the transfer_results.success column
+ print(Text(f"🎉 Finished {self.direction}ing successfully ! 🎉", style="spring_green3"))
+ return transfer_results
+
+ @staticmethod
+ def copy_files_with_progress(src_paths, dst_paths, decisions, max_workers=4):
+
+ def copy_file(src, dst, decision, progress, task_id):
+
+ src, dst = Path(src), Path(dst)
+
+ message = f"Failed to copy {src} to {dst}."
+ success = False
+
+ if not src.exists():
+ return src, dst, success, f"{message} Source doesn't exist"
+
+ if dst.exists() and decision != "overwrite":
+ return (
+ src,
+ dst,
+ success,
+ (
+ f"{message} "
+ "Decision was set to {decision} but the destination existed already. "
+ "Decision should have been set to overwrite"
+ ),
+ )
+
+ dst.parent.mkdir(exist_ok=True, parents=True)
+
+ try:
+ shutil.copy2(src, dst)
+ message = f"Copied {src} to {dst}"
+ success = True
+ except Exception as e:
+ message = f"{message}: Unexpected error : {e}"
+
+ progress.advance(task_id)
+ return src, dst, success, message
+
+ total_files = len(src_paths)
+
+ with Progress(
+ TextColumn("[progress.description]{task.description}"),
+ BarColumn(),
+ TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
+ TextColumn("({task.completed}/{task.total})"),
+ ) as progress:
+ task_id = progress.add_task("Copying files...", total=total_files)
+
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
+ copy_results = []
+ futures = [
+ executor.submit(copy_file, src, dst, dec, progress, task_id)
+ for src, dst, dec in zip(src_paths, dst_paths, decisions)
+ ]
+ for future in as_completed(futures):
+ src, dst, success, message = future.result()
+ copy_results.append(
+ dict(source_filepath=src, destination_filepath=dst, success=success, message=message)
+ )
+ return DataFrame(copy_results)
+
+ def fetch(self, policies=None, show_status=True):
+ results = self.check_files(source="remote_path", destination="local_path", policies=policies)
+ new_file_namager = FileTransferManager(self.sessions, results, direction="pull")
+ if show_status:
+ new_file_namager.status()
+ return new_file_namager
+
+ def push_request(self, policies=None, show_status=True):
+ results = self.check_files(source="local_path", destination="remote_path", policies=policies)
+ new_file_namager = FileTransferManager(self.sessions, results, direction="push")
+ if show_status:
+ new_file_namager.status()
+ return new_file_namager
+
+ def check_files(self, source="remote_path", destination="local_path", policies=None):
+
+ console = Console()
+
+ default_policies = {
+ "close_dates": "conflict",
+ "destination_older": "overwrite",
+ "no_file_exists": "transfer",
+ "destination_younger": "ignore",
+ "close_dates_threshold": 10,
+ }
+
+ if policies is not None:
+ default_policies.update(policies)
+
+ policies = default_policies
+
+ def get_volume(full_path, rel_path):
+ full_path, rel_path = str(Path(full_path)), str(Path(rel_path))
+ root_path = Path(full_path.replace(rel_path, ""))
+ return Path(root_path.drive)
+
+ results = []
+
+ with Progress(console=console) as progress:
+
+ task = progress.add_task(f"Checking files for {len(self.sessions)} sessions", total=len(self.sessions))
+
+ for _, session in self.sessions.iterrows():
+
+ source_path = Path(str(session[source]))
+ destination_path = Path(str(session[destination]))
+
+ source_volume = get_volume(source_path, session["rel_path"])
+ destination_volume = get_volume(destination_path, session["rel_path"])
+
+ for root, dirs, files in source_path.walk():
+
+ for file in files:
+
+ source_filepath = root / file
+ source_filesize = source_filepath.stat().st_size * Quantity("bytes")
+ relative_filepath = source_filepath.relative_to(source_path)
+ destination_filepath = destination_path / relative_filepath
+
+ destination_exists = destination_filepath.exists()
+
+ source_stat = source_filepath.stat()
+
+ source_creation_date = source_stat.st_birthtime
+ source_modification_date = source_stat.st_mtime
+
+ source_date = max(source_creation_date, source_modification_date)
+
+ if destination_exists:
+
+ destination_stat = destination_filepath.stat()
+
+ destination_creation_date = destination_stat.st_birthtime
+ destination_modification_date = destination_stat.st_mtime
+
+ destination_date = max(destination_creation_date, destination_modification_date)
+
+ if (difference := abs(source_date - destination_date)) < policies["close_dates_threshold"]:
+ decision = policies["close_dates"]
+ warnings = (
+ f"The destination file and the source file last modification dates "
+ f"are very close ({difference} sec difference). "
+ "Double checking the file contents before transfer might be safer."
+ )
+
+ elif destination_date > source_date:
+ decision = policies["destination_younger"]
+ warnings = (
+ "The destination file date is more recent than the source file one. "
+ "Double checking the file contents before transfering is mandatory !"
+ )
+
+ else:
+ decision = policies["destination_older"]
+ warnings = "The destination file is older than source. "
+ "Transfering will most likely be okay."
+
+ else:
+ decision = policies["no_file_exists"]
+ warnings = "File is not existing on the destination, transfering is absolutely okay."
+ destination_date = None
+
+ record = dict(
+ source_filepath=source_filepath,
+ destination_filepath=destination_filepath,
+ source_filesize=source_filesize,
+ relative_filepath=relative_filepath,
+ destination_exists=destination_exists,
+ source_date=source_date,
+ destination_date=destination_date,
+ source_volume=source_volume,
+ destination_volume=destination_volume,
+ decision=decision,
+ warnings=warnings,
+ session=session.u_alias,
+ )
+
+ results.append(record)
+ progress.advance(task)
+
+ return DataFrame(results)