diff --git a/src/main/java/fr/pasteur/ida/zellige/gui/controller/MainController.java b/src/main/java/fr/pasteur/ida/zellige/gui/controller/MainController.java
index 3b1f07e25c08487128e6f443514043416e56f4b8..1619f294cfc59ae86e0816e1a20fbd8c3bd60dbf 100644
--- a/src/main/java/fr/pasteur/ida/zellige/gui/controller/MainController.java
+++ b/src/main/java/fr/pasteur/ida/zellige/gui/controller/MainController.java
@@ -53,14 +53,14 @@ import javafx.collections.FXCollections;
 import javafx.collections.ObservableList;
 import javafx.fxml.FXML;
 import javafx.fxml.Initializable;
-import javafx.scene.control.Alert;
-import javafx.scene.control.Button;
-import javafx.scene.control.ComboBox;
-import javafx.scene.control.Label;
+import javafx.scene.control.*;
 import javafx.stage.FileChooser;
 import net.imagej.Dataset;
 import net.imagej.ImgPlus;
+import net.imglib2.RandomAccess;
+import net.imglib2.RandomAccessibleInterval;
 import net.imglib2.img.Img;
+import net.imglib2.img.display.imagej.ImageJFunctions;
 import net.imglib2.type.NativeType;
 import net.imglib2.type.numeric.RealType;
 import net.imglib2.type.numeric.real.FloatType;
@@ -78,25 +78,30 @@ public class MainController< T extends RealType< T > & NativeType< T > > impleme
 {
 
     private final static Logger LOGGER = LoggerFactory.getLogger( MainController.class );
+    private int targetChannel = 1;
     private final SimpleObjectProperty< Dataset > currentDataset = new SimpleObjectProperty<>();
+    private final SimpleObjectProperty< Img< T > > referenceImage = new SimpleObjectProperty<>();
     private final SimpleObjectProperty< ClassifiedImages< FloatType > > images = new SimpleObjectProperty<>();
     private final SimpleObjectProperty< Img< FloatType > > pretreatedImg = new SimpleObjectProperty<>();
     private final SimpleBooleanProperty disableGUI = new SimpleBooleanProperty();
     private final SimpleBooleanProperty changedParameters = new SimpleBooleanProperty();
+
+    AbstractTask< Img< FloatType > > pretreatmentTask = null;
+    @FXML
+    Spinner<Integer> channels;
     @FXML
     ComboBox< Dataset > activeDatasets;
     @FXML
     private Button runButton;
     @FXML
     private ConstructionController< T > constructionController;
-
     @FXML
     private Label logInfo;
     private MainAppFrame mainAppFrame;
     @FXML
     private ProjectionController< T > projectionController;
     @FXML
-    private SelectionController< T > selectionController;
+    private SelectionController selectionController;
 
 
     MainModel< T > model;
@@ -107,7 +112,6 @@ public class MainController< T extends RealType< T > & NativeType< T > > impleme
         alert.setTitle( "Error alert" );
         alert.setHeaderText( exception.getMessage() );
         alert.showAndWait();
-       // throw new RuntimeException(exception);
     }
 
     @Override
@@ -121,10 +125,10 @@ public class MainController< T extends RealType< T > & NativeType< T > > impleme
         activeDatasets.getSelectionModel().selectedItemProperty().addListener( ( observableValue, oldValue, newValue ) ->
         {
             LOGGER.debug( "selectedItemProperty : old value = {}, new value = {} ",  oldValue, newValue );
+            channels.setDisable( false );
             if(newValue!= null && isValid( newValue ))
             {
-
-                    setCurrentDataset( newValue );
+                    currentDataset.setValue( newValue );
             }
             else
             {
@@ -132,13 +136,23 @@ public class MainController< T extends RealType< T > & NativeType< T > > impleme
             }
         } );
 
+        channels.valueProperty().addListener( (observableValue, oldValue, newValue ) ->
+        {
+            if(!channels.isDisable())
+            {
+                setTargetChannel( newValue );
+                setReferenceImage( currentDataset.getValue() );
+            }
+        } );
+
         currentDataset.addListener( ( observableValue, dataset, t1 ) ->
         {
             if ( t1 != null )
             {
-                selectionController.disableParameters();
-                disableGUI.setValue( true );
-                runPretreatment( currentDataset.getValue() );
+                SpinnerValueFactory< Integer > valueFactory =
+                        new SpinnerValueFactory.IntegerSpinnerValueFactory( 1, (int) t1.getChannels(), 1 );
+                channels.setValueFactory( valueFactory );
+                setReferenceImage( currentDataset.getValue() );
             }
             else
             {
@@ -147,6 +161,14 @@ public class MainController< T extends RealType< T > & NativeType< T > > impleme
             }
         } );
 
+        referenceImage.addListener(( observableValue, dataset, t1 ) ->
+        {
+            selectionController.disableParameters();
+            disableGUI.setValue( true );
+            LOGGER.debug("New reference image selected");
+            runPretreatment( referenceImage.getValue()  );
+        } );
+
         pretreatedImg.addListener( ( observable, oldValue, newValue ) ->
         {
             if ( newValue != null )
@@ -188,7 +210,6 @@ public class MainController< T extends RealType< T > & NativeType< T > > impleme
             {
                 LOGGER.debug( "Active button" );
                 disableGUI.setValue( false );
-
             }
         } );
 
@@ -201,7 +222,6 @@ public class MainController< T extends RealType< T > & NativeType< T > > impleme
 
     public void initExtraction()
     {
-
         LOGGER.debug( "Init Extraction" );
         Dataset dataset = mainAppFrame.getImage().getActiveDataset();
         LOGGER.debug( dataset.getImgPlus().firstElement().getClass().toGenericString() );
@@ -238,13 +258,7 @@ public class MainController< T extends RealType< T > & NativeType< T > > impleme
         LOGGER.debug( "###########################################---NEW RUN---###########################################" );
     }
 
-    private void runPretreatment( Dataset dataset )
-    {
-        AbstractTask< Img< FloatType > > task = new PretreatmentTask<>( dataset );
-        task.setOnSucceeded( workerStateEvent ->
-                pretreatedImg.setValue( task.getValue() ) );
-        task.start();
-    }
+
 
     public void setAndDisplayDatasetChoices()
     {
@@ -266,16 +280,80 @@ public class MainController< T extends RealType< T > & NativeType< T > > impleme
         } );
         displayTask.start();
     }
+    private boolean isValid(Dataset dataset)
+    {
+        setTargetChannel( 1 );
+        channels.setDisable( false );
+        int numZ = ( int ) dataset.getDepth();
+        int numFrames = ( int ) dataset.getFrames();
+        int numChannels = (int) dataset.getChannels();
+        LOGGER.debug( "NUmber of dimensions: {}", dataset.numDimensions() );
+        LOGGER.debug( "NUmber of channels: {}", dataset.getChannels() );
+        LOGGER.debug( "Number of frames: {}", dataset.getFrames() );
+        if ( numFrames > 1 )
+        {
+            LOGGER.debug( "TimeLapseException" );
+            showError( new TimeLapseException() );
+            return false;
+        }
+        if ( numZ == 1 )
+        {
+            LOGGER.debug( "NoA3DStackException" );
+            showError( new NoA3DStackException() );
+            return false;
+        }
+        if (numChannels == 1)
+        {
+            channels.setDisable( true );
+            setTargetChannel( 0 );
+        }
+        LOGGER.debug( "The dataset is valid." );
+        return true;
+    }
 
-    @SuppressWarnings( "unchecked" )
-    public void setCurrentDataset( Dataset dataset )
+    public void setReferenceImage( Dataset dataset )
     {
-                currentDataset.setValue( dataset );
-                Img< T > input = ( ImgPlus< T > ) dataset.getImgPlus();
-                constructionController.getConstructionModel().setInput( input );
-                constructionController.getConstructionModel().setFactory( input.factory() );
+        Img< T > input = ( ImgPlus< T > ) dataset.getImgPlus();
+        if (getTargetChannel() == 0)
+        {
+            referenceImage.setValue( input );
+        }
+        else
+        {
+            Img< T > output = input.factory().create( input.dimension( 0 ), input.dimension( 1 ), input.dimension( 3 ) );
+            RandomAccess< T > randomAccess = input.randomAccess();
+            RandomAccess< T > randomAccess1 = output.randomAccess();
+            for ( int x = 0; x < ( int ) input.dimension( 0 ); x++ )
+            {
+                for ( int y = 0; y < ( int ) input.dimension( 1 ); y++ )
+                {
+                    for ( int z = 0; z < ( int ) input.dimension( 3 ); z++ )
+                    {
+                        randomAccess.setPositionAndGet( x, y, getTargetChannel()-1, z );
+                        randomAccess1.setPosition( new int[]{ x, y, z } );
+                        randomAccess1.get().set( randomAccess.get() );
+                    }
+                }
+            }
+            referenceImage.setValue( output );
+        }
+        constructionController.getConstructionModel().setInput( referenceImage.getValue());
+        constructionController.getConstructionModel().setFactory(referenceImage.get().factory() );
 
     }
+    private void runPretreatment( RandomAccessibleInterval<T> input )
+    {
+        ImageJFunctions.show(input , "pretreatment");
+        if (pretreatmentTask!= null)
+        {
+            pretreatmentTask.cancel();
+        }
+        pretreatmentTask = new PretreatmentTask<>( input );
+        pretreatmentTask.setOnSucceeded( workerStateEvent ->
+                pretreatedImg.setValue( pretreatmentTask.getValue() ) );
+
+        pretreatmentTask.start();
+    }
 
     public void computeClassifiedImages()
     {
@@ -372,26 +450,15 @@ public class MainController< T extends RealType< T > & NativeType< T > > impleme
         projectionController.setParameters( parameters );
     }
 
-    private boolean isValid(Dataset dataset)
+    public int getTargetChannel()
     {
-        int numZ = ( int ) dataset.getDepth();
-        int numFrames = ( int ) dataset.getFrames();
-        LOGGER.debug( "NUmber of dimensions: {}", dataset.numDimensions() );
-        LOGGER.debug( "NUmber of channels: {}", dataset.getChannels() );
-        LOGGER.debug( "Number of frames: {}", dataset.getFrames() );
-        if ( numFrames > 1 )
-        {
-            LOGGER.debug( "TimeLapseException" );
-            showError( new TimeLapseException() );
-            return false;
-        }
-        if ( numZ == 1 )
-        {
-            LOGGER.debug( "NoA3DStackException" );
-            showError( new NoA3DStackException() );
-            return false;
-        }
-        return true;
+        return targetChannel;
+    }
+
+    public void setTargetChannel( int targetChannel )
+    {
+        LOGGER.debug( " targetChannel = {}", targetChannel);
+        this.targetChannel = targetChannel;
     }
 }