Skip to content
Snippets Groups Projects
Commit becb7f82 authored by François  LAURENT's avatar François LAURENT
Browse files

seems to fix larvatagger.jl#76

parent 8e0ef41f
No related branches found
No related tags found
No related merge requests found
...@@ -90,6 +90,9 @@ class MaggotModule(nn.Module): ...@@ -90,6 +90,9 @@ class MaggotModule(nn.Module):
def parameters(self, recurse=True): def parameters(self, recurse=True):
return self.model.parameters(recurse) return self.model.parameters(recurse)
def to(self, device):
self.model.to(device)
""" """
Initialize a model's weights and bias (if any). Initialize a model's weights and bias (if any).
...@@ -305,6 +308,9 @@ class DeepLinear(nn.Module): ...@@ -305,6 +308,9 @@ class DeepLinear(nn.Module):
torch.save(self.state_dict(), path) torch.save(self.state_dict(), path)
check_permissions(path) check_permissions(path)
def to(self, device):
self.layers.to(device)
class MaggotClassifier(MaggotModule): class MaggotClassifier(MaggotModule):
def __init__(self, path, behavior_labels=[], n_latent_features=None, def __init__(self, path, behavior_labels=[], n_latent_features=None,
n_layers=1, cfgfile=None, ptfile="trained_classifier.pt"): n_layers=1, cfgfile=None, ptfile="trained_classifier.pt"):
...@@ -385,6 +391,10 @@ class SupervisedMaggot(nn.Module): ...@@ -385,6 +391,10 @@ class SupervisedMaggot(nn.Module):
self.clf.model # force parameter loading or initialization self.clf.model # force parameter loading or initialization
return super().parameters(self) return super().parameters(self)
def to(self, device):
self.encoder.to(device)
self.clf.to(device)
class MultiscaleSupervisedMaggot(nn.Module): class MultiscaleSupervisedMaggot(nn.Module):
def __init__(self, cfgfilepath, behaviors=[], n_layers=1): def __init__(self, cfgfilepath, behaviors=[], n_layers=1):
super().__init__() super().__init__()
......
...@@ -231,7 +231,7 @@ class MaggotTrainer: ...@@ -231,7 +231,7 @@ class MaggotTrainer:
self.model.save() self.model.save()
def new_generator(seed=None): def new_generator(seed=None):
generator = torch.Generator(device) generator = torch.Generator('cpu')
if seed == 'random': return generator if seed == 'random': return generator
if seed is None: seed = 0b11010111001001101001110 if seed is None: seed = 0b11010111001001101001110
return generator.manual_seed(seed) return generator.manual_seed(seed)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment