prepare.py 56 KB
Newer Older
amichaut's avatar
amichaut committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
##########################################################################
# Track Analyzer - Quantification and visualization of tracking data     #
# Authors: Arthur Michaut                                                #
# Copyright 2016-2019 Harvard Medical School and Brigham and             #
#                          Women's Hospital                              #
# Copyright 2019-2021 Institut Pasteur and CNRS–UMR3738                  #
# See the COPYRIGHT file for details                                     #
#                                                                        #
# This file is part of Track Analyzer package.                           #
#                                                                        #
# Track Analyzer is free software: you can redistribute it and/or modify #
# it under the terms of the GNU General Public License as published by   #
# the Free Software Foundation, either version 3 of the License, or      #
# (at your option) any later version.                                    #
#                                                                        #
# Track Analyzer is distributed in the hope that it will be useful,      #
# but WITHOUT ANY WARRANTY; without even the implied warranty of         #
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the           #
# GNU General Public License for more details .                          #
#                                                                        #
# You should have received a copy of the GNU General Public License      #
# along with Track Analyzer (COPYING).                                   #
# If not, see <https://www.gnu.org/licenses/>.                           #
##########################################################################

amichaut's avatar
amichaut committed
26
27
28
29
import os
import os.path as osp
import csv

amichaut's avatar
amichaut committed
30
31
32
import matplotlib.pyplot as plt
from matplotlib import cm
import numpy as np
amichaut's avatar
amichaut committed
33
34
35
36
import pandas as pd
import pickle
import seaborn as sns
import napari
amichaut's avatar
amichaut committed
37
38
39
40
from skimage import io
from skimage.color import rgb2gray
from skimage.util import img_as_ubyte
import tifffile as tifff
amichaut's avatar
amichaut committed
41
42
43

from track_analyzer import plotting as tpl
from track_analyzer import calculate as tca
amichaut's avatar
amichaut committed
44
45

# Plotting parameters
amichaut's avatar
amichaut committed
46
47
48
49
color_list = [c['color'] for c in list(plt.rcParams['axes.prop_cycle'])] + sns.color_palette("Set1", n_colors=9,
                                                                                             desat=.5)
plot_param = {'figsize': (5, 5), 'dpi': 300, 'color_list': color_list, 'format': '.png', 'despine': True, 'logx': False,
              'logy': False, 'invert_yaxis': True, 'export_data_pts': False}
amichaut's avatar
amichaut committed
50
51


amichaut's avatar
amichaut committed
52
def paper_style():
53
54
55
56
    """
    Set some Seaborn plotting parameters
    :return:
    """
amichaut's avatar
amichaut committed
57
    mpl.rcParams.update(mpl.rcParamsDefault)  # ensure the default params are active
amichaut's avatar
amichaut committed
58
    sns.set_style("ticks")
amichaut's avatar
amichaut committed
59
60
    sns.set_context("paper", font_scale=2., rc={"lines.linewidth": 2, "lines.markersize": 9})

amichaut's avatar
amichaut committed
61

62
def get_cmap_color(value, colormap='plasma', vmin=None, vmax=None):
63
64
65
66
67
68
69
70
71
72
73
74
75
    """
    Get color corresponding to a value from a colormap. Optionally, give boundaries to colormap with vmin, vmax.
    :param value: value to be converted to color
    :type value: float
    :param colormap: Matplotlib colormap name
    :type colormap: str
    :param vmin: if not None, minimum value of colormap
    :type vmin: float or None
    :param vmax: if not None, maximum value of colormap
    :type vmax: float or None
    :return: color
    :rtype: tuple
    """
76
    colormap = plt.get_cmap(colormap)
amichaut's avatar
amichaut committed
77
78
79
    norm = plt.Normalize(vmin, vmax)
    return colormap(norm(value))

amichaut's avatar
amichaut committed
80
81

def listdir_nohidden(path):
82
83
84
85
86
87
88
89
90
91
    """
    List a directory without hidden files starting with a dot
    :param path: path to directory to list
    :type path: str
    :return: list of instances within a directory
    :rtype: list
    """
    if not osp.isdir(path):
        raise Exception('ERROR: {} is not a directory'.format(path))

amichaut's avatar
amichaut committed
92
93
94
95
96
97
    dir_list = []
    for f in os.listdir(path):
        if not f.startswith('.'):
            dir_list.append(f)
    return dir_list

amichaut's avatar
amichaut committed
98

99
100
101
102
103
104
105
106
107
108
109
def safe_mkdir(path):
    """
    Make directory from only if it doesn't already exist
    :param path: path of directory to make
    :type path: str
    :return: path of directory
    :rtype: str
    """
    if not osp.isdir(path):
        os.mkdir(path)
    return path
amichaut's avatar
amichaut committed
110

amichaut's avatar
amichaut committed
111

amichaut's avatar
amichaut committed
112
def get_param_time_der(param):
113
114
115
116
117
118
119
120
    """
    Guess from a parameter name if it is a time derivative.
    If ends with _dot: first time derivative, if ends with _ddot: second time derivative
    :param param: parameters name
    :type param: str
    :return: raw parameter name (without derivative suffix) and type of derivative
    :rtype: str
    """
amichaut's avatar
amichaut committed
121
    if param.endswith('_dot'):
amichaut's avatar
amichaut committed
122
        return [param[:param.find('_dot')], 'first']
amichaut's avatar
amichaut committed
123
    elif param.endswith('_ddot'):
amichaut's avatar
amichaut committed
124
125
126
        return [param[:param.find('_ddot')], 'sec']
    else:
        return [param, None]
amichaut's avatar
amichaut committed
127

amichaut's avatar
amichaut committed
128
129

def make_unit_label(dimension='L', l_unit='um', t_unit='min'):
130
131
132
133
134
135
136
137
138
139
140
141
142
143
    """
    Make the Latex label of unit type depending on the dimension formula.
    Supported dimensions: L, LL, L/T, LL/T, L/TT, 1/L, 1/LL, T, 1/T
    Supported length units: um (for micrometers), s (for seconds), px (for pixels), au (for arbitrary unit)
    Supported time units: min (for minutes), mm (for millimeters), frame, au (for arbitrary unit)
    :param dimension: dimension using L and T for length and time dimensions
    :type dimension: str
    :param l_unit: length unit
    :type l_unit: str
    :param t_unit: time unit
    :type t_unit: str
    :return: label as a Latex formatted string
    :rtype: str
    """
amichaut's avatar
amichaut committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
    l_unit_dict = {'um': r'\mu m', 'mm': r'mm', 'px': 'px', 'none': '', 'au': ''}
    t_unit_dict = {'min': r'min', 's': r's', 'frame': 'frame', 'none': '', 'au': ''}

    if dimension == 'L':
        label = l_unit_dict[l_unit]
    elif dimension == 'LL':
        label = l_unit_dict[l_unit] + r'^2'
    elif dimension == 'LL/T':
        label = l_unit_dict[l_unit] + r'^2.' + t_unit_dict[t_unit] + r'^{-1}'
    elif dimension == 'L/TT':
        label = l_unit_dict[l_unit] + r'.' + t_unit_dict[t_unit] + r'^{-2}'
    elif dimension == 'L/T':
        label = l_unit_dict[l_unit] + r'.' + t_unit_dict[t_unit] + r'^{-1}'
    elif dimension == '1/L':
        label = l_unit_dict[l_unit] + r'^{-1}'
    elif dimension == '1/LL':
        label = l_unit_dict[l_unit] + r'^{-2}'
    elif dimension == 'T':
        label = t_unit_dict[t_unit]
    elif dimension == '1/T':
        label = t_unit_dict[t_unit] + r'^{-1}'
    elif dimension == 'none':
        label = ''
amichaut's avatar
amichaut committed
167
168
    else:
        print('Warning: this unit is not supported')
amichaut's avatar
amichaut committed
169
        label = r''
amichaut's avatar
amichaut committed
170
171
172

    return label

amichaut's avatar
amichaut committed
173
174

def make_param_label(param, l_unit='um', t_unit='min', time_der=None, mean=False, only_symbol=False, only_unit=False):
175
176
177
    """
    Make a Latex formatted label of a parameter.
    The first and second time derivative and the mean symbol can be used.
amichaut's avatar
amichaut committed
178
    Supported parameters: x,y,z,x_scaled,y_scaled,z_scaled,z_rel,vx,vy,vz,v,ax,ay,az,a,area
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
    t,frame,D,curl,div,track_length,track
    Supported length units: um (for micrometers), s (for seconds), px (for pixels), au (for arbitrary unit)
    Supported time units: min (for minutes), mm (for millimeters), frame, au (for arbitrary unit)
    :param param: parameter
    :type param: str
    :param l_unit: length unit
    :type l_unit: str
    :param t_unit: time unit
    :type t_unit: str
    :param time_der: time derivative: 'first' or 'sec' or None
    :type time_der: str
    :param mean: mean parameter: add angle brackets as a sign for a mean parameter
    :type mean: str
    :param only_symbol: return only the parameter symbol
    :type only_symbol: bool
    :param only_unit: return only the parameter unit
    :type only_unit: bool
    :return: parameter label: symbol and/or unit
    :rtype: str
    """

    # latex symbol
amichaut's avatar
amichaut committed
201
    symbol_dict = {'v': 'v', 'a': 'a', 'vx': 'v_x', 'vy': 'v_y', 'vz': 'v_z', 'ax': 'a_x', 'ay': 'a_y', 'az': 'a_z'}
amichaut's avatar
amichaut committed
202

203
    # a dict containing the features for each parameters (sym: symbol, dim: dimension, units, latex usage)
amichaut's avatar
amichaut committed
204
    param_dict = {}
amichaut's avatar
amichaut committed
205
    for p in list('xyz'):
amichaut's avatar
amichaut committed
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
        param_dict[p] = {'sym': p, 'dim': 'L', 'l_unit': 'px', 't_unit': 'none', 'latex': True}
    for p in ['x_scaled', 'y_scaled', 'z_scaled', 'z_rel']:
        param_dict[p] = {'sym': p[0], 'dim': 'L', 'l_unit': l_unit, 't_unit': 'none', 'latex': True}
    for p in ['vx', 'vy', 'vz', 'v']:
        param_dict[p] = {'sym': symbol_dict[p], 'dim': 'L/T', 'l_unit': l_unit, 't_unit': t_unit, 'latex': True}
    for p in ['ax', 'ay', 'az', 'a']:
        param_dict[p] = {'sym': symbol_dict[p], 'dim': 'L/TT', 'l_unit': l_unit, 't_unit': t_unit, 'latex': True}
    param_dict['t'] = {'sym': 't', 'dim': 'T', 'l_unit': 'none', 't_unit': t_unit, 'latex': True}
    param_dict['frame'] = {'sym': 'frame', 'dim': 'T', 'l_unit': 'none', 't_unit': 'none', 'latex': True}
    param_dict['D'] = {'sym': 'D', 'dim': 'LL/T', 'l_unit': l_unit, 't_unit': t_unit, 'latex': True}
    param_dict['div'] = {'sym': 'div', 'dim': '1/T', 'l_unit': l_unit, 't_unit': t_unit, 'latex': True}
    param_dict['curl'] = {'sym': 'curl', 'dim': '1/T', 'l_unit': l_unit, 't_unit': t_unit, 'latex': True}
    param_dict['track_length'] = {'sym': 'track duration', 'dim': 'T', 'l_unit': l_unit, 't_unit': t_unit,
                                  'latex': False}
    param_dict['track'] = {'sym': 'track id', 'dim': 'none', 'l_unit': l_unit, 't_unit': t_unit, 'latex': False}
amichaut's avatar
amichaut committed
221
    param_dict['area'] = {'sym': 'area', 'dim': 'LL', 'l_unit': l_unit, 't_unit': t_unit, 'latex': True}
amichaut's avatar
amichaut committed
222
223

    # check if mean
amichaut's avatar
amichaut committed
224
    if param.endswith('_mean'):
amichaut's avatar
amichaut committed
225
226
        param = param[:param.find('_mean')]
        mean = True
amichaut's avatar
amichaut committed
227

amichaut's avatar
amichaut committed
228
229
    # convert param if it ends by _dot or _ddot
    raw_param, time_der_ = get_param_time_der(param)
amichaut's avatar
amichaut committed
230
    if time_der_ is not None:
amichaut's avatar
amichaut committed
231
232
233
        time_der = time_der_
        param = raw_param

234
235
    # make symbol adding derivative or mean symbols
    symbol = param_dict[param]['sym']
amichaut's avatar
amichaut committed
236
237
    dim_der = ''
    if time_der == 'first':
238
239
        symbol = r'\dot{' + symbol + r'}'
        dim_der = '/T'  # change dimension
amichaut's avatar
amichaut committed
240
    elif time_der == 'sec':
241
242
        symbol = r'\ddot{' + symbol + r'}'
        dim_der = '/TT'  # change dimension
amichaut's avatar
amichaut committed
243
    if mean:
244
        symbol = r'\langle ' + symbol + r' \rangle'
amichaut's avatar
amichaut committed
245

246
    # make unit
amichaut's avatar
amichaut committed
247
248
    unit_ = make_unit_label(param_dict[param]['dim'] + dim_der, l_unit=param_dict[param]['l_unit'],
                            t_unit=param_dict[param]['t_unit'])
amichaut's avatar
amichaut committed
249

250
251
252
253
254
    # output
    if only_symbol and only_unit:
        print("Warning: only_symbol and only_unit can't be both True. Making them both False.")
        only_symbol = False
        only_unit = False
amichaut's avatar
amichaut committed
255

256
257
258
    label = ''
    # add latex dollar symbols if needed
    if param_dict[param]['latex']:
amichaut's avatar
amichaut committed
259
        if only_symbol:
260
261
262
            label = r'$' + symbol + r'$'
        elif only_unit:
            label = r'$' + unit_ + r'$'
amichaut's avatar
amichaut committed
263
        else:
264
265
266
267
268
            if len(unit_) > 0:  # if unit exists
                label = r'$' + symbol + r'$ ($' + unit_ + r'$)'
            else:
                label = r'$' + symbol + r'$'
    else:
amichaut's avatar
amichaut committed
269
        if only_symbol:
270
271
272
273
274
275
276
277
            label = symbol
        elif only_unit:
            label = unit_
        else:
            if len(unit_) > 0:  # if unit exists
                label = symbol + ' (' + unit_ + ')'
            else:
                label = symbol
amichaut's avatar
amichaut committed
278

amichaut's avatar
amichaut committed
279
280
    return label

amichaut's avatar
amichaut committed
281
282

def write_dict(dicts, filename, dict_names=None):
283
284
285
286
287
288
289
290
291
292
293
    """
    Write a dict or a list of dict into a csv file with keys in first column and values in the second column.
    Optional: if dicts is a list, a list of names for each dict can be given. There will be written in a separated
    row at the beginning of each dict
    :param dicts: a dict or a list of dict
    :type dicts: dict or list
    :param filename: filename of the csv file
    :type filename: str
    :param dict_names:
    :type dict_names: list
    """
amichaut's avatar
amichaut committed
294
    if type(dicts) is dict:
amichaut's avatar
amichaut committed
295
        dicts = [dicts]
amichaut's avatar
amichaut committed
296
297

    if type(dict_names) is list:
amichaut's avatar
amichaut committed
298
        if len(dicts) != len(dict_names):
amichaut's avatar
amichaut committed
299
            print("Warning: the name list doesn't match the dict list. Not printing names")
amichaut's avatar
amichaut committed
300
            dict_names = None
amichaut's avatar
amichaut committed
301
302
303

    with open(filename, "w+") as f:
        w = csv.writer(f)
amichaut's avatar
amichaut committed
304
        for i, d in enumerate(dicts):
amichaut's avatar
amichaut committed
305
306
            if type(d) is dict:
                if dict_names is not None:
amichaut's avatar
amichaut committed
307
                    f.write(dict_names[i] + '\n')
amichaut's avatar
amichaut committed
308
309
310
311
                for key, val in d.items():
                    w.writerow([key, val])
                f.write('\n')

312

313
def load_dict(filename):
314
315
316
317
318
319
320
321
    """
    Read a csv file and returns a dict, with csv first column as key and csv second column as values
    Try to convert to python objects if possible using the eval function
    :param filename: filename of the csv file
    :type filename: str
    :return: converted dict
    :rtype: dict
    """
322
    if not filename.endswith('.csv'):
323
324
        raise Exception("ERROR: No csv file passed. Aborting...")

325
    if not osp.exists(filename):
326
327
328
329
330
331
        raise Exception("ERROR: File does not exist. Aborting...")

    with open(filename, mode='r') as infile:
        reader = csv.reader(infile)
        mydict = {}
        for rows in reader:
332
            if len(rows) > 0:
333
334
335
                if rows[1] == '':
                    mydict[rows[0]] = None
                else:
336
                    try:
337
                        mydict[rows[0]] = eval(rows[1])  # if needs conversion
338
                    except:
339
                        mydict[rows[0]] = rows[1]  # if string
340
341
342

    return mydict

amichaut's avatar
amichaut committed
343

344
345
def make_grid(image_size, x_num=None, y_num=None, cell_size=None, scaled=False, lengthscale=1., origin=None,
              plot_grid=False, save_plot_fn=None):
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
    """
    Make a regular grid using numpy.meshgrid() over an image.
    Two grids are returned: a node_grid with the vertices of cells, and center_grid with the centers of the cells
    The meshsize can be defined by either the number of cells along one dimension (x or y), or the size of a cell
    (in px or scaled unit if scaled=True).
    If several definitions of meshsize are passed, by default the number of cells is used to ensure there is no
    conflict between definitions.
    If the grid does not cover the all image size it is cropped and centered (if origin is None) or tethered to one
    of 8 positions (left-bottom,center-bottom,etc.).
    :param image_size: image size (width,height) in px
    :type image_size: list or tuple
    :param x_num: number of cells along the x dimension
    :type x_num: int
    :param y_num: number of cells along the y dimension
    :type y_num: int
    :param cell_size: size of cell (in pixel or unit if scaled is True)
    :type cell_size: int or float
    :param scaled: cell size is scaled if True
    :type scaled: bool
    :param lengthscale: pixel size
    :type lengthscale: float
    :param origin: anchorage of the grid
    :type origin: str
    :param plot_grid: to plot an representation of the generated grid
    :type plot_grid: bool
    :param save_plot_fn: filename for plot
    :type save_plot_fn: str
    :return: node_grid of shape (n,m) depending on definition and center_grid of shape (n-1,m-1)
    :rtype: tuple
    """
amichaut's avatar
amichaut committed
376

amichaut's avatar
amichaut committed
377
    width, height = image_size
amichaut's avatar
amichaut committed
378

amichaut's avatar
amichaut committed
379
380
    # ensure there is no conflict by keeping only one definition (priority: x_num,y_num,cell_size)
    if [x_num, y_num, cell_size] == [None, None, None]:  # if no definition passed
amichaut's avatar
amichaut committed
381
        raise Exception("ERROR: cannot generate grids with no information. Aborting...")
amichaut's avatar
amichaut committed
382
    elif x_num is not None and y_num is not None and cell_size is not None:  # if three definitions, use x_num
amichaut's avatar
amichaut committed
383
384
        y_num = None
        cell_size = None
amichaut's avatar
amichaut committed
385
    elif x_num is not None and y_num is not None and cell_size is None:  # if x_num and y_num, use x_num
amichaut's avatar
amichaut committed
386
        y_num = None
387
388
    # if x_num or y_num and cell_size, use x_num or y_num
    elif [x_num, y_num] != [None, None] and cell_size is not None:
amichaut's avatar
amichaut committed
389
390
        cell_size = None

391
    # find definition available
amichaut's avatar
amichaut committed
392
    definition = None
amichaut's avatar
amichaut committed
393
    if x_num is not None and y_num is None and cell_size is None:
amichaut's avatar
amichaut committed
394
        definition = 'x_num'
amichaut's avatar
amichaut committed
395
    elif x_num is None and y_num is not None and cell_size is None:
amichaut's avatar
amichaut committed
396
        definition = 'y_num'
amichaut's avatar
amichaut committed
397
    elif x_num is None and y_num is None and cell_size is not None:
amichaut's avatar
amichaut committed
398
399
        definition = 'cell_size'
    else:
400
        raise Exception("ERROR: definition not found. Aborting...")
amichaut's avatar
amichaut committed
401

amichaut's avatar
amichaut committed
402
    # compute cell_size depending on the definition
403
    cell_size_ = 0
amichaut's avatar
amichaut committed
404
405
406
    if definition == 'x_num':
        x_num = int(x_num)
        if x_num < 1:
amichaut's avatar
amichaut committed
407
            raise Exception("ERROR: x_num needs to be at least 1. Aborting...")
amichaut's avatar
amichaut committed
408
409
410
411
        cell_size_ = float(width) / (x_num + 1)  # so x_num is the number of cells along the dimension
    elif definition == 'y_num':
        y_num = int(y_num)
        if y_num < 1:
amichaut's avatar
amichaut committed
412
            raise Exception("ERROR: y_num needs to be at least 1. Aborting...")
amichaut's avatar
amichaut committed
413
414
415
416
        cell_size_ = float(height) / (y_num + 1)  # so y_num is the number of cells along the dimension
    elif definition == 'cell_size':
        cell_size_ = cell_size if not scaled else cell_size / lengthscale
        if cell_size_ > width or cell_size_ > height:
amichaut's avatar
amichaut committed
417
418
            raise Exception("ERROR: cell size larger than image size. Aborting...")

amichaut's avatar
amichaut committed
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
    x_array = np.arange(0, width + cell_size_, cell_size_)
    x_array = x_array[x_array < width]
    y_array = np.arange(0, height + cell_size_, cell_size_)
    y_array = y_array[y_array < height]
    x_edge = width - x_array.max()
    y_edge = height - y_array.max()

    if origin is None or origin == 'center':  # center
        x_array = x_array + x_edge / 2
        y_array = y_array + y_edge / 2
    elif origin == "left-bottom":
        pass  # nothing to change
    elif origin == "center-bottom":
        x_array = x_array + x_edge / 2
    elif origin == "right-bottom":
        x_array = x_array + x_edge
    elif origin == "right-center":
        x_array = x_array + x_edge
        y_array = y_array + y_edge / 2
    elif origin == "right-top":
        x_array = x_array + x_edge
        y_array = y_array + y_edge
    elif origin == "center-top":
        x_array = x_array + x_edge / 2
        y_array = y_array + y_edge
    elif origin == "left-top":
        y_array = y_array + y_edge
    elif origin == "left-center":
        y_array = y_array + y_edge / 2

    x_center = x_array + cell_size_ / 2
    y_center = y_array + cell_size_ / 2

452
453
    node_grid = np.meshgrid(x_array, y_array)
    center_grid = np.meshgrid(x_center[:-1], y_center[:-1])
amichaut's avatar
amichaut committed
454
455

    if plot_grid:
amichaut's avatar
amichaut committed
456
457
458
        X, Y = node_grid
        x, y = center_grid
        fig, ax = plt.subplots(1, 1, figsize=plot_param['figsize'])
amichaut's avatar
amichaut committed
459
460
        ax.set_aspect('equal')
        for i in range(len(x_array)):
amichaut's avatar
amichaut committed
461
            plt.plot([X[0, i], X[-1, i]], [Y[0, i], Y[-1, i]], color_list[0])  # plot vertical lines
amichaut's avatar
amichaut committed
462
        for i in range(len(y_array)):
amichaut's avatar
amichaut committed
463
464
465
466
            plt.plot([X[i, 0], X[i, -1]], [Y[i, 0], Y[i, -1]], color_list[0])  # plot horizontal lines
        plt.scatter(x, y, color=color_list[1])  # plot center of cells as dot
        ax.set_xlim(0, width)
        ax.set_ylim(0, height)
amichaut's avatar
amichaut committed
467
        if save_plot_fn is not None:
468
469
            fig.tight_layout()
            fig.savefig(save_plot_fn, dpi=plot_param['dpi'])
amichaut's avatar
amichaut committed
470
471
            plt.close(fig)

amichaut's avatar
amichaut committed
472
473
474
475
    return node_grid, center_grid


def pool_datasets(df_list, name_list):
476
477
478
479
480
481
482
483
484
    """
    Concatenate together several dataframes with a column identifying the datasets' names
    :param df_list: list of DataFrames
    :type df_list: list
    :param name_list: names identifying each DataFrame
    :type name_list: list
    :return: DataFrame concatenating list of DataFrames
    :rtype: pandas.DataFrame
    """
amichaut's avatar
amichaut committed
485
    df_out = pd.DataFrame()
amichaut's avatar
amichaut committed
486

amichaut's avatar
amichaut committed
487
    for i, df in enumerate(df_list):
amichaut's avatar
amichaut committed
488
489
        if df is None:
            continue
amichaut's avatar
amichaut committed
490
491
        df['dataset'] = name_list[i]
        df_out = pd.concat([df_out, df])
amichaut's avatar
amichaut committed
492
493
494

    return df_out

amichaut's avatar
amichaut committed
495

amichaut's avatar
amichaut committed
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
def group_consecutives(vals, step=1):
    """
    Group together list of consecutive integer from a list of integer.
    :param vals: input list to be grouped
    :type vals: list
    :param step: expected gap between consecutive integer
    :type step: int
    :return: list of list (each being the consecutive integers)
    :rtype: list
    """
    run = []
    result = [run]
    expect = None
    for v in vals:
        if (v == expect) or (expect is None):
            run.append(v)
        else:
            run = [v]
            result.append(run)
        expect = v + step
    return result


amichaut's avatar
amichaut committed
519
def get_info(data_dir):
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
    """
    Get info about data given by a info.txt in the data directory.
    info.txt contains two mandatory info and several other optional ones.
    Each info is written on a line following the template: keyword : value
    Mandatory info: lengthscale in um/px and delta_t (frame intervalle) in min by default.
    Alternative units can be given by length_unit, time_unit.
    Other optional info:
    - table_unit: unit used in data table
    - separator: separator used in data table
    - image_width (in px)
    - image_height (in px)
    :param data_dir: path of the data directory
    :type data_dir: str
    :return: info in a dict
    :rtype: dict
    """
amichaut's avatar
amichaut committed
536
537
    filename = osp.join(data_dir, "info.txt")
    info = {}
amichaut's avatar
amichaut committed
538

amichaut's avatar
amichaut committed
539
    # list of parameters to grab
540
    string_list = ['length_unit', 'time_unit', 'table_unit', 'separator']
amichaut's avatar
amichaut committed
541
542
    int_list = ['image_width', 'image_height']
    float_list = ['lengthscale', 'timescale', 'z_step']
amichaut's avatar
amichaut committed
543
544

    if osp.exists(filename):
amichaut's avatar
amichaut committed
545
        # get parameters
amichaut's avatar
amichaut committed
546
547
        with open(filename) as f:
            for line in f:
amichaut's avatar
amichaut committed
548
                for param in string_list + int_list + float_list:
amichaut's avatar
amichaut committed
549
550
                    if param in line:
                        tokens = line.split(':')
amichaut's avatar
amichaut committed
551
552
                        if len(tokens) == 2:
                            if len(tokens[1].strip('\n')) > 0:
amichaut's avatar
amichaut committed
553
                                info[param] = tokens[1].strip('\n')
amichaut's avatar
amichaut committed
554
        # convert parameters
amichaut's avatar
amichaut committed
555
556
        for k in info.keys():
            if k in int_list:
amichaut's avatar
amichaut committed
557
                info[k] = int(info[k])
amichaut's avatar
amichaut committed
558
            elif k in float_list:
amichaut's avatar
amichaut committed
559
                info[k] = float(info[k])
amichaut's avatar
amichaut committed
560

amichaut's avatar
amichaut committed
561
    else:
amichaut's avatar
amichaut committed
562
563
        raise Exception("ERROR: info.txt doesn't exist or is not at the main data folder")

amichaut's avatar
amichaut committed
564
    mandatory_info = ['timescale', 'lengthscale']
amichaut's avatar
amichaut committed
565
    for mand_info in mandatory_info:
amichaut's avatar
amichaut committed
566
        if mand_info not in info.keys():
amichaut's avatar
amichaut committed
567
568
569
            print("ERROR: {} is not in info.txt".format(mand_info))
    return info

amichaut's avatar
amichaut committed
570

571
572
def get_data(data_dir, df=None, refresh=False, split_traj=False, set_origin_=False, image=None, reset_dim=['x', 'y'],
             invert_axes=[]):
amichaut's avatar
amichaut committed
573
    """
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
    Main function to import data and perform the initial processing (scaling and computing of time derivatives).
    It saves the database as a pickle object.
    If the database already exists it just loads it from the pickle object if refresh is False.
    The data are either loaded from a file (positions.txt or positions.csv) or passed as pandas.DataFrame.
    Column names of df or the positions file must be: 'x','y',('z'),'frame','track'
    The coordinates origin and orientation can be reset with set_origin_ and invert_axes
    :param data_dir: path of the data directory
    :type data_dir: str
    :param df: raw positions Dataframe
    :type df: pandas.DataFrame or None
    :param refresh: refresh database
    :type refresh: bool
    :param split_traj: solve gaps in trajectories: interpolate missing data if False (default), or split in new tracks
    :type split_traj: bool
    :param set_origin_: to reset origin of coordinates with dict with dimensions as key and coordinates as values
    :param set_origin_: dict or None
    :param image: image dict returned by get_image()
    :type image: dict
    :param reset_dim: list of dimensions to reset, as set_origin_ contains coordinates along all dimensions
    :type reset_dim: list
    :param invert_axes: list of dimensions to invert (change sign)
    :type invert_axes: list
    :return: dict with dataframe and key info
    :rtype: dict
amichaut's avatar
amichaut committed
598
599
600
601
    """
    # load existing database 
    pickle_fn = osp.join(data_dir, "data_base.p")

602
603
    if not osp.exists(pickle_fn) or refresh:  # compute database

amichaut's avatar
amichaut committed
604
        # get info
amichaut's avatar
amichaut committed
605
606
607
608
        info = get_info(data_dir)
        lengthscale = info["lengthscale"]
        timescale = info["timescale"]
        table_unit = 'px' if 'table_unit' not in info.keys() else info['table_unit']  # by default positions are in px
amichaut's avatar
amichaut committed
609
        z_step = None if 'z_step' not in info.keys() else info['z_step']
amichaut's avatar
amichaut committed
610
        if z_step == 0:
amichaut's avatar
amichaut committed
611
612
            z_step = None

amichaut's avatar
amichaut committed
613
        # if no dataframe is passed try to get it from a csv of txt file
amichaut's avatar
amichaut committed
614
        if df is None:
615
            data_file = osp.join(data_dir, 'positions.csv')
amichaut's avatar
amichaut committed
616
617
            sep = info["separator"] if "separator" in info.keys() else ','  # by default comma separated
            sep = '\t' if sep == 'tab' else sep
618
            df = pd.read_csv(data_file, sep=sep)  # columns must be ['x','y','z','frame','track']
amichaut's avatar
amichaut committed
619

amichaut's avatar
amichaut committed
620
        # check data type
621
        dimensions = ['x', 'y', 'z'] if 'z' in df.columns else ['x', 'y']
amichaut's avatar
amichaut committed
622
623
        dim = len(dimensions)
        df['frame'] = df['frame'].astype(np.int)
amichaut's avatar
amichaut committed
624
        for d in dimensions:
amichaut's avatar
amichaut committed
625
626
            df[d] = df[d].astype(np.float)

627
        # rename track id and deal with gaps
amichaut's avatar
amichaut committed
628
        df = tca.regularize_traj(df, dimensions, split_traj)
629
630

        # scale data
amichaut's avatar
amichaut committed
631
632
        tca.scale_dim(df, dimensions=dimensions, timescale=timescale, lengthscale=lengthscale, z_step=z_step,
                      unit=table_unit, invert_axes=invert_axes)
633
634

        # compute velocities and accelerations
amichaut's avatar
amichaut committed
635
636
637
        tca.compute_vel_acc(df, dimensions=dimensions, timescale=timescale)

        # reset coordinates origin
amichaut's avatar
amichaut committed
638
639
        if set_origin_ is not False:
            if type(set_origin_) is dict:
amichaut's avatar
amichaut committed
640
                orig_coord_ = set_origin_
amichaut's avatar
amichaut committed
641
            else:
amichaut's avatar
amichaut committed
642
643
                orig_coord_ = None
            df, orig_coord = set_origin(df, image, reset_dim, lengthscale, orig_coord_)
amichaut's avatar
amichaut committed
644

amichaut's avatar
amichaut committed
645
646
        if 'z' in dimensions:  # relative z: centered around mean
            df['z_rel'] = df['z_scaled'] - df['z_scaled'].mean()
amichaut's avatar
amichaut committed
647

amichaut's avatar
amichaut committed
648
        # update pickle
649
        data = {'df': df, 'lengthscale': lengthscale, 'timescale': timescale, 'dim': dim, 'dimensions': dimensions}
amichaut's avatar
amichaut committed
650
        pickle.dump(data, open(pickle_fn, "wb"))
amichaut's avatar
amichaut committed
651
    else:
amichaut's avatar
amichaut committed
652
653
        data = pickle.load(open(pickle_fn, "rb"))

amichaut's avatar
amichaut committed
654
655
    return data

amichaut's avatar
amichaut committed
656
657

def get_traj(track_groups, track, min_frame=None, max_frame=None):
658
659
660
661
662
663
664
665
666
667
668
669
670
671
    """
    Get a single trajectory from its track id.
    A subset of the trajectory can be extracted by using min_frame and max_frame
    :param track_groups: output of a pandas.groupby()
    :type track_groups: pandas.DataFrameGroupBy
    :param track: track id
    :type track: int
    :param min_frame: minimal frame
    :type min_frame: int or None
    :param max_frame: maximal frame
    :type max_frame: int or None
    :return: dataframe of trajectory
    :rtype: pandas.DataFrame
    """
amichaut's avatar
amichaut committed
672
    group = track_groups.get_group(track)
amichaut's avatar
amichaut committed
673
    if min_frame is not None:
amichaut's avatar
amichaut committed
674
        group = group[group['frame'] >= min_frame]
amichaut's avatar
amichaut committed
675
    if max_frame is not None:
amichaut's avatar
amichaut committed
676
        group = group[group['frame'] <= max_frame]
amichaut's avatar
amichaut committed
677
678
    return group.reset_index(drop=True)

amichaut's avatar
amichaut committed
679
680

def filter_by_traj_len(df, min_traj_len=1, max_traj_len=None):
681
682
683
684
685
686
687
688
689
690
691
    """
    Filter trajectories by their minimal and/or maximal length (in frames)
    :param df: dataframe of trajectories
    :type df: pandas.DataFrame
    :param min_traj_len: minimal trajectory length
    :type min_traj_len: int or None
    :param max_traj_len: maximal trajectory length
    :type max_traj_len: int or None
    :return: filtered dataframe of trajectories
    :rtype: pandas.DataFrame
    """
692

amichaut's avatar
amichaut committed
693
694
    if max_traj_len is None:  # assign the longest possible track
        max_traj_len = df['frame'].max() - df['frame'].min() + 1
695
    min_traj_len = 1 if min_traj_len is None else min_traj_len  # assign 1, if not given
696

amichaut's avatar
amichaut committed
697
    tracks = df.groupby('track')
698
    df_list = []
amichaut's avatar
amichaut committed
699
    for t in df['track'].unique():
amichaut's avatar
amichaut committed
700
        track = tracks.get_group(t)
701
702
703
704
705
        if (track.shape[0] >= min_traj_len) & (track.shape[0] <= max_traj_len):
            df_list.append(track)

    out_df = pd.concat(df_list, ignore_index=True)
    return out_df
amichaut's avatar
amichaut committed
706

amichaut's avatar
amichaut committed
707
708

def filter_by_frame_subset(df, frame_subset=None):
709
710
711
712
713
714
715
716
717
    """
    Filter trajectories by extracting a subset of frames
    :param df: dataframe of trajectories
    :type df: pandas.DataFrame
    :param frame_subset: frame boundaries of subset [minimal_frame,maximal_frame]
    :type frame_subset: list
    :return: filtered dataframe of trajectories
    :rtype: pandas.DataFrame
    """
amichaut's avatar
amichaut committed
718
719
    if frame_subset is None:
        return df
amichaut's avatar
amichaut committed
720
    elif frame_subset[0] is None and frame_subset[1] is None:
amichaut's avatar
amichaut committed
721
722
        return df
    elif frame_subset[0] is None or frame_subset[1] is None:
amichaut's avatar
amichaut committed
723
724
725
726
727
728
729
730
        if frame_subset[0] is not None:
            df_ = df[df['frame'] >= frame_subset[0]]
        elif frame_subset[1] is not None:
            df_ = df[df['frame'] <= frame_subset[1]]
    else:
        df_ = df[((df['frame'] >= frame_subset[0]) & (df['frame'] <= frame_subset[1]))]

    if df_.shape[0] == 0:
amichaut's avatar
amichaut committed
731
732
733
734
        print("WARNING: no data for this frame subset. Returning unfiltered database")
        return df
    return df_

amichaut's avatar
amichaut committed
735

736
def filter_by_region(df, xlim=None, ylim=None, zlim=None):
737
738
739
740
741
742
743
744
745
746
747
748
749
    """
    Extract data within a box given by xlim, ylim and zlim in px
    :param df: dataframe of trajectories
    :type df: pandas.DataFrame
    :param xlim: [xmin,xmax]
    :type xlim: list or None
    :param ylim: [ymin,ymax]
    :type ylim: list or None
    :param zlim: [zmin,zmax]
    :type zlim: list or None
    :return: filtered dataframe of trajectories
    :rtype: pandas.DataFrame
    """
amichaut's avatar
amichaut committed
750
    df_ = df.copy()
amichaut's avatar
amichaut committed
751

amichaut's avatar
amichaut committed
752
753
    dims = list('xyz')
    for i, lim_ in enumerate([xlim, ylim, zlim]):
amichaut's avatar
amichaut committed
754
        if lim_ is not None:
amichaut's avatar
amichaut committed
755
            if lim_[0] is None and lim_[1] is None:
amichaut's avatar
amichaut committed
756
757
                pass
            elif lim_[0] is None or lim_[1] is None:
amichaut's avatar
amichaut committed
758
759
760
761
762
763
764
765
                if lim_[0] is not None:
                    df_ = df_[df_[dims[i]] >= lim_[0]]
                elif lim_[1] is not None:
                    df_ = df_[df_[dims[i]] <= lim_[1]]
            else:
                df_ = df_[((df_[dims[i]] >= lim_[0]) & (df_[dims[i]] <= lim_[1]))]

    if df_.shape[0] == 0:
amichaut's avatar
amichaut committed
766
767
768
769
        print("WARNING: no data for this frame subset. Returning unfiltered database")
        return df
    return df_

amichaut's avatar
amichaut committed
770

771
def get_coordinates(image, interactive=True, verbose=True):
772
773
774
775
776
777
778
779
    """
    Interactive selection of coordinates on an image by hand-drawing using a Napari viewer.
    Selection supported: points and rectangle.
    :param image: dict returned by get_image()
    :type image: dict
    :return: dict of list of selected shapes: {'points':[coordinates1,...],'rectangle':[coordinates1,...]}
    :rtype: dict
    """
amichaut's avatar
amichaut committed
780

amichaut's avatar
amichaut committed
781
782
783
    image_fn = image['image_fn']
    t_dim = image['t_dim']
    z_dim = image['z_dim']
amichaut's avatar
amichaut committed
784
785
786

    im = io.imread(image_fn)

amichaut's avatar
amichaut committed
787
    selecting = True
amichaut's avatar
amichaut committed
788
    while selecting:
amichaut's avatar
amichaut committed
789
790
791
        # create a list to be modified in get_coord so it is not deleted when get_coord ends
        shape_list = []
        points_list = []
amichaut's avatar
amichaut committed
792
793
        with napari.gui_qt():
            viewer = napari.view_image(im)
794
795
            if verbose:
                print("Draw points or rectangles, then press ENTER and close the image viewer")
amichaut's avatar
amichaut committed
796

797
            # retrieve coodinates on clicking Enter
amichaut's avatar
amichaut committed
798
799
800
801
802
803
804
805
            @viewer.bind_key('Enter')
            def get_coord(viewer):
                for layer in viewer.layers:
                    if type(layer) is napari.layers.shapes.shapes.Shapes:
                        shape_list.append(layer)
                    if type(layer) is napari.layers.points.points.Points:
                        points_list.append(layer.data)

amichaut's avatar
amichaut committed
806
807
808
809
        # inspect selected layers
        rectangle_list = []
        if len(shape_list) > 0:
            for i, shape_type_ in enumerate(shape_list[0].shape_type):
amichaut's avatar
amichaut committed
810
811
812
                if shape_type_ == 'rectangle':
                    rectangle_list.append(shape_list[0].data[i])
        points = np.array([])
amichaut's avatar
amichaut committed
813
814
        if len(points_list) > 0:
            points = points_list[0]
amichaut's avatar
amichaut committed
815

816
817
818
819
820
821
822
823
        # interactive validation of selection 
        if verbose:
            print('You have selected {} point(s) and {} rectangle(s)'.format(points.shape[0], len(rectangle_list)))
        if interactive:
            finished = input('Is the selection correct? [y]/n: ')
            if finished != 'n':
                selecting = False
        else: 
amichaut's avatar
amichaut committed
824
            selecting = False
amichaut's avatar
amichaut committed
825

826
    # retreive coordinates
amichaut's avatar
amichaut committed
827
    coord_dict = {'points': [], 'rectangle': []}
828
829
    
    # get rectangle coordinates
amichaut's avatar
amichaut committed
830
831
    for rect in rectangle_list:
        # if 4D stack
amichaut's avatar
amichaut committed
832
        if t_dim is not None and z_dim is not None:
amichaut's avatar
amichaut committed
833
834
835
836
837
            frame = int(rect[0, t_dim])
            z = int(rect[0, z_dim])
            xmin, xmax = [rect[:, 3].min(), rect[:, 3].max()]
            ymin, ymax = [rect[:, 2].min(), rect[:, 2].max()]
        # if 3D (3rd dim being time or z)
amichaut's avatar
amichaut committed
838
        elif t_dim is not None or z_dim is not None:
amichaut's avatar
amichaut committed
839
840
841
842
843
            frame = int(rect[0, t_dim]) if t_dim is not None else None
            z = int(rect[0, z_dim]) if z_dim is not None else None
            xmin, xmax = [rect[:, 2].min(), rect[:, 2].max()]
            ymin, ymax = [rect[:, 1].min(), rect[:, 1].max()]
        # if 2D
amichaut's avatar
amichaut committed
844
845
846
        else:
            frame = None
            z = None
amichaut's avatar
amichaut committed
847
848
            xmin, xmax = [rect[:, 1].min(), rect[:, 1].max()]
            ymin, ymax = [rect[:, 0].min(), rect[:, 0].max()]
amichaut's avatar
amichaut committed
849

850
        coord_dict['rectangle'].append({'frame': frame, 'z': z, 'xlim': [xmin, xmax], 'ylim': [ymin, ymax]})
amichaut's avatar
amichaut committed
851

852
    # get points coordinates
amichaut's avatar
amichaut committed
853
854
    for i in range(points.shape[0]):
        # if 4D stack
amichaut's avatar
amichaut committed
855
        if t_dim is not None and z_dim is not None:
amichaut's avatar
amichaut committed
856
857
858
859
            frame = int(points[i, t_dim])
            z = int(points[i, z_dim])
            x, y = [points[i, 3], points[i, 2]]
        # if 3D (3rd dim being time or z)
amichaut's avatar
amichaut committed
860
        elif t_dim is not None or z_dim is not None:
amichaut's avatar
amichaut committed
861
862
863
864
            frame = int(points[i, t_dim]) if t_dim is not None else None
            z = int(points[i, z_dim]) if z_dim is not None else None
            x, y = [points[i, 2], points[i, 1]]
        # if 2D
amichaut's avatar
amichaut committed
865
866
867
        else:
            frame = None
            z = None
amichaut's avatar
amichaut committed
868
            x, y = [points[i, 1], points[i, 0]]
amichaut's avatar
amichaut committed
869

amichaut's avatar
amichaut committed
870
        coord_dict['points'].append({'frame': frame, 'x': x, 'y': y, 'z': z})
amichaut's avatar
amichaut committed
871
872
873

    return coord_dict

amichaut's avatar
amichaut committed
874

875
876
877
878
879
880
881
882
883
884
def filter_by_traj_id(df, track_list=None):
    """
    Filter by trajectory id. Only one id can be given
    :param df: dataframe of trajectories
    :type df: pandas.DataFrame
    :param track_list: list of trajectory ids
    :type track_list: list or int or float or None
    :return: filtered dataframe of trajectories
    :rtype: pandas.DataFrame
    """
amichaut's avatar
amichaut committed
885

886
887
888
889
    if track_list is None:
        return df
    elif type(track_list) is float or type(track_list) is int:
        track_list = [track_list]
amichaut's avatar
amichaut committed
890

amichaut's avatar
amichaut committed
891
    tracks = df.groupby('track')
892
893
894
    df_list = []
    for track in track_list:
        df_list.append(get_traj(tracks, track))
amichaut's avatar
amichaut committed
895

896
897
    out_df = pd.concat(df_list, ignore_index=True)
    return out_df
amichaut's avatar
amichaut committed
898
899


900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
def select_traj_by_xyzt(df, xlim=None, ylim=None, zlim=None, frame_lim=None):
    """
    Get ids of trajectories going through an xyzt box. The spatiotemporal box is defined its boundaries
    :param df: dataframe of trajectories
    :type df: pandas.DataFrame
    :param xlim: x boundaries
    :param xlim: list or None
    :param ylim: y boundaries
    :param ylim: list or None
    :param zlim: z boundaries
    :param zlim: list or None
    :param frame_lim: frame boundaries or unique frame
    :param frame_lim: int or list or None
    :return: list of trajectories id
    :rtype: list
    """
    # filter by frame subset
    if type(frame_lim) is list:
        df = filter_by_frame_subset(df, frame_subset=frame_lim)
    elif type(frame_lim) is int or type(frame_lim) is float:
        df = df[df['frame'] == frame_lim]
amichaut's avatar
amichaut committed
921

922
923
924
925
926
927
    # filter by region
    df = filter_by_region(df, xlim=xlim, ylim=ylim, zlim=zlim)

    # get ids
    track_list = df['track'].unique()
    return track_list
amichaut's avatar
amichaut committed
928

amichaut's avatar
amichaut committed
929
930

def set_origin(df, image=None, reset_dim=['x', 'y'], lengthscale=1., orig_coord=None):
amichaut's avatar
amichaut committed
931
932
933
934
935
    """Set the origin of coordinates by selecting a point through a viewer. 
    Only some dimensions can be reset by reset_dim, the other are left unchanged.
    If no image is provided, the origin coordinates can be manually passed by orig_coord"""

    if orig_coord is None:
amichaut's avatar
amichaut committed
936
        # draw origin on image
amichaut's avatar
amichaut committed
937
        if image is not None:
amichaut's avatar
amichaut committed
938
            selection = get_coordinates(image)
amichaut's avatar
amichaut committed
939
940
941
            if len(selection['points']) != 1:
                raise Exception("ERROR: you need to select exactly one point to set the origin. Aborting...")

amichaut's avatar
amichaut committed
942
            origin = dict.fromkeys(reset_dim)
amichaut's avatar
amichaut committed
943
944

            for d in reset_dim:
amichaut's avatar
amichaut committed
945
                coord = selection['points'][0][d]
amichaut's avatar
amichaut committed
946
                if coord is not None:
amichaut's avatar
amichaut committed
947
948
                    coord *= lengthscale  # scale coordinate
                origin[d] = coord
amichaut's avatar
amichaut committed
949
950
951
952
        else:
            raise Exception("ERROR: no image nor origin coordinates provided. Aborting...")
    else:
        reset_dim = list(orig_coord.keys())
amichaut's avatar
amichaut committed
953
        origin = {d: orig_coord[d] * lengthscale for d in reset_dim}
amichaut's avatar
amichaut committed
954
955
956

    for d in reset_dim:
        if origin[d] is not None:
amichaut's avatar
amichaut committed
957
958
959
            df[d + '_scaled'] = df[d + '_scaled'] - origin[d]

    return df, origin
amichaut's avatar
amichaut committed
960
961


962
963
964
965
966
967
968
969
970
971
972
def select_sub_data(df, filters=[]):
    """
    Select subsets of data according a list of filters. Each element of the list will lead to a subset.
    Each subset is filtered by a dict:
    {'xlim','ylim','zlim','frame_subset','min_traj_len','max_traj_len','track_list','name'}
    :param df: dataframe of trajectories
    :type df: pandas.DataFrame
    :param filters: list of filters or single set of filters
    :type filters: list or dict or None
    :return: filtered dataframe of trajectories
    :rtype: pandas.DataFrame
amichaut's avatar
amichaut committed
973
974
    """

975
976
977
978
979
    # if no filter return input
    if filters is None:
        return df
    elif len(filters) == 0:
        return df
amichaut's avatar
amichaut committed
980

981
982
    # if only one set of filters
    if type(filters) is dict:  
amichaut's avatar
amichaut committed
983
        filters = [filters]
amichaut's avatar
amichaut committed
984

985
986
    # perform filtering
    df_list = []  # temp list of df 
amichaut's avatar
amichaut committed
987
    for filt in filters:
988
989
        # filter
        df_ = filter_by_region(df, xlim=filt['xlim'], ylim=filt['ylim'], zlim=filt['zlim'])
amichaut's avatar
amichaut committed
990
991
        df_ = filter_by_frame_subset(df_, frame_subset=filt['frame_subset'])
        df_ = filter_by_traj_len(df_, min_traj_len=filt['min_traj_len'], max_traj_len=filt['max_traj_len'])
992
993
994
995
996
997
998
999
        df_ = filter_by_traj_id(df_, filt['track_list'])
        if filt['track_ROI'] is not None:
            track_list = select_traj_by_xyzt(df_,xlim=filt['track_ROI']['xlim'], ylim=filt['track_ROI']['ylim'], 
                                            zlim=filt['track_ROI']['zlim'], frame_lim=filt['track_ROI']['frame_lim'])
            df_ = filter_by_traj_id(df_, track_list)

        # subset name
        df_['subset'] = filt['name']
amichaut's avatar
amichaut committed
1000
        df_list.append(df_)
For faster browsing, not all history is shown. View entire blame