package mcmas.plugins.axb;

import java.util.Random;

import org.jocl.Pointer;
import org.perf4j.StopWatch;
import org.perf4j.slf4j.Slf4JStopWatch;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import mcmas.api.MCMASContext;
import mcmas.api.MCMASPlugin;
import mcmas.core.MCMCommandQueue;
import mcmas.core.MCMCommandQueueProperty;
import mcmas.core.MCMContext;
import mcmas.core.MCMEvent;
import mcmas.core.MCMKernel;
import mcmas.core.MCMMem;
import mcmas.core.MCMProgram;
import mcmas.core.MCMProgramOptions;
import mcmas.core.MCMType;
import mcmas.core.MCMUtils;

public class AXBPlugin extends MCMASPlugin<AXBPlugin> {
	
	private final Logger logger = LoggerFactory.getLogger(AXBPlugin.class);
	private final StopWatch watch = new Slf4JStopWatch().setNormalPriority(Slf4JStopWatch.DEBUG_LEVEL);
	private final StopWatch runWatch = new Slf4JStopWatch().setNormalPriority(Slf4JStopWatch.TRACE_LEVEL);
	
	@Override
	public AXBPlugin newInstance(MCMASContext context) {
		return null;
	}
	
	private static final String SOURCE = "kernels/plugins/axb.cl";
	
	private MCMKernel kernel = null;
	
	public AXBPlugin(MCMASContext context) {
		super(context);
	}
	
	public void transform(int[] vector, float a, float b) {
		watch.start("axb_transform");
		MCMMem vectorMem = getContext().getContext().newBuffer().Using(vector).b();
		transform("int", vectorMem, Pointer.to(vector), vector.length, a, b);
		watch.stop();
	}
	
	public void transform(float[] vector, float a, float b) {
		watch.start("axb_transform");
		MCMMem vectorMem = getContext().getContext().newBuffer().Size(vector.length, MCMType.FLOAT).b();
		transform("float", vectorMem, Pointer.to(vector), vector.length, a, b);
		watch.stop();
	}
	
	private void transform(String type, MCMMem mem, Pointer pointer, int size, Object a, Object b) {
		MCMProgramOptions options = new MCMProgramOptions();
		options.define("TYPE", type);
		
		MCMContext c = getContext().getContext();
		MCMCommandQueue q = getContext().getQueue();
		//OCLKernel kernel = null;
		MCMProgram program = null;
		MCMEvent finished = null;
		
		try {
			runWatch.start("axb_build");
			if (kernel == null) {
				c = getContext().getContext();
			
				q = getContext().getQueue();
				program = c.createProgramFromFile(SOURCE, options);
			
				kernel = program.createKernel("transform");
			}
			runWatch.stop();
			
			runWatch.start("axb_kernel");
			
			kernel.setArguments(mem, a, b);
			finished = q.enqueue1DKernel(kernel, size);
			q.blockingReadBuffer(mem, pointer, 0, mem.getSize(), finished);
			
			runWatch.stop();
			
			logger.debug(MCMUtils.formatEventStats("axb_opencl", finished));
		} finally {
			finished.release();
			mem.release();
			//kernel.release();
			
			if (program != null) {
				program.release();
			}
		}		
	}
	
	public void clampedTransform(int[] vector, float a, int b, int min, int max) {
		watch.start("axb_clamped_transform " + vector.length);
		MCMMem vectorMem = getMCMContext().newBuffer().Using(vector).b();
		clampedTransform("int", vectorMem, Pointer.to(vector), vector.length, a, b, min, max);
		watch.stop();
	}
	
	public void clampedTransform(float[] vector, float a, float b, float min, float max) {
		watch.start("axb_clamped_transform");
		MCMMem vectorMem = getMCMContext().newBuffer().Using(vector).b();
		clampedTransform("float", vectorMem, Pointer.to(vector), vector.length, a, b, min, max);
		watch.stop();
	}
	
	private void clampedTransform(String type, MCMMem mem, Pointer pointer, int size, Object a, Object b,
			Object min, Object max) {
		MCMProgramOptions options = new MCMProgramOptions();
		options.define("TYPE", type);
		
		MCMContext c = getContext().getContext();
		MCMCommandQueue q = getContext().getQueue();
		//OCLKernel kernel = null;
		MCMProgram program = null;
		MCMEvent finished = null;
		
		try {
			runWatch.start("axb_build");
			if (kernel == null) {
				c = getContext().getContext();
			
				q = getContext().getQueue();
				program = c.createProgramFromFile(SOURCE, options);
			
				kernel = program.createKernel("clamped_transform");
			}
			runWatch.stop();
			
			runWatch.start("axb_kernel");
			
			kernel.setArguments(mem, a, b, min, max);
			finished = q.enqueue1DKernel(kernel, size);
			q.blockingReadBuffer(mem, pointer, 0, mem.getSize(), finished);
			
			runWatch.stop();
			
			logger.trace(MCMUtils.formatEventStats("opencl_transform", finished));
		} finally {
			finished.release();
			mem.release();
			//kernel.release();
			
			if (program != null) {
				program.release();
			}
		}		
	}
	
	public void cpu(int[] storage, float a, float b, int min, int max) {
		
		for (int i = 0; i < storage.length; i++) {
			storage[i] = Math.max(min, Math.max(max, (int) (storage[i] * a + b)));
		}
	}
	
	public static void main(String[] args) {
		StopWatch w = new Slf4JStopWatch().setNormalPriority(Slf4JStopWatch.DEBUG_LEVEL);
		
		MCMASContext context = new MCMASContext(MCMCommandQueueProperty.ENABLE_PROFILING);
		AXBPlugin plugin = new AXBPlugin(context);
		
		Random rng = new Random();
				
		int [] vector = new int[50000000];
		
		for (int i = 0; i < vector.length; i++) {
			vector[i] = rng.nextInt(20);
		}
		
		//System.out.println(Arrays.toString(vector));
		
		w.start("opencl");
		plugin.clampedTransform(vector, 2, 5, 0, 100);
		System.out.println(w.stop());
		
		w.start("java");
		plugin.cpu(vector, 2, 5, 0, 100);
		System.out.println(w.stop());
		
		//System.out.println(Arrays.toString(vector));
	}

}
