
package ucs;

import java.util.ArrayList;
import java.util.Iterator;

import dataprocessing.*;

/**
 * Main UCS class. Make an instance of this (providing a <b>UCSconfig</b> object)
 * and call <i>run(int)</i> or <i>run(int, int)</i> to use UCS on your data.
 * <b>Note: this implementation (version 0.1b) can handle only DISCRETE inputs (any arity), and only TWO CLASS problems.</b>
 * 
 * @author Gavin Brown
 */
public class UCS {

	////////////////////////////////////////////////////////////////////////////////////////////
	////////////////////////////////////////////////////////////////////////////////////////////
	////////////////////////////////////////////////////////////////////////////////////////////

    protected long GA_counter = 0;

	/** Flag indicating whether UCS is in testing mode at the moment or not. */
	protected static boolean TESTMODE = false;
	
	/** Counter indicating the current number of iterations (aka generations) processed by this UCS object. */
	protected static int currentIteration;
	protected double aucAccuracy;

        /** Used to compute aucAccuracy */
        protected double sumSystemCorrect;

        /** Counts the number of times test statistics have been generated during this run*/
	protected int numTests;

        /** Stores the accuracy from the most recent time it was evaluated by a call to Test() */
	protected double accuracy;

        /** Store the average voting margin from the most recent time it was evaluated by a call to Test() */
        protected double averageMargin;

	protected ArrayList traindata = null, testdata = null;
	protected ArrayList lastFewExamples = null;
	
	protected ArrayList matchSet = null;
	protected ArrayList correctSet = null;

	////////////////////////////////////////////////////////////////	
	////////////////////////////////////////////////////////////////	
	////////////////////////////////////////////////////////////////
	
	/** The population in this instance of UCS */
	public Population pop = new Population();
	
	/** The configuration for this instance of UCS */
	public UCSconfig params = null;


	public double getAUCAccuracy() { return aucAccuracy; }
	public double getAccuracy() { return accuracy; }
	public double getAverageMargin() { return averageMargin; }
	
	////////////////////////////////////////////////////////////////	
	////////////////////////////////////////////////////////////////	
	////////////////////////////////////////////////////////////////	

	/**
	 * Constructor for the main UCS class.
	 * @param p A UCSparams object specifying how this instance of UCS should function.
	 */
	public UCS( UCSconfig p )
	{		
	        params = p;		
		if (params.trainingFile==null)
		{
			System.out.println("No problem data specified.");
			System.exit(1);
		}
		else loadData(params.trainingFile, params.folds);
	}
		
	////////////////////////////////////////////////////////////////////////////////////////////
	////////////////////////////////////////////////////////////////////////////////////////////
	////////////////////////////////////////////////////////////////////////////////////////////
	
	/**
	 * Set the problem data for this UCS instance.
	 * @param filename The datafile to use.
	 */
	protected void loadData( String filename, int folds )
	{
		//READ THE DATAFILE
		//
		DataSource reader = new DataSource(filename);
		if (params.verbosity >= 2) System.out.println("Data loaded: "+filename);
		DataSource testReader = null;
		reader.shuffle();
		
		//make a blank RuleCondition so it initializes its necessary variables from the datafile.
		RuleCondition x = new RuleCondition( reader.numInputs );
		
		if (params.onlinelearning && folds != 1)
		    throw new IllegalArgumentException("folds must be 1 when onlinelearning == true");
		
		//note: i'm not sure what would happen if there were missing values
		//in either the training or testing data

		// USE ALL DATA FOR BOTH TRAIN AND TEST SETS
		if (params.onlinelearning) 
		{
			traindata = reader.getData();
			testdata = traindata;
			return;
		}

		// USE SEPARATE TRAIN AND TEST SETS

		// use the specified testing data file
		if(params.testingFile!=null)		
		    {
			traindata = reader.getData();
			testReader = new DataSource(params.testingFile);
			testdata = testReader.getData();
			return;
		    }

		// no test data was specified.
		// if folds != 1 we do crossvalidation
		if(params.folds != 1)
		    {
			// if initial division into folds has not been done yet
			if (params.currentFold == -1) {
			    // otherwise perform initial folding and use the 0th fold as the test fold
			    reader.addTargetNoise(params.noise, reader.exampleList);
			    reader.initialFolding(folds, params);
			    params.currentFold = 0;
			    assignFoldsToTrainAndTest(params);
			    return;
			}

			// if this isn't the first run then progress to next fold
			params.currentFold ++;
			if (params.currentFold == params.folds)
			    params.currentFold = 0;
			assignFoldsToTrainAndTest(params);
			return;			
		    }

		// no test data specified and folds == 1. Split into train/test in some proportion 
		traindata = reader.getTrainingData();
		reader.addTargetNoise(params.noise, traindata);
	        testdata = reader.getTestingData();
		if (params.verbosity >=2)
		    System.out.println("loaded " + traindata.size() + " train examples and " + testdata.size() + " test examples");
		return;				
	}

    private void assignFoldsToTrainAndTest(UCSconfig params) {
	if (params.verbosity >= 2) 
	    System.out.println("Using fold " + params.currentFold + " for testing");	
	testdata = params.allFolds.get(params.currentFold);
	
	// add all other folds to traindata
	traindata = new ArrayList();
	for (int fold=0; fold<params.folds; fold++) {
	    if (fold != params.currentFold) {
		if (params.verbosity >= 2) 
		    System.out.println("Adding fold " + fold + " to training set");
		traindata.addAll(params.allFolds.get(fold));
	    }
	}
    }
	
	////////////////////////////////////////////////////////////////////////////////////////////
	////////////////////////////////////////////////////////////////////////////////////////////
	////////////////////////////////////////////////////////////////////////////////////////////
	
	/**
	 *  Main method to be run.
	 *  @param maxiterations The number of iterations to execute.
	 */
	public void run( int maxiterations ) { run(maxiterations, 100); }

	/**
	 *  Main method to be run.
	 *  @param maxiterations The number of iterations to execute.
	 *  @param TEST_INTERVAL The regular number of iterations when output information is printed.
	 */
	public void run( int maxiterations, final int TEST_INTERVAL )
	{	
	        // INITIALISE FOR A NEW RUN
	        currentIteration = 1; // if set to 0 we run a test before any training
		sumSystemCorrect = 0.0;
                aucAccuracy = 0.0;
		averageMargin = 0.0;
		numTests = 0;
		lastFewExamples = new ArrayList(); // this particular object will store first 50 examples

		//START THE ITERATIONS LOOP
		int stoppoint=currentIteration+maxiterations;
		for (; currentIteration<=stoppoint; currentIteration++)
		{
		        // Test on a test set
			if(currentIteration%TEST_INTERVAL==0) test();

			// The remainder of this loop handles online learning (testing on the train set)
			if(params.onlinelearning)
			    if(currentIteration%50==0) lastFewExamples = new ArrayList(); // allocate another to store next 50
			
			//PICK A RANDOM DATA ITEM
			//
			int randomIndex = UCSconfig.generator.nextInt(traindata.size());
			Example e = (Example)traindata.get(randomIndex);
			
			//STORE IT SO WE CAN CALCULATE ACCURACY OVER LAST FIFTY EXAMPLES
			//
			if(params.onlinelearning)
			    lastFewExamples.add(e);
			
			//IDENTIFY THE MATCH SET AND THE CORRECT SET
			//
			getMatchAndCorrectSets(e);

			//CHECK THE CORRECT SET
			//
			int correctSetSize = correctSet.size();
			int averageLastTimeInTheGA=0;
			if (correctSetSize==0)
			{
				//DO COVERING IF NECESSARY
				//
				covering(e);
			}
			else
			{
				//UPDATE THE AVERAGE CORRECT SET SIZES
				// - TO BE USED LATER IN DELETION PROBABILITIES
				//
				for(int i=0; i<correctSetSize; i++)
				{
					Indiv ind = (Indiv)correctSet.get(i);
					ind.updateCorrectSetSize(correctSetSize);
					averageLastTimeInTheGA += ind.lastTimeThisWasInTheGA;
				}
			}
			
			//GET THE AVERAGE TIMESTAMP
			//
			averageLastTimeInTheGA /= (double)correctSetSize;
			int gaRecency = currentIteration - averageLastTimeInTheGA;
			
			//INVOKE THE GA MAYBE
			//
			if(		gaRecency > params.gaThreshold
				&& 	currentIteration > params.gaThreshold
				&&	correctSetSize>0)
			{
					invokeGA();
			}
	
			
		}//END ITERATIONS LOOP

		if (params.verbosity >= 2) {
		    System.out.println("Done "+maxiterations+" iterations.");
		    System.out.println("GA occured " + GA_counter + " times");
		}
				
	}//END MAIN
	

	////////////////////////////////////////////////////////////////////////////////////////////
	////////////////////////////////////////////////////////////////////////////////////////////
	////////////////////////////////////////////////////////////////////////////////////////////
	

	protected void test()
	{
		ucs.UCS.TESTMODE=true;
		double systemCorrect=0;
		double averageMatchSetSize=0, averageCorrectSetSize=0;
		double sumMargin = 0.0;
		
		//HOW ARE WE DOING TESTING?
		int numTesting;
		Iterator testDataIterator = null;
		if(params.onlinelearning)
		{
			//TEST ON LAST 50 EXAMPLES
			numTesting = lastFewExamples.size();
			testDataIterator = lastFewExamples.iterator();
		}
		else
		{
			//TEST ON THE TESTING DATA
			numTesting = testdata.size();
			testDataIterator = testdata.iterator();
		}
		
		int zeros=0;
		while (testDataIterator.hasNext())
		{
			//GET THE NEXT TESTING EXAMPLE
			//
			Example e = (Example)testDataIterator.next();
			
			//GET THE MATCH AND CORRECT SETS FOR THIS EXAMPLE
			//
			getMatchAndCorrectSets(e);
			
			if(correctSet.size() ==0)
			    covering(e);

			//UPDATE SOME STATISTICS
			//
			averageMatchSetSize += matchSet.size();
			averageCorrectSetSize += correctSet.size();
			
			//FORM THE SYSTEM PREDICTION
			//
			int guess = params.systemPredictor.predict( matchSet, params.fitfunc );
			
			//CHECK IT and update the summed voting margin on the test set
			//
			if(guess==e.target) {
			    systemCorrect++;
			    sumMargin += params.systemPredictor.margin;
			}
			else
			    sumMargin -= params.systemPredictor.margin;			
		}
		
		
		//FINALISE THE STATISTICS
		//
		averageMatchSetSize /= numTesting;
		averageCorrectSetSize /= numTesting;
		systemCorrect /= numTesting;
		double percentageCorrectInMatchSet = (averageCorrectSetSize/averageMatchSetSize);
		if(Double.isNaN(percentageCorrectInMatchSet)) percentageCorrectInMatchSet=0;
		if (Double.isNaN(systemCorrect)) systemCorrect=0;
		
		int macroclassifiers = pop.numMacroClassifiers();

		// update aucAccuracy
		numTests ++;
		sumSystemCorrect += systemCorrect;
		aucAccuracy = sumSystemCorrect / numTests;

		// store accuracy so it can can be accessed externally if so desired
		accuracy = systemCorrect; 

		// compute average margin for this test and store for external access
		averageMargin = sumMargin / testdata.size();

		if (Double.isNaN(averageMargin)) 
		    throw new IllegalStateException("averageMargin is NaN. sumMargin="+sumMargin 
						    + " testdata.size()="+testdata.size());


		int onlinelearn=0; if (params.onlinelearning) onlinelearn=1;
		
		//PRINT THEM ALL OUT
		//
		if (params.verbosity >=1)
		    System.out.println("RESULTS iteration "+currentIteration+
				       " accuracy "+systemCorrect+
				       " v "+params.v+
				       //" popsize "+this.pop.size()+
				       //" matchsetsize "+averageMatchSetSize+
				       //" cs "+percentageCorrectInMatchSet+
				       " macro "+macroclassifiers+
				       " aucacc "+aucAccuracy+
				       " online "+onlinelearn+
				       " avgmargin "+averageMargin
				       //" numTesting "+numTesting
				       );
		
		ucs.UCS.TESTMODE=false;
	}
	
	////////////////////////////////////////////////////////////////////////////////////////////
	////////////////////////////////////////////////////////////////////////////////////////////
	////////////////////////////////////////////////////////////////////////////////////////////
	/////////////////////////////////////////////////////////////////////////////
	
	protected void covering( Example e )
	{
	        // UCS only covers the correct set (since it's a supervised learner - UCS paper p.215)
		//MAKE AN INDIVIDUAL THAT MATCHES THIS INPUT
		Indiv ind = new Indiv( e.inputs.length, params );
		ind.condition.set( e.inputs );
		
		//FLIP SOME BITS TO A DON'T CARE SYMBOL
		for (int i=0; i<ind.condition.values.length; i++)
		{
			//FLIP WITH PROBABILITY p
			if(params.generator.nextDouble() < params.coveringProbability)
			{
				ind.condition.values[i] = RuleCondition.HASH;
			}
		}

		//initialise classifier
		ind.action = e.target;
		ind.correctSetSize = 1; // since we're covering

		//add it
		//
		pop.add( ind );
		matchSet.add(ind);
		correctSet.add(ind);
		
		//delete if necessary
		deleteIfNecessary();

		// if we've just deleted the covering rule, try again
		if (correctSet.size() == 0) covering(e);
	}
	
	////////////////////////////////////////////////////////////////////////////////////////////
	////////////////////////////////////////////////////////////////////////////////////////////
	////////////////////////////////////////////////////////////////////////////////////////////

	protected void getMatchAndCorrectSets( Example e )
	{
		//BLANK THE PREVIOUS SETS
		matchSet = new ArrayList();
		correctSet = new ArrayList();

		//IDENTIFY THE MATCH AND CORRECT SETS
		Iterator popIterator = pop.iterator();
		while (popIterator.hasNext())
		{
			Indiv ind = (Indiv)popIterator.next();
			
			if(ind.matches(e.inputs)) // this also updates rule's numMatches and accuracy
			{
				matchSet.add(ind);
				if( ind.test(e.target) )	correctSet.add(ind);
			}
		}				
	}

	////////////////////////////////////////////////////////////////////////////////////////////
	////////////////////////////////////////////////////////////////////////////////////////////
	////////////////////////////////////////////////////////////////////////////////////////////

	protected void deleteIfNecessary()
	{
		//DELETION SCHEME AS SPECIFIED IN THEIR PAPERS
		//
		while(pop.size() > params.POPMAXSIZE)
		{
			Iterator popit = pop.iterator();
			int sum=0;
			double averageFitness=0;
			while (popit.hasNext())
			{
				Indiv ind = ((Indiv)popit.next());
				sum += ind.correctSetSize;
				//averageFitness += params.fitfunc.evaluate(ind);
				averageFitness += params.GAfitfunc.evaluate(ind);
			}
			averageFitness /= pop.size();
			
			
			int choicePoint = params.generator.nextInt(sum);
			
			sum=0;
			popit = pop.iterator();
			while (popit.hasNext())
			{
				Indiv ind = ((Indiv)popit.next());
				double deletionVote = ind.correctSetSize;
				//double thisfitness = params.fitfunc.evaluate(ind);
				double thisfitness = params.GAfitfunc.evaluate(ind);
				
				//SEE KOVACS 1999 FOR THIS BIT
				//
				if(ind.numMatches > params.ThetaDel && thisfitness < params.ThetaDelFrac*averageFitness)
				{
					deletionVote *= (averageFitness/thisfitness);
				}
				
				sum += deletionVote;
				if (sum > choicePoint)
				{
					pop.remove(ind);
					correctSet.remove(ind);
					matchSet.remove(ind);
					break;
				}
			}
		}
	}


	////////////////////////////////////////////////////////////////////////////////////////////
	////////////////////////////////////////////////////////////////////////////////////////////
	////////////////////////////////////////////////////////////////////////////////////////////
	
	protected void invokeGA()
	{
	    GA_counter ++;

	        // This is a niche GA run only on correct set. It uses roulette wheel selection
		//
		double sumOfFitnesses=0;
		Iterator citer = correctSet.iterator();
		while (citer.hasNext())
		{
			Indiv ind = (Indiv)citer.next();

			//double f = params.fitfunc.evaluate(ind);
			double f = params.GAfitfunc.evaluate(ind);
						
			sumOfFitnesses += f;
			
			//UPDATE THE TIMESTAMPS TO SAY THEY ARE IN THE GA
			//
			ind.lastTimeThisWasInTheGA = ucs.UCS.currentIteration;
		}
		
		//System.out.println(sumOfFitnesses+" <--- sum Of fitnesses, cs="+cs);
		double rnd1 = params.generator.nextDouble() * sumOfFitnesses;
		double rnd2 = params.generator.nextDouble() * sumOfFitnesses;
		double sum1 = 0, sum2 = 0;
		
		int randomMommy = -1, randomDaddy = -1;
		Indiv mommy=null, daddy=null;
		
		citer = correctSet.iterator();
		while (citer.hasNext())
		{	
			Indiv ind = (Indiv)citer.next();

			//double f = params.fitfunc.evaluate(ind);
			double f = params.GAfitfunc.evaluate(ind);
			
			sum1 += f;
			sum2 += f;
			
			if (sum1 >= rnd1 && randomMommy == -1)
			{
				mommy = (Indiv)ind.clone();
				randomMommy=1;
			}
			
			if (sum2 >= rnd2 && randomDaddy == -1)
			{
				daddy = (Indiv)ind.clone();
				randomDaddy=1;
			}
		}

		Indiv[] children = mommy.crossAndMutate(daddy);
		
		
		//CHECK FOR SUBSUMPTION
		//
		for (int N=0; N<2; N++)
		{
			boolean subsumed=false;
			if (mommy.subsumes(children[N]))
			{
				pop.add( (Indiv)mommy.clone() );
				subsumed=true;
			}
			if (daddy.subsumes(children[N]))
			{
				pop.add( (Indiv)daddy.clone() );
				subsumed=true;
			}			
		
			if(!subsumed) pop.add(children[N]);
		}	
		
		while(pop.size()>params.POPMAXSIZE)
			deleteIfNecessary();
	}
	
	////////////////////////////////////////////////////////////////////////////////////////////
	////////////////////////////////////////////////////////////////////////////////////////////
	////////////////////////////////////////////////////////////////////////////////////////////

}//end class
