diff --git a/setup.cfg b/setup.cfg index 244e9350604ad2dbb206b338feeab46b58810249..6c188968ad16f8df4d917ca1298185136cc34fcf 100755 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = napari-3dtimereg -version = 0.0.6 +version = 0.0.7 description = Registration of 3D movies applied to all channels long_description = file: README.md long_description_content_type = text/markdown diff --git a/src/napari_3dtimereg/movieRegistration.py b/src/napari_3dtimereg/movieRegistration.py index 231ed0cd07535efea77e029b63d7d2844542e506..3fecb151d0fdcb0b23a3d6e3f4f9c8010ba36662 100755 --- a/src/napari_3dtimereg/movieRegistration.py +++ b/src/napari_3dtimereg/movieRegistration.py @@ -37,21 +37,30 @@ def get_filename(): return None def start(): - global viewer, aligndir, imagedir, ptsdir + global viewer + viewer = napari.current_viewer() + viewer.title = "3dTimeReg" + filename = get_filename() + if filename is None: + print("No file selected") + return + open_file( filename ) + return getChanels() + +def start_noshow(): + global viewer + viewer = napari.Viewer( show=False ) + +def open_file( filename, show_images=True ): global refimg, refchanel - global resimg global imagename + global resimg global scaleXY, scaleZ + global aligndir, imagedir, ptsdir global colchan, dim global refpts refpts = None refchanel = 0 - viewer = napari.current_viewer() - viewer.title = "3dTimeReg" - filename = get_filename() - if filename is None: - print("No file selected") - return refimg, scaleXY, scaleZ, names = ut.open_image(filename, verbose=True) global pixel_spacing #pixel_spacing = [scaleXY, scaleXY] @@ -74,18 +83,18 @@ def start(): imagename, imagedir, aligndir = ut.extract_names( filename, subname="aligned" ) ptsdir = os.path.join(imagedir, "regpoints") update_save_history(imagedir) - for chan in range(refimg.shape[colchan]): - cmap = ut.colormapname(chan) - if dim == 3: - cview = viewer.add_image( refimg[:,:,chan,:,:], name="Movie_"+"C"+str(chan), blending="additive", colormap = cmap ) - quants = tuple( np.quantile( refimg[:,:,chan,:,:], [0.01, 0.9999]) ) - else: - cview = viewer.add_image( refimg[:,chan,:,:], name="Movie_"+"C"+str(chan), blending="additive", colormap = cmap ) - quants = tuple( np.quantile( refimg[:,chan,:,:], [0.01, 0.9999]) ) - cview.contrast_limits = quants - cview.gamma = 0.95 + if show_images: + for chan in range(refimg.shape[colchan]): + cmap = ut.colormapname(chan) + if dim == 3: + cview = viewer.add_image( refimg[:,:,chan,:,:], name="Movie_"+"C"+str(chan), blending="additive", colormap = cmap ) + quants = tuple( np.quantile( refimg[:,:,chan,:,:], [0.01, 0.9999]) ) + else: + cview = viewer.add_image( refimg[:,chan,:,:], name="Movie_"+"C"+str(chan), blending="additive", colormap = cmap ) + quants = tuple( np.quantile( refimg[:,chan,:,:], [0.01, 0.9999]) ) + cview.contrast_limits = quants + cview.gamma = 0.95 - return getChanels() def show_help_chanel(): """ Open the gitlab page with the documentation """ @@ -119,12 +128,7 @@ def getChanels(): viewer.layers.remove(layname) else: viewer.layers.remove(layname) ## tmp - # img = (refimg[:,:,chan,:,:]) - # if layname in viewer.layers: - # viewer.layers.remove(layname) - # viewer.add_image( img, name=layname, blending="additive", colormap = "red" ) - # else: - # viewer.layers[layname].colormap = "red" + if "Do registration" not in viewer.window._dock_widgets: if dim == 2: resimg = np.copy(refimg[:,refchanel,:,:]) @@ -141,6 +145,22 @@ def getChanels(): wid = viewer.window.add_dock_widget(get_chanel, name="Choose chanel") return wid +def getChanels_noshow( reference_chanel=0 ): + """ Do the chanel step without interface """ + global refchanel + global resimg + global colchan + refchanel = reference_chanel + + if dim == 2: + resimg = np.copy(refimg[:,refchanel,:,:]) + else: + resimg = np.copy(refimg[:,:,refchanel,:,:]) + resimg[0] = resimg[0] - np.min(resimg[0]) + + iterative_registration() + + def itk_to_layer(img, name, color): lay = layer_from_image(img) lay.blending = "additive" @@ -156,118 +176,115 @@ def img_to_itk(img): image_itk = image_itk.astype(itk.F) return image_itk -def iterative_registration(): - """ use Elastix to perform registration with possible deformation, iteratively in time """ - - def rigid_map(): - """ Set-up rigid (affine) transformation parameters """ - - parameter_object = itk.ParameterObject.New() - parameter_map_rigid = parameter_object.GetDefaultParameterMap('rigid') - parameter_map_rigid['MaximumNumberOfIterations'] = [str(get_paras.iterations.value)] - parameter_map_rigid['MaximumStepLength'] = ['2.0'] - parameter_map_rigid["NumberOfResolutions"] = [str(get_paras.rigid_resolution.value)] - parameter_map_rigid['NumberOfSpatialSamples'] = ['10000'] - parameter_map_rigid['MaximumNumberOfSamplingAttempts'] = ['10'] - parameter_map_rigid['RequiredRatioOfValidSamples'] = ['0.05'] - parameter_map_rigid['CheckNumberOfSamples'] = ['false'] - final = int(get_paras.rigid_final_spacing.value) - parameter_map_rigid['FinalGridSpacingInPhysicalUnits'] = [str(final)] - parameter_map_rigid['Registration'] = ['MultiMetricMultiResolutionRegistration'] - parameter_map_rigid["AutomaticTransformInitialization"] = ['true'] - parameter_map_rigid["AutomaticTransformInitializationMethod"] = ['CenterOfGravity'] +def rigid_map( iterations, rig_resolution, rig_final_spacing, use_points=True ): + """ Set-up rigid (affine) transformation parameters """ + parameter_object = itk.ParameterObject.New() + parameter_map_rigid = parameter_object.GetDefaultParameterMap('rigid') + parameter_map_rigid['MaximumNumberOfIterations'] = [iterations] + parameter_map_rigid['MaximumStepLength'] = ['2.0'] + parameter_map_rigid["NumberOfResolutions"] = [rig_resolution] + parameter_map_rigid['NumberOfSpatialSamples'] = ['10000'] + parameter_map_rigid['MaximumNumberOfSamplingAttempts'] = ['10'] + parameter_map_rigid['RequiredRatioOfValidSamples'] = ['0.05'] + parameter_map_rigid['CheckNumberOfSamples'] = ['false'] + parameter_map_rigid['FinalGridSpacingInPhysicalUnits'] = [str(rig_final_spacing)] + parameter_map_rigid['Registration'] = ['MultiMetricMultiResolutionRegistration'] + parameter_map_rigid["AutomaticTransformInitialization"] = ['true'] + parameter_map_rigid["AutomaticTransformInitializationMethod"] = ['CenterOfGravity'] - original_metric = parameter_map_rigid['Metric'] - #if get_paras.use_reference_points.value==True: - # parameter_map_rigid['Metric'] = [original_metric[0], 'CorrespondingPointsEuclideanDistanceMetric'] + original_metric = parameter_map_rigid['Metric'] + if use_points==True: + parameter_map_rigid['Metric'] = [original_metric[0], 'CorrespondingPointsEuclideanDistanceMetric'] - return parameter_map_rigid + return parameter_map_rigid +def bspline_map( spline_resolution, iterations, final_order, spline_final_spacing ): + """ Set-up bspline transformation parameters """ + preset = "bspline" + parameter_object = itk.ParameterObject.New() + parameter_map = parameter_object.GetDefaultParameterMap(preset) - def bspline_map(): - """ Set-up bspline transformation parameters """ - preset = "bspline" - parameter_object = itk.ParameterObject.New() - parameter_map = parameter_object.GetDefaultParameterMap(preset) + parameter_map["NumberOfResolutions"] = [spline_resolution] + parameter_map["WriteIterationInfo"] = ["false"] + parameter_map['MaximumStepLength'] = ['2.0'] + parameter_map['NumberOfSpatialSamples'] = ['10000'] + parameter_map['MaximumNumberOfSamplingAttempts'] = ['10'] + parameter_map['RequiredRatioOfValidSamples'] = ['0.05'] + parameter_map['MaximumNumberOfIterations'] = [iterations] + parameter_map['FinalBSplineInterpolationOrder'] = [final_order] + parameter_map['BSplineInterpolationOrder'] = ['3'] + parameter_map['HowToCombineTransform'] = ['Compose'] + nres = int(spline_resolution) + spaces = [] + for step in range(nres): + spaces.append( math.pow(2, nres-1-step) ) + parameter_map['GridSpacingSchedule'] = [str(v) for v in spaces ] + parameter_map['FinalGridSpacingInPhysicalUnits'] = [str(v) for v in [spline_final_spacing]*int(spline_resolution)] + + return parameter_map + +def time_registration( do_rigid, do_bspline, iterations, rigid_resolution, rigid_final_spacing, use_reference_points, spline_resolution, spline_final_spacing, final_order, show_log=True ): + """ Go for frame by frame registration """ - parameter_map["NumberOfResolutions"] = [str(get_paras.spline_resolution.value)] - parameter_map["WriteIterationInfo"] = ["false"] - parameter_map['MaximumStepLength'] = ['2.0'] - parameter_map['NumberOfSpatialSamples'] = ['10000'] - parameter_map['MaximumNumberOfSamplingAttempts'] = ['10'] - parameter_map['RequiredRatioOfValidSamples'] = ['0.05'] - parameter_map['MaximumNumberOfIterations'] = [str(get_paras.iterations.value)] - parameter_map['FinalBSplineInterpolationOrder'] = [str(get_paras.final_order.value)] - parameter_map['BSplineInterpolationOrder'] = ['3'] - parameter_map['HowToCombineTransform'] = ['Compose'] - nres = int(get_paras.spline_resolution.value) - spaces = [] - for step in range(nres): - spaces.append( math.pow(2, nres-1-step) ) - parameter_map['GridSpacingSchedule'] = [str(v) for v in spaces ] - parameter_map['FinalGridSpacingInPhysicalUnits'] = [str(v) for v in [get_paras.spline_final_spacing.value]*int(get_paras.spline_resolution.value)] - - return parameter_map - - def time_registration(): - """ Go for frame by frame registration """ - - ## Build registration parameter maps from GUI parameters - registration_parameter_object = itk.ParameterObject.New() - nmap = 0 - if get_paras.do_rigid.value: - pmap_rigid = rigid_map() - registration_parameter_object.AddParameterMap(pmap_rigid) - nmap = nmap + 1 - if get_paras.do_bspline.value: - pmap_spline = bspline_map() - registration_parameter_object.AddParameterMap(pmap_spline) - nmap = nmap + 1 + ## Build registration parameter maps from GUI parameters + registration_parameter_object = itk.ParameterObject.New() + nmap = 0 + if do_rigid: + pmap_rigid = rigid_map( iterations=str(iterations), rig_resolution=str(rigid_resolution), rig_final_spacing=int(rigid_final_spacing), use_points=use_reference_points ) + registration_parameter_object.AddParameterMap(pmap_rigid) + nmap = nmap + 1 + if do_bspline: + pmap_spline = bspline_map( spline_resolution=str(spline_resolution), iterations=str(iterations), final_order=str(final_order), spline_final_spacing=int(spline_final_spacing) ) + registration_parameter_object.AddParameterMap(pmap_spline) + nmap = nmap + 1 - ## apply "alignement" to first frame - apply_registration(0, None) + ## apply "alignement" to first frame + apply_registration(0, None) - # initialise a parameter object to which the transforms will be appended that result from the pairwise slice registrations - curr_transform_object = itk.ParameterObject.New() + # initialise a parameter object to which the transforms will be appended that result from the pairwise slice registrations + curr_transform_object = itk.ParameterObject.New() - # the first fixed image will be the reference slice - fixed_image_itk = img_to_itk(resimg[0]) + # the first fixed image will be the reference slice + fixed_image_itk = img_to_itk(resimg[0]) - ## Register all frames to previous one and add it - for t in range(resimg.shape[0]): - print("Calculate registration for time point "+str(t)) + ## Register all frames to previous one and add it + for t in range(resimg.shape[0]): + print("Calculate registration for time point "+str(t)) - if t > 0: - # the moving image is the current slice - moving_image_itk = img_to_itk(resimg[t]) + if t > 0: + # the moving image is the current slice + moving_image_itk = img_to_itk(resimg[t]) - # perform the pairwise registration between two slices - elastix_object = itk.ElastixRegistrationMethod.New(fixed_image_itk, moving_image_itk) - elastix_object.SetParameterObject(registration_parameter_object) + # perform the pairwise registration between two slices + elastix_object = itk.ElastixRegistrationMethod.New(fixed_image_itk, moving_image_itk) + elastix_object.SetParameterObject(registration_parameter_object) - #if get_paras.use_reference_points.value: - # elastix_object.SetFixedPointSetFileName(os.path.join(aligndir, imagename+"_refpts_fixed.txt")) - # elastix_object.SetMovingPointSetFileName(os.path.join(aligndir, imagename+"_refpts_moving.txt")) + if use_reference_points: + get_ref_points(t-1, t) + elastix_object.SetFixedPointSetFileName(os.path.join(aligndir, imagename+"_refpts_fixed.txt")) + elastix_object.SetMovingPointSetFileName(os.path.join(aligndir, imagename+"_refpts_moving.txt")) - elastix_object.SetLogToConsole(get_paras.show_log.value==True) + elastix_object.SetLogToConsole( show_log==True ) - # Update filter object (required) - elastix_object.UpdateLargestPossibleRegion() + # Update filter object (required) + elastix_object.UpdateLargestPossibleRegion() - # Results of Registration - #affimage = elastix_object.GetOutput() - results_transform_parameters = elastix_object.GetTransformParameterObject() + # Results of Registration + #affimage = elastix_object.GetOutput() + results_transform_parameters = elastix_object.GetTransformParameterObject() - # set the current moving image as the fixed image for the registration in the next iteration - fixed_image_itk = moving_image_itk + # set the current moving image as the fixed image for the registration in the next iteration + fixed_image_itk = moving_image_itk - # append the obtained transform to the transform parameter object - for i in range(nmap): - curr_transform_object.AddParameterMap(results_transform_parameters.GetParameterMap(i)) + # append the obtained transform to the transform parameter object + for i in range(nmap): + curr_transform_object.AddParameterMap(results_transform_parameters.GetParameterMap(i)) - # transform the current slice and append it to the reconstructed stack - apply_registration(t, curr_transform_object) + # transform the current slice and append it to the reconstructed stack + apply_registration(t, curr_transform_object) +def iterative_registration(): + """ use Elastix to perform registration with possible deformation, iteratively in time """ @magicgui(call_button="Go", rigid_resolution={"widget_type":"LiteralEvalLineEdit"}, @@ -280,8 +297,8 @@ def iterative_registration(): ) def get_paras( show_log = True, - #use_reference_points = False, - #refpoints_file = pathlib.Path(os.path.join(imagedir, imagename+"_reference_points.csv")), + use_reference_points = False, + refpoints_file = pathlib.Path(os.path.join(imagedir, imagename+"_reference_points.csv")), do_rigid = True, do_bspline = True, show_advanced_parameters = False, @@ -296,12 +313,11 @@ def iterative_registration(): global move_points reslay = viewer.layers["ResMovie"] - use_reference_points = False + #use_reference_points = False if use_reference_points: - read_points() - move_points = True - - time_registration() + read_points( refpoints_file ) + #move_points = True + time_registration( do_rigid=do_rigid, do_bspline=do_bspline, iterations=iterations, rigid_resolution=rigid_resolution, rigid_final_spacing=rigid_final_spacing, use_reference_points=use_reference_points, spline_resolution=spline_resolution, spline_final_spacing=spline_final_spacing, final_order=final_order, show_log=show_log ) finish_image() def show_advanced(booly): @@ -324,14 +340,33 @@ def iterative_registration(): wid = viewer.window.add_dock_widget(get_paras, name="Calculate alignement") +def read_points( refpoints_file ): + """ Read the TrackMate file containing all the points coordinates """ + global refpts + global move_points + move_points = False + ptsfile = refpoints_file + if not os.path.exists(ptsfile): + print("Reference points file "+ptsfile+" not found") + refpts = [] + with open(ptsfile, "r") as infile: + csvreader = csv.DictReader(infile) + for row in csvreader: + cres = [] + if row["TrackID"].isdigit(): + for col in ["TrackID", "X", "Y", "Z", "T"]: + cres.append(int(float(row[col]))) + refpts.append(cres) + refpts = np.array(refpts) + def get_ref_points(time0, time1): - """ Get the reference points common between time0 and time1 and put them to file """ - global refpts - pttime0 = refpts[refpts[:,4]==time0,] - pttime1 = refpts[refpts[:,4]==time1,] - inter, ind1, ind0 = np.intersect1d(pttime1[:,0], pttime0[:,0], return_indices=True) - write_ref_file(ind1, pttime1, "moving") - write_ref_file(ind0, pttime0, "fixed") + """ Get the reference points common between time0 and time1 and put them to file """ + global refpts + pttime0 = refpts[refpts[:,4]==time0,] + pttime1 = refpts[refpts[:,4]==time1,] + inter, ind1, ind0 = np.intersect1d(pttime1[:,0], pttime0[:,0], return_indices=True) + write_ref_file(ind1, pttime1, "moving") + write_ref_file(ind0, pttime0, "fixed") def get_closest_label(pts, tid): """ Find closest pt id to tid """ @@ -520,35 +555,39 @@ def create_result_image(): @magicgui(call_button = "Concatenate aligned images",) def get_files(): - resimg = np.zeros(refimg.shape) - if dim == 2: - nchans = refimg.shape[1] - else: - nchans = refimg.shape[2] - - for chan in range(nchans): - for time in range(refimg.shape[0]): - filename = os.path.join(aligndir, imagename+"_C"+str(chan)+"_T"+"{:04d}".format(time)+".tif") - img, tscaleXY, tscaleZ, names = ut.open_image(filename, verbose=False) - if dim == 2: - resimg[time, chan, :,:] = img - else: - resimg[time, :, chan, :,:] = img - os.remove(filename) - - viewer.add_image(resimg, name="Res", blending="additive") - for lay in viewer.layers: - if lay.name != "Res": - remove_layer(lay) - imgname = os.path.join(aligndir, imagename+".tif") - resimg = np.array(resimg, "uint16") - # move the chanel axis after the Z axis (imageJ format) - if dim == 3: - resimg = np.moveaxis(resimg, 0, 1) - print(resimg.shape) - tifffile.imwrite(imgname, resimg, imagej=True, resolution=[1./scaleXY, 1./scaleXY], metadata={'PhysicalSizeX': scaleXY, 'spacing': scaleZ, 'unit': 'um', 'axes': 'TZCYX'}) - else: - tifffile.imwrite(imgname, resimg, imagej=True, resolution=[1./scaleXY, 1./scaleXY], metadata={'PhysicalSizeX': scaleXY, 'unit': 'um', 'axes': 'TCYX'}) - show_info("Image "+imgname+" saved") - + save_result_image() viewer.window.add_dock_widget(get_files, name="Concatenate") + +def save_result_image(): + resimg = np.zeros(refimg.shape) + if dim == 2: + nchans = refimg.shape[1] + else: + nchans = refimg.shape[2] + + for chan in range(nchans): + for time in range(refimg.shape[0]): + filename = os.path.join(aligndir, imagename+"_C"+str(chan)+"_T"+"{:04d}".format(time)+".tif") + img, tscaleXY, tscaleZ, names = ut.open_image(filename, verbose=False) + if dim == 2: + resimg[time, chan, :,:] = img + else: + resimg[time, :, chan, :,:] = img + os.remove(filename) + + viewer.add_image(resimg, name="Res", blending="additive") + for lay in viewer.layers: + if lay.name != "Res": + remove_layer(lay) + imgname = os.path.join(aligndir, imagename+".tif") + resimg = np.array(resimg, "uint16") + # move the chanel axis after the Z axis (imageJ format) + if dim == 3: + resimg = np.moveaxis(resimg, 0, 1) + print(resimg.shape) + tifffile.imwrite(imgname, resimg, imagej=True, resolution=[1./scaleXY, 1./scaleXY], metadata={'PhysicalSizeX': scaleXY, 'spacing': scaleZ, 'unit': 'um', 'axes': 'TZCYX'}) + else: + tifffile.imwrite(imgname, resimg, imagej=True, resolution=[1./scaleXY, 1./scaleXY], metadata={'PhysicalSizeX': scaleXY, 'unit': 'um', 'axes': 'TCYX'}) + show_info("Image "+imgname+" saved") + +