package collembola;

import java.net.MalformedURLException;
import java.net.URL;

import joptsimple.OptionParser;
import joptsimple.OptionSet;
import mcmas.core.MCMChrono;
import mcmas.core.MCMCommandQueue;
import mcmas.core.MCMCommandQueueProperty;
import mcmas.core.MCMContext;
import mcmas.core.MCMEvent;
import mcmas.core.MCMKernel;
import mcmas.core.MCMMem;
import mcmas.core.MCMProgram;
import mcmas.core.MCMUtils;

import org.jocl.CL;
import org.jocl.Pointer;

import collembola.model.CollemLoader;
import collembola.model.CollemModel;
//import collembola.model.CollemUtils;

public class CollemSimulation2 {
	
	private final CollemModel model;
	
	private MCMContext context;
	private final MCMCommandQueue queue;
	private final MCMProgram program;
	
	//private final OCLKernel kernel;
	private final MCMKernel reproduceKernel;
	private final MCMKernel newArrivalsKernel;
	private final MCMKernel diffuseKernel;
	private final MCMKernel deathKernel;
	
	//private final OCLMem patchesMem;
	//private final OCLMem terrainsMem;
	
	//private final IntBuffer patchesPopulations;
	//private final IntBuffer patchesOverflows;
	//private final IntBuffer patchesOwners;
	//private final IntBuffer patchesTypes;
	//private final IntBuffer plotsPopulations;
	//private final IntBuffer plotSurface;
	//private final CollemWorld world;
	
	private final int [] patchesPopulations;
	private final int [] newPopulations;
	private final int [] patchesOverflows;
	private final int [] plotsPopulations;
	
	private final MCMMem patchesPopulationsMem;
	private final MCMMem patchesOverflowsMem;
	private final MCMMem patchesOwnersMem;
	private final MCMMem patchesTypesMem;
	
	private final MCMMem newPopulationsMem;
	
	private final MCMMem plotsPopulationsMem;
	private final MCMMem plotsSurfacesMem;
	
	private final MCMMem worldMem;
	
	public CollemSimulation2(CollemModel model) {
		this(model, new MCMContext());
	}
	
	public CollemSimulation2(CollemModel model, MCMContext context) {
		this.model = model;
		
		this.context = context;
		this.queue = context.createCommandQueue(MCMCommandQueueProperty.ENABLE_PROFILING);
		
		//this.patchesMem = context.newBuffer().Using(model.getPatches()).b();
		//this.terrainsMem = context.newBuffer().Using(model.getTerrains()).b();
		//this.populationsMem = context.newBuffer().Using(model.getPopulations()).b();
		//this.worldMem = context.newBuffer().Using(model.getWorld()).b();
		
		this.program = MCMUtils.compileFile(context, "kernels/collembola2.cl",
				" -DWIDTH=" + model.getWidth() + " -DHEIGHT=" + model.getHeight());
		//this.kernel = program.createKernel("collembola");
		
		//this.kernel = program.createKernel("diffusion");
		this.reproduceKernel = program.createKernel("reproduction");
		this.newArrivalsKernel = program.createKernel("newArrivals");
		this.diffuseKernel = program.createKernel("diffusion");
		this.deathKernel = program.createKernel("death");
		
		/*this.patchesPopulations = new int[(model.getHeight() + 2) * (model.getWidth() + 2)];
		this.newPopulations = new int[(model.getHeight() + 2) * (model.getWidth() + 2)];
		this.patchesOverflows = new int[model.getHeight() * model.getWidth()];*/
		
		this.patchesPopulations = model.getPatchesPopulations();
		//this.patchesPopulations = new int[model.getHeight() * model.getWidth()];
		this.newPopulations = new int[model.getHeight() * model.getWidth()];
		this.patchesOverflows = new int[(model.getHeight() + 2) * (model.getWidth() + 2)];
		
		for (int i = 0; i < model.getWorld().width; i++) {
			this.patchesPopulations[model.getHeight() / 2 * model.getWidth() + i] = 100;
		}
		//this.patchesPopulations[8 * 8] = 1000;
		//this.patchesPopulations[8 * 8] = 1000;
		
		this.plotsPopulations = new int[model.getPlotsSurfaces().length];
		
		this.patchesPopulationsMem = context.newBuffer().Using(patchesPopulations).b();
		this.newPopulationsMem = context.newBuffer().Using(newPopulations).b();
		
		this.patchesOwnersMem = context.newBuffer().Using(model.getPatchesOwners()).b();
		this.patchesTypesMem = context.newBuffer().Using(model.getPatchesTypes()).b();
		this.patchesOverflowsMem = context.newBuffer().Using(patchesOverflows).b();
		
		this.plotsSurfacesMem = context.newBuffer().Using(model.getPlotsSurfaces()).b();
		this.plotsPopulationsMem = context.newBuffer().Using(plotsPopulations).b();
		this.worldMem = context.newBuffer().Using(model.getWorld()).b();
	}
	
	/*
	public void prepare() {
		int [] terrains = model.getTerrains();
		
		System.out.println(Arrays.toString(model.getPopulations()));
		
		//kernel.setArguments(patchesMem, model.getWidth());
		kernel.setArguments(populationsMem, terrainsMem, model.getWidth(), model.getHeight());
		queue.enqueueKernel(kernel, 2, new long[] { model.getHeight(),  model.getWidth() });
		
		queue.blockingReadBuffer(terrainsMem, Pointer.to(terrains), 0, terrainsMem.getSize());
		
		
		System.out.println(Arrays.toString(terrains));
		
		queue.enqueueCopyBuffer(terrainsMem, populationsMem, 0, 0, terrainsMem.getSize());
		queue.enqueueKernel(kernel, 2, new long[] { model.getHeight(),  model.getWidth() });
		queue.blockingReadBuffer(terrainsMem, Pointer.to(terrains), 0, terrainsMem.getSize());
		
		//System.out.println(Arrays.toString(model.getPatches()));
		//System.out.println(Arrays.toString(terrains));
		
		System.out.println(Arrays.toString(terrains));
		
		queue.enqueueCopyBuffer(terrainsMem, populationsMem, 0, 0, terrainsMem.getSize());
		queue.enqueueKernel(kernel, 2, new long[] { model.getHeight(),  model.getWidth() });
		queue.blockingReadBuffer(terrainsMem, Pointer.to(terrains), 0, terrainsMem.getSize());
		
		//System.out.println(Arrays.toString(model.getPatches()));
		//System.out.println(Arrays.toString(terrains));
		
		System.out.println(Arrays.toString(terrains));
		
		queue.enqueueCopyBuffer(terrainsMem, populationsMem, 0, 0, terrainsMem.getSize());
		queue.enqueueKernel(kernel, 2, new long[] { model.getHeight(),  model.getWidth() });
		queue.blockingReadBuffer(terrainsMem, Pointer.to(terrains), 0, terrainsMem.getSize());
		
		//System.out.println(Arrays.toString(model.getPatches()));
		//System.out.println(Arrays.toString(terrains));
		
		System.out.println(Arrays.toString(terrains));
		
		queue.enqueueCopyBuffer(terrainsMem, populationsMem, 0, 0, terrainsMem.getSize());
		queue.enqueueKernel(kernel, 2, new long[] { model.getHeight(),  model.getWidth() });
		queue.blockingReadBuffer(terrainsMem, Pointer.to(terrains), 0, terrainsMem.getSize());
		
		//System.out.println(Arrays.toString(model.getPatches()));
		//System.out.println(Arrays.toString(terrains));
		
		System.out.println(Arrays.toString(terrains));
	}*/
	
	public void run(int nbIterations) {
		newArrivalsKernel.setArguments(worldMem, patchesPopulationsMem, patchesOwnersMem, plotsPopulationsMem, plotsSurfacesMem, patchesTypesMem);
		reproduceKernel.setArguments(worldMem, patchesPopulationsMem, patchesOwnersMem, plotsPopulationsMem, patchesTypesMem);
		diffuseKernel.setArguments(worldMem, patchesPopulationsMem, patchesOverflowsMem, newPopulationsMem, patchesTypesMem);
		deathKernel.setArguments(worldMem, patchesPopulationsMem, patchesTypesMem);
		
		MCMEvent newArrivalFinished, reproduceFinished, diffuseFinished, copyFinished, deathFinished;
		
		//int i = 1;
		
		MCMChrono totalTime = new MCMChrono("totalTime");
		totalTime.start();
		
		//System.out.println(Arrays.toString(model.getPatchesTypes()));
		
		for (int i = 0; i < nbIterations; i++) {
		//while (true) {
			newArrivalFinished = enqueueKernel(newArrivalsKernel);
			reproduceFinished = enqueueKernel(reproduceKernel, newArrivalFinished);
			diffuseFinished = enqueueKernel(diffuseKernel, reproduceFinished);
			copyFinished = queue.enqueueCopyBuffer(newPopulationsMem, patchesPopulationsMem, 0L, 0L, patchesPopulationsMem.getSize(), diffuseFinished);
			//OCLEvent.waitFor(copyFinished);
			deathFinished = enqueueKernel(deathKernel, copyFinished);
			
			queue.blockingReadBuffer(plotsPopulationsMem, Pointer.to(plotsPopulations), 0, plotsPopulationsMem.getSize(), deathFinished);
			queue.blockingReadBuffer(patchesPopulationsMem, Pointer.to(patchesPopulations), 0, patchesPopulationsMem.getSize(), deathFinished);
			queue.blockingReadBuffer(patchesOverflowsMem, Pointer.to(patchesOverflows), 0, patchesOverflowsMem.getSize(), deathFinished);
			/*queue.flush();*/
			/*queue.finish();*/
			
			/*OCLUtils.printEventStats("newArrivals", newArrivalFinished);
			OCLUtils.printEventStats("reproduction", reproduceFinished);
			OCLUtils.printEventStats("diffusion", diffuseFinished);
			OCLUtils.printEventStats("swapBuffers", copyFinished);
			OCLUtils.printEventStats("death", deathFinished);*/
						
			// TODO: Uncomment these line for population pictures
			//CollemUtils.savePicture(patchesPopulations, model.getWidth(), "pop", i);
			//CollemUtils.savePicture(patchesOverflows, model.getWidth() + 2, "over", i);
			//i++;
			
			//if (i == 100) break;
		}
		
		//queue.flush();
		queue.finish();
		
		System.out.println("For " + nbIterations + " iterations, " + totalTime.stop());
	}

	public void release() {
		worldMem.release();
		
		patchesOverflowsMem.release();
		patchesTypesMem.release();
		patchesOwnersMem.release();
		patchesPopulationsMem.release();
		newPopulationsMem.release();
		
		plotsPopulationsMem.release();
		plotsSurfacesMem.release();
		
		newArrivalsKernel.release();
		reproduceKernel.release();
		diffuseKernel.release();
		deathKernel.release();
		
		program.release();
		queue.release();
		context.release();
	}
	
	public static void main(String[] args) throws MalformedURLException {
		OptionParser parser = new OptionParser();
		
		parser.accepts("cpu", "Force OpenCL CPU mode");
		parser.accepts("scale", "Scale to use").withRequiredArg()
		      .ofType(Integer.class).defaultsTo(1);
		parser.accepts("n", "Number of iterations to execute").withRequiredArg()
		      .ofType(Integer.class).defaultsTo(500);
		
		OptionSet options = parser.parse(args);
		MCMContext context = null;
		
		if (options.has("cpu")) {
			System.out.println("Forcing OpenCL CPU mode...");
			context = new MCMContext(CL.CL_DEVICE_TYPE_CPU);
		} else {
			context = new MCMContext();
		}
		
		final int nbIterations = (Integer) options.valueOf("n");
		
		final int scale = (Integer) options.valueOf("scale");
		System.out.println("Using scale: " + scale);
		
		CollemModel model = CollemLoader.load(scale * 256, scale * 256, new URL("file://data/site4.shp"));
		model.getWorld().growthRate = 3f;
		model.getWorld().diffusionThreshold = 0f;
		model.getWorld().diffusionRate = 0.5f;
		
		//model.savePicture();
		CollemSimulation2 sim = new CollemSimulation2(model, context);
		
		//System.out.println(Arrays.toString(model.getPatchesOwners()));
		//CollemUtils.savePicture(model.getPatchesOwners(), model.getWidth(), "owners", 0);
		//CollemUtils.savePicture(sim.patchesPopulations, model.getWidth(), "pop", 0);
		sim.run(nbIterations);
		sim.release();
	}
	
	private MCMEvent enqueueKernel(MCMKernel kernel, MCMEvent... events) {
		MCMEvent e = queue.enqueueKernel(kernel, 2, new long[] { model.getHeight(),  model.getWidth() }, events);
		//queue.flush();
		//queue.finish();
		return e;
	}
	
}
