package demos;

import ucs.*;

/** A demo using n-fold cross-validation.
 *  @author Gavin Brown and Tim Kovacs
 *
 *  This demo generates train and test sets using cross-validation. 
 *  The number of folds must be 2 or more. (With
 *  1 fold all the data ends up in the test set and there's none in
 *  the train set.) So you cannot set folds to 1 here and have it work.
 *  <p>
 *  If instead of n-fold cross-validation you want to split the data
 *  into train and test sets by percentage use the approach in BasicUCS.java
 *  <p>
 *  Note that onlinelearning must be false as it indicates that all data
 *  is used for both the train and test sets.
 *  <p>
 *  The AUCAccuracy statistic is the Area Under the Curve for Accuracy. 
 *  It ranges from 0 (accuracy was always 0 in all runs) to 1 (accuracy was always 1).
 *  It's computed in UCS.java's test() method. Each time test() is called 
 *  the accuracy (percentage correct) on the test set is computed and AUCAccuracy 
 *  records the average of these values over the run. 
 *  The AUCAccuracy is then summed over all folds and the final average AUCAccuacy 
 *  is calculated over all the folds.
 *  <p>
 *  ****Normal cross-validation would create n folds and use each as the test set
 *  once and all others as the training set. However, this implementation
 *  re-folds the data each time a new UCS is created (within the for-loop below).
 *  This isn't cross-validation, it's just using a certain proportion of 
 *  train / test split and repeating the process n times.****
 */

public class CrossValidate11Mux {

	////////////////////////////////////////////////////////////////
	////////////////////////////////////////////////////////////////
	////////////////////////////////////////////////////////////////
		
        public static void main( String [] args )
        {
	    UCSconfig params = new UCSconfig();
	    params.onlinelearning=false;
	    params.setProblem("./data/11mux.csv");
	    params.folds = 2;
	    params.v = 10;
	    params.noise = 0.15;

	    double sumAUCacc=0.0;
	    double sumAcc=0.0;
	    double sumMargin=0.0;

	    for (int fold=0; fold<params.folds; fold++)
	    {
		UCS myucs = new UCS(params);

		//Run for 20000, generating and printing statistics at intervals of 500.
		//When we generate statistics with separate train and test sets, the
		//stats are over the entire test set.
		myucs.run(50000, 500);
		
		sumAUCacc +=  myucs.getAUCAccuracy();
		sumAcc += myucs.getAccuracy();
		sumMargin += myucs.getAverageMargin();
	    }

	    // Print the average of the AUCAccuracies of each fold
	    System.out.println("Average AUCAccuracy over folds: " + sumAUCacc/params.folds);

	    // Print average accuracy from the final test phase of each fold
	    System.out.println("Average final accuracy over folds: " + sumAcc/params.folds);

	    // Print average of average margins from the final test phase of each fold
	    System.out.println("Average final margin over folds: " + sumMargin/params.folds);
	}
}
