From 85d5112dc38ca0cd5603108b3ccea6198c1df5a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Laurent?= <francois.laurent@posteo.net> Date: Thu, 28 Dec 2023 16:19:21 +0100 Subject: [PATCH] preprocessing steps extracted from trainers module --- Project.toml | 1 + src/maggotuba/features/preprocess.py | 5 +++-- src/maggotuba/models/predict_model.py | 4 +++- src/maggotuba/models/train_model.py | 6 +++++- src/maggotuba/models/trainers.py | 3 +-- 5 files changed, 13 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 5f5beef..8ec9cff 100644 --- a/Project.toml +++ b/Project.toml @@ -2,3 +2,4 @@ JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" PlanarLarvae = "c2615984-ef14-4d40-b148-916c85b43307" +PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" diff --git a/src/maggotuba/features/preprocess.py b/src/maggotuba/features/preprocess.py index 0b98c4d..52ca2cf 100644 --- a/src/maggotuba/features/preprocess.py +++ b/src/maggotuba/features/preprocess.py @@ -1,4 +1,5 @@ import numpy as np +from taggingbackends.features.skeleton import interpolate class Preprocessor: @@ -91,8 +92,8 @@ class Preprocessor: ret = ret[:,:,::-1,:] return ret - def __callable__(self, *args): - return self.proprocess(*args) + def __call__(self, *args): + return self.preprocess(*args) # Julia functions diff --git a/src/maggotuba/models/predict_model.py b/src/maggotuba/models/predict_model.py index 086e4b8..c4cb3ac 100644 --- a/src/maggotuba/models/predict_model.py +++ b/src/maggotuba/models/predict_model.py @@ -4,7 +4,6 @@ from maggotuba.models.trainers import MaggotTrainer, MultiscaleMaggotTrainer, Ma import numpy as np import logging - def predict_model(backend, **kwargs): """ This function generates predicted labels for all the input data. @@ -18,6 +17,9 @@ def predict_model(backend, **kwargs): files, following the same directory structure as in `data/interim` or `data/raw`. """ + if kwargs.pop('debug', False): + logging.root.setLevel(logging.DEBUG) + # we pick files in `data/interim` if any, otherwise in `data/raw` input_files = backend.list_interim_files(group_by_directories=True) if not input_files: diff --git a/src/maggotuba/models/train_model.py b/src/maggotuba/models/train_model.py index c966538..ff3bbca 100644 --- a/src/maggotuba/models/train_model.py +++ b/src/maggotuba/models/train_model.py @@ -2,9 +2,13 @@ from taggingbackends.data.labels import Labels from taggingbackends.data.dataset import LarvaDataset from maggotuba.models.trainers import make_trainer, new_generator, enforce_reproducibility import glob +import logging def train_model(backend, layers=1, pretrained_model_instance="default", subsets=(1, 0, 0), rng_seed=None, iterations=1000, **kwargs): + if kwargs.pop('debug', False): + logging.root.setLevel(logging.DEBUG) + # list training data files; # we actually expect a single larva_dataset file that make_dataset generated # or moved into data/interim/{instance}/ @@ -48,7 +52,7 @@ def train_model(backend, layers=1, pretrained_model_instance="default", model.clf_config['post_filters'] = ['ABC->AAC'] # save the model - print(f"saving model \"{backend.model_instance}\"") + logging.debug(f"saving model \"{backend.model_instance}\"") model.save() diff --git a/src/maggotuba/models/trainers.py b/src/maggotuba/models/trainers.py index 40f733b..95459c3 100644 --- a/src/maggotuba/models/trainers.py +++ b/src/maggotuba/models/trainers.py @@ -4,7 +4,6 @@ import torch.nn as nn from behavior_model.models.neural_nets import device from maggotuba.models.modules import SupervisedMaggot, MultiscaleSupervisedMaggot, MaggotBag from maggotuba.features.preprocess import Preprocessor -from taggingbackends.features.skeleton import interpolate from taggingbackends.explorer import BackendExplorer, check_permissions import logging import json @@ -12,7 +11,7 @@ import re import os.path """ -This model borrows the pre-trained MaggotUBA encoder, substitute a dense layer +This model borrows the pre-trained MaggotUBA encoder, substitutes a dense layer for the decoder, and (re-)trains the entire model. Attribute `config` refers to MaggotUBA autoencoder. -- GitLab