//
//  PROGRAM: Matrix Multiply (tiled)
//
//  PURPOSE: This matrix multiply program uses a tiled algorithm
//           to make better use of the caches.   It computes 
//           the product 
//
//                C  = A * B
//
//           A and B are set to constant matrices so we
//           can make a quick test of the multiplication.
//
//  USAGE:   Right now, I hardwire the martix dimensions. 
//           later, I'll take them from the command line.
//
//  HISTORY: Written by Tim Mattson, Nov 1999.
//
#include <stdlib.h>
#include <stdio.h>

#define ORDER 1000
#define AVAL 3.0
#define BVAL 5.0
#define TOL  0.001

extern double wtime();

void mat_mul_ijk(int Mdim, int Ndim, int Pdim, double *A, double *B, double *C)
{
    int i, j, k;

    for (i=0; i<Ndim; i++){
        for (j=0; j<Mdim; j++){
	    for(k=0;k<Pdim;k++){
	         /* C(i,j) = sum(over k) A(i,k) * B(k,j) */
	         C[i*Ndim+j] += A[i*Ndim+k] *  B[k*Pdim+j];
             }
          }
    }
}

void mat_mul_ikj(int Mdim, int Ndim, int Pdim, double *A, double *B, double *C)
{
    int i, j, k;

    for (i=0; i<Ndim; i++){
	    for(k=0;k<Pdim;k++){
        for (j=0; j<Mdim; j++){
	         /* C(i,j) = sum(over k) A(i,k) * B(k,j) */
	         C[i*Ndim+j]  += A[i*Ndim+k] *  B[k*Pdim+j];
             }
          }
    }
}

void mat_mul_kij(int Mdim, int Ndim, int Pdim, double *A, double *B, double *C)
{
    int i, j, k;

	    for(k=0;k<Pdim;k++){
    for (i=0; i<Ndim; i++){
        for (j=0; j<Mdim; j++){
	         /* C(i,j) = sum(over k) A(i,k) * B(k,j) */
	         C[i*Ndim+j]  += A[i*Ndim+k] *  B[k*Pdim+j];
             }
          }
    }
}

void initmat(int Mdim, int Ndim, int Pdim, double *A, double *B, double *C)
{
    int i, j;

    /* Initialize matrices */

	for (i=0; i<Ndim; i++)
		for (j=0; j<Pdim; j++)
			*(A+(i*Ndim+j)) = AVAL;

	for (i=0; i<Pdim; i++)
		for (j=0; j<Mdim; j++)
			*(B+(i*Pdim+j)) = BVAL;

	for (i=0; i<Ndim; i++)
		for (j=0; j<Mdim; j++)
			*(C+(i*Ndim+j)) = 0.0;
}

double error(int Mdim, int Ndim, int Pdim, double *C)
{
   int i,j;
   double cval, errsq,err;
   cval = Pdim * AVAL * BVAL;
   errsq = 0.0;

   for (i=0; i<Ndim; i++){
       for (j=0; j<Mdim; j++){
	    err = *(C+i*Ndim+j) - cval;
	    errsq += err * err;
       }
   }
   return errsq;
}


int main(int argc, char **argv)
{
	int Ndim, Pdim, Mdim;   /* A[N][P], B[P][M], C[N][M] */
	int i,j,k;
	double *A, *B, *C, cval, tmp, err, errsq;
        double dN, mflops;
	double start_time, run_time_seq;


	Ndim = ORDER;
	Pdim = ORDER;
	Mdim = ORDER;

   	A = (double *)malloc(Ndim*Pdim*sizeof(double));
        B = (double *)malloc(Pdim*Mdim*sizeof(double));
        C = (double *)malloc(Ndim*Mdim*sizeof(double));
 
	initmat(Mdim, Ndim, Pdim, A, B, C);

/* Do the matrix product */

	start_time = wtime(); 

        mat_mul_ijk(Mdim, Ndim, Pdim, A, B, C);        

	run_time_seq  = wtime() - start_time;
        dN = (double)ORDER;
        mflops = 2.0 * dN * dN * dN/(1000000.0* run_time_seq);
	printf(" \n Order %d ijk (dot prod) mat mult in %.2f seconds ", ORDER, run_time_seq);
	printf(" at %.1f mflops\n", mflops);

       /* Check the answer */
        errsq = error(Mdim, Ndim, Pdim, C);

	if (errsq > TOL) 
		printf("\n Errors in multiplication: %f",errsq);
	else
		printf("\n Hey, it worked");

/* Do the second  matrix product */
	initmat(Mdim, Ndim, Pdim, A, B, C);

	start_time = wtime(); 

        mat_mul_ikj(Mdim, Ndim, Pdim, A, B, C);        

	run_time_seq  = wtime() - start_time;
        dN = (double)ORDER;
        mflops = 2.0 * dN * dN * dN/(1000000.0* run_time_seq);
	printf(" \n Order %d ikj mat mult in %.2f seconds ", ORDER, run_time_seq);
	printf(" at %.1f mflops\n", mflops);

        /* Check the answer */
        errsq = error(Mdim, Ndim, Pdim, C);

	if (errsq > TOL) 
		printf("\n Errors in multiplication: %f",errsq);
	else
		printf("\n Hey, it worked");


/* Do the third matrix product */
	initmat(Mdim, Ndim, Pdim, A, B, C);
	start_time = wtime(); 

        mat_mul_kij(Mdim, Ndim, Pdim, A, B, C);        

	run_time_seq  = wtime() - start_time;
        dN = (double)ORDER;
        mflops = 2.0 * dN * dN * dN/(1000000.0* run_time_seq);
	printf(" \n Order %d kij mat mult in %.2f seconds ", ORDER, run_time_seq);
	printf(" at %.1f mflops\n", mflops);

        /* Check the answer */
        errsq = error(Mdim, Ndim, Pdim, C);

	if (errsq > TOL) 
		printf("\n Errors in multiplication: %f",errsq);
	else
		printf("\n Hey, it worked");


	printf("\n all done \n");

}
