package mior.model;

import java.util.Arrays;
import java.util.Random;

import mcmas.core.MCM;
import mcmas.core.MCMCommandQueue;
import mcmas.core.MCMCommandQueueProperty;
import mcmas.core.MCMContext;
import mcmas.core.MCMEvent;
import mcmas.core.MCMKernel;
import mcmas.core.MCMMem;
import mcmas.core.MCMProgram;
import mcmas.core.MCMUtils;

import org.jocl.Pointer;
import org.perf4j.slf4j.Slf4JStopWatch;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;


public class OCLMiorModel2 extends SimpleMiorModel {
	
	private MCMContext 		context;
	private MCMProgram 		program;
	private MCMCommandQueue queue;
	
	private MCMKernel		topoKernel;
	private MCMKernel       scatterKernel;
	private MCMKernel		liveKernel;
	private MCMKernel       reduceKernel;
	
	private MCMMem     		mmMem;
	private MCMMem     		omMem;
	private MCMMem     		associationsMem;
	private MCMMem     		worldMem;
	
	private MCMMem          omIndexesMem;
	private MCMMem          totalsMem;
	
	private int[]           totals;
	private final Random    rand;
	
	private final static Logger logger = LoggerFactory.getLogger(OCLMiorModel2.class);
	private final Slf4JStopWatch watch = new Slf4JStopWatch(logger);
	
	private final String MODEL_SOURCE = "kernels/mior_model2.cl";
	
	public OCLMiorModel2(MiorWorld world) {
		super(world);
		this.rand = new Random();
		
		watch.start("OpenCL setup");
		
		this.context = new MCMContext();
		this.queue = context.createCommandQueue(MCMCommandQueueProperty.ENABLE_PROFILING);
		
		watch.stop();
		watch.start("Program compilation");
		
		this.program = MCMUtils.compileFile(context, MODEL_SOURCE);
		
		watch.stop();
		watch.start("Kernel allocation");
		
		this.topoKernel = program.createKernel("topology");
		this.scatterKernel = program.createKernel("carbon_scatter");
		this.liveKernel = program.createKernel("live");
		this.reduceKernel = program.createKernel("carbon_reduce");
		
		watch.stop();
		watch.start("Memory allocation");
		
		this.mmMem = context.newBuffer().Using(mmList).b();
		this.omMem = context.newBuffer().Using(omList).b();
		this.associationsMem = context.newBuffer().Using(associations).b();
		this.worldMem = context.newBuffer().Using(world).b();
		
		this.omIndexesMem = context.newBuffer().Size(omList.length * mmList.length, MCM.INT).b();
		
		this.totals = new int[omList.length];
		Arrays.fill(totals, 0);
		
		this.totalsMem = context.newBuffer().Using(totals).b();
		
		watch.stop();
	}
	
	
	@Override
	protected void resetImpl() {
		super.resetImpl();
		Arrays.fill(totals, 0);
		
		queue.enqueueWriteBuffer(mmMem, mmList, 0, mmMem.getSize());
		queue.enqueueWriteBuffer(omMem, omList, 0, omMem.getSize());
		queue.enqueueWriteBuffer(associationsMem, Pointer.to(associations), 0, associationsMem.getSize());
		queue.enqueueWriteBuffer(worldMem, new MiorWorld[] { world }, 0, worldMem.getSize());
		queue.enqueueWriteBuffer(totalsMem, Pointer.to(totals), 0, totalsMem.getSize());
		queue.finish();
		
		fireMiorUpdated();
	}
	
	@Override
	public void doTopologyImpl() {
		if (topoKernel.getArgumentsNumber() == 5) {
			topoKernel.setArguments(mmMem, omMem, associationsMem, worldMem, totalsMem);
		} else {
			topoKernel.setArguments(mmMem, omMem, associationsMem, worldMem);
		}
		
		/*
		OCLEvent event = queue.enqueueKernel(
				topoKernel, 2,
				new long[] { 0, 0 },
				new long[] { mmList.length, omList.length },
				new long[] { 1, 1 }
		);*/
		
		MCMEvent event = queue.enqueueKernel(topoKernel, 2, new long[] { nbMM, nbOM });
		
		MCMEvent.waitFor(event);
		MCMUtils.printEventStats("topology", event);
		
		if (! isBatchModeEnabled()) {
			queue.blockingReadBuffer(associationsMem, Pointer.to(associations), 0, associationsMem.getSize());
			queue.blockingReadBuffer(totalsMem, Pointer.to(totals), 0, totalsMem.getSize());
		}
		
		fireMiorUpdated();
	}
	
	@Override
	public void doLiveImpl() {
		//System.out.println("doLiveImpl");
		scatterKernel.setArguments(mmMem, omMem, associationsMem, worldMem, totalsMem);
		MCMEvent scatteredEvent = queue.enqueue1DKernel(scatterKernel, nbOM);
		
		if (isRandomEnabled()) {
			liveKernel.setArguments(mmMem, omMem, associationsMem, worldMem, omIndexesMem, rand.nextLong());
		} else {
			liveKernel.setArguments(mmMem, omMem, associationsMem, worldMem, omIndexesMem, 0L);
		}
		
		//System.out.println("scattered");
		MCMEvent livedEvent = queue.enqueue1DKernel(liveKernel, nbMM, scatteredEvent);
		
		//System.out.println("lived");
		reduceKernel.setArguments(mmMem, omMem, associationsMem, worldMem);
		MCMEvent reducedEvent = queue.enqueue1DKernel(reduceKernel, nbOM, livedEvent);
		
		//System.out.println("prewait");
		MCMEvent.waitFor(reducedEvent);
		
		//queue.finish();
		
		//System.out.println("postwait");
		MCMUtils.printEventStats("scatter", scatteredEvent);
		MCMUtils.printEventStats("live", livedEvent);
		MCMUtils.printEventStats("reduce", reducedEvent);
		
		//System.out.println("post stats");
		if (! isBatchModeEnabled()) {
			queue.blockingReadBuffer(mmMem, mmList, 0, mmMem.getSize());
			queue.blockingReadBuffer(omMem, omList, 0, omMem.getSize());
			queue.blockingReadBuffer(associationsMem, Pointer.to(associations), 0, associationsMem.getSize());
		}
		
		//System.out.println("preread");
		
		queue.blockingReadBuffer(worldMem, new MiorWorld[] {world}, 0, worldMem.getSize());
	}
	
	protected void onSimulationFinished() {
		if (isBatchModeEnabled()) {
			queue.blockingReadBuffer(mmMem, mmList, 0, mmMem.getSize());
			queue.blockingReadBuffer(omMem, omList, 0, omMem.getSize());
			queue.blockingReadBuffer(associationsMem, Pointer.to(associations), 0, associationsMem.getSize());
		}
		
		super.onSimulationFinished();
	};
	
	
	@Override
	protected void releaseImpl() {
		totalsMem.release();
		omIndexesMem.release();
		
		associationsMem.release();
		worldMem.release();
		mmMem.release();
		omMem.release();
		
		reduceKernel.release();
		liveKernel.release();
		scatterKernel.release();
		topoKernel.release();
		program.release();
		queue.release();
		context.release();
	}
}
