Skip to content
Snippets Groups Projects
Commit 64c16796 authored by François  LAURENT's avatar François LAURENT
Browse files

preprocessing steps extracted from trainers module

parent 99c8a8a6
Branches
No related tags found
1 merge request!11Set of commits to be tagged v0.18
...@@ -2,3 +2,4 @@ ...@@ -2,3 +2,4 @@
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
PlanarLarvae = "c2615984-ef14-4d40-b148-916c85b43307" PlanarLarvae = "c2615984-ef14-4d40-b148-916c85b43307"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
import numpy as np import numpy as np
from taggingbackends.features.skeleton import interpolate
class Preprocessor: class Preprocessor:
...@@ -91,8 +92,8 @@ class Preprocessor: ...@@ -91,8 +92,8 @@ class Preprocessor:
ret = ret[:,:,::-1,:] ret = ret[:,:,::-1,:]
return ret return ret
def __callable__(self, *args): def __call__(self, *args):
return self.proprocess(*args) return self.preprocess(*args)
# Julia functions # Julia functions
......
...@@ -4,7 +4,6 @@ from maggotuba.models.trainers import MaggotTrainer, MultiscaleMaggotTrainer, Ma ...@@ -4,7 +4,6 @@ from maggotuba.models.trainers import MaggotTrainer, MultiscaleMaggotTrainer, Ma
import numpy as np import numpy as np
import logging import logging
def predict_model(backend, **kwargs): def predict_model(backend, **kwargs):
""" """
This function generates predicted labels for all the input data. This function generates predicted labels for all the input data.
...@@ -18,6 +17,9 @@ def predict_model(backend, **kwargs): ...@@ -18,6 +17,9 @@ def predict_model(backend, **kwargs):
files, following the same directory structure as in `data/interim` or files, following the same directory structure as in `data/interim` or
`data/raw`. `data/raw`.
""" """
if kwargs.pop('debug', False):
logging.root.setLevel(logging.DEBUG)
# we pick files in `data/interim` if any, otherwise in `data/raw` # we pick files in `data/interim` if any, otherwise in `data/raw`
input_files = backend.list_interim_files(group_by_directories=True) input_files = backend.list_interim_files(group_by_directories=True)
if not input_files: if not input_files:
......
...@@ -2,9 +2,13 @@ from taggingbackends.data.labels import Labels ...@@ -2,9 +2,13 @@ from taggingbackends.data.labels import Labels
from taggingbackends.data.dataset import LarvaDataset from taggingbackends.data.dataset import LarvaDataset
from maggotuba.models.trainers import make_trainer, new_generator, enforce_reproducibility from maggotuba.models.trainers import make_trainer, new_generator, enforce_reproducibility
import glob import glob
import logging
def train_model(backend, layers=1, pretrained_model_instance="default", def train_model(backend, layers=1, pretrained_model_instance="default",
subsets=(1, 0, 0), rng_seed=None, iterations=1000, **kwargs): subsets=(1, 0, 0), rng_seed=None, iterations=1000, **kwargs):
if kwargs.pop('debug', False):
logging.root.setLevel(logging.DEBUG)
# list training data files; # list training data files;
# we actually expect a single larva_dataset file that make_dataset generated # we actually expect a single larva_dataset file that make_dataset generated
# or moved into data/interim/{instance}/ # or moved into data/interim/{instance}/
...@@ -48,7 +52,7 @@ def train_model(backend, layers=1, pretrained_model_instance="default", ...@@ -48,7 +52,7 @@ def train_model(backend, layers=1, pretrained_model_instance="default",
model.clf_config['post_filters'] = ['ABC->AAC'] model.clf_config['post_filters'] = ['ABC->AAC']
# save the model # save the model
print(f"saving model \"{backend.model_instance}\"") logging.debug(f"saving model \"{backend.model_instance}\"")
model.save() model.save()
......
...@@ -4,7 +4,6 @@ import torch.nn as nn ...@@ -4,7 +4,6 @@ import torch.nn as nn
from behavior_model.models.neural_nets import device from behavior_model.models.neural_nets import device
from maggotuba.models.modules import SupervisedMaggot, MultiscaleSupervisedMaggot, MaggotBag from maggotuba.models.modules import SupervisedMaggot, MultiscaleSupervisedMaggot, MaggotBag
from maggotuba.features.preprocess import Preprocessor from maggotuba.features.preprocess import Preprocessor
from taggingbackends.features.skeleton import interpolate
from taggingbackends.explorer import BackendExplorer, check_permissions from taggingbackends.explorer import BackendExplorer, check_permissions
import logging import logging
import json import json
...@@ -12,7 +11,7 @@ import re ...@@ -12,7 +11,7 @@ import re
import os.path import os.path
""" """
This model borrows the pre-trained MaggotUBA encoder, substitute a dense layer This model borrows the pre-trained MaggotUBA encoder, substitutes a dense layer
for the decoder, and (re-)trains the entire model. for the decoder, and (re-)trains the entire model.
Attribute `config` refers to MaggotUBA autoencoder. Attribute `config` refers to MaggotUBA autoencoder.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment