//------------------------------------------------------------------------------
//
//  PROGRAM: C_elem 
//
//  PURPOSE: Compute matrix product, 
//
//                C  = A * B
//
//           each work item does a dot product to compute an element of the C matrix.
//
//  HISTORY: Written by Tim Mattson, August 2010 
//
//------------------------------------------------------------------------------

#include "mult.h"

//------------------------------------------------------------------------------
//
// kernel:  mmul  
//
// Purpose: Compute matrix product c = a*b
// 
// input: a and b float matrices of order ORDER 
//
// output: c float matrix of order ORDER 
//
 
const char *C_elem_KernelSource = "\n" \
"__kernel void mmul(                                                    \n" \
"   const int Mdim,                                                     \n" \
"   const int Ndim,                                                     \n" \
"   const int Pdim,                                                     \n" \
"   __global float* A,                                                  \n" \
"   __global float* B,                                                  \n" \
"   __global float* C)                                                  \n" \
"{                                                                      \n" \
"   int k;                                                              \n" \
"   int i = get_global_id(0);                                           \n" \
"   int j = get_global_id(1);                                           \n" \
"   float tmp;                                                          \n" \
"   if( (i < Ndim) && (j <Mdim))                                        \n" \
"   {                                                                   \n" \
"       tmp = 0.0;                                                      \n" \
"       for(k=0;k<Pdim;k++)                                             \n" \
"           tmp         += A[i*Ndim+k] *  B[k*Pdim+j];                  \n" \
"       C[i*Ndim+j] = tmp;                                              \n" \
"   }                                                                   \n" \
"}                                                                      \n" \
"\n";

//------------------------------------------------------------------------------
int setup_kern_c_elem(
   cl_device_id   device_id,  // compute device id 
   cl_context     context,    // compute context
   int           Ndim,       // number of rows in A and C
   int           Pdim,       // number of rows for B and columns for A
   int           Mdim,       // number of columns in B and C  
   cl_mem        a_in,       // device memory used for the input  a vector
   cl_mem        b_in,       // device memory used for the input  b vector
   cl_mem        c_out,      // device memory used for the output c vector
   cl_program    *program,    // compute program
   cl_kernel     *kernel,     // compute kernel
   size_t        *global,     // global domain size  
   cl_uint       *ndim)       // Number of dimensions in NDRange
{
    int              err;       // error code returned from OpenCL calls

    // Create the compute program from the source buffer
    *program = clCreateProgramWithSource(context, 1, (const char **) & C_elem_KernelSource, 
                                                                               NULL, &err);
    if (!program)
    {
        printf("Error: Failed to create compute program!\n");
        return FAILURE;
    }

    // Build the program  
    err = clBuildProgram(*program, 0, NULL, NULL, NULL, NULL);
    if (err != CL_SUCCESS)
    {
        size_t len;
        char buffer[2048];

        printf("Error: Failed to build program executable!\n");
        clGetProgramBuildInfo(*program, device_id, CL_PROGRAM_BUILD_LOG, sizeof(buffer), 
                                                                         buffer, &len);
        printf("%s\n", buffer);
        return FAILURE;
    }

    // Create the compute kernel from the program 
    *kernel = clCreateKernel(*program, "mmul", &err);
    if (!kernel || err != CL_SUCCESS)
    {
        printf("Error: Failed to create compute kernel!\n");
        return FAILURE;
    }
	
    // Set the arguments to our compute kernel
    err  = 0;
    err  = clSetKernelArg(*kernel, 0, sizeof(int),    &Mdim);
    err |= clSetKernelArg(*kernel, 1, sizeof(int),    &Ndim);
    err |= clSetKernelArg(*kernel, 2, sizeof(int),    &Pdim);
    err != clSetKernelArg(*kernel, 3, sizeof(cl_mem), &a_in);
    err |= clSetKernelArg(*kernel, 4, sizeof(cl_mem), &b_in);
    err |= clSetKernelArg(*kernel, 5, sizeof(cl_mem), &c_out);
    if (err != CL_SUCCESS)
    {
        printf("Error: Failed to set kernel arguments! \n",err_code(err));
        return FAILURE;
    }
    global[0] = (size_t) Ndim;       global[1] = (size_t) Mdim;
    *ndim     = 2;

    printf("\n===== OpenCL, matrix mult, C(i,j) per work item, order %dx%d ======\n",Ndim,Ndim);
    return SUCCESS;
    
}

