package mior.model;

import java.nio.ByteBuffer;
import java.util.Random;

import org.jocl.Sizeof;
import org.perf4j.slf4j.Slf4JStopWatch;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import mcmas.core.MCMCommandQueue;
import mcmas.core.MCMContext;
import mcmas.core.MCMEvent;
import mcmas.core.MCMKernel;
import mcmas.core.MCMMem;
import mcmas.core.MCMProgram;
import mcmas.core.MCMUtils;
import mior.model.dist.IMiorDistribution;

public class MiorSimulation {
	
	private final int nbMM;
	private final int nbOM;
	private final int nbSim;
	
	public final MiorMM[] mmList;
	public final MiorOM[] omList;
	public final MiorWorld[] worldList;
	
	public MCMEvent topoEvent;
	public MCMEvent autoEvent;
	public MCMEvent [] copyEvents;
	
	private final ByteBuffer mmCSRBuffer;
	private final ByteBuffer omCSRBuffer;
	private final ByteBuffer partsBuffer;
	
	private final static Logger logger = LoggerFactory.getLogger(MiorSimulation.class);
	private final Slf4JStopWatch watch = new Slf4JStopWatch(logger);
	
	private final static String MODEL_SOURCE = "kernels/mior_model_multisim3.cl";
	public static MCMProgram program;
	public static final Object programLock = new Object();
	
	public MiorSimulation(int onbMM, int onbOM, int nbSim, IMiorDistribution dist) {
		this.nbOM = (int) (dist.getMaxFactor() * dist.getMeanOM());
		this.nbMM = (int) (dist.getMaxFactor() * dist.getMeanMM());
		this.nbSim = nbSim;
		
		this.mmList = new MiorMM[nbSim * nbMM];
		this.omList = new MiorOM[nbSim * nbOM];
		this.worldList = new MiorWorld[nbSim];
		
		//System.out.println("initWorld");
		MiorUtils.initWorldArray(worldList, dist);
		
		Random rng = new Random();
		
		//System.out.println("init lists");
		for (int i = 0; i < worldList.length; i++) {
			final int width = worldList[i].width;
			
			for (int iMM = 0; iMM < nbMM; iMM++) {
				mmList[i * nbMM + iMM] = new MiorMM(rng.nextFloat() * width, rng.nextFloat() * width);
			}
			
			for (int iOM = 0; iOM < nbOM; iOM++) {
				omList[i * nbOM + iOM] = new MiorOM(rng.nextFloat() * width, rng.nextFloat() * width);
			}
		}
		
		//System.out.println("buffers");
		
		this.mmCSRBuffer = this.allocateLists(nbSim * nbOM, nbMM);
		this.omCSRBuffer = this.allocateLists(nbSim * nbMM, nbOM);
		this.partsBuffer = this.allocateIntBuffer(nbSim * nbMM * nbOM);
		
		clearBuffer(mmCSRBuffer);
		clearBuffer(omCSRBuffer);
		clearBuffer(partsBuffer);
	}
	
	public MiorFuture execute(MCMContext context, MCMCommandQueue queue) {
		watch.start("Program compilation");
		
		synchronized(programLock) {
			if (program == null) {
				System.out.println("compiling program");
				program = MCMUtils.compileFile(context, MODEL_SOURCE, " -DNB_MM=" + nbMM + " -DNB_OM=" + nbOM);
			}
		}
		
		System.out.println("nbMM: " + nbMM + ", nbOM: " + nbOM);
		
		watch.stop();
		watch.start("Kernel allocation");
		
		MCMKernel topoKernel = program.createKernel("topology");
		MCMKernel autoliveKernel = program.createKernel("autolive");
		
		watch.stop();
		watch.start("Memory allocation");
		
		MCMMem mmMem = context.newBuffer().Using(mmList).b();
		MCMMem omMem = context.newBuffer().Using(omList).b();
		MCMMem worldsMem = context.newBuffer().Using(worldList).b();
		MCMMem mmCSRMem = context.newBuffer().Using(mmCSRBuffer).b();
		MCMMem omCSRMem = context.newBuffer().Using(omCSRBuffer).b();
		MCMMem partsMem = context.newBuffer().Using(partsBuffer).b();
		
		watch.stop();
		
		topoKernel.setArguments(mmMem, omMem, worldsMem, mmCSRMem, omCSRMem);
		
		MCMEvent topoEvent = queue.enqueueKernel(topoKernel, 2,
				new long[] { nbSim * nbMM , nbOM} // Global size
		);
		
		/*for (int i = 0; i < 10; i++) {
			topoEvent = queue.enqueueKernel(topoKernel, 2,
					new long[] { nbSim * nbMM , nbOM} // Global size
					, topoEvent
			);
		}*/
		
		autoliveKernel.setArguments(mmMem, omMem, worldsMem, mmCSRMem, omCSRMem, partsMem);
		MCMEvent autoEvent = queue.enqueue1DKernel(autoliveKernel, nbSim * nbOM, nbOM, topoEvent);
		
		this.copyEvents = new MCMEvent[3];
		this.copyEvents[0] = queue.enqueueReadBuffer(mmMem, mmList, 0, mmMem.getSize(), autoEvent);
		this.copyEvents[1] = queue.enqueueReadBuffer(omMem, omList, 0, omMem.getSize(), autoEvent);
		this.copyEvents[2] = queue.enqueueReadBuffer(worldsMem, worldList, 0, worldsMem.getSize(), autoEvent);
		//OCLEvent r4 = queue.enqueueReadBuffer(partsMem, Pointer.to(partsBuffer), 0, partsMem.getSize(), autoEvent);
		
		//queue.finish();
		
		//OCLUtils.printEventStats("topology", topoEvent);
		//OCLUtils.printEventStats("autolive", autoEvent);
		
		mmMem.release();
		omMem.release();
		worldsMem.release();
		mmCSRMem.release();
		omCSRMem.release();
		partsMem.release();
		autoliveKernel.release();
		topoKernel.release();
		
		return new MiorFuture(this);
	}
	
	private ByteBuffer allocateLists(int nbList, int listSize) {
		return allocateIntBuffer(nbList * (listSize + 1));
	}
	
	private ByteBuffer allocateIntBuffer(int capacity) {
		return ByteBuffer.allocate(Sizeof.cl_int * capacity);
	}
	
	private void clearBuffer(ByteBuffer buffer) {
		for (int i = 0; i < buffer.capacity(); i++) {
			buffer.put(i, (byte) 0);
		}
		buffer.rewind();
	}
}
