Commit df134f7e authored by amichaut's avatar amichaut
Browse files

added possibility to plot custom variable to trajectory module

parent 24a0d3aa
This diff is collapsed.
This diff is collapsed.
......@@ -166,7 +166,7 @@ def recompute_pos(df, dimensions=['x', 'y', 'z'], lengthscale=1.):
df[dim] = df[dim + '_scaled'] * lengthscale
def compute_track_prop(df, dimensions=['x', 'y', 'z']):
def compute_track_prop(df, dimensions=['x', 'y', 'z'],custom_var={}):
"""Compute track properties by averaging all data along a traj"""
groups = df.groupby('track')
......@@ -175,7 +175,7 @@ def compute_track_prop(df, dimensions=['x', 'y', 'z']):
scaled_dimensions = [dim + '_scaled' for dim in dimensions]
vel = ['v' + dim for dim in dimensions]
acc = ['a' + dim for dim in dimensions]
mean_col = ['t'] + dimensions + scaled_dimensions + vel + acc + ['v', 'a']
mean_col = ['t'] + dimensions + scaled_dimensions + vel + acc + ['v', 'a'] + list(custom_var.keys())
# all columns
columns = ['track', 'track_length'] + mean_col
......@@ -183,6 +183,10 @@ def compute_track_prop(df, dimensions=['x', 'y', 'z']):
columns.append('subset')
df_out = pd.DataFrame(columns=columns)
#check custom_var are numeric
for var in custom_var.keys():
df[var] = pd.to_numeric(df[var],errors='coerce')
# average
for i, track in enumerate(df['track'].unique()):
traj = groups.get_group(track)
......
......@@ -645,7 +645,7 @@ def plot_MSD(data_dir, track, track_groups=None, df=None, df_out=None, fit_model
def plot_param_vs_param(data_dir, x_param, y_param, df=None, hue=None, hue_order=None, set_axis_lim=None,
plot_config=None, x_bin_num=None, ci=None, fit_reg=False, scatter=True,
plot_dir=None, prefix='', suffix=''):
plot_dir=None, prefix='', suffix='', custom_var={}):
"""Plot a parameter of df (y_param) against another parameter (x_param). Optional: compare datasets with hue as datasets identifier."""
plot_config = make_plot_config() if plot_config is None else plot_config
......@@ -672,12 +672,18 @@ def plot_param_vs_param(data_dir, x_param, y_param, df=None, hue=None, hue_order
# make labels
info = tpr.get_info(data_dir)
x_lab = tpr.make_param_label(x_param, l_unit=info['length_unit'], t_unit=info['time_unit'])
y_lab = tpr.make_param_label(y_param, l_unit=info['length_unit'], t_unit=info['time_unit'])
if x_param in custom_var.keys():
x_lab = tpr.make_param_label(None, manual_symbol=custom_var[x_param]['name'], manual_unit=custom_var[x_param]['unit'])
else:
x_lab = tpr.make_param_label(x_param, l_unit=info['length_unit'], t_unit=info['time_unit'])
if y_param in custom_var.keys():
y_lab = tpr.make_param_label(None, manual_symbol=custom_var[y_param]['name'], manual_unit=custom_var[y_param]['unit'])
else:
y_lab = tpr.make_param_label(y_param, l_unit=info['length_unit'], t_unit=info['time_unit'])
# make sure data is float and finite
for p in [x_param,y_param]:
df[p] = df[p].astype(np.float)
df[p] = pd.to_numeric(df[p],errors='coerce')
df = df[np.isfinite(df[p])]
# make sure that sns.lmplot does not use the continuous colormap
......@@ -728,8 +734,7 @@ def plot_param_vs_param(data_dir, x_param, y_param, df=None, hue=None, hue_order
def plot_param_hist(data_dir, param, df=None, hue=None, hue_order=None, hist=True, kde=True,
plot_config=None,
plot_dir=None, prefix='', suffix=''):
plot_config=None, plot_dir=None, prefix='', suffix='', custom_var={}):
"""Plot a parameter histogram. Optional: compare datasets with hue as datasets identifier."""
plot_config = make_plot_config() if plot_config is None else plot_config
......@@ -753,10 +758,13 @@ def plot_param_hist(data_dir, param, df=None, hue=None, hue_order=None, hist=Tru
# make label
info = tpr.get_info(data_dir)
param_label = tpr.make_param_label(param, l_unit=info['length_unit'], t_unit=info['time_unit'])
if param in custom_var.keys():
param_label = tpr.make_param_label(None, manual_symbol=custom_var[param]['name'], manual_unit=custom_var[param]['unit'])
else:
param_label = tpr.make_param_label(param, l_unit=info['length_unit'], t_unit=info['time_unit'])
# make sure data is float and finite
df[param] = df[param].astype(np.float)
df[param] = pd.to_numeric(df[param],errors='coerce')
df = df[np.isfinite(df[param])]
kind = "hist" if hist else "kde"
......
......@@ -172,7 +172,7 @@ def make_unit_label(dimension='L', l_unit='um', t_unit='min'):
return label
def make_param_label(param, l_unit='um', t_unit='min', time_der=None, mean=False, only_symbol=False, only_unit=False):
def make_param_label(param, l_unit='um', t_unit='min', time_der=None, mean=False, only_symbol=False, only_unit=False, manual_symbol=None, manual_unit=None):
"""
Make a Latex formatted label of a parameter.
The first and second time derivative and the mean symbol can be used.
......@@ -194,6 +194,10 @@ def make_param_label(param, l_unit='um', t_unit='min', time_der=None, mean=False
:type only_symbol: bool
:param only_unit: return only the parameter unit
:type only_unit: bool
:param manual_symbol: custom symbol to be used with param=None
:type manual_symbol: str or None
:param manual_unit: custom unit to be used with param=None
:type manual_unit: str or None
:return: parameter label: symbol and/or unit
:rtype: str
"""
......@@ -221,61 +225,90 @@ def make_param_label(param, l_unit='um', t_unit='min', time_der=None, mean=False
param_dict['track'] = {'sym': 'track id', 'dim': 'none', 'l_unit': l_unit, 't_unit': t_unit, 'latex': False}
param_dict['area'] = {'sym': 'area', 'dim': 'LL', 'l_unit': l_unit, 't_unit': t_unit, 'latex': True}
# check if mean
if param.endswith('_mean'):
param = param[:param.find('_mean')]
mean = True
# convert param if it ends by _dot or _ddot
raw_param, time_der_ = get_param_time_der(param)
if time_der_ is not None:
time_der = time_der_
param = raw_param
# make symbol adding derivative or mean symbols
symbol = param_dict[param]['sym']
dim_der = ''
if time_der == 'first':
symbol = r'\dot{' + symbol + r'}'
dim_der = '/T' # change dimension
elif time_der == 'sec':
symbol = r'\ddot{' + symbol + r'}'
dim_der = '/TT' # change dimension
if mean:
symbol = r'\langle ' + symbol + r' \rangle'
# make unit
unit_ = make_unit_label(param_dict[param]['dim'] + dim_der, l_unit=param_dict[param]['l_unit'],
t_unit=param_dict[param]['t_unit'])
# output
# output format
if only_symbol and only_unit:
print("Warning: only_symbol and only_unit can't be both True. Making them both False.")
only_symbol = False
only_unit = False
label = ''
# add latex dollar symbols if needed
if param_dict[param]['latex']:
if only_symbol:
label = r'$' + symbol + r'$'
elif only_unit:
label = r'$' + unit_ + r'$'
# output label
label = ''
if param is not None:
# check if mean
if param.endswith('_mean'):
param = param[:param.find('_mean')]
mean = True
# convert param if it ends by _dot or _ddot
raw_param, time_der_ = get_param_time_der(param)
if time_der_ is not None:
time_der = time_der_
param = raw_param
# make symbol adding derivative or mean symbols
symbol = param_dict[param]['sym']
dim_der = ''
if time_der == 'first':
symbol = r'\dot{' + symbol + r'}'
dim_der = '/T' # change dimension
elif time_der == 'sec':
symbol = r'\ddot{' + symbol + r'}'
dim_der = '/TT' # change dimension
if mean:
symbol = r'\langle ' + symbol + r' \rangle'
# make unit
unit_ = make_unit_label(param_dict[param]['dim'] + dim_der, l_unit=param_dict[param]['l_unit'],
t_unit=param_dict[param]['t_unit'])
# add latex dollar symbols if needed
if param_dict[param]['latex']:
if only_symbol:
label = r'$' + symbol + r'$'
elif only_unit:
label = r'$' + unit_ + r'$'
else:
if len(unit_) > 0: # if unit exists
label = r'$' + symbol + r'$ ($' + unit_ + r'$)'
else:
label = r'$' + symbol + r'$'
else:
if len(unit_) > 0: # if unit exists
label = r'$' + symbol + r'$ ($' + unit_ + r'$)'
if only_symbol:
label = symbol
elif only_unit:
label = unit_
else:
label = r'$' + symbol + r'$'
else:
if only_symbol:
label = symbol
elif only_unit:
label = unit_
if len(unit_) > 0: # if unit exists
label = symbol + ' (' + unit_ + ')'
else:
label = symbol
else:
no_symbol = False
no_unit = False
if manual_unit is None:
no_unit = True
else:
if len(unit_) > 0: # if unit exists
label = symbol + ' (' + unit_ + ')'
if len(manual_unit) == 0:
no_unit = True
if manual_symbol is None:
no_symbol = True
else:
if len(manual_symbol) == 0:
no_symbol = True
if only_symbol and not no_symbol:
label = r'' + manual_symbol
elif only_unit and not no_unit:
label = r'' + manual_unit
elif not only_symbol and not only_unit:
if no_symbol and not no_unit:
label = r'' + manual_unit
elif not no_symbol and no_unit:
label = r'' + manual_symbol
elif not no_symbol and not no_unit:
label = r'' + manual_symbol + r' (' + manual_unit + r')'
else:
label = symbol
label = ''
return label
......@@ -592,7 +625,7 @@ def get_info(data_dir):
def get_data(data_dir, df=None, refresh=False, split_traj=False, set_origin_=False, image=None, reset_dim=['x', 'y'],
invert_axes=[]):
invert_axes=[], custom_var={}):
"""
Main function to import data and perform the initial processing (scaling and computing of time derivatives).
It saves the database as a pickle object.
......@@ -616,6 +649,8 @@ def get_data(data_dir, df=None, refresh=False, split_traj=False, set_origin_=Fal
:type reset_dim: list
:param invert_axes: list of dimensions to invert (change sign)
:type invert_axes: list
:param custom_var: dict of custom variables: {'var_i':{'name':'var_name','unit':'var_unit'}}
:type custom_var: dict
:return: dict with dataframe and key info
:rtype: dict
"""
......@@ -641,7 +676,7 @@ def get_data(data_dir, df=None, refresh=False, split_traj=False, set_origin_=Fal
df = pd.read_csv(data_file, sep=sep) # columns must be ['x','y','z','frame','track']
# check data type
df = df.apply(pd.to_numeric, errors='coerce')
df = df.apply(pd.to_numeric, errors='ignore')
dimensions = ['x', 'y', 'z'] if 'z' in df.columns else ['x', 'y']
dim = len(dimensions)
df['frame'] = df['frame'].astype(np.int)
......@@ -670,14 +705,14 @@ def get_data(data_dir, df=None, refresh=False, split_traj=False, set_origin_=Fal
df['z_rel'] = df['z_scaled'] - df['z_scaled'].mean()
# update pickle
data = {'df': df, 'lengthscale': lengthscale, 'timescale': timescale, 'dim': dim, 'dimensions': dimensions}
data = {'df': df, 'lengthscale': lengthscale, 'timescale': timescale, 'dim': dim, 'dimensions': dimensions, 'custom_var':custom_var}
pickle.dump(data, open(pickle_fn, "wb"))
else:
data = pickle.load(open(pickle_fn, "rb"))
# extra check that df is numerci (in case of old pickle)
df = data['df']
df = df.apply(pd.to_numeric, errors='coerce')
df = df.apply(pd.to_numeric, errors='ignore')
data['df'] = df
return data
......
......@@ -58,6 +58,7 @@ def traj_analysis(data_dir, data=None, image=None, refresh=False, parallelize=Fa
df = data['df']
dim = data['dim']
dimensions = data['dimensions']
custom_var = data['custom_var']
# Get plot_config
plot_config = tpl.make_plot_config(data_dir=data_dir, export_config=False) if plot_config is None else plot_config
......@@ -144,7 +145,7 @@ def traj_analysis(data_dir, data=None, image=None, refresh=False, parallelize=Fa
# compute mean track properties
mean_fn = osp.join(sub_dir, 'track_prop.csv')
df_prop = tca.compute_track_prop(df, dimensions)
df_prop = tca.compute_track_prop(df, dimensions,custom_var)
df_prop.to_csv(mean_fn)
# analysis and plotting
......@@ -210,7 +211,7 @@ def traj_analysis(data_dir, data=None, image=None, refresh=False, parallelize=Fa
print("Plotting parameters histograms...")
for p in hist_config['var_list']:
tpl.plot_param_hist(data_dir, p, df, plot_config=plot_config, plot_dir=sub_dir, hue=hue,
hue_order=hue_order)
hue_order=hue_order,custom_var=custom_var)
if 'mean_var_list' in hist_config.keys():
if len(hist_config['mean_var_list']) > 0:
......@@ -218,7 +219,7 @@ def traj_analysis(data_dir, data=None, image=None, refresh=False, parallelize=Fa
print("Plotting whole-track histograms...")
for p in hist_config['mean_var_list']:
tpl.plot_param_hist(data_dir, p, df_prop, plot_config=plot_config, plot_dir=sub_dir, prefix='track_',
hue=hue, hue_order=hue_order)
hue=hue, hue_order=hue_order,custom_var=custom_var)
if scatter_config['run']:
fn = osp.join(config_dir, 'scatter_config.csv')
......@@ -232,7 +233,7 @@ def traj_analysis(data_dir, data=None, image=None, refresh=False, parallelize=Fa
x_param, y_param = param_vs_param
tpl.plot_param_vs_param(data_dir, x_param, y_param, df, plot_dir=sub_dir, plot_config=plot_config,
hue=hue, hue_order=hue_order, x_bin_num=scatter_config["x_bin_num"], ci=scatter_config["ci"],
fit_reg=scatter_config["fit_reg"], scatter=scatter_config["scatter"])
fit_reg=scatter_config["fit_reg"], scatter=scatter_config["scatter"],custom_var=custom_var)
if 'mean_couple_list' in scatter_config.keys():
if len(scatter_config['mean_couple_list']) > 0:
......@@ -242,7 +243,7 @@ def traj_analysis(data_dir, data=None, image=None, refresh=False, parallelize=Fa
x_param, y_param = param_vs_param
tpl.plot_param_vs_param(data_dir, x_param, y_param, df_prop, plot_dir=sub_dir, plot_config=plot_config,
prefix='track_', hue=hue, hue_order=hue_order, x_bin_num=scatter_config["x_bin_num"], ci=scatter_config["ci"],
fit_reg=scatter_config["fit_reg"], scatter=scatter_config["scatter"])
fit_reg=scatter_config["fit_reg"], scatter=scatter_config["scatter"],custom_var=custom_var)
return df_list
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment