package mior.model;

import java.nio.ByteBuffer;
import java.nio.IntBuffer;
import java.util.Arrays;

import mcmas.core.MCMCommandQueue;
import mcmas.core.MCMCommandQueueProperty;
import mcmas.core.MCMContext;
import mcmas.core.MCMEvent;
import mcmas.core.MCMKernel;
import mcmas.core.MCMMem;
import mcmas.core.MCMProgram;
import mcmas.core.MCMUtils;
import mior.model.dist.IMiorDistribution;

import org.jocl.Pointer;
import org.jocl.Sizeof;
import org.jocl.struct.Buffers;
import org.perf4j.slf4j.Slf4JStopWatch;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;


public class OCLParaCSRModelAlt extends AbstractMiorModel {
	
	/**
	 * Java view
	 */
	
	private final int nbMM;
	private final int nbOM;
	
	private final MiorMM [] mmList;
	private final MiorOM [] omList;
	private final MiorWorld [] worlds;
	
	//private final int [] associations;
	
	/**
	 * OpenCL implementation
	 */
	final MCMContext context;
	final MCMCommandQueue queue;
	final MCMProgram program;
	
	final MCMKernel topoKernel;
	final MCMKernel simulateKernel;
	final MCMKernel autoliveKernel;
	
	final MCMMem mmMem;
	final MCMMem omMem;
	final MCMMem worldsMem;
	
	final MCMMem mmCSRMem;
	final MCMMem omCSRMem;
	//final OCLMem associationsMem;
	final MCMMem partsMem;
	
	private final ByteBuffer mmBuffer;
	private final ByteBuffer omBuffer;
	private final ByteBuffer worldBuffer;
	
	private final ByteBuffer mmCSRBuffer;
	private final ByteBuffer omCSRBuffer;
	private final ByteBuffer partsBuffer;
	
	/*private final int [] mmCSR;
	private final int [] omCSR;
	private final int [] parts;*/
	
	private final int nbSim;
	private int blockSize;
	
	private final static Logger logger = LoggerFactory.getLogger(OCLParaCSRModelAlt.class);
	private final Slf4JStopWatch watch = new Slf4JStopWatch(logger);
	
	private final static String MODEL_SOURCE = "kernels/mior_model_multisim3.cl";
	private final IMiorDistribution dist;
	
	public OCLParaCSRModelAlt(int onbMM, int onbOM, int nbSim, IMiorDistribution dist) {
		this.dist = dist;
		System.out.println(dist);
		
		this.nbOM = (int) (dist.getMaxFactor() * dist.getMeanOM());
		this.nbMM = (int) (dist.getMaxFactor() * dist.getMeanMM());
		
		this.nbSim = nbSim;
		this.blockSize = Math.max(nbOM, nbMM);
		
		// Allocate Java MIOR structures
		this.mmList       = new MiorMM[nbSim * nbMM];
		this.omList       = new MiorOM[nbSim * nbOM];
		this.worlds       = new MiorWorld[nbSim];
		
		MiorUtils.initWorldArray(worlds, dist);
		MiorUtils.initRandomMMArray(mmList, worlds[0].width);
		MiorUtils.initRandomOMArray(omList, worlds[0].width);
		
		this.mmBuffer = Buffers.allocateBuffer(mmList);
		this.omBuffer = Buffers.allocateBuffer(omList);
		this.worldBuffer = Buffers.allocateBuffer(worlds);
		
		this.mmCSRBuffer = this.allocateLists(nbSim * nbOM, nbMM);
		this.omCSRBuffer = this.allocateLists(nbSim * nbMM, nbOM);
		this.partsBuffer = this.allocateDirectIntBuffer(nbSim * nbMM * nbOM);
		
		// OpenCL allocations
		watch.start("OpenCL setup");
		
		this.context = new MCMContext();
		this.queue = context.createCommandQueue(
				MCMCommandQueueProperty.ENABLE_PROFILING,
				MCMCommandQueueProperty.ENABLE_OOO_MODE
		);
		
		watch.stop();
		
		System.out.println("NB_MM: " + nbMM + ", NB_OM: " + nbOM);
		watch.start("Program compilation");
		this.program = MCMUtils.compileFile(context, MODEL_SOURCE, " -DNB_MM=" + nbMM + " -DNB_OM=" + nbOM);// + " -g -O0");
		watch.stop();
		
		watch.start("Kernel allocation");
		
		this.topoKernel = program.createKernel("topology");
		this.simulateKernel = program.createKernel("simulate");
		this.autoliveKernel = program.createKernel("autolive");
		
		watch.stop();
		watch.start("Memory allocation");
		
		this.mmMem = context.newBuffer().Using(mmList).b();
		this.omMem = context.newBuffer().Using(omList).b();
		this.worldsMem = context.newBuffer().Using(worlds).b();
		//this.associationsMem = context.newBuffer().Using(associations).b();
		System.out.println(nbSim * nbOM * nbMM);
		System.out.println("capa " + mmCSRBuffer.capacity());
		this.mmCSRMem = context.newBuffer().Using(mmCSRBuffer).b();
		this.omCSRMem = context.newBuffer().Using(omCSRBuffer).b();
		this.partsMem = context.newBuffer().Using(partsBuffer).b();
		
		watch.stop();
		
	}
	
	private ByteBuffer allocateLists(int nbList, int listSize) {
		return allocateDirectIntBuffer(nbList * (listSize + 1));
	}
	
	private ByteBuffer allocateDirectIntBuffer(int capacity) {
		return ByteBuffer.allocateDirect(Sizeof.cl_int * capacity);
	}

	@Override
	public void setBlockSize(int blockSize) {
		this.blockSize = blockSize;
	}
	
	@Override
	public int getBlockSize() {
		return blockSize;
	}
	
	
	@Override
	public int getNbSimulations() {
		return nbSim;
	}
	
	private void clearIntBuffer(IntBuffer b) {
		for (int i = 0; i < b.capacity(); i++) {
			b.put(i, 0);
		}
		b.rewind();
	}
	
	@Override
	protected void resetImpl() {
		// Initialise models
		MiorUtils.initWorldArray(worlds, dist);
		MiorUtils.initRandomMMArray(mmList, worlds[0].width);
		MiorUtils.initRandomOMArray(omList, worlds[0].width);
		
		//Arrays.fill(associations, -1);
		//Arrays.fill(mmCSR, 0);
		//Arrays.fill(omCSR, 0);
		//Arrays.fill(parts, 0);
		
		Buffers.writeToBuffer(mmBuffer, mmList); mmBuffer.rewind();
		Buffers.writeToBuffer(omBuffer, omList); omBuffer.rewind();
		Buffers.writeToBuffer(worldBuffer, worlds); worldBuffer.rewind();
		
		MCMEvent w1 = queue.enqueueWriteBuffer(mmMem, Pointer.to(mmBuffer), 0, mmMem.getSize());
		MCMEvent w2 = queue.enqueueWriteBuffer(omMem, Pointer.to(omBuffer), 0, omMem.getSize());
		MCMEvent w3 = queue.enqueueWriteBuffer(worldsMem, Pointer.to(worldBuffer), 0, worldsMem.getSize());
		
		clearIntBuffer(mmCSRBuffer.asIntBuffer());
		clearIntBuffer(omCSRBuffer.asIntBuffer());
		clearIntBuffer(partsBuffer.asIntBuffer());
		
		MCMEvent w4 = queue.enqueueWriteBuffer(mmCSRMem, Pointer.to(mmCSRBuffer), 0, mmCSRMem.getSize());
		MCMEvent w5 = queue.enqueueWriteBuffer(omCSRMem, Pointer.to(omCSRBuffer), 0, omCSRMem.getSize());
		MCMEvent w6 = queue.enqueueWriteBuffer(partsMem, Pointer.to(partsBuffer), 0, partsMem.getSize());
		MCMEvent.waitFor(w1, w2, w3, w4, w5, w6);
	}
	
	@Override
	protected void doTopologyImpl() {
		// System.out.println(topoKernel.getArgumentsNumber());
		topoKernel.setArguments(mmMem, omMem, worldsMem, mmCSRMem, omCSRMem); // associationsMem
		
		MCMEvent event = queue.enqueueKernel(topoKernel, 2,
				new long[] { nbSim * nbMM , nbOM} // Global size
		);
		
		/*OCLEvent event = queue.enqueueKernel(topoKernel, 2,
				new long[] { nbMM , nbSim * nbOM}, // Global size
				new long[] { 1, nbOM }
		);*/
		
		MCMEvent.waitFor(event);
		MCMUtils.printEventStats("topology", event);
		
		if (! isBatchModeEnabled()) {
			MCMEvent r1 = queue.enqueueReadBuffer(mmCSRMem, Pointer.to(mmCSRBuffer), 0, mmCSRMem.getSize());
			MCMEvent r2 = queue.enqueueReadBuffer(omCSRMem, Pointer.to(omCSRBuffer), 0, omCSRMem.getSize());
			MCMEvent.waitFor(r1, r2);
			//System.out.println(Arrays.toString(mmCSR));
			//queue.blockingReadBuffer(associationsMem, Pointer.to(associations), 0, associationsMem.getSize());
		}
	}
	
	@Override
	protected void doLiveImpl() {
		//throw new UnsupportedOperationException();
		simulateKernel.setArguments(mmMem, omMem, worldsMem, mmCSRMem, omCSRMem, partsMem);
		
		if (blockSize < Math.max(nbOM, nbMM)) {
			throw new RuntimeException("blockSize (" + blockSize + ") is too small to execute the simulation");
		}
		
		MCMEvent event = queue.enqueue1DKernel(simulateKernel, nbSim * blockSize, blockSize);
		
		MCMEvent.waitFor(event);
		MCMUtils.printEventStats("simulate", event);
		
		if (! isBatchModeEnabled()) {
			MCMEvent r1 = queue.enqueueReadBuffer(mmMem, mmList, 0, mmMem.getSize());
			MCMEvent r2 = queue.enqueueReadBuffer(omMem, omList, 0, omMem.getSize());
			MCMEvent r3 = queue.enqueueReadBuffer(worldsMem, worlds, 0, worldsMem.getSize());
			MCMEvent.waitFor(r1, r2, r3);
			
			Buffers.readFromBuffer(mmBuffer, mmList); mmBuffer.rewind();
			Buffers.readFromBuffer(omBuffer, omList); omBuffer.rewind();
			Buffers.readFromBuffer(worldBuffer, worlds); worldBuffer.rewind();
			
			System.out.println("copy");
		}
	}
	
	@Override
	public void doAutoLive() {
		resetImpl();
		doTopologyImpl();
		
		if (blockSize < Math.max(nbOM, nbMM)) {
			throw new RuntimeException("blockSize (" + blockSize + ") is too small to execute the simulation");
		}
		
		autoliveKernel.setArguments(mmMem, omMem, worldsMem, mmCSRMem, omCSRMem, partsMem);
		
		if (! isBatchModeEnabled()) {
				int CO2 = -1;
				
				while (CO2 == -1 || CO2 != worlds[0].CO2) {
					CO2 = worlds[0].CO2;
					System.out.println("CO2 = " + CO2);
					doLive();
				}
				
		} else { 
			autoliveKernel.setArguments(mmMem, omMem, worldsMem, mmCSRMem, omCSRMem, partsMem);
			
			MCMEvent event = queue.enqueue1DKernel(autoliveKernel, nbSim * blockSize, blockSize);
			
			MCMEvent.waitFor(event);
			MCMUtils.printEventStats("autolive", event);
		}
		
		onSimulationFinished();
	}
	
	@Override
	protected void onSimulationFinished() {
		
		MCMEvent r1 = queue.enqueueReadBuffer(mmMem, mmList, 0, mmMem.getSize());
		MCMEvent r2 = queue.enqueueReadBuffer(omMem, omList, 0, omMem.getSize());
		MCMEvent r3 = queue.enqueueReadBuffer(worldsMem, worlds, 0, worldsMem.getSize());
		//queue.blockingReadBuffer(associationsMem, Pointer.to(associations), 0, associationsMem.getSize());
		MCMEvent r4 = queue.enqueueReadBuffer(partsMem, Pointer.to(partsBuffer), 0, partsMem.getSize());
		
		MCMEvent.waitFor(r1, r2, r3, r4);
		super.onSimulationFinished();
	}
	
	@Override
	protected void releaseImpl() {
		partsMem.release();
		
		omCSRMem.release();
		mmCSRMem.release();
		//associationsMem.release();
		
		worldsMem.release();
		omMem.release();
		mmMem.release();
		
		autoliveKernel.release();
		simulateKernel.release();
		topoKernel.release();
		
		program.release();
		queue.release();
		context.release();
	}

	@Override
	public MiorWorld getWorld() {
		return worlds[0];
	}

	@Override
	public MiorOM[] getOMList() {
		return Arrays.copyOfRange(omList, 0, nbOM);
	}

	@Override
	public MiorMM[] getMMList() {
		return Arrays.copyOfRange(mmList, 0, nbMM);
	}

	@Override
	public boolean isAccessible(int iMM, int iOM) {
		return false;
		//return associations[iMM * nbOM + iOM] != -1;
	}

	@Override
	public boolean isAccessible(int iMM, int iOM, int iSim) {
		return false;
		//return associations[iSim * nbMM * nbOM + iMM * nbOM + iOM] != -1;
	}
	
}
