package mior.model; import mcmas.core.MCM; import mcmas.core.MCMCommandQueue; import mcmas.core.MCMCommandQueueProperty; import mcmas.core.MCMContext; import mcmas.core.MCMEvent; import mcmas.core.MCMKernel; import mcmas.core.MCMMem; import mcmas.core.MCMProgram; import mcmas.core.MCMUtils; import org.jocl.Pointer; import org.perf4j.slf4j.Slf4JStopWatch; import org.slf4j.Logger; import org.slf4j.LoggerFactory; public class OCLCSRModel extends SimpleMiorModel { final MCMContext context; final MCMCommandQueue queue; final MCMProgram program; final MCMKernel topoKernel; final MCMKernel scatterKernel; final MCMKernel liveKernel; final MCMKernel gatherKernel; final MCMMem mmMem; final MCMMem omMem; final MCMMem worldMem; final MCMMem mmCSRMem; final MCMMem omCSRMem; final MCMMem associationsMem; final MCMMem partsMem; private final int[] mmCSR; private final int[] omCSR; private final static Logger logger = LoggerFactory.getLogger(OCLCSRModel.class); private final Slf4JStopWatch watch = new Slf4JStopWatch(logger); private final static String MODEL_SOURCE = "kernels/mior_model_csr.cl"; public OCLCSRModel(MiorWorld world) { this(world, false); } public OCLCSRModel(MiorWorld world, boolean copyParts) { super(world); watch.start("OpenCL setup"); this.context = new MCMContext(); this.queue = context.createCommandQueue(MCMCommandQueueProperty.ENABLE_PROFILING); watch.stop(); watch.start("Program compilation"); this.program = MCMUtils.compileFile(context, MODEL_SOURCE, " -DNBMM=" + nbMM + " -DNBOM=" + nbOM + (copyParts ? " -DCOPY_PARTS" : "")); watch.stop(); watch.start("Kernel allocation"); this.topoKernel = program.createKernel("topology"); this.scatterKernel = program.createKernel("scatter"); this.liveKernel = program.createKernel("live"); this.gatherKernel = program.createKernel("gather"); watch.stop(); watch.start("Memory allocation"); this.mmMem = context.newBuffer().Using(mmList).b(); this.omMem = context.newBuffer().Using(omList).b(); this.worldMem = context.newBuffer().Using(world).b(); this.mmCSR = MiorUtils.allocateCSRStorage(nbOM, nbMM); this.mmCSRMem = context.newBuffer().Using(mmCSR).b(); this.omCSR = MiorUtils.allocateCSRStorage(nbMM, nbOM); this.omCSRMem = context.newBuffer().Using(omCSR).b(); this.associationsMem = context.newBuffer().Using(associations).b(); this.partsMem = context.newBuffer().Size(nbMM * nbOM, MCM.INT).b(); watch.stop(); } @Override protected void resetImpl() { super.resetImpl(); MiorUtils.initCSRStorage(mmCSR); MiorUtils.initCSRStorage(omCSR); queue.enqueueWriteBuffer(mmMem, mmList, 0, mmMem.getSize()); queue.enqueueWriteBuffer(omMem, omList, 0, omMem.getSize()); queue.enqueueWriteBuffer(mmCSRMem, Pointer.to(mmCSR), 0, mmCSRMem.getSize()); queue.enqueueWriteBuffer(omCSRMem, Pointer.to(omCSR), 0, omCSRMem.getSize()); queue.enqueueWriteBuffer(worldMem, new MiorWorld[] { world }, 0, worldMem.getSize()); queue.enqueueWriteBuffer(associationsMem, Pointer.to(associations), 0, associationsMem.getSize()); queue.finish(); } @Override protected void doTopologyImpl() { topoKernel.setArguments(mmMem, omMem, worldMem, mmCSRMem, omCSRMem, associationsMem); /*OCLEvent event = queue.enqueueKernel( topoKernel, 2, new long[] { 0, 0 }, new long[] { mmList.length, omList.length }, new long[] { 1, 1 } );*/ MCMEvent event = queue.enqueueKernel(topoKernel, 2, new long[] { nbMM, nbOM }); MCMEvent.waitFor(event); MCMUtils.printEventStats("topology", event); //queue.enqueueCopyBuffer(mmCSRMem, mmCSRMemConst, 0, 0, mmCSRMem.getSize()); //queue.enqueueCopyBuffer(omCSRMem, omCSRMemConst, 0, 0, omCSRMem.getSize()); if (! isBatchModeEnabled()) { queue.blockingReadBuffer(mmCSRMem, Pointer.to(mmCSR), 0, mmCSRMem.getSize()); queue.blockingReadBuffer(omCSRMem, Pointer.to(omCSR), 0, omCSRMem.getSize()); queue.blockingReadBuffer(associationsMem, Pointer.to(associations), 0, associationsMem.getSize()); } } @Override protected void doLiveImpl() { // SCATTER scatterKernel.setArguments(omMem, worldMem, mmCSRMem/*Const*/, partsMem); MCMEvent scatteredEvent = queue.enqueue1DKernel(scatterKernel, nbOM); // LIVE liveKernel.setArguments(mmMem, worldMem, omCSRMem/*Const*/, partsMem); MCMEvent livedEvent = queue.enqueue1DKernel(liveKernel, nbMM, scatteredEvent); // GATHER gatherKernel.setArguments(omMem, worldMem, mmCSRMem/*Const*/, partsMem); MCMEvent reducedEvent = queue.enqueue1DKernel(gatherKernel, nbOM, livedEvent); MCMEvent.waitFor(reducedEvent); MCMUtils.printEventStats("scatter", scatteredEvent); MCMUtils.printEventStats("live", livedEvent); MCMUtils.printEventStats("reduce", reducedEvent); if (! isBatchModeEnabled()) { queue.blockingReadBuffer(mmMem, mmList, 0, mmMem.getSize()); queue.blockingReadBuffer(omMem, omList, 0, omMem.getSize()); queue.blockingReadBuffer(associationsMem, Pointer.to(associations), 0, associationsMem.getSize()); } // Required to detect termination in autolive queue.blockingReadBuffer(worldMem, new MiorWorld[] {world}, 0, worldMem.getSize()); } @Override public void releaseImpl() { partsMem.release(); associationsMem.release(); omCSRMem.release(); mmCSRMem.release(); worldMem.release(); omMem.release(); mmMem.release(); gatherKernel.release(); liveKernel.release(); scatterKernel.release(); topoKernel.release(); program.release(); queue.release(); context.release(); } @Override protected void onSimulationFinished() { if (isBatchModeEnabled()) { queue.blockingReadBuffer(mmMem, mmList, 0, mmMem.getSize()); queue.blockingReadBuffer(omMem, omList, 0, omMem.getSize()); queue.blockingReadBuffer(associationsMem, Pointer.to(associations), 0, associationsMem.getSize()); } super.onSimulationFinished(); } }