package ip.gui.frames;

import ip.gui.dialog.AdaptiveLog;
import ip.gui.dialog.ExponentialLog;
import ip.gui.dialog.GrayLog;
import ip.gui.dialog.RayleighLog;
import ip.gui.Histogram;
import ip.gui.TransformTable;
import ip.gui.Print;
import math.Mat;

import java.awt.*;
import java.awt.event.ActionEvent;

public class NegateFrame extends GrabFrame {

    private Histogram rh,gh,bh;

    private int min = Integer.MAX_VALUE;
    private int max = Integer.MIN_VALUE;

    // used for adaptive image enhancement
    private int mosaicWidth = 2;
    private int mosaicHeight = 2;

    private double rBar = 0;
    private double gBar = 0;
    private double bBar = 0;

    private TransformTable tt =
            new TransformTable(256);

    private Menu negateMenu = getMenu("Negate");
    private Menu histogramMenu = getMenu("Histogram");

    private MenuItem negate_mi =
            addMenuItem(negateMenu, "[E-n]egate");
    private MenuItem add10_mi =
            addMenuItem(negateMenu, "[E-1]add 10");
    private MenuItem brighten_mi =
            addMenuItem(negateMenu, "[E-b]righten");
    private MenuItem darken_mi =
            addMenuItem(negateMenu, "[E-d]arken");
    private MenuItem linear_mi =
            addMenuItem(negateMenu, "[E-l]inear transform...");

    private MenuItem histogram_mi =
            addMenuItem(histogramMenu, "[E-h]istogram");
    private MenuItem unahe_mi =
            addMenuItem(histogramMenu,
                    "[E-u]niform non-adaptive histogram equalization");
    private MenuItem enahe_mi =
            addMenuItem(histogramMenu, "[E-e]xponential non-adaptive histogram equalization...");
    private MenuItem rnahe_mi =
            addMenuItem(histogramMenu,
                    "[E-r]ayleigh non-adaptive histogram equalization...");

    private MenuItem auhe_mi =
            addMenuItem(histogramMenu, "[E-a]uhe adaptive uniform histogram equalization");
    private MenuItem drawMosaic_mi =
            addMenuItem(histogramMenu, "[E-d]raw Mosaic");
    private MenuItem printTT_mi =
            addMenuItem(histogramMenu, "[E-T-t]print transform table...");
    private MenuItem printStats_mi =
            addMenuItem(histogramMenu, "[E-T-p]rintStats");
    private MenuItem printPMFs_mi =
            addMenuItem(histogramMenu, "[E-1]printPMFs");
    private MenuItem printCMFs_mi =
            addMenuItem(histogramMenu, "[E-2]printCMF for r");


    public void actionPerformed(ActionEvent e) {

        if (match(e, drawMosaic_mi)) {
            drawMosaic(this);
            return;
        }
        if (match(e, rnahe_mi)) {
            rayleighLog();
            return;
        }
        if (match(e, auhe_mi)) {
            auhe();
            return;
        }
        if (match(e, enahe_mi)) {
            eponentialLog();
            return;
        }
        if (match(e, printCMFs_mi)) {
            printCMFs();
            return;
        }
        if (match(e, printPMFs_mi)) {
            printPMFs();
            return;
        }
        if (match(e, printStats_mi)) {
            printStats();
            return;
        }
        if (match(e, unahe_mi)) {
            unahe();
            return;
        }
        if (match(e, printTT_mi)) {
            printTT();
            return;
        }
        if (match(e, linear_mi)) {
            linearLog();
            return;
        }
        if (match(e, darken_mi)) {
            darken();
            return;
        }
        if (match(e, brighten_mi)) {
            brighten();
            return;
        }
        if (match(e, histogram_mi)) {
            histogram();
            return;
        }
        if (match(e, add10_mi)) {
            add10();
            return;
        }
        if (match(e, negate_mi)) {
            negate(this);
            return;
        }
        super.actionPerformed(e);

    }

    /**
     mosaic - transform an array of
     short into sub-images
     */
    public void auhe() {
        auhe(this, mosaicWidth, mosaicHeight);
    }

    public static void drawMosaic(NegateFrame negateFrame) {
        AdaptiveLog.doit(negateFrame);
    }

    /**
     mosaic - transform an array of
     short into sub-images
     ignoring fractional parts.
     */
    public static void auhe(NegateFrame negateFrame, int blocksHigh, int blocksWide) {

        int pelsWide = negateFrame.getImageWidth() / blocksWide;
        int pelsHigh = negateFrame.getImageHeight() / blocksHigh;
        int newWidth = pelsWide * blocksWide;
        int newHeight = pelsHigh * blocksHigh;
        NegateFrame nf;

        for (int x1 = 0; x1 < newWidth; x1 += pelsWide)
            for (int y1 = 0; y1 < newHeight; y1 += pelsHigh) {
                nf = subFrame(negateFrame, x1, y1, pelsWide, pelsHigh);
                nf.unahe();
                assembleMosaic(negateFrame, nf, x1, y1);
            }
        negateFrame.short2Image();
    }

    public static void assembleMosaic(NegateFrame negateFrame, NegateFrame nf, int x1, int y1) {
        int x2 = negateFrame.getImageWidth() + x1;
        int y2 = negateFrame.getImageHeight() + y1;
        int xs = 0;
        int ys = 0;
        for (int x = x1; x < x2; x++) {
            for (int y = y1; y < y2; y++) {
                negateFrame.getR()[x][y] = nf.getR()[xs][ys];
                negateFrame.getG()[x][y] = nf.getG()[xs][ys];
                negateFrame.getB()[x][y] = nf.getB()[xs][ys];
                ys++;
            }
            ys = 0;
            xs++;
        }
    }

    public static void drawMosaic(NegateFrame negateFrame, int blocksHigh, int blocksWide) {
        negateFrame.mosaicWidth = blocksWide;
        negateFrame.mosaicHeight = blocksHigh;

        int pelsWide = negateFrame.getImageWidth() / blocksWide;
        int pelsHigh = negateFrame.getImageHeight() / blocksHigh;
        int newWidth = pelsWide * blocksWide;
        int newHeight = pelsHigh * blocksHigh;

        int x1 = 0, y1 = 0;
        Print.println("DrawMosaic" +
                " newWidth=" + newWidth +
                " newHeight=" + newHeight +
                " pelsWide=" + pelsWide +
                " pelsHigh=" + pelsHigh);
        Graphics gx = negateFrame.getGraphics();
        for (x1 = 0; x1 < newWidth; x1 += pelsWide)
            for (y1 = 0; y1 < newHeight; y1 += pelsHigh) {
                gx.drawRect(x1, y1, pelsWide, pelsHigh);
            }
    }

    public static NegateFrame subFrame(NegateFrame negateFrame, int x1, int y1, int w, int h) {
        short _r[][] = new short[w][h];
        short _g[][] = new short[w][h];
        short _b[][] = new short[w][h];
        int x2 = x1 + w;
        int y2 = y1 + h;
        // for loop computes source coordinates
        int xd = 0;
        int yd = 0;

        Print.println("Subframe" +
                " x1=" + x1 +
                " y1=" + y1 +
                " x2=" + x2 +
                " y2=" + y2);
        for (int x = x1; x < x2; x++) {
            for (int y = y1; y < y2; y++) {
                _r[xd][yd] = negateFrame.getR()[x][y];
                _g[xd][yd] = negateFrame.getG()[x][y];
                _b[xd][yd] = negateFrame.getB()[x][y];
                yd++;
            }
            yd = 0;
            xd++;
        }
        return new
                NegateFrame(_r, _g, _b, "frame");
    }

    private void doMenus() {
        negateMenu.add(histogramMenu);
        filterMenu.add(negateMenu);
    }

    public NegateFrame(
            short _r[][], short _g[][], short _b[][],
            String title) {
        super(title);
        doMenus();
        System.out.println("New constructor invoked");
        setR(_r);
        setG(_g);
        setB(_b);
        // show image ...very slow
        // but interesting!
        short2Image();
    }

    public void printTT() {
        tt.print();
    }

    public void add10() {
        for (int x = 0; x < getImageWidth(); x++)
            for (int y = 0; y < getImageHeight(); y++) {
                getR()[x][y] = (short) (getR()[x][y] + 10);
                getG()[x][y] = (short) (getG()[x][y] + 10);
                getB()[x][y] = (short) (getB()[x][y] + 10);
            }
        short2Image();
    }

    public void histogram() {
        rh = new Histogram(getR(), "Red");
        gh = new Histogram(getG(), "Green");
        bh = new Histogram(getB(), "Blue");
        rh.myShow();
        gh.myShow();
        bh.myShow();
    }

    public static void negate(NegateFrame negateFrame) {
        for (int x = 0; x < negateFrame.getImageWidth(); x++)
            for (int y = 0; y < negateFrame.getImageHeight(); y++) {
                negateFrame.getR()[x][y] = (short) (255 - negateFrame.getR()[x][y]);
                negateFrame.getG()[x][y] = (short) (255 - negateFrame.getG()[x][y]);
                negateFrame.getB()[x][y] = (short) (255 - negateFrame.getB()[x][y]);
            }
        negateFrame.short2Image();
    }

    private void brighten() {
        powImage(0.9);
    }

    private void darken() {
        powImage(1.5);
    }

    public void powImage(double p) {
        for (int x = 0; x < getImageWidth(); x++)
            for (int y = 0; y < getImageHeight(); y++) {
                getR()[x][y] = (short)
                        (255 * Math.pow((getR()[x][y] / 255.0), p));
                getG()[x][y] = (short)
                        (255 * Math.pow((getG()[x][y] / 255.0), p));
                getB()[x][y] = (short)
                        (255 * Math.pow((getB()[x][y] / 255.0), p));
            }
        short2Image();
    }

    // Uniform Non Adaptive Histogram
    // Equalization
    public void unahe() {
        short lut[] = tt.getLut();
        double h[] = getAverageCMF();
        for (short i = 0; i < lut.length; i++)
            lut[i] = (short) (255 * h[i]);
        applyLut(lut);
    }

    // Rayleigh Non Adaptive Histogram
    // Equalization
    public void rnahe(double alpha) {
        short lut[] = tt.getLut();
        double h[] = getAverageCMF();
        double alpha2 = 2 * alpha * alpha;
        double v;
        double g;
        for (short i = 0; i < h.length; i++) {
            g = alpha2 * Math.log(1 / (1.0 - h[i]));
            v = Math.sqrt(g);
            lut[i] = (short) (255 * v);
        }
        tt.clip();
        applyLut(lut);
    }

    // Exponential Non Adaptive Histogram
    // Equalization
    public void enahe(double alpha) {
        short lut[] = tt.getLut();
        double h[] = getAverageCMF();
        for (short i = 0; i < 256; i++)
            lut[i] = (short)
                    (255 * (-Math.log(1.0 - h[i]) / alpha));
        tt.clip();
        applyLut(lut);
    }


    public double[] getAverageCMF() {
        rh = new Histogram(getR(), "Red");
        gh = new Histogram(getG(), "Green");
        bh = new Histogram(getB(), "Blue");
        double CMFr[] = rh.getCMF();
        double CMFg[] = gh.getCMF();
        double CMFb[] = bh.getCMF();
        return Mat.getAverage(CMFr, CMFg, CMFb);
    }

    public void applyLut(short lut[]) {
        wellConditioned(); //Shorts could be out of range;
        for (int x = 0; x < getImageWidth(); x++)
            for (int y = 0; y < getImageHeight(); y++) {

                getR()[x][y] = lut[getR()[x][y]];
                getG()[x][y] = lut[getG()[x][y]];
                getB()[x][y] = lut[getB()[x][y]];
            }
        short2Image();
    }

    public void applyLut(short lutr[], short lutg[], short lutb[]) {
        wellConditioned(); //Shorts could be out of range;
        for (int x = 0; x < getImageWidth(); x++)
            for (int y = 0; y < getImageHeight(); y++) {

                getR()[x][y] = lutr[getR()[x][y]];
                getG()[x][y] = lutg[getG()[x][y]];
                getB()[x][y] = lutb[getB()[x][y]];
            }
        short2Image();
    }

    public void wellConditioned() {
        for (int x = 0; x < getImageWidth(); x++)
            for (int y = 0; y < getImageHeight(); y++) {
                getR()[x][y] = inRange(getR()[x][y], x, y);
                getG()[x][y] = inRange(getG()[x][y], x, y);
                getB()[x][y] = inRange(getB()[x][y], x, y);
            }
    }

    public short inRange(short v, int x, int y) {
        if (v > 255) {
            Print.println(
                    "out of range x=" + x + " y=" + y +
                    "v=" + v + " clipping to 255");
            return 255;
        }
        if (v < 0) {
            Print.println(
                    "out of range x=" + x + " y=" + y +
                    "v=" + v + " clipping to 0");
            return 0;
        }
        return v;
    }

    public short linearMap(short v,
                           double c, double b) {
        // scale gray value to 0..1 range
        double f = c * v + b;
        // scale f into 0..255 range
        // clip f into range
        if (f > 255) f = 255;
        if (f < 0) f = 0;
        return (short) f;
    }

    public void linearTransform() {
        computeStats();
        int Vmin = getMinimum();
        int Vmax = getMaximum();
        int Dmin = 0;
        int Dmax = 255;
        double deltaV = Vmax - Vmin;
        double deltaD = Dmax - Dmin;
        double c = deltaD / deltaV;
        double b = (Dmin * Vmax - Dmax * Vmin) / deltaV;
        linearTransform(this, c, b);
    }

    public static void linearTransform(NegateFrame negateFrame, double c, double br) {
        for (int x = 0; x < negateFrame.getImageWidth(); x++)
            for (int y = 0; y < negateFrame.getImageHeight(); y++) {
                negateFrame.getR()[x][y] = (short) (c * negateFrame.getR()[x][y] + br);
                negateFrame.getG()[x][y] = (short) (c * negateFrame.getG()[x][y] + br);
                negateFrame.getB()[x][y] = (short) (c * negateFrame.getB()[x][y] + br);
            }
    }

    // The following transform is fast, but
    // only works on well conditioned input.
    // I.e., r,g,b [0..255].
    public void linearTransform2(double c, double br) {

        short lut[] = tt.getLut();
        for (short i = 0; i < 256; i++)
            lut[i] = linearMap(i, c, br);
        for (int x = 0; x < getImageWidth(); x++)
            for (int y = 0; y < getImageHeight(); y++) {
                getR()[x][y] = lut[getR()[x][y]];
                getG()[x][y] = lut[getG()[x][y]];
                getB()[x][y] = lut[getB()[x][y]];
            }
        short2Image();
    }

    public void computeStats() {
        min = Integer.MAX_VALUE;
        max = Integer.MIN_VALUE;
        rBar = 0;
        gBar = 0;
        bBar = 0;
        double N = getImageWidth() * getImageHeight();
        for (int x = 0; x < getImageWidth(); x++)
            for (int y = 0; y < getImageHeight(); y++) {
                rBar += getR()[x][y];
                gBar += getG()[x][y];
                bBar += getB()[x][y];
                min = Math.min(getR()[x][y], min);
                min = Math.min(getG()[x][y], min);
                min = Math.min(getB()[x][y], min);
                max = Math.max(getR()[x][y], max);
                max = Math.max(getG()[x][y], max);
                max = Math.max(getB()[x][y], max);
            }
        rBar /= N;
        gBar /= N;
        bBar /= N;
    }

    public void printPMFr() {
        rh = new Histogram(getR(), "Red");
        rh.printPMF();
    }

    public void printCMFs() {
        rh = new Histogram(getR(), "Red");
        rh.printCMF();
    }

    public void printPMFg() {
        gh = new Histogram(getG(), "Green");
        gh.printPMF();

    }

    public void printPMFb() {
        bh = new Histogram(getB(), "Blue");
        bh.printPMF();

    }

    public void printPMFs() {
        printPMFr();
        printPMFg();
        printPMFb();
    }

    public void printStats() {
        computeStats();
        Print.println(
                "Min Vij=" + getMinimum() + "\n" +
                "Max Vij=" + getMaximum() + "\n" +
                "rBar = " + getRBar() + "\n" +
                "gBar = " + getGBar() + "\n" +
                "bBar = " + getBBar()
        );

    }

    public double getRBar() {
        return rBar;
    }

    public double getGBar() {
        return gBar;
    }

    public double getBBar() {
        return bBar;
    }

    public int getMinimum() {
        return min;
    }

    public int getMaximum() {
        return max;
    }

    public void eponentialLog() {
        String prompts[] = {
            "alpha = "
        };
        String defaults[] = {
            "4.0"};
        String title = "Exponential Transform Dialog";


        new ExponentialLog(
                this,
                title,
                prompts,
                defaults, 9);
    }

    public void rayleighLog() {
        String prompts[] = {
            "alpha = "
        };
        String defaults[] = {
            "4.0"};
        String title = "Rayleigh Transform Dialog";
        new RayleighLog(
                this,
                title,
                prompts,
                defaults, 9);
    }

    public void linearLog() {
        String prompts[] = {
            "Contrast = c =",
            "Brightness = b ="
        };
        computeStats();
        int Vmin = getMinimum();
        int Vmax = getMaximum();
        int Dmin = 0;
        int Dmax = 255;
        double deltaV = Vmax - Vmin;
        double deltaD = Dmax - Dmin;
        double c = deltaD / deltaV;
        double b = (Dmin * Vmax - Dmax * Vmin) / deltaV;
        Print.println("C=" + c + " b=" + b);

        String defaults[] = {
            Double.toString(c),
            Double.toString(b)};
        String title = "Linear Grayscale Transform Dialog";
        new GrayLog(
                this,
                title,
                prompts,
                defaults, 9);

    }

    public NegateFrame(String title) {
        super(title);
        doMenus();
    }

    public static void main(String args[]) {
        new NegateFrame("NegateFrame");
    }
}