Skip to content
Snippets Groups Projects
Select Git revision
  • 68fa827ef495578e281d24afc2b2959ec93ecac8
  • main default protected
  • torch2
  • torch1
  • dev protected
  • 20230311_new_default
  • 20230311
  • design protected
  • 20230129
  • 20230111
  • 20221005 protected
  • 20220418 protected
  • v0.20
  • v0.19
  • v0.18
  • v0.17
  • v0.16.4
  • v0.16.3
  • v0.16.2
  • v0.16.1
  • v0.16
  • v0.15
  • v0.14
  • v0.13
  • v0.12.4
  • v0.12.3
  • v0.12.2
  • v0.12.1
  • v0.12
  • v0.11
  • v0.10
  • v0.9.1
32 results

modules.py

Blame
  • François Laurent's avatar
    implements larvatagger.jl#58 at this project level
    François LAURENT authored
    68fa827e
    History
    modules.py 11.03 KiB
    import os
    from pathlib import Path
    import torch
    from torch import nn
    import json
    import functools
    from behavior_model.models.neural_nets import Encoder
    from taggingbackends.explorer import check_permissions
    
    class MaggotModule(nn.Module):
        def __init__(self, path, cfgfile=None, ptfile=None):
            super().__init__()
            self.path = path if isinstance(path, Path) else Path(path)
            if cfgfile is None:
                cfgfile = self.path.name
                self.path = self.path.parent
            elif not self.path.is_dir():
                raise ValueError("\"str(self.path)\" is not a directory")
            self.cfgfile = cfgfile
            if ptfile is not None and Path(ptfile).parent == path:
                ptfile = Path(ptfile).name
            self.ptfile = ptfile
            self._config = None
            self._model = None
    
        @classmethod
        def load_config(cls, path):
            with open(path, "r") as f:
                return json.load(f)
    
        @property
        def cfgfilepath(self):
            return self.path / self.cfgfile
    
        @property
        def ptfilepath(self):
            return self.path / self.ptfile
    
        @property
        def config(self):
            if self._config is None:
                self._config = self.load_config(self.cfgfilepath)
            return self._config
    
        @config.setter
        def config(self, cfg):
            self._config = cfg
    
        @classmethod
        def load_model(cls, config, path):
            raise NotImplementedError
    
        @property
        def model(self):
            if self._model is None:
                self._model = self.load_model(self.config, self.ptfilepath)
            return self._model
    
        @model.setter
        def model(self, model):
            self._model = model
    
        def forward(self, x):
            return self.model(x)
    
        def save_config(self, cfgfile=None):
            if cfgfile is None: cfgfile = self.cfgfile
            path = self.path / cfgfile
            with open(path, "w") as f:
                json.dump(self.config, f, indent=2)
            check_permissions(path)
            return path
    
        def save_model(self, ptfile=None):
            if ptfile is None: ptfile = self.ptfile
            path = self.path / ptfile
            torch.save(self.model.state_dict(), path)
            check_permissions(path)
            return path
    
        def save(self):
            self.save_model()
            self.save_config()
    
        def parameters(self, recurse=True):
            return self.model.parameters(recurse)
    
    
    class MaggotEncoder(MaggotModule):
        def __init__(self, path,
                cfgfile=None,
                #cfgfile="autoencoder_config.json",
                ptfile="retrained_encoder.pt"):
            super().__init__(path, cfgfile, ptfile)
    
        @classmethod
        def load_config(self, path):
            config = super().load_config(path)
            config["config"] = str(path)
            return config
    
        @classmethod
        def load_model(cls, config, path):
            encoder = Encoder(**config)
            encoder.load_state_dict(torch.load(path))
            return encoder
    
        @functools.lru_cache(maxsize=1)
        def mask(self, size):
            selfsize = self.config["len_traj"]
            assert selfsize <= size
            if size == selfsize:
                return
            mask = torch.zeros(size, dtype=torch.bool)
            crop = (size - selfsize) // 2
            if size % 2 == selfsize % 2:
                mask[crop:-crop] = True
                # the actual midpoint is assumed to be:
                # * the segment's midpoint in odd-sized segments
                # * the first point in the second half of even-sized segments
            elif selfsize % 2 == 0:
                # size is odd;
                # let us assume size - selfsize == 3; crop == 1 and
                # we want to crop 1 elem from the start and 2 elems from the end
                mask[crop:-crop-1] = True
            else:
                # selfsize is odd and size is even;
                # again, let us assume size - selfsize == 3,
                # we want to crop 2 elems from the start and 1 elem from the end
                mask[crop+1:-crop] = True
            mask, = mask.nonzero(as_tuple=True) # index_select only supports IntTensor
            return mask
    
        def mask_forward(self, batch, dim=3):
            mask = self.mask(batch.shape[dim])
            if mask is not None:
                batch = torch.index_select(batch, dim, mask)
            return self.forward(batch)
    
    class PretrainedMaggotEncoder(MaggotEncoder):
        def __init__(self, path,
                cfgfile=None,
                #cfgfile="autoencoder_config.json",
                ptfile="best_validated_encoder.pt"):
            super().__init__(path, cfgfile, ptfile)
    
        def save(self, ptfile="retrained_encoder.pt"):
            self.ptfile = ptfile
            return super().save()
    
    class MaggotEncoders(nn.Module):
        def __init__(self, paths, cls=MaggotEncoder, **kwargs):
            super().__init__()
            self._pattern = None
            if isinstance(paths, (str, Path)):
                self._pattern = paths
                import glob
                paths = glob.glob(str(paths))
            try:
                ptfiles = kwargs.pop("ptfile")
            except KeyError:
                self.encoders = [cls(path, **kwargs) for path in paths]
            else:
                if isinstance(ptfiles, list):
                    self.encoders = [cls(path, ptfile=ptfile, **kwargs)
                            for path, ptfile in zip(paths, ptfiles)]
                else:
                    kwargs["ptfile"] = ptfiles
                    self.encoders = [cls(path, **kwargs) for path in paths]
    
        def __iter__(self):
            return iter(self.encoders)
    
        def forward(self, x):
            return torch.cat([encoder.mask_forward(x) for encoder in self.encoders], dim=1)
    
        @property
        def cfgfilepaths(self):
            return [enc.cfgfilepath for enc in self.encoders]
    
        @property
        def ptfilepaths(self):
            return [enc.ptfilepath for enc in self.encoders]
    
        def save_config(self, cfgfile=None):
            for encoder in self.encoders:
                encoder.save_config(cfgfile)
    
        def save_model(self, ptfile=None):
            for encoder in self.encoders:
                encoder.save_model(ptfile)
    
        def save(self):
            for encoder in self.encoders:
                encoder.save()
    
    class DeepLinear(nn.Module):
        def __init__(self, n_input, n_output, n_hidden=[], batch_norm=False,
                weight_init="xavier"):
            super().__init__()
            self.batch_norm = batch_norm
            self.weight_init = weight_init
            layers = []
            for n_hidden in list(n_hidden):
                if n_hidden is None: n_hidden = n_input
                layers.append(nn.Linear(n_input, n_hidden))
                layers.append(nn.ReLU())
                if batch_norm:
                    layers.append(nn.BatchNorm1d(n_hidden))
                n_input = n_hidden
            layers.append(nn.Linear(n_input, n_output))
            self.layers = nn.Sequential(*layers)
    
        def init_layers(self):
            for layer in self.layers:
                if isinstance(layer, nn.Linear):
                    if self.weight_init == "xavier":
                        nn.init.xavier_uniform_(layer.weight)
                    elif self.weight_init == "kaiming":
                        nn.init.kaiming_normal_(layer.weight)
                    else:
                        raise NotImplementedError(self.weight_init)
                    nn.init.zeros_(layer.bias)
    
        def forward(self, x):
            return self.layers(x)
    
        def load(self, path):
            self.load_state_dict(torch.load(path))
    
        def save(self, path):
            torch.save(self.state_dict(), path)
            check_permissions(path)
    
    class MaggotClassifier(MaggotModule):
        def __init__(self, path, behavior_labels=[], n_latent_features=None,
                n_layers=1, cfgfile=None, ptfile="trained_classifier.pt"):
            super().__init__(path, cfgfile, ptfile)
            try: # try load config file, if any
                self.config
            except:
                assert bool(behavior_labels)
                assert bool(n_latent_features)
                self.config = dict(
                    clf_path=str(self.ptfilepath),
                    dim_latent=n_latent_features,
                    behavior_labels=behavior_labels,
                    clf_depth=0 if n_layers is None else n_layers - 1,
                    batch_norm=False,
                    weight_init="xavier",
                    loss="cross-entropy",
                    optimizer="adam")
    
        @classmethod
        def load_model(cls, config, path):
            model = DeepLinear(
                    n_input=config["dim_latent"],
                    n_output=len(config["behavior_labels"]),
                    n_hidden=config["clf_depth"]*[None],
                    batch_norm=config["batch_norm"],
                    weight_init=config["weight_init"],
                    )
            try:
                model.load(path)
            except:
                # try:
                #     path = config["clf_path"]
                # except KeyError:
                    model.init_layers()
                # else:
                #     model.load(path)
            return model
    
        @property
        def behavior_labels(self):
            return self.config["behavior_labels"]
    
        @behavior_labels.setter
        def behavior_labels(self, labels):
            self.config["behavior_labels"] = labels
    
        @property
        def n_latent_features(self):
            return self.config["dim_latent"]
    
        @property
        def n_behaviors(self):
            return len(self.behavior_labels)
    
        @property
        def n_layers(self):
            return self.config["clf_depth"] + 1
    
    class SupervisedMaggot(nn.Module):
        def __init__(self, cfgfilepath, behaviors=[], n_layers=1):
            super().__init__()
            if behaviors: # the model is only pre-trained
                self.encoder = PretrainedMaggotEncoder(cfgfilepath)
                self.clf = MaggotClassifier(self.encoder.path / "clf_config.json",
                        behaviors, self.encoder.config["dim_latent"], n_layers)
            else: # the model has been retrained
                self.clf = MaggotClassifier(cfgfilepath)
                self.encoder = MaggotEncoder(self.clf.config["autoencoder_config"],
                        ptfile=self.clf.config["enc_path"])
    
        def forward(self, x):
            return self.clf(self.encoder(x))
    
        def save(self):
            enc, clf = self.encoder, self.clf
            enc.save()
            clf.config["autoencoder_config"] = str(enc.cfgfilepath)
            clf.config["enc_path"] = str(enc.ptfilepath)
            clf.save()
    
    class MultiscaleSupervisedMaggot(nn.Module):
        def __init__(self, cfgfilepath, behaviors=[], n_layers=1):
            super().__init__()
            if behaviors: # the model is only pre-trained
                self.encoders = MaggotEncoders(cfgfilepath, cls=PretrainedMaggotEncoder)
                path = next(iter(self.encoders)).path.parent
                n_latent_features = sum(enc.config["dim_latent"] for enc in self.encoders)
                self.clf = MaggotClassifier(path / "clf_config.json",
                        behaviors, n_latent_features, n_layers)
            else: # the model has been retrained
                self.clf = MaggotClassifier(cfgfilepath)
                self.encoders = MaggotEncoders(self.clf.config["autoencoder_config"],
                        ptfile=self.clf.config["enc_path"])
    
        def forward(self, x):
            return self.clf(self.encoders(x))
    
        def save(self):
            enc, clf = self.encoders, self.clf
            enc.save()
            clf.config["autoencoder_config"] = [str(p) for p in enc.cfgfilepaths]
            clf.config["enc_path"] = [str(p) for p in enc.ptfilepaths]
            clf.save()