diff --git a/src/main/java/fr/pasteur/ida/zellige/utils/MedianSkipZeros.java b/src/main/java/fr/pasteur/ida/zellige/utils/MedianSkipZeros.java new file mode 100644 index 0000000000000000000000000000000000000000..2d11fadcaf4a1ddc41b77522e2c24160c8391589 --- /dev/null +++ b/src/main/java/fr/pasteur/ida/zellige/utils/MedianSkipZeros.java @@ -0,0 +1,197 @@ +package fr.pasteur.ida.zellige.utils; + +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.stream.Collectors; + +import ij.IJ; +import ij.ImagePlus; +import net.imagej.ImageJ; +import net.imglib2.Cursor; +import net.imglib2.Interval; +import net.imglib2.RandomAccess; +import net.imglib2.RandomAccessibleInterval; +import net.imglib2.algorithm.neighborhood.Neighborhood; +import net.imglib2.algorithm.neighborhood.RectangleShape; +import net.imglib2.algorithm.neighborhood.RectangleShape.NeighborhoodsAccessible; +import net.imglib2.img.Img; +import net.imglib2.img.ImgFactory; +import net.imglib2.img.display.imagej.ImageJFunctions; +import net.imglib2.loops.IntervalChunks; +import net.imglib2.type.NativeType; +import net.imglib2.type.numeric.RealType; +import net.imglib2.util.Util; +import net.imglib2.view.IntervalView; +import net.imglib2.view.Views; + +public class MedianSkipZeros +{ + + /** + * Apply a 3x3 median to a nD image, possibly skipping for pixel with + * 0-values. The computation is multithreaded. + * + * @param <T> + * the pixel type of the source image. + * @param source + * the source image. + * @param skipZeros + * if <code>true</code>, pixels with 0-value will be skipped. + * @return a new image of the same pixel type that of the source. + */ + public static final < T extends RealType< T > & NativeType< T > > Img< T > median( final RandomAccessibleInterval< T > source, final boolean skipZeros ) + { + // Prepare output. + final ImgFactory< T > factory = Util.getArrayOrCellImgFactory( source, Util.getTypeFromInterval( source ) ); + final Img< T > output = factory.create( source ); + + // Create the neighborhoods. + final RectangleShape kernel = new RectangleShape( 1, false ); + final NeighborhoodsAccessible< T > neighborhoods = kernel.neighborhoodsRandomAccessible( Views.extendMirrorDouble( source ) ); + + // Divide into chunks. + final int numberOfChunks = Runtime.getRuntime().availableProcessors() / 2; + final List< Interval > chunks = IntervalChunks.chunkInterval( source, numberOfChunks ); + + // Create tasks. + final List< Runnable > runnables = ( skipZeros ) + ? chunks.stream() + .map( chunk -> new MedianOperatorSkipZeros<>( chunk, output, source, neighborhoods ) ) + .collect( Collectors.toList() ) + : chunks.stream() + .map( chunk -> new MedianOperator<>( chunk, output, source, neighborhoods ) ) + .collect( Collectors.toList() ); + + // Launch computation on several threads + final ExecutorService executorService = Executors.newFixedThreadPool( numberOfChunks ); + runnables.forEach( executorService::submit ); + executorService.shutdown(); + + return output; + } + + private static final class MedianOperator< T extends RealType< T > & NativeType< T > > implements Runnable + { + + private final Interval chunk; + + private final Img< T > output; + + private final RandomAccessibleInterval< T > source; + + private final NeighborhoodsAccessible< T > neighborhoods; + + private MedianOperator( final Interval chunk, final Img< T > output, final RandomAccessibleInterval< T > source, final NeighborhoodsAccessible< T > neighborhoods ) + { + this.chunk = chunk; + this.output = output; + this.source = source; + this.neighborhoods = neighborhoods; + } + + @Override + public void run() + { + final IntervalView< T > interval = Views.interval( source, chunk ); + + final Cursor< T > sourceCursor = interval.localizingCursor(); + final RandomAccess< T > raOutput = output.randomAccess( chunk ); + final RandomAccess< Neighborhood< T > > raNeighborhoods = neighborhoods.randomAccess( chunk ); + + final int size = ( int ) raNeighborhoods.get().size(); + final double[] values = new double[ size ]; + + while ( sourceCursor.hasNext() ) + { + sourceCursor.fwd(); + + raNeighborhoods.setPosition( sourceCursor ); + int index = 0; + for ( final T pixel : raNeighborhoods.get() ) + values[ index++ ] = pixel.getRealDouble(); + + Arrays.sort( values, 0, index ); + final double median = values[ ( index - 1 ) / 2 ]; + + raOutput.setPosition( sourceCursor ); + raOutput.get().setReal( median ); + } + } + } + + private static final class MedianOperatorSkipZeros< T extends RealType< T > & NativeType< T > > implements Runnable + { + + private final Interval chunk; + + private final Img< T > output; + + private final RandomAccessibleInterval< T > source; + + private final NeighborhoodsAccessible< T > neighborhoods; + + private MedianOperatorSkipZeros( final Interval chunk, final Img< T > output, final RandomAccessibleInterval< T > source, final NeighborhoodsAccessible< T > neighborhoods ) + { + this.chunk = chunk; + this.output = output; + this.source = source; + this.neighborhoods = neighborhoods; + } + + @Override + public void run() + { + final IntervalView< T > interval = Views.interval( source, chunk ); + + final Cursor< T > sourceCursor = interval.localizingCursor(); + final RandomAccess< T > raOutput = output.randomAccess( chunk ); + final RandomAccess< Neighborhood< T > > raNeighborhoods = neighborhoods.randomAccess( chunk ); + + final int size = ( int ) raNeighborhoods.get().size(); + final double[] values = new double[ size ]; + + while ( sourceCursor.hasNext() ) + { + sourceCursor.fwd(); + if ( sourceCursor.get().getRealDouble() == 0 ) + { + raOutput.setPosition( sourceCursor ); + raOutput.get().set( sourceCursor.get() ); + continue; + } + + raNeighborhoods.setPosition( sourceCursor ); + int index = 0; + for ( final T pixel : raNeighborhoods.get() ) + values[ index++ ] = pixel.getRealDouble(); + + Arrays.sort( values, 0, index ); + final double median = values[ ( index - 1 ) / 2 ]; + + raOutput.setPosition( sourceCursor ); + raOutput.get().setReal( median ); + } + } + } + + /* + * DEMO + */ + + public static < T extends RealType< T > & NativeType< T > > void main( final String[] args ) + { + final ImageJ ij = new ImageJ(); + ij.launch( args ); + + final String imgFile = "samples/STK_epithelium.tif"; + final ImagePlus imp = IJ.openImage( imgFile ); + imp.show(); + + final Img< T > wrap = ImageJFunctions.wrap( imp ); + final Img< T > output = MedianSkipZeros.median( wrap, true ); + ImageJFunctions.show( output, "Median skip 0" ); + + } +}