package mcmas.core;

import mcmas.utils.Objects7;

import org.jocl.CL;
import org.jocl.CLException;
import org.jocl.Pointer;
import org.jocl.Sizeof;
import org.jocl.cl_kernel;
import org.jocl.cl_mem;

/**
 * MCM representation of an OpenCL kernel.
 */
public class MCMKernel extends MCMObject {
	
	private final cl_kernel kernel;
	
	/**
	 * Create a new MCMKernel from the specified program and function name.
	 * @param program compiled program containing the kernel
	 * @param name name of the function to use as kernel
	 */
	public MCMKernel(MCMProgram program, String name) {
		program = Objects7.requireNonNull(program, "Can't create a kernel from a null program");
		this.kernel = MCMHelpers.createKernel(program.getProgram(), name);
	}
	
	/**
	 * Retrieve the native cl_kernel object for interoperability with JOCL.
	 * @return
	 */
	public cl_kernel getKernel() {
		return kernel;
	}
	
	/**
	 * Set the arguments for kernel call in one pass.
	 * @param params parameters to use (MCMMem objects or primitive types)
	 */
	public void setArguments(Object... params) {
		// Map arguments to function
		for (int i = 0; i < params.length; i++) {
			setArgument(i, params[i]);
		}
	}
	
	/**
	 * Set the argument at a particular position
	 * @param index index of the argument to set
	 * @param param value for the argument (MCMMem object or primitive type)
	 */
	public void setArgument(int index, Object param) {
		Pointer p = null;
		long size;
		
		if (param == null) {
			throw new CLException("NULL cannot be passed to a kernel");
		} else if (param instanceof Pointer) {
			p = (Pointer) param;
			size = Sizeof.cl_mem;
		} else if (param instanceof MCMMem) {
			MCMMem buffer = (MCMMem) param;
			p = (buffer.isLocal() ? null : Pointer.to(buffer.getMem()));
			size = Sizeof.cl_mem;
		} else if (param instanceof cl_mem) {
			p = Pointer.to((cl_mem) param);
			size = Sizeof.cl_mem;
		} else if (param instanceof Byte) {
			p = Pointer.to(new byte[] { (Byte) param });
			size = Sizeof.cl_int8;
		} else if (param instanceof Short) {
			p = Pointer.to(new short[] { (Short) param });
			size = Sizeof.cl_short;
		} else if (param instanceof Integer) {
			p = Pointer.to(new int[] { (Integer) param });
			size = Sizeof.cl_int;
		} else if (param instanceof Long) {
			p = Pointer.to(new long[] { (Long) param });
			size = Sizeof.cl_long;
		} else if (param instanceof Float) {
			p = Pointer.to(new float[] { (Float) param });
			size = Sizeof.cl_float;
		} else if (param instanceof Double) {
			p = Pointer.to(new double[] { (Double) param });
			size = Sizeof.cl_double;
		} else {
			throw new CLException("Type " + param.getClass()
					+ " may not be passed to a function");
		}
		
		CL.clSetKernelArg(kernel, index, size, p);
	}
	
	/**
	 * Retrieve the number of arguments expected by the kernel.
	 * @return the number of arguments required
	 */
	public int getArgumentsNumber() {
		int[] size = new int[1];
		CL.clGetKernelInfo(kernel, CL.CL_KERNEL_NUM_ARGS, Sizeof.cl_uint, Pointer.to(size), null);
		return size[0];
	}
	
	/**
	 * Release the native memory associated with this object.
	 */
	@Override
	protected void releaseImpl() {
		CL.clReleaseKernel(kernel);
	}
	
}
