package mior.model;

import mcmas.core.MCM;
import mcmas.core.MCMCommandQueue;
import mcmas.core.MCMCommandQueueProperty;
import mcmas.core.MCMContext;
import mcmas.core.MCMEvent;
import mcmas.core.MCMKernel;
import mcmas.core.MCMMem;
import mcmas.core.MCMProgram;
import mcmas.core.MCMUtils;

import org.jocl.Pointer;
import org.perf4j.slf4j.Slf4JStopWatch;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;


public class OCLCSRModel2 extends SimpleMiorModel {
	
	final MCMContext context;
	final MCMCommandQueue queue;
	final MCMProgram program;
	
	final MCMKernel topoKernel;
	final MCMKernel autoliveKernel;
	
	final MCMMem mmMem;
	final MCMMem omMem;
	final MCMMem worldMem;
	
	final MCMMem mmCSRMem;
	final MCMMem omCSRMem;
	final MCMMem associationsMem;
	
	final MCMMem partsMem;
	
	private final int[] mmCSR;
	private final int[] omCSR;
	
	private final static Logger logger = LoggerFactory.getLogger(OCLCSRModel2.class);
	private final Slf4JStopWatch watch = new Slf4JStopWatch(logger);
	
	private final static String MODEL_SOURCE = "kernels/mior_model_csr.cl";
	
	public OCLCSRModel2(MiorWorld world) {
		super(world);
		
		watch.start("OpenCL setup");
		
		this.context = new MCMContext();
		this.queue = context.createCommandQueue(MCMCommandQueueProperty.ENABLE_PROFILING);
		
		watch.stop();
		
		watch.start("Program compilation");
		this.program = MCMUtils.compileFile(context, MODEL_SOURCE, " -DNBMM=" + nbMM + " -DNBOM=" + nbOM);
		watch.stop();
		
		watch.start("Kernel allocation");
		
		this.topoKernel = program.createKernel("topology");
		this.autoliveKernel = program.createKernel("autolive");
		
		watch.stop();
		watch.start("Memory allocation");
		
		this.mmMem = context.newBuffer().Using(mmList).b();
		this.omMem = context.newBuffer().Using(omList).b();
		this.worldMem = context.newBuffer().Using(world).b();
		
		this.mmCSR = MiorUtils.allocateCSRStorage(nbOM, nbMM);
		this.mmCSRMem = context.newBuffer().Using(mmCSR).b();
		
		this.omCSR = MiorUtils.allocateCSRStorage(nbMM, nbOM);
		this.omCSRMem = context.newBuffer().Using(omCSR).b();
		
		this.associationsMem = context.newBuffer().Using(associations).b();
		
		this.partsMem = context.newBuffer().Size(nbMM * nbOM, MCM.INT).b();
		watch.stop();
	}
	
	@Override
	protected void resetImpl() {
		super.resetImpl();
		
		MiorUtils.initCSRStorage(mmCSR);
		MiorUtils.initCSRStorage(omCSR);
		
		queue.enqueueWriteBuffer(mmMem, mmList, 0, mmMem.getSize());
		queue.enqueueWriteBuffer(omMem, omList, 0, omMem.getSize());
		queue.enqueueWriteBuffer(mmCSRMem, Pointer.to(mmCSR), 0, mmCSRMem.getSize());
		queue.enqueueWriteBuffer(omCSRMem, Pointer.to(omCSR), 0, omCSRMem.getSize());
		queue.enqueueWriteBuffer(worldMem, new MiorWorld[] { world }, 0, worldMem.getSize());
		queue.enqueueWriteBuffer(associationsMem, Pointer.to(associations), 0, associationsMem.getSize());
		queue.finish();
	}
	
	@Override
	protected void doTopologyImpl() {
		topoKernel.setArguments(mmMem, omMem, worldMem, mmCSRMem, omCSRMem, associationsMem);
		
		/*
		OCLEvent event = queue.enqueueKernel(
				topoKernel, 2,
				new long[] { 0, 0 },
				new long[] { mmList.length, omList.length },
				new long[] { 1, 1 }
		);*/
		
		MCMEvent event = queue.enqueueKernel(topoKernel, 2, new long[] { nbMM, nbOM });
		
		MCMEvent.waitFor(event);
		MCMUtils.printEventStats("topology", event);
		
		if (! isBatchModeEnabled()) {
			queue.blockingReadBuffer(mmCSRMem, Pointer.to(mmCSR), 0, mmCSRMem.getSize());
			queue.blockingReadBuffer(omCSRMem, Pointer.to(omCSR), 0, omCSRMem.getSize());
			queue.blockingReadBuffer(associationsMem, Pointer.to(associations), 0, associationsMem.getSize());
		}
	}
	
	@Override
	protected void doLiveImpl() {
		throw new UnsupportedOperationException();
	}

	@Override
	public void releaseImpl() {
		partsMem.release();
		
		associationsMem.release();
		omCSRMem.release();
		mmCSRMem.release();
		
		worldMem.release();
		omMem.release();
		mmMem.release();
		
		autoliveKernel.release();
		topoKernel.release();
		
		program.release();
		queue.release();
		context.release();
	}
	
	@Override
	public void doAutoLive() {
		reset();
		doTopology();
		
		autoliveKernel.setArguments(mmMem, omMem, worldMem, mmCSRMem, omCSRMem, partsMem);
		MCMEvent event = queue.enqueue1DKernel(autoliveKernel, Math.max(nbOM, nbMM));
		
		MCMEvent.waitFor(event);
		MCMUtils.printEventStats("autolive", event);
		
		onSimulationFinished();
	}
	
	@Override
	protected void onSimulationFinished() {
		if (isBatchModeEnabled()) {
			queue.blockingReadBuffer(mmMem, mmList, 0, mmMem.getSize());
			queue.blockingReadBuffer(omMem, omList, 0, omMem.getSize());
			queue.blockingReadBuffer(worldMem, new MiorWorld[] {world}, 0, worldMem.getSize());
			queue.blockingReadBuffer(associationsMem, Pointer.to(associations), 0, associationsMem.getSize());
		}
		
		super.onSimulationFinished();
	}

}
