Skip to content

Automatic Parameter Tuning

EdwardRaff edited this page Oct 16, 2016 · 1 revision

Introduction

Most machine learning models have one or more parameters that should be tuned to get the best performance in practice. While there are a few algorithms, such as Random Forests, that will work well event without tuning - in many cases it's the difference between performance almost as bad as random guessing, and a accurate and useful model. Unfortunately knowing which parameters to tune, and what value ranges to try, can be daunting. Even for knowledge practitioners, you may not have a good idea of what value ranges to try when testing a new algorithm or a method you aren't familiar with. JSAT makes this easier with automatic parameter inference when using GridSearch or RandomSearch to tune parameters.

Example 1, Automatic Parameter Tuning a single model

In the first example, we will tune a SVM with RBF Kernel to show the difference that parameter tuning can have. In JSAT, reflection is used to find "guess" methods that correspond to a specific parameter, and determine a search range based on the dataset being fit against. This allows for data sensitive tuning without the user having to always be aware of it. In the example below we use the unscaled version of a dataset, which normally breaks the standard parameter search for the RBF Kernel. JSAT can instead increase the search radius for the RBF Kernel's width based on the data given, resulting in better accuracy - especially compared to the model without tuning.

import java.io.File;
import java.io.IOException;
import java.util.List;
import jsat.classifiers.*;
import jsat.classifiers.svm.PlattSMO;
import jsat.classifiers.svm.SupportVectorLearner.CacheMode;
import jsat.distributions.kernels.RBFKernel;
import jsat.io.LIBSVMLoader;
import jsat.parameters.RandomSearch;

/**
 *
 * @author Edward Raff
 */
public class EasyParameterSearch
{
    public static void main(String[] args) throws IOException
    {
        //Download dataset from 
        //https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/diabetes
        ClassificationDataSet dataset = LIBSVMLoader.loadC(new File("diabetes.libsvm"));
        
        ///////First, the code someone new would use////////
        PlattSMO model = new PlattSMO(new RBFKernel());
        model.setCacheMode(CacheMode.FULL);//Small dataset, so we can do this
        
        ClassificationModelEvaluation cme = new ClassificationModelEvaluation(model, dataset);
        cme.evaluateCrossValidation(10);
        
        System.out.println("Error rate: " + cme.getErrorRate());
        
        /*
         * Now some easy code to tune the model. Because the parameter values
         * can be impacted by the dataset, we should split the data into a train 
         * and test set to avoid overfitting. 
         */

        List<ClassificationDataSet> splits = dataset.randomSplit(0.75, 0.25);
        ClassificationDataSet train = splits.get(0), test = splits.get(1);
        
        //The 3 in the constructor is the number of CV folds to evaluate each parameter combination
        RandomSearch search = new RandomSearch((Classifier)model, 3);
        search.setTrials(100);
        if(search.autoAddParameters(train) > 0)//this method adds parameters, and returns the number of parameters added
        {
            //that way we only do the search if there are any parameters to actually tune
            cme = new ClassificationModelEvaluation(search, train);
            cme.evaluateTestSet(test);
            System.out.println("Tuned Error rate: " + cme.getErrorRate());
        }
        else//otherwise we will just have to trust our original CV error rate
            System.out.println("This model dosn't seem to have any easy to tune parameters");
    }
}

Example 2, Automatic Parameter Tuning Multiple Models

Automatic parameter inference becomes particularly useful when you want to run a bunch of models on a dataset. This is a common scenario in practice, especially when taking a first cut at a new dataset. In most libraries you would have to copy-paste most of the code for each model you wanted to try, and adjust the parameters for what you wanted to adjust for each model. Or keep a data structure paired with the models to keep track of each combination. In JSAT this can be done in a simple loop. For this dataset we will use MNIST, so this may take a good while longer to run depending on your hardware.

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import jsat.classifiers.*;
import jsat.classifiers.linear.LogisticRegressionDCD;
import jsat.classifiers.linear.kernelized.KernelSGD;
import jsat.classifiers.svm.extended.AMM;
import jsat.classifiers.trees.RandomForest;
import jsat.io.LIBSVMLoader;
import jsat.parameters.RandomSearch;
import jsat.utils.SystemInfo;

/**
 *
 * @author Edward Raff
 */
public class EasyParameterSearch2
{
    public static void main(String[] args) throws IOException
    {
        //Lets use the slightly larger MNIST dataset, download dataset from 
        //https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass.html#mnist
        ClassificationDataSet train = LIBSVMLoader.loadC(new File("mnist.scale"));
        //LIBSVM format isn't very good for sparse datasets when you don't see all features
        //the extra arguments help deal with issues where a feature isn't seen in the test set but is in the training set
        ClassificationDataSet test = LIBSVMLoader.loadC(new File("mnist.scale.t"), 0.5, train.getNumNumericalVars());
        
        /*
         * New, we create a list of models we would like to try on our dataset.
         * Below I've picked 4 of my favorite models to use. They tend to work
         * well on most datasets most of the time, so I like to use them as 
         * general starting points.
         */
        List<Classifier> models = new ArrayList<>();
        models.add(new OneVSAll(new LogisticRegressionDCD(), true));//a fast exact LR algorithm
        models.add(new AMM());//A non linear model with linear-like efficency 
        models.add(new RandomForest());//everyone's favorite tree ensemble 
        models.add(new KernelSGD());//A faster approximate version of an SVM
        
        //For this example, we will make the search parallel 
        ExecutorService exec = Executors.newFixedThreadPool(SystemInfo.LogicalCores);
        
        for(Classifier model : models)//loop over each model we want to try
        {
            System.out.println("Testing model: " + model.getClass().getSimpleName());
            RandomSearch search = new RandomSearch((Classifier)model, 3);
            search.setTrials(50);
            if(search.autoAddParameters(train) > 0)//this method adds parameters, and returns the number of parameters added
            {
                //adding the exec tot he constructor makes it use multiple threads when possible 
                ClassificationModelEvaluation cme = new ClassificationModelEvaluation(search, train, exec);
                cme.evaluateTestSet(test);
                System.out.println("\tTuned Error rate: " + cme.getErrorRate());
            }
            else//otherwise we will evaluation
            {
                ClassificationModelEvaluation cme = new ClassificationModelEvaluation(model, train, exec);
                cme.evaluateTestSet(test);
                System.out.println("\tError rate: " + cme.getErrorRate());
            }
        }
        
        /*
         * Sample output:
         * Testing model: OneVSAll
         * 	Tuned Error rate: 0.08109999999999995
         * Testing model: AMM
         * 	Tuned Error rate: 0.03939999999999999
         * Testing model: RandomForest
         * 	Error rate: 0.042300000000000004
         * Testing model: KernelSGD
         * 	Tuned Error rate: 0.04920000000000002
         */
        
        exec.shutdownNow();
    }
}

Review

You now know how to set up your code to automatically tune a model for your dataset. In most of the cases this basic setup will be all you need if you just want to evaluate models on a new dataset. Not all parameters may have a "guess" method in JSAT, so you shouldn't consider the results as the best possible - but they will likely be close. For many algorithms, the parameters values can also have an impact on runtime. This may cause the runtime to be longer than you expect, but is a common issue.