package tracer;

import tracer.geometry.Ray3d;
import tracer.geometry.Vector3d;
import tracer.primatives.Isect;
import tracer.primatives.Light;
import tracer.primatives.Primitive;
import tracer.primatives.Surface;

import java.awt.*;
import java.awt.image.ColorModel;
import java.awt.image.ImageConsumer;
import java.awt.image.ImageProducer;
import java.util.Hashtable;

public class Tracer implements ImageProducer {
  Light lights[] = new Light[8];
  int nlights;
  Primitive prims[] = new Primitive[16];
  int nprims;

  int row[];
  view theView;

  int width, height;
  ColorModel model;
  Hashtable props;
  private ImageConsumer theConsumer;

  public Tracer(int w, int h, view v) {
    width = w;
    height = h;
    props = new Hashtable();
    theView = v;
    model = ColorModel.getRGBdefault();
  }

  /*
   * This ImageProducer boilerplate should be in
   * another class we can extend.
   */
  public synchronized void addConsumer(ImageConsumer ic) {
    theConsumer = ic;
    try {
      produce();
    } catch (Exception e) {
      if (theConsumer != null) {
        theConsumer.imageComplete(
            ImageConsumer.IMAGEERROR);
      }
    }
    theConsumer = null;
  }

  public synchronized boolean isConsumer(ImageConsumer ic) {
    return ic == theConsumer;
  }

  public synchronized void removeConsumer(ImageConsumer ic) {
    if (ic == theConsumer) {
      theConsumer = null;
    }
  }

  public void startProduction(ImageConsumer ic) {
    addConsumer(ic);
  }

  public void requestTopDownLeftRightResend(ImageConsumer ic) {
  }

  /*
   * Find closest ray, return initialized isect
   * with intersection information.
   */
  Isect intersect(Ray3d r, double maxt) {
    Isect tp;
    Isect inter;
    int i, nhits;

    nhits = 0;
    inter = new Isect();
    inter.t = 1e9;
    for (i = 0; i < nprims; i++) {
      tp = prims[i].intersect(r);
      if (tp != null && tp.t < inter.t) {
        inter.t = tp.t;
        inter.prim = tp.prim;
        inter.surf = tp.surf;
        inter.enter = tp.enter;
        nhits++;
      }
    }
    return nhits > 0 ? inter : null;
  }

  int Shadow(Ray3d r, double tmax) {
    if (intersect(r, tmax) != null)
      return 0;
    return 1;
  }

  /*
   * Return the vector's cutils.reflection direction.
   */
  Vector3d SpecularDirection(Vector3d I, Vector3d N) {
    Vector3d r;
    r = Vector3d.comb(1.0 / Math.abs(Vector3d.dot(I, N)), I, 2.0, N);
    r.normalize();
    return r;
  }

  /*
   * Likewise for transmission direction...
   */
  Vector3d TransDir(Surface m1, Surface m2, Vector3d I, Vector3d N) {
    double n1, n2, eta, c1, cs2;
    Vector3d r;
    n1 = m1 == null ? 1.0 : m1.ior;
    n2 = m2 == null ? 1.0 : m2.ior;
    eta = n1 / n2;
    c1 = -Vector3d.dot(I, N);
    cs2 = 1.0 - eta * eta * (1.0 - c1 * c1);
    if (cs2 < 0.0)
      return null;
    r = Vector3d.comb(eta, I, eta * c1 - Math.sqrt(cs2), N);
    r.normalize();
    return r;
  }

  /*
   * Straight out of Glassner via MTV ...
   */
  Vector3d shade(int level, double weight, Vector3d P, Vector3d N, Vector3d I, Isect hit) {
    Ray3d tray;
    Vector3d tcol;
    Vector3d L, H, R;
    double t, diff, spec;
    Surface surf;
    Vector3d col;
    int l;

    col = new Vector3d(0, 0, 0);
    surf = hit.surf;
    R = new Vector3d(0, 0, 0);
    if (surf.shine > 1e-6) {
      R = SpecularDirection(I, N);
    }
    for (l = 0; l < nlights; l++) {
      L = Vector3d.sub(lights[l].pos, P);
      if (Vector3d.dot(N, L) >= 0.0) {
        t = L.normalize();
        tray = new Ray3d(P, L);
        if (Shadow(tray, t) > 0) {
          diff = Vector3d.dot(N, L) * surf.kd *
              lights[l].brightness;
          col = Vector3d.adds(diff, surf.color, col);
          if (surf.shine > 1e-6) {
            spec = Vector3d.dot(R, L);
            if (spec > 1e-6) {
              spec = Math.pow(spec,
                              surf.shine);
              col.setX(col.getX() + spec);
              col.setY(col.getY() + spec);
              col.setZ(col.getZ() + spec);
            }
          }
        }
      }
    }

    tray = new Ray3d(P, new Vector3d(0, 0, 0));
    if (surf.ks * weight > 1e-3) {
      tray.setDirection(SpecularDirection(I, N));
      tcol = trace(level + 1, surf.ks * weight, tray);
      col = Vector3d.adds(surf.ks, tcol, col);
    }
    if (surf.kt * weight > 1e-3) {
      if (hit.enter > 0)
        tray.setDirection(TransDir(null, surf, I, N));
      else
        tray.setDirection(TransDir(surf, null, I, N));
      tcol = trace(level + 1, surf.kt * weight, tray);
      col = Vector3d.adds(surf.kt, tcol, col);
    }
    return col;
  }

  Vector3d trace(int level, double weight, Ray3d r) {
    Primitive prim;
    Vector3d P, N;
    Isect hit;

    if (level > 6) {
      return new Vector3d(0, 0, 0);
    }

    hit = intersect(r, 1e6);
    if (hit != null) {
      prim = hit.prim;
      P = r.point(hit.t);
      N = prim.normal(P);
      if (Vector3d.dot(r.getDirection(), N) >= 0.0) {
        N.negate();
      }
      return shade(level, weight, P, N, r.getDirection(), hit);
    }
    return new Vector3d(0, 0, 0);
  }

  void scan(Vector3d eye, Vector3d viewvec, Vector3d upvec, Vector3d leftvec,
            int xres, int yres, int xmin, int xmax) {
    Ray3d r;
    int x, y, red, green, blue;
    double xlen, ylen;
    Image imgLine;
    Vector3d col;

    r = new Ray3d(eye, new Vector3d(0, 0, 0));
    System.out.println("scan: " + xres + "," + yres);
    System.out.println("scan: viewvec " + viewvec.toString());
    System.out.println("scan: upvec " + upvec.toString());
    System.out.println("scan: leftvec " + leftvec.toString());
    col = new Vector3d(0, 0, 0);
    for (y = 0; y < yres; y++) {
      ylen = (double) (2.0 * y) / (double) yres - 1.0;
      int subImage[] =
          getColumn(xmin, xmax, xres, leftvec, ylen, upvec, r, viewvec);

      theConsumer.setPixels(0, y, xmax - xmin, 1, model,
                            subImage, 0, xmax - xmin);
    }
  }

  private int[] getColumn(
      int xmin, int xmax, int xres,
      Vector3d leftvec, double ylen, Vector3d upvec, Ray3d r, Vector3d viewvec) {
    int x;
    double xlen;
    Vector3d col;
    int red;
    int green;
    int blue;
    int subImage[] = new int[xmax - xmin];
    for (x = xmin; x < xmax; x++) {
      xlen = (double) (2.0 * x) / (double) xres - 1.0;
      r.setDirection(Vector3d.comb(xlen, leftvec, ylen, upvec));
      r.setDirection(Vector3d.add(r.getDirection(), viewvec));
      r.getDirection().normalize();
      col = trace(0, 1.0, r);

      red = (int) (col.getX() * 255.0);
      if (red > 255)
        red = 255;
      green = (int) (col.getY() * 255.0);
      if (green > 255)
        green = 255;
      blue = (int) (col.getZ() * 255.0);
      if (blue > 255)
        blue = 255;
      subImage[x] = (255 << 24) |
          (red << 16) |
          (green << 8) |
          (blue);
    }
    return subImage;
  }

  private void produce() {
    Vector3d viewvec, leftvec, upvec, tmpvec;
    double frustrumwidth;

    if (theConsumer != null) {
      theConsumer.setDimensions(width, height);
      theConsumer.setProperties(props);
      theConsumer.setColorModel(model);
      theConsumer.setHints(ImageConsumer.TOPDOWNLEFTRIGHT |
                           ImageConsumer.COMPLETESCANLINES |
                           ImageConsumer.SINGLEPASS |
                           ImageConsumer.SINGLEFRAME);
    }

    viewvec = Vector3d.sub(theView.getAt(), theView.getFrom());
    viewvec.normalize();

    tmpvec = new Vector3d(viewvec);
    tmpvec.scale(Vector3d.dot(theView.getUp(), viewvec));
    upvec = Vector3d.sub(theView.getUp(), tmpvec);
    upvec.normalize();

    leftvec = Vector3d.cross(theView.getUp(), viewvec);
    leftvec.normalize();

    frustrumwidth = theView.getDist() * Math.tan(theView.getAngle());
    upvec.scale(-frustrumwidth);
    leftvec.scale(theView.getAspect() * frustrumwidth);

    scan(theView.getFrom(), viewvec, upvec, leftvec, width, height, 0, width);
    //scan(theView.from, viewvec, upvec, leftvec, width, height,0,100);

    if (theConsumer != null) {
      theConsumer.imageComplete(
          ImageConsumer.STATICIMAGEDONE);
    }
  }

  void newLight(Vector3d p, double b) {
    if (nlights < 8) {
      lights[nlights] = new Light();
      lights[nlights].pos = p;
      lights[nlights].brightness = b;
      nlights++;
    }
  }

  void newPrim(Primitive p) {
    if (nprims < 16) {
      prims[nprims] = p;
      nprims++;
    }
  }
}