OCLCSRModel.java 5.84 KB
package mior.model;

import mcmas.core.MCM;
import mcmas.core.MCMCommandQueue;
import mcmas.core.MCMCommandQueueProperty;
import mcmas.core.MCMContext;
import mcmas.core.MCMEvent;
import mcmas.core.MCMKernel;
import mcmas.core.MCMMem;
import mcmas.core.MCMProgram;
import mcmas.core.MCMUtils;

import org.jocl.Pointer;
import org.perf4j.slf4j.Slf4JStopWatch;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;


public class OCLCSRModel extends SimpleMiorModel {
	
	final MCMContext context;
	final MCMCommandQueue queue;
	final MCMProgram program;
	
	final MCMKernel topoKernel;
	final MCMKernel scatterKernel;
	final MCMKernel liveKernel;
	final MCMKernel gatherKernel; 
	
	final MCMMem mmMem;
	final MCMMem omMem;
	final MCMMem worldMem;
	
	final MCMMem mmCSRMem;
	final MCMMem omCSRMem;
	final MCMMem associationsMem;
	
	final MCMMem partsMem;
	
	private final int[] mmCSR;
	private final int[] omCSR;
	
	private final static Logger logger = LoggerFactory.getLogger(OCLCSRModel.class);
	private final Slf4JStopWatch watch = new Slf4JStopWatch(logger);
	
	private final static String MODEL_SOURCE = "kernels/mior_model_csr.cl";
	
	public OCLCSRModel(MiorWorld world) {
		this(world, false);
	}
	
	public OCLCSRModel(MiorWorld world, boolean copyParts) {
		super(world);
		
		watch.start("OpenCL setup");
		
		this.context = new MCMContext();
		this.queue = context.createCommandQueue(MCMCommandQueueProperty.ENABLE_PROFILING);
		
		watch.stop();
		watch.start("Program compilation");
		
		this.program = MCMUtils.compileFile(context, MODEL_SOURCE,
				" -DNBMM=" + nbMM + " -DNBOM=" + nbOM + (copyParts ? " -DCOPY_PARTS" : ""));
		
		watch.stop();
		watch.start("Kernel allocation");
		
		this.topoKernel = program.createKernel("topology");
		this.scatterKernel = program.createKernel("scatter");
		this.liveKernel = program.createKernel("live");
		this.gatherKernel = program.createKernel("gather");
		
		watch.stop();
		watch.start("Memory allocation");
		
		this.mmMem = context.newBuffer().Using(mmList).b();
		this.omMem = context.newBuffer().Using(omList).b();
		this.worldMem = context.newBuffer().Using(world).b();
		
		this.mmCSR = MiorUtils.allocateCSRStorage(nbOM, nbMM);
		this.mmCSRMem = context.newBuffer().Using(mmCSR).b();
		
		this.omCSR = MiorUtils.allocateCSRStorage(nbMM, nbOM);
		this.omCSRMem = context.newBuffer().Using(omCSR).b();
		
		this.associationsMem = context.newBuffer().Using(associations).b();
		
		this.partsMem = context.newBuffer().Size(nbMM * nbOM, MCM.INT).b();
		watch.stop();
	}
	
	@Override
	protected void resetImpl() {
		super.resetImpl();
		
		MiorUtils.initCSRStorage(mmCSR);
		MiorUtils.initCSRStorage(omCSR);
		
		queue.enqueueWriteBuffer(mmMem, mmList, 0, mmMem.getSize());
		queue.enqueueWriteBuffer(omMem, omList, 0, omMem.getSize());
		queue.enqueueWriteBuffer(mmCSRMem, Pointer.to(mmCSR), 0, mmCSRMem.getSize());
		queue.enqueueWriteBuffer(omCSRMem, Pointer.to(omCSR), 0, omCSRMem.getSize());
		queue.enqueueWriteBuffer(worldMem, new MiorWorld[] { world }, 0, worldMem.getSize());
		queue.enqueueWriteBuffer(associationsMem, Pointer.to(associations), 0, associationsMem.getSize());
		queue.finish();
	}
	
	@Override
	protected void doTopologyImpl() {
		topoKernel.setArguments(mmMem, omMem, worldMem, mmCSRMem, omCSRMem, associationsMem);
		
		/*OCLEvent event = queue.enqueueKernel(
				topoKernel, 2,
				new long[] { 0, 0 },
				new long[] { mmList.length, omList.length },
				new long[] { 1, 1 }
		);*/
		
		MCMEvent event = queue.enqueueKernel(topoKernel, 2, new long[] { nbMM, nbOM });
		
		MCMEvent.waitFor(event);
		MCMUtils.printEventStats("topology", event);
		
		//queue.enqueueCopyBuffer(mmCSRMem, mmCSRMemConst, 0, 0, mmCSRMem.getSize());
		//queue.enqueueCopyBuffer(omCSRMem, omCSRMemConst, 0, 0, omCSRMem.getSize());
		
		if (! isBatchModeEnabled()) {
			queue.blockingReadBuffer(mmCSRMem, Pointer.to(mmCSR), 0, mmCSRMem.getSize());
			queue.blockingReadBuffer(omCSRMem, Pointer.to(omCSR), 0, omCSRMem.getSize());
			queue.blockingReadBuffer(associationsMem, Pointer.to(associations), 0, associationsMem.getSize());
		}
	}
	
	@Override
	protected void doLiveImpl() {
		// SCATTER
		scatterKernel.setArguments(omMem, worldMem, mmCSRMem/*Const*/, partsMem);
		MCMEvent scatteredEvent = queue.enqueue1DKernel(scatterKernel, nbOM);
		
		// LIVE
		liveKernel.setArguments(mmMem, worldMem, omCSRMem/*Const*/, partsMem);
		MCMEvent livedEvent = queue.enqueue1DKernel(liveKernel, nbMM, scatteredEvent);
		
		// GATHER
		gatherKernel.setArguments(omMem, worldMem, mmCSRMem/*Const*/, partsMem);
		MCMEvent reducedEvent = queue.enqueue1DKernel(gatherKernel, nbOM, livedEvent);
		
		MCMEvent.waitFor(reducedEvent);
		
		MCMUtils.printEventStats("scatter", scatteredEvent);
		MCMUtils.printEventStats("live", livedEvent);
		MCMUtils.printEventStats("reduce", reducedEvent);
		
		if (! isBatchModeEnabled()) {
			queue.blockingReadBuffer(mmMem, mmList, 0, mmMem.getSize());
			queue.blockingReadBuffer(omMem, omList, 0, omMem.getSize());
			queue.blockingReadBuffer(associationsMem, Pointer.to(associations), 0, associationsMem.getSize());
		}
		
		// Required to detect termination in autolive
		queue.blockingReadBuffer(worldMem, new MiorWorld[] {world}, 0, worldMem.getSize());
	}

	@Override
	public void releaseImpl() {
		partsMem.release();
		
		associationsMem.release();
		omCSRMem.release();
		mmCSRMem.release();
		
		worldMem.release();
		omMem.release();
		mmMem.release();
		
		gatherKernel.release();
		liveKernel.release();
		scatterKernel.release();
		topoKernel.release();
		
		program.release();
		queue.release();
		context.release();
	}
	
	@Override
	protected void onSimulationFinished() {
		if (isBatchModeEnabled()) {
			queue.blockingReadBuffer(mmMem, mmList, 0, mmMem.getSize());
			queue.blockingReadBuffer(omMem, omList, 0, omMem.getSize());
			queue.blockingReadBuffer(associationsMem, Pointer.to(associations), 0, associationsMem.getSize());
		}
		
		super.onSimulationFinished();
	}
	
}