Skip to content
Snippets Groups Projects
Select Git revision
  • 7386df2ee2ff8ebc032d7ecc298a612ce39805f0
  • main default protected
  • torch2
  • torch1
  • dev protected
  • 20230311_new_default
  • 20230311
  • design protected
  • 20230129
  • 20230111
  • 20221005 protected
  • 20220418 protected
  • v0.20
  • v0.19
  • v0.18
  • v0.17
  • v0.16.4
  • v0.16.3
  • v0.16.2
  • v0.16.1
  • v0.16
  • v0.15
  • v0.14
  • v0.13
  • v0.12.4
  • v0.12.3
  • v0.12.2
  • v0.12.1
  • v0.12
  • v0.11
  • v0.10
  • v0.9.1
32 results

trainers.py

Blame
  • modules.py 17.64 KiB
    import logging
    import os
    from pathlib import Path
    import torch
    from torch import nn
    import json
    import functools
    from behavior_model.models.neural_nets import Encoder
    from taggingbackends.explorer import check_permissions
    
    class MaggotModule(nn.Module):
        def __init__(self, path, cfgfile=None, ptfile=None):
            super().__init__()
            self.path = path if isinstance(path, Path) else Path(path)
            if cfgfile is None:
                cfgfile = self.path.name
                self.path = self.path.parent
            elif not self.path.is_dir():
                raise ValueError("\"str(self.path)\" is not a directory")
            self.cfgfile = cfgfile
            if ptfile is not None and Path(ptfile).parent == path:
                ptfile = Path(ptfile).name
            self.ptfile = ptfile
            self._config = None
            self._model = None
    
        @classmethod
        def load_config(cls, path):
            with open(path, "r") as f:
                return json.load(f)
    
        @property
        def cfgfilepath(self):
            return self.path / self.cfgfile
    
        @property
        def ptfilepath(self):
            return self.path / self.ptfile
    
        @property
        def config(self):
            if self._config is None:
                self._config = self.load_config(self.cfgfilepath)
            return self._config
    
        @config.setter
        def config(self, cfg):
            self._config = cfg
    
        @classmethod
        def load_model(cls, config, path):
            raise NotImplementedError
    
        @property
        def model(self):
            if self._model is None:
                try:
                    self._model = self.load_model(self.config, self.ptfilepath)
                except Exception as e:
                    logging.error(e)
                    logging.error('could not load or initialize the model; check the load_model class method')
            return self._model
    
        @model.setter
        def model(self, model):
            self._model = model
    
        def forward(self, x):
            return self.model(x)
    
        def save_config(self, cfgfile=None):
            if cfgfile is None: cfgfile = self.cfgfile
            path = self.path / cfgfile
            with open(path, "w") as f:
                json.dump(self.config, f, indent=2)
            check_permissions(path)
            return path
    
        def save_model(self, ptfile=None):
            if ptfile is None: ptfile = self.ptfile
            path = self.path / ptfile
            torch.save(self.model.state_dict(), path)
            check_permissions(path)
            return path
    
        def save(self):
            self.save_model()
            self.save_config()
    
        def parameters(self, recurse=True):
            return self.model.parameters(recurse)
    
        def to(self, device):
            self.model.to(device)
    
    
    """
    Initialize a model's weights and bias (if any).
    
    Adapted from `behavior_model.models.neural_nets.AutoEncoder._init_weights`
    
    Passing `None` as argument `weight_init` or `has_bias` selects the default
    value.
    """
    def init_weights(model, weight_init='xavier', has_bias=False):
        if has_bias:
            nn.init.constant_(model.bias, 0)
        if weight_init is None:
            weight_init = 'xavier'
        _init = dict(
                kaiming='kaiming_uniform',
                xavier='xavier_uniform',
                ).get(weight_init, weight_init)
        if _init == 'orthogonal':
            nn.init.orthogonal_(model.weight)
        elif _init == 'xavier_uniform':
            nn.init.xavier_uniform_(model.weight)
        elif _init == 'kaiming_uniform':
            nn.init.kaiming_uniform_(model.weight, nonlinearity='relu')
        else:
            raise ValueError(f"initialization method not supported: {weight_init}")
    
    
    """
    Note: per default MaggotEncoder represents a retrained encoder
          (retrained = trained for a behavior-tagging task);
          see PretrainedMaggotEncoder for encoders that were only pretrained
          (pretrained = trained in a self-supervised task)
    """
    class MaggotEncoder(MaggotModule):
        def __init__(self, path,
                cfgfile=None,
                #cfgfile="autoencoder_config.json",
                ptfile="retrained_encoder.pt"):
            super().__init__(path, cfgfile, ptfile)
    
        @classmethod
        def load_config(self, path):
            config = super().load_config(path)
            config["config"] = str(path)
            return config
    
        @classmethod
        def load_model(cls, config, path):
            encoder = Encoder(**config)
            _reason = None
            if config.get('load_state', True):
                try:
                    encoder.load_state_dict(torch.load(path))
                except Exception as e:
                    _reason = e
                    config['load_state'] = False # for `was_pretrained` to properly work
            else:
                _reason = '"load_state" is set to false'
            # if state file not found or config option "load_state" is False,
            # (re-)initialize the model's weights
            if _reason:
                logging.debug(f"initializing the encoder ({_reason})")
                _init, _bias = config.get('init', None), config.get('bias', None)
                for child in encoder.children():
                    if isinstance(child,
                                  (nn.Linear, nn.Conv2d, nn.Conv1d,
                                   nn.ConvTranspose1d, nn.ConvTranspose2d)):
                        init_weights(child, _init, _bias)
            return encoder
    
        @functools.lru_cache(maxsize=1)
        def mask(self, size):
            selfsize = self.config["len_traj"]
            assert selfsize <= size
            if size == selfsize:
                return
            mask = torch.zeros(size, dtype=torch.bool)
            crop = (size - selfsize) // 2
            if size % 2 == selfsize % 2:
                mask[crop:-crop] = True
                # the actual midpoint is assumed to be:
                # * the segment's midpoint in odd-sized segments
                # * the first point in the second half of even-sized segments
            elif selfsize % 2 == 0:
                # size is odd;
                # let us assume size - selfsize == 3; crop == 1 and
                # we want to crop 1 elem from the start and 2 elems from the end
                mask[crop:-crop-1] = True
            else:
                # selfsize is odd and size is even;
                # again, let us assume size - selfsize == 3,
                # we want to crop 2 elems from the start and 1 elem from the end
                mask[crop+1:-crop] = True
            mask, = mask.nonzero(as_tuple=True) # index_select only supports IntTensor
            return mask
    
        def mask_forward(self, batch, dim=3):
            mask = self.mask(batch.shape[dim])
            if mask is not None:
                batch = torch.index_select(batch, dim, mask)
            return self.forward(batch)
    
        """
        Determine whether the encoder was pretrained as part of a MaggotUBA
        autoencoder, or only initialized in the context of testing the benefit of
        using a pretrained encoder.
    
        This is to be distinguished from the MaggotEncoder/PretrainedMaggotEncoder
        classes that instead represent the different states *after* and *before*
        retraining.
    
        The purpose of this method is to determine at retraining time whether to
        pretrain the classifier and then fine-tune the full encoder+classifier, or
        instead train both the classifier and encoder in a single-stage training
        process. Indeed, the two-stage retraining process only makes sense if the
        encoder was pretrained.
    
        See `trainers.MaggotTrainer.train`.
        """
        def was_pretrained(self):
            return self.config.get('load_state', True)
    
    class PretrainedMaggotEncoder(MaggotEncoder):
        def __init__(self, path,
                cfgfile=None,
                #cfgfile="autoencoder_config.json",
                ptfile="best_validated_encoder.pt"):
            super().__init__(path, cfgfile, ptfile)
    
        def save(self, ptfile="retrained_encoder.pt"):
            self.ptfile = ptfile
            # "load_state" was introduced in json config file as a mechanism to load
            # untrained encoders; once trained, this key must be removed:
            self.config.pop('load_state', None)
            return super().save()
    
    class MaggotEncoders(nn.Module):
        def __init__(self, paths, cls=MaggotEncoder, **kwargs):
            super().__init__()
            self._pattern = None
            if isinstance(paths, (str, Path)):
                self._pattern = paths
                import glob
                paths = glob.glob(str(paths))
            try:
                ptfiles = kwargs.pop("ptfile")
            except KeyError:
                self.encoders = [cls(path, **kwargs) for path in paths]
            else:
                if isinstance(ptfiles, list):
                    self.encoders = [cls(path, ptfile=ptfile, **kwargs)
                            for path, ptfile in zip(paths, ptfiles)]
                else:
                    kwargs["ptfile"] = ptfiles
                    self.encoders = [cls(path, **kwargs) for path in paths]
    
        def __iter__(self):
            return iter(self.encoders)
    
        def forward(self, x):
            return torch.cat([encoder.mask_forward(x) for encoder in self.encoders], dim=1)
    
        @property
        def cfgfilepaths(self):
            return [enc.cfgfilepath for enc in self.encoders]
    
        @property
        def ptfilepaths(self):
            return [enc.ptfilepath for enc in self.encoders]
    
        def save_config(self, cfgfile=None):
            for encoder in self.encoders:
                encoder.save_config(cfgfile)
    
        def save_model(self, ptfile=None):
            for encoder in self.encoders:
                encoder.save_model(ptfile)
    
        def save(self):
            for encoder in self.encoders:
                encoder.save()
    
    class DeepLinear(nn.Module):
        def __init__(self, n_input, n_output, n_hidden=[], batch_norm=False,
                weight_init="xavier"):
            super().__init__()
            self.batch_norm = batch_norm
            self.weight_init = weight_init
            layers = []
            for n_hidden in list(n_hidden):
                if n_hidden is None: n_hidden = n_input
                layers.append(nn.Linear(n_input, n_hidden))
                layers.append(nn.ReLU())
                if batch_norm:
                    layers.append(nn.BatchNorm1d(n_hidden))
                n_input = n_hidden
            layers.append(nn.Linear(n_input, n_output))
            self.layers = nn.Sequential(*layers)
    
        def init_layers(self):
            for layer in self.layers:
                if isinstance(layer, nn.Linear):
                    init_weights(layer, self.weight_init, True)
    
        def forward(self, x):
            return self.layers(x)
    
        def load(self, path):
            self.load_state_dict(torch.load(path))
    
        def save(self, path):
            torch.save(self.state_dict(), path)
            check_permissions(path)
    
        def to(self, device):
            self.layers.to(device)
    
    class MaggotClassifier(MaggotModule):
        def __init__(self, path, behavior_labels=[], n_latent_features=None,
                n_layers=1, n_iterations=None, cfgfile=None,
                ptfile="trained_classifier.pt"):
            super().__init__(path, cfgfile, ptfile)
            try: # try load config file, if any
                self.config
            except:
                assert bool(behavior_labels)
                assert bool(n_latent_features)
                self.config = dict(
                    clf_path=str(self.ptfilepath),
                    dim_latent=n_latent_features,
                    behavior_labels=behavior_labels,
                    clf_depth=0 if n_layers is None else n_layers - 1,
                    batch_norm=False,
                    weight_init="xavier",
                    loss="cross-entropy",
                    optimizer="adam")
                if n_iterations is not None:
                    if isinstance(n_iterations, str):
                        n_iterations = map(int, n_iterations.split(','))
                    if isinstance(n_iterations, int):
                        n_pretraining_iter = n_iterations // 2
                        n_finetuning_iter = n_iterations // 2
                    else:
                        n_pretraining_iter, n_finetuning_iter = n_iterations
                    self.config['pretraining_iter'] = n_pretraining_iter
                    self.config['finetuning_iter'] = n_finetuning_iter
    
        @classmethod
        def load_model(cls, config, path):
            model = DeepLinear(
                    n_input=config["dim_latent"],
                    n_output=len(config["behavior_labels"]),
                    n_hidden=config["clf_depth"]*[None],
                    batch_norm=config["batch_norm"],
                    weight_init=config["weight_init"],
                    )
            try:
                model.load(path)
            except:
                model.init_layers()
            return model
    
        @property
        def behavior_labels(self):
            return self.config["behavior_labels"]
    
        @behavior_labels.setter
        def behavior_labels(self, labels):
            self.config["behavior_labels"] = labels
    
        @property
        def n_latent_features(self):
            return self.config["dim_latent"]
    
        @property
        def n_behaviors(self):
            return len(self.behavior_labels)
    
        @property
        def n_layers(self):
            return self.config["clf_depth"] + 1
    
        @property
        def n_pretraining_iter(self):
            return self.config.get('pretraining_iter', None)
    
        @property
        def n_finetuning_iter(self):
            return self.config.get('finetuning_iter', None)
    
    class SupervisedMaggot(nn.Module):
        def __init__(self, cfgfilepath, behaviors=[], n_layers=1, n_epochs=None):
            super().__init__()
            if behaviors: # the model is only pre-trained
                self.encoder = PretrainedMaggotEncoder(cfgfilepath)
                self.clf = MaggotClassifier(self.encoder.path / "clf_config.json",
                        behaviors, self.encoder.config["dim_latent"], n_layers,
                        n_epochs)
            else: # the model has been retrained
                self.clf = MaggotClassifier(cfgfilepath)
                self.encoder = MaggotEncoder(self.clf.config["autoencoder_config"],
                        ptfile=self.clf.config["enc_path"])
    
        def forward(self, x):
            return self.clf(self.encoder(x))
    
        def mask_forward(self, x):
            return self.clf(self.encoder.mask_forward(x))
    
        def save(self):
            enc, clf = self.encoder, self.clf
            enc.save()
            clf.config["autoencoder_config"] = str(enc.cfgfilepath)
            clf.config["enc_path"] = str(enc.ptfilepath)
            clf.save()
    
        def parameters(self):
            self.clf.model # force parameter loading or initialization
            return super().parameters(self)
    
        def to(self, device):
            self.encoder.to(device)
            self.clf.to(device)
    
        @property
        def n_pretraining_iter(self):
            n = self.clf.n_pretraining_iter
            if n is None:
                enc = self.encoder
                n = enc.config['optim_iter']
                if enc.was_pretrained():
                    n = n // 2
            return n
    
        @property
        def n_finetuning_iter(self):
            n = self.clf.n_finetuning_iter
            if n is None:
                enc = self.encoder
                n = enc.config['optim_iter']
                if enc.was_pretrained():
                    n = n // 2
            return n
    
    class MultiscaleSupervisedMaggot(nn.Module):
        def __init__(self, cfgfilepath, behaviors=[], n_layers=1, n_iterations=None):
            super().__init__()
            if behaviors: # the model is only pre-trained
                self.encoders = MaggotEncoders(cfgfilepath, cls=PretrainedMaggotEncoder)
                path = next(iter(self.encoders)).path.parent
                n_latent_features = sum(enc.config["dim_latent"] for enc in self.encoders)
                self.clf = MaggotClassifier(path / "clf_config.json",
                        behaviors, n_latent_features, n_layers, n_iterations)
            else: # the model has been retrained
                self.clf = MaggotClassifier(cfgfilepath)
                self.encoders = MaggotEncoders(self.clf.config["autoencoder_config"],
                        ptfile=self.clf.config["enc_path"])
    
        def forward(self, x):
            return self.clf(self.encoders(x))
    
        def save(self):
            enc, clf = self.encoders, self.clf
            enc.save()
            clf.config["autoencoder_config"] = [str(p) for p in enc.cfgfilepaths]
            clf.config["enc_path"] = [str(p) for p in enc.ptfilepaths]
            clf.save()
    
        def parameters(self):
            self.clf.model # force parameter loading or initialization
            return super().parameters(self)
    
        @property
        def n_pretraining_iter(self):
            n = self.clf.n_pretraining_iter
            if n is None:
                any_enc = self.encoders[0]
                n = any_enc.config['optim_iter']
                if any_enc.was_pretrained():
                    n = n // 2
            return n
    
        @property
        def n_finetuning_iter(self):
            n = self.clf.n_finetuning_iter
            if n is None:
                any_enc = self.encoders[0]
                n = any_enc.config['optim_iter']
                if any_enc.was_pretrained():
                    n = n // 2
            return n
    
    """
    Bagging for `SupervisedMaggot`.
    
    Taggers in the bag are individually trained.
    For now the bag itself cannot be trained and is used for prediction only.
    
    Bags of taggers are stored so that the models directory only contains
    subdirectories, each subdirectory specifying an individual tagger.
    """
    class MaggotBag(nn.Module):
        def __init__(self, paths, behaviors=[], n_layers=1, n_iterations=None,
                     cls=SupervisedMaggot):
            super().__init__()
            self.maggots = [cls(path, behaviors, n_layers, n_iterations) for path in paths]
            self._lead_maggot = None
    
        def forward(self, x):
            #return torch.cat([encoder.mask_forward(x) for encoder in self.encoders], dim=1)
            return self.vote([maggot.mask_forward(x) for maggot in self.maggots])
    
        def vote(self, y):
            vote, _ = torch.mode(torch.stack(y, dim=len(y[0].shape)))
            return vote
    
        @property
        def encoder(self):
            return self.maggots[self.lead_maggot].encoder
    
        @property
        def clf(self):
            return self.maggots[self.lead_maggot].clf
    
        @property
        def lead_maggot(self):
            if self._lead_maggot is None:
                len_traj = 0
                for i, maggot in enumerate(self.maggots):
                    len_traj_ = maggot.encoder.config['len_traj']
                    if len_traj < len_traj_:
                        len_traj = len_traj_
                        self._lead_maggot = i
            return self._lead_maggot