package ip.gui;

import transforms.fft.FFT2d;
import transforms.fft.IFFT2d;
import transforms.fft.ImageUtils;

import java.awt.*;
import java.awt.image.ImageObserver;
import java.awt.image.MemoryImageSource;
import java.awt.image.PixelGrabber;

// Calculate the transforms.fft of input image, using transforms.fft.FFT2d.

public class FFTImage implements ImageObserver {

    int intImage[];
    int imageWidth, imageHeight;
    // image.length, or imageWidth * imageHeight
    int N;
    // scale is used to scale the FFT input to prevent overflow,
    // N = imageWidth * imageHeight is often used.
    int scale;
    // Scale the FFT output magnitude to get the best display result.
    float magScale;

    boolean fftShift;

    short alpha[];
    float redRe[], greenRe[], blueRe[];
    float redIm[], greenIm[], blueIm[];

    public boolean imageUpdate(Image img, int infoflags,
                               int x, int y, int width, int height) {
        initImage(img);
        return true;
    }

    public FFTImage(Image img) {

        initImage(img);
    }

    /**
     *    uses one image to filter another.
     *    only works if the two images are the same size.
     */
    public static Image filter(Image img1, Image img2) {
        FFTImage fimg1 = new FFTImage(img1);
        FFTImage fimg2 = new FFTImage(img2);
        fimg1.fft();
        fimg2.fft();
        fimg1.mult(fimg2);
        fimg1.ifft();
        return fimg1.getImage();
    }

    public Image getImage() {
        Toolkit tk = Toolkit.getDefaultToolkit();
        return tk.createImage(
                new MemoryImageSource(
                        imageWidth,
                        imageHeight,
                        tk.getColorModel(),
                        intImage, 0,
                        imageWidth));
    }

    private void initImage(Image img) {
        if (img == null) return;
        int width = img.getWidth(this);
        int height = img.getHeight(this);


        int intImage[] = new int[width * height];

        PixelGrabber grabber =
                new PixelGrabber(
                        img, 0, 0,
                        width, height, true);
        imageWidth = grabber.getWidth();
        imageHeight = grabber.getHeight();

        try {
            grabber.grabPixels();
        } catch (InterruptedException e) {
        }
        initVariables((int[]) grabber.getPixels(), imageWidth, 100.0f, true);
    }


    /**
     *  complex multiply, for filtering
     */
    public void mult(FFTImage f) {
        for (int i = 0; i < redRe.length; i++) {
            redRe[i] *= f.redRe[i] / 255.0f;
            greenRe[i] *= f.greenRe[i] / 255.0f;
            blueRe[i] *= f.blueRe[i] / 255.0f;
            redIm[i] *= f.redIm[i] / 255.0f;
            greenIm[i] *= f.greenIm[i] / 255.0f;
            blueIm[i] *= f.blueIm[i] / 255.0f;
        }
    }

    // getPels(), width, 100.0f, true
    public FFTImage(int intImage[], int imageWidth,
                    float magScale, boolean fftShift) {
        initVariables(intImage, imageWidth, magScale, fftShift);

    }

    private void initVariables(int[] intImage, int imageWidth, float magScale, boolean fftShift) {
        this.intImage = intImage;
        this.imageWidth = imageWidth;
        N = intImage.length;
        imageHeight = N / imageWidth;
        this.magScale = magScale;
        this.fftShift = fftShift;
        scale = N;

        alpha = ImageUtils.getAlpha(intImage);
        short red[] = ImageUtils.getRed(intImage);
        short green[] = ImageUtils.getGreen(intImage);
        short blue[] = ImageUtils.getBlue(intImage);

        // If fftShift is true, shift the zero frequency to the center.
        redRe = fftReorder(red);
        greenRe = fftReorder(green);
        blueRe = fftReorder(blue);
        redIm = new float[N];
        greenIm = new float[N];
        blueIm = new float[N];
    }

    public void fft() {
        intImage = getFftIntArray();
    }

    public int[] getFftIntArray() {
        new FFT2d(redRe, redIm, imageWidth);
        new FFT2d(greenRe, greenIm, imageWidth);
        new FFT2d(blueRe, blueIm, imageWidth);

        float resultRed[] = magnitude(redRe, redIm);
        float resultGreen[] = magnitude(greenRe, greenIm);
        float resultBlue[] = magnitude(blueRe, blueIm);

        int resultImage[] = ImageUtils.argbToInt(alpha, resultRed,
                resultGreen, resultBlue);
        return resultImage;
    }

    public void ifft() {
        intImage = getIfftIntArray();
    }

    public int[] getIfftIntArray() {
        new IFFT2d(redRe, redIm, imageWidth);
        new IFFT2d(greenRe, greenIm, imageWidth);
        new IFFT2d(blueRe, blueIm, imageWidth);

        short resultRed[] = ifftReorder(redRe);
        short resultGreen[] = ifftReorder(greenRe);
        short resultBlue[] = ifftReorder(blueRe);

        int resultImage[] = ImageUtils.argbToInt(alpha, resultRed,
                resultGreen, resultBlue);
        return resultImage;
    }

    // reorder color data of transforms.fft input.
    // 1. Convert color data from short to float.
    // 2. Scale the color data by scale.
    // 3. If fftShift is true, shift the zero frequency in the center of matrix.
    private float[] fftReorder(short color[]) {
        float result[] = new float[N];

        if (!fftShift) {   // Without zero frequency shift.
            for (int i = 0; i < N; i++)
                result[i] = (float) color[i] / scale;
        } else {            // With zero frequency shift.
            int k = 0;
            float alternateSign = 1;
            for (int i = 0; i < imageHeight; i++)
                for (int j = 0; j < imageWidth; j++) {
                    alternateSign = ((i + j) % 2 == 0) ? -1 : 1;
                    result[i * imageWidth + j] = (float) (color[k++] * alternateSign / scale);
                }
        }
        return result;
    } // End of function fftReorder().

    private short[] ifftReorder(float re[]) {
        short result[] = new short[N];

        if (!fftShift) {   // Without zero frequency shift.
            for (int i = 0; i < N; i++)
                result[i] = (short) (re[i] * scale);
        } else {            // With zero frequency shift.
            int k = 0;
            float alternateSign = 1;
            for (int i = 0; i < imageHeight; i++)
                for (int j = 0; j < imageWidth; j++) {
                    alternateSign = ((i + j) % 2 == 0) ? -1 : 1;
                    result[i * imageWidth + j] = (short) (re[k++] * alternateSign * scale);
                }
        }
        return result;
    } // End of function fftReorder().

    // Scale the FFT output magnitude to get the best display result.
    private float[] magnitude(float re[], float im[]) {
        float result[] = new float[N];
        for (int i = 0; i < N; i++) {
            result[i] = (float) (magScale *
                    Math.log(1 + Math.sqrt(re[i] * re[i] + im[i] * im[i])));
            if (result[i] > 255)
                result[i] = 255;
        }
        return result;
    } // End of function magnitude().

} // End of class FFTImage.