Select Git revision
larvatagger.jl#58 at this project level
François LAURENT authored
modules.py 11.03 KiB
import os
from pathlib import Path
import torch
from torch import nn
import json
import functools
from behavior_model.models.neural_nets import Encoder
from taggingbackends.explorer import check_permissions
class MaggotModule(nn.Module):
def __init__(self, path, cfgfile=None, ptfile=None):
super().__init__()
self.path = path if isinstance(path, Path) else Path(path)
if cfgfile is None:
cfgfile = self.path.name
self.path = self.path.parent
elif not self.path.is_dir():
raise ValueError("\"str(self.path)\" is not a directory")
self.cfgfile = cfgfile
if ptfile is not None and Path(ptfile).parent == path:
ptfile = Path(ptfile).name
self.ptfile = ptfile
self._config = None
self._model = None
@classmethod
def load_config(cls, path):
with open(path, "r") as f:
return json.load(f)
@property
def cfgfilepath(self):
return self.path / self.cfgfile
@property
def ptfilepath(self):
return self.path / self.ptfile
@property
def config(self):
if self._config is None:
self._config = self.load_config(self.cfgfilepath)
return self._config
@config.setter
def config(self, cfg):
self._config = cfg
@classmethod
def load_model(cls, config, path):
raise NotImplementedError
@property
def model(self):
if self._model is None:
self._model = self.load_model(self.config, self.ptfilepath)
return self._model
@model.setter
def model(self, model):
self._model = model
def forward(self, x):
return self.model(x)
def save_config(self, cfgfile=None):
if cfgfile is None: cfgfile = self.cfgfile
path = self.path / cfgfile
with open(path, "w") as f:
json.dump(self.config, f, indent=2)
check_permissions(path)
return path
def save_model(self, ptfile=None):
if ptfile is None: ptfile = self.ptfile
path = self.path / ptfile
torch.save(self.model.state_dict(), path)
check_permissions(path)
return path
def save(self):
self.save_model()
self.save_config()
def parameters(self, recurse=True):
return self.model.parameters(recurse)
class MaggotEncoder(MaggotModule):
def __init__(self, path,
cfgfile=None,
#cfgfile="autoencoder_config.json",
ptfile="retrained_encoder.pt"):
super().__init__(path, cfgfile, ptfile)
@classmethod
def load_config(self, path):
config = super().load_config(path)
config["config"] = str(path)
return config
@classmethod
def load_model(cls, config, path):
encoder = Encoder(**config)
encoder.load_state_dict(torch.load(path))
return encoder
@functools.lru_cache(maxsize=1)
def mask(self, size):
selfsize = self.config["len_traj"]
assert selfsize <= size
if size == selfsize:
return
mask = torch.zeros(size, dtype=torch.bool)
crop = (size - selfsize) // 2
if size % 2 == selfsize % 2:
mask[crop:-crop] = True
# the actual midpoint is assumed to be:
# * the segment's midpoint in odd-sized segments
# * the first point in the second half of even-sized segments
elif selfsize % 2 == 0:
# size is odd;
# let us assume size - selfsize == 3; crop == 1 and
# we want to crop 1 elem from the start and 2 elems from the end
mask[crop:-crop-1] = True
else:
# selfsize is odd and size is even;
# again, let us assume size - selfsize == 3,
# we want to crop 2 elems from the start and 1 elem from the end
mask[crop+1:-crop] = True
mask, = mask.nonzero(as_tuple=True) # index_select only supports IntTensor
return mask
def mask_forward(self, batch, dim=3):
mask = self.mask(batch.shape[dim])
if mask is not None:
batch = torch.index_select(batch, dim, mask)
return self.forward(batch)
class PretrainedMaggotEncoder(MaggotEncoder):
def __init__(self, path,
cfgfile=None,
#cfgfile="autoencoder_config.json",
ptfile="best_validated_encoder.pt"):
super().__init__(path, cfgfile, ptfile)
def save(self, ptfile="retrained_encoder.pt"):
self.ptfile = ptfile
return super().save()
class MaggotEncoders(nn.Module):
def __init__(self, paths, cls=MaggotEncoder, **kwargs):
super().__init__()
self._pattern = None
if isinstance(paths, (str, Path)):
self._pattern = paths
import glob
paths = glob.glob(str(paths))
try:
ptfiles = kwargs.pop("ptfile")
except KeyError:
self.encoders = [cls(path, **kwargs) for path in paths]
else:
if isinstance(ptfiles, list):
self.encoders = [cls(path, ptfile=ptfile, **kwargs)
for path, ptfile in zip(paths, ptfiles)]
else:
kwargs["ptfile"] = ptfiles
self.encoders = [cls(path, **kwargs) for path in paths]
def __iter__(self):
return iter(self.encoders)
def forward(self, x):
return torch.cat([encoder.mask_forward(x) for encoder in self.encoders], dim=1)
@property
def cfgfilepaths(self):
return [enc.cfgfilepath for enc in self.encoders]
@property
def ptfilepaths(self):
return [enc.ptfilepath for enc in self.encoders]
def save_config(self, cfgfile=None):
for encoder in self.encoders:
encoder.save_config(cfgfile)
def save_model(self, ptfile=None):
for encoder in self.encoders:
encoder.save_model(ptfile)
def save(self):
for encoder in self.encoders:
encoder.save()
class DeepLinear(nn.Module):
def __init__(self, n_input, n_output, n_hidden=[], batch_norm=False,
weight_init="xavier"):
super().__init__()
self.batch_norm = batch_norm
self.weight_init = weight_init
layers = []
for n_hidden in list(n_hidden):
if n_hidden is None: n_hidden = n_input
layers.append(nn.Linear(n_input, n_hidden))
layers.append(nn.ReLU())
if batch_norm:
layers.append(nn.BatchNorm1d(n_hidden))
n_input = n_hidden
layers.append(nn.Linear(n_input, n_output))
self.layers = nn.Sequential(*layers)
def init_layers(self):
for layer in self.layers:
if isinstance(layer, nn.Linear):
if self.weight_init == "xavier":
nn.init.xavier_uniform_(layer.weight)
elif self.weight_init == "kaiming":
nn.init.kaiming_normal_(layer.weight)
else:
raise NotImplementedError(self.weight_init)
nn.init.zeros_(layer.bias)
def forward(self, x):
return self.layers(x)
def load(self, path):
self.load_state_dict(torch.load(path))
def save(self, path):
torch.save(self.state_dict(), path)
check_permissions(path)
class MaggotClassifier(MaggotModule):
def __init__(self, path, behavior_labels=[], n_latent_features=None,
n_layers=1, cfgfile=None, ptfile="trained_classifier.pt"):
super().__init__(path, cfgfile, ptfile)
try: # try load config file, if any
self.config
except:
assert bool(behavior_labels)
assert bool(n_latent_features)
self.config = dict(
clf_path=str(self.ptfilepath),
dim_latent=n_latent_features,
behavior_labels=behavior_labels,
clf_depth=0 if n_layers is None else n_layers - 1,
batch_norm=False,
weight_init="xavier",
loss="cross-entropy",
optimizer="adam")
@classmethod
def load_model(cls, config, path):
model = DeepLinear(
n_input=config["dim_latent"],
n_output=len(config["behavior_labels"]),
n_hidden=config["clf_depth"]*[None],
batch_norm=config["batch_norm"],
weight_init=config["weight_init"],
)
try:
model.load(path)
except:
# try:
# path = config["clf_path"]
# except KeyError:
model.init_layers()
# else:
# model.load(path)
return model
@property
def behavior_labels(self):
return self.config["behavior_labels"]
@behavior_labels.setter
def behavior_labels(self, labels):
self.config["behavior_labels"] = labels
@property
def n_latent_features(self):
return self.config["dim_latent"]
@property
def n_behaviors(self):
return len(self.behavior_labels)
@property
def n_layers(self):
return self.config["clf_depth"] + 1
class SupervisedMaggot(nn.Module):
def __init__(self, cfgfilepath, behaviors=[], n_layers=1):
super().__init__()
if behaviors: # the model is only pre-trained
self.encoder = PretrainedMaggotEncoder(cfgfilepath)
self.clf = MaggotClassifier(self.encoder.path / "clf_config.json",
behaviors, self.encoder.config["dim_latent"], n_layers)
else: # the model has been retrained
self.clf = MaggotClassifier(cfgfilepath)
self.encoder = MaggotEncoder(self.clf.config["autoencoder_config"],
ptfile=self.clf.config["enc_path"])
def forward(self, x):
return self.clf(self.encoder(x))
def save(self):
enc, clf = self.encoder, self.clf
enc.save()
clf.config["autoencoder_config"] = str(enc.cfgfilepath)
clf.config["enc_path"] = str(enc.ptfilepath)
clf.save()
class MultiscaleSupervisedMaggot(nn.Module):
def __init__(self, cfgfilepath, behaviors=[], n_layers=1):
super().__init__()
if behaviors: # the model is only pre-trained
self.encoders = MaggotEncoders(cfgfilepath, cls=PretrainedMaggotEncoder)
path = next(iter(self.encoders)).path.parent
n_latent_features = sum(enc.config["dim_latent"] for enc in self.encoders)
self.clf = MaggotClassifier(path / "clf_config.json",
behaviors, n_latent_features, n_layers)
else: # the model has been retrained
self.clf = MaggotClassifier(cfgfilepath)
self.encoders = MaggotEncoders(self.clf.config["autoencoder_config"],
ptfile=self.clf.config["enc_path"])
def forward(self, x):
return self.clf(self.encoders(x))
def save(self):
enc, clf = self.encoders, self.clf
enc.save()
clf.config["autoencoder_config"] = [str(p) for p in enc.cfgfilepaths]
clf.config["enc_path"] = [str(p) for p in enc.ptfilepaths]
clf.save()