diff --git a/src/maggotuba/models/denselayer.py b/src/maggotuba/models/denselayer.py index 82a6f7a6eea0f529419522d9e0ca27fab51ddc38..7196bcd3cbe76fc077ddbf176b6822ee8fd3a9c2 100644 --- a/src/maggotuba/models/denselayer.py +++ b/src/maggotuba/models/denselayer.py @@ -169,7 +169,7 @@ class DenseLayer: if train: return y else: - return y.numpy() + return y.cpu().numpy() def train(self, all_spines=None, tags=None): if all_spines is None or tags is None: diff --git a/src/maggotuba/models/predict_model.py b/src/maggotuba/models/predict_model.py index 681cf4b327c980c030cd10776c53fe314f047d77..4299dd2f381df7a3d2ccd758d6608f406bcf9027 100644 --- a/src/maggotuba/models/predict_model.py +++ b/src/maggotuba/models/predict_model.py @@ -25,13 +25,10 @@ def predict_model(backend): # we could go and pick files in `data/interim` as well: input_files += backend.list_interim_files() assert 0 < len(input_files) - metadata = None - metadata_file = [file for file in input_files if file.name == "metadata"] - if metadata_file: - metadata_file = metadata_file[0] - input_files.remove(metadata_file) - with open(metadata_file, "r") as f: - metadata = json.load(f) + # initialize output labels + input_files, labels = backend.prepare_labels(input_files) + assert 0 < len(input_files) + # for file in input_files: # load the input data (or features) if file.name.endswith(".spine"): @@ -40,7 +37,7 @@ def predict_model(backend): larvae = spine["larva_id"].values t = spine["time"].values data = spine.iloc[:,3:].values - elif file.name == "trx.mat": + elif file.name.endswith(".mat"): trx = TrxMat(file) t = trx["t"] data = trx["spine"] @@ -49,11 +46,14 @@ def predict_model(backend): run, data = next(iter(data.items())) t = t[run] elif file.name.endswith(".csv"): - print("assuming 30 fps") - t, data = fimtrack.read_spines(file, fps=30) + if labels.camera_framerate: + print(f"camera frame rate: {labels.camera_framerate}fps") + else: + print("assuming 30-fps camera frame rate") + labels.camera_framerate = 30 + t, data = fimtrack.read_spines(file, fps=labels.camera_framerate) run = "NA" else: - # TODO: support more file formats continue # downsample the skeleton if isinstance(data, dict): @@ -68,7 +68,6 @@ def predict_model(backend): config_file = [file for file in config_file if file.name.endswith("clf_config.json")] model = Clf(config_file[-1]) # assign labels - labels = Labels(tracking=input_files) if isinstance(data, dict): ref_length = np.mean(np.concatenate([ model.body_length(spines) for spines in data.values() @@ -93,10 +92,6 @@ def predict_model(backend): else: labels[run, larva] = dict(zip(t[mask], predictions)) # save the predicted labels to file - if metadata: - labels.metadata = metadata - else: - labels.metadata = {'filename': file.name} labels.labelspec = { "names": ["run", "bend", "stop", "hunch", "back", "roll"], "colors": ["#000000", "#ff0000", "#00ff00", "#0000ff",