package ij.plugin;
import ij.*;
import ij.process.*;
import ij.gui.*;
import ij.measure.Calibration;
import java.awt.*;
import java.awt.image.*;

/** This plugin implements the Image/Transform/Bin command.
 * It reduces the size of an image or stack by binning groups of 
 * pixels of user-specified sizes. The resulting pixel can be 
 * calculated as average, median, maximum or minimum.
 *
 * @author Nico Stuurman
 * @author Wayne Rasband
 */
public class Binner implements PlugIn {
    public static int AVERAGE=0, MEDIAN=1, MIN=2, MAX=3, SUM=4;
    private static String[] methods = {"Average", "Median", "Min", "Max", "Sum"};
    private int xshrink=2, yshrink=2, zshrink=1;
    private int method = AVERAGE;
    private float maxValue;

    public void run(String arg) {
        ImagePlus imp = IJ.getImage();
        if (!showDialog(imp))
            return;
        if (imp.getStackSize()==1)
            Undo.setup(Undo.TYPE_CONVERSION, imp);
        imp.startTiming();
        ImagePlus imp2 = shrink(imp, xshrink, yshrink, zshrink, method);
        IJ.showTime(imp, imp.getStartTime(), "", imp.getStackSize());
        imp.setStack(imp2.getStack());
        imp.setCalibration(imp2.getCalibration());
        if (zshrink>1)
            imp.setSlice(1);
    }

    public ImagePlus shrink(ImagePlus imp, int xshrink, int yshrink, int zshrink, int method) {
        this.xshrink = xshrink;
        this.yshrink = yshrink;
        int w = imp.getWidth()/xshrink;
        int h = imp.getHeight()/yshrink;
        ColorModel cm=imp.createLut().getColorModel();
        ImageStack stack=imp.getStack();
        ImageStack stack2 = new ImageStack (w, h, cm);
        int d = stack.getSize();
        if (method==SUM) {
            int bitDepth = imp.getBitDepth();
            if (bitDepth==8)
                maxValue = 255;
            else if (bitDepth==16)
                maxValue = 65535;
            else
                maxValue = 0;
        }
        for (int z=1; z<=d; z++) {
            IJ.showProgress(z, d);
            ImageProcessor ip = stack.getProcessor(z);
            if (ip.isInvertedLut()) 
                ip.invert();
            ImageProcessor ip2 = shrink(ip, method);
            if (ip.isInvertedLut()) ip2.invert();
            stack2.addSlice(stack.getSliceLabel(z), ip2);
        }
        if (zshrink>1 && !imp.isHyperStack())
            stack2 = shrinkZ(stack2, zshrink);
        ImagePlus imp2 = imp.createImagePlus();
        imp2.setStack("Reduced "+imp.getShortTitle(), stack2);
        Calibration cal2 = imp2.getCalibration();
        if (cal2.scaled()) {
            cal2.pixelWidth *= xshrink;
            cal2.pixelHeight *= yshrink;
            cal2.pixelDepth *= zshrink;
        }
        //if (zshrink>1 && imp.isHyperStack())
        //  imp2 = shrinkHyperstackZ(imp2, zshrink);
        imp2.setOpenAsHyperStack(imp.isHyperStack());
        if (method==SUM  && imp2.getBitDepth()>8) {
            ImageProcessor ip = imp2.getProcessor();
            ip.setMinAndMax(ip.getMin(), ip.getMax()*xshrink*yshrink*zshrink);
        }
        return imp2;
    }
    
    private ImageStack shrinkZ(ImageStack stack, int zshrink) {
        int w = stack.getWidth();
        int h = stack.getHeight();
        int d = stack.getSize();
        int d2 = d/zshrink;
        ImageStack stack2 = new ImageStack (w, h, stack.getColorModel());
        for (int z=1; z<=d2; z++)
            stack2.addSlice(stack.getProcessor(z).duplicate());
        boolean rgb = stack.getBitDepth()==24;
        ImageProcessor ip = rgb?new ColorProcessor(d, h):new FloatProcessor(d, h);
        for (int x=0; x<w; x++) {
            IJ.showProgress(x+1, w);
            for (int y=0; y<h; y++) {
                float value;
                for (int z=0; z<d; z++) {
                    value = (float)stack.getVoxel(x, y, z);
                    ip.setf(z, y, value);
                }
            }
            ImageProcessor ip2 = shrink(ip, zshrink, 1, method);
            for (int x2=0; x2<d2; x2++) {
                for (int y2=0; y2<h; y2++) {
                    stack2.setVoxel(x, y2, x2, ip2.getf(x2,y2));
                }
            }
        }
        return stack2;
    }
    
    public ImagePlus shrinkHyperstackZ(ImagePlus imp, int zshrink) {
        int width = imp.getWidth();
        int height = imp.getHeight();
        int channels = imp.getNChannels();
        int slices = imp.getNSlices();
        int frames = imp.getNFrames();
        ImageStack stack = imp.getStack();
        int slices2 = slices/zshrink;
        ImageStack stack2 = new ImageStack(width, height);
        for (int c=1; c<=channels; c++) {
            for (int t=1; t<=frames; t++) {
                ImageStack tstack = new ImageStack(width, height);
                for (int z=1; z<=slices; z++) {
                    int i = imp.getStackIndex(c, z, t);
                    ImageProcessor ip = stack.getProcessor(imp.getStackIndex(c, z, t));
                        tstack.addSlice(stack.getSliceLabel(i), ip);
                }
                //IJ.log("1: "+c+"  "+t+" "+tstack.getSize()+"  "+slices);
                tstack = shrinkZ(tstack, zshrink);
                for (int i=1; i<=tstack.getSize(); i++)
                    stack2.addSlice(tstack.getSliceLabel(i), tstack.getProcessor(i));
            }
        }
        imp.setStack(stack2, channels, slices2, frames);
        new HyperStackConverter().shuffle(imp, HyperStackConverter.ZTC);
        IJ.showProgress(1.0);
        return imp;
    }
    
    public ImageProcessor shrink(ImageProcessor ip, int xshrink, int yshrink, int method) {
        this.xshrink = xshrink;
        this.yshrink = yshrink;
        return shrink(ip, method);
    }

    private ImageProcessor shrink(ImageProcessor ip, int method) {
        if (method<0 || method>methods.length)
            method = AVERAGE;
        int w = ip.getWidth()/xshrink;
        int h = ip.getHeight()/yshrink;
        ImageProcessor ip2 = ip.createProcessor(w, h);
        if (ip instanceof ColorProcessor)
            return shrinkRGB((ColorProcessor)ip, (ColorProcessor)ip2, method);
        for (int y=0; y<h; y++) {
            for (int x=0; x<w; x++) {
                if (method==AVERAGE)
                    ip2.setf(x, y, getAverage(ip, x, y));
                else if (method==MEDIAN)
                    ip2.setf(x, y, getMedian(ip, x, y));
                else if (method==MIN)
                    ip2.setf(x, y, getMin(ip, x, y));
                else if (method==MAX)
                    ip2.setf(x, y, getMax(ip, x, y));
                else if (method==SUM)
                    ip2.setf(x, y, getSum(ip, x, y));
            }
        }
        return ip2;
    }

    private ImageProcessor shrinkRGB(ColorProcessor cp, ColorProcessor cp2, int method) {
        ByteProcessor bp = cp.getChannel(1, null);
        cp2.setChannel(1, (ByteProcessor)shrink(bp, method));
        cp2.setChannel(2, (ByteProcessor)shrink(cp.getChannel(2,bp), method));
        cp2.setChannel(3, (ByteProcessor)shrink(cp.getChannel(3,bp), method));
        return cp2;
    }

    private float getAverage(ImageProcessor ip, int x, int y) {
        float sum = 0;
        for (int y2=0; y2<yshrink; y2++) {
            for (int x2=0;  x2<xshrink; x2++)
                sum += ip.getf(x*xshrink+x2, y*yshrink+y2); 
        }
        return (float)(sum/(xshrink*yshrink));
    }

    private float getMedian(ImageProcessor ip, int x, int y) {
        int shrinksize=xshrink*yshrink;
        float[] pixels = new float[shrinksize];
        int p=0;
        // fill pixels within local neighborhood
        for (int y2=0; y2<yshrink; y2++) {
            for (int x2=0;  x2<xshrink; x2++)
                pixels[p++]= ip.getf(x*xshrink+x2, y*yshrink+y2); 
        }
        // find median value
        int halfsize=shrinksize/2;
        for (int i=0; i<=halfsize; i++) {
            float max=0f;
            int mj=0;
            for (int j=0; j<shrinksize; j++) {
                if (pixels[j]>max) {
                    max = pixels[j];
                    mj = j;
                }
            }
            pixels[mj] = 0;
        }
        float max = -Float.MAX_VALUE;
        for (int j=0; j<shrinksize; j++) {
            if (pixels[j]>max)
                max = pixels[j];
        }
        return max;
    }

    private float getMin(ImageProcessor ip, int x, int y) {
        float min = Float.MAX_VALUE;
        float pixel;
        for (int y2=0; y2<yshrink; y2++) {
            for (int x2=0;  x2<xshrink; x2++) {
                pixel = ip.getf(x*xshrink+x2, y*yshrink+y2); 
                if (pixel<min)
                    min = pixel;
            }
        }
        return min;
    }

    private float getMax(ImageProcessor ip, int x, int y) {
        float max = -Float.MAX_VALUE;
        float pixel;
        for (int y2=0; y2<yshrink; y2++) {
            for (int x2=0;  x2<xshrink; x2++) {
                pixel = ip.getf(x*xshrink+x2, y*yshrink+y2); 
                if (pixel>max)
                    max = pixel;
            }
        }
        return max;
    }

    private float getSum(ImageProcessor ip, int x, int y) {
        float sum = 0;
        for (int y2=0; y2<yshrink; y2++) {
            for (int x2=0;  x2<xshrink; x2++)
                sum += ip.getf(x*xshrink+x2, y*yshrink+y2); 
        }
        if (maxValue>0f && sum>maxValue)
            sum = maxValue;
        return sum;
    }

    private boolean showDialog(ImagePlus imp) {
        boolean stack = imp.getStackSize()>1;
        if (imp.isComposite() && imp.getNChannels()==imp.getStackSize())
            stack = false;
        GenericDialog gd = new GenericDialog("Image Shrink");
        gd.addNumericField("X shrink factor:", xshrink, 0);
        gd.addNumericField("Y shrink factor:", yshrink, 0);
        if (stack)
            gd.addNumericField("Z shrink factor:", zshrink, 0);
        if (method>methods.length)
            method = 0;
        gd.addChoice ("Bin Method: ", methods, methods[method]);
        if (imp.getStackSize()==1) {
            gd.setInsets(5, 0, 0);
            gd.addMessage("This command supports Undo", null, Color.darkGray);
        }
        gd.showDialog();
        if (gd.wasCanceled()) 
            return false;
        xshrink = (int) gd.getNextNumber();
        yshrink = (int) gd.getNextNumber();
        if (stack)
            zshrink = (int) gd.getNextNumber();
        method = gd.getNextChoiceIndex();
        return true;
    }

}