I have been implementing Machine-Learning models in Java. As you know there are not many libraries to support analytics concepts in Java, which means, often times I have to write the raw code, which can be functions after functions. In case there are libraries to use, such as Table saw, Smile, Weka etc., there is no huge community to offer support in comparison to python and Mat lab. We will dwell into how I have been coping later but for now, let us focus on a common challenge that cut across splitting the data set to train and test models from the common data sets such as the IRIS data set, the housing data set, the customer data set and the adult data set.
Most of these data sets are publicly available, and anyone can use them to create and train the model, depending on the category. When creating the models either in supervised or unsupervised learning, the most important thing is that one has to split the data into two different sets where one will be used for training and the other one used for testing. Often times, we tend to split the data in the ratio of 70:30.
Just to give a background, training data is used to training the Model. It takes into account the independent variables and the outcome, which is defined as the dependent variable. The models learn from the data given and later used to predict a new set of data. The model accuracy in predicting the new set, highly depends on the training data provided. Testing data set is used to test the models developed. Often times, this data is labelled, but one just uses the independent variables to test the model and then compare the outcome with the actual values and from there one can determine the accuracy of the model, or how better the model predicts the unknown and how it will perform in a generalized data set.
In splitting the data set, they are key things to factor in:
- Is the data arranged in a certain order that might affect the splitting
- Does the dependent variable fluctuate
- Is there is large imbalance in the response variables e.g. house prices
- In case of classification data set, are positive values many than negative values and vice versa?
In considering the above scenarios, Cross Validation is the best solution to ensure that we achieve the train-test split in the most efficient manner. However, there are different types of cross validation depending on the set of the data.
- K folds cross Validation
- Stratified cross Validation
- Hold one out Method
K folds cross Validation
It involves dividing the data into small groups (folds) referred to as K folds. The K is the number of folds that you will divide the data set into. The recommended value of K is between 1 and 10. For a K of five, it means that, the data set will be:
- Divided into 5 different groups
- At each time, the model will be trained for 5 times and tested 5 times
- At each scenario, four groups will be used to train the model and then tested with the remaining group. This is repeated five times, using at least a different testing group each time.
- Each and every time, the model error is recorded and the average of the errors in five times forms the error of the model
Using K folds Cross Validation, the model is trained on familiar data and predicts some unfamiliar data points at every Kth time. Using this, the model is able to predict dependent and generalized data hence avoiding over fitting and under fitting.
Stratified Cross Validation
Stratified Cross Validation is an advanced K folds Cross Validation taking care of imbalance in the dependent data. This means that, when using the housing data set and splitting it to K folds, one has to ensure that the number of houses with high prices and low prices are evenly spread in the different folds.
What I used in my Project:
In the project, I used the Iris data set, which had 150 samples of three different flowers. First set of 50 included Setosa, second set included Virginica and third set was Versicolor. All the flowers were not shuffled which meant if I did the train –test split I could have ended up with unequal distribution of flowers across the different K folds, hence I choose to use Stratified K folds. In stratified K folds, I used my K as five, and hence I Split the data to the different K folds by ensuring that each fold had equal number of Setosa, Virginica, Versicolor. Which means I ended up having five folds, each with 10 flowers of each type.
Through the training, I realized that the errors were different but much smaller compared to the Folds. This means that, ensuring balance in the training and the testing data set trains the model well without over fitting or under fitting hence improving the accuracy of the model in generalized data.
Why choose Cross Validation instead of normal Train-Test Split:
- Reduces bias in the results
- Ensures that the model can be used in a more generalized data set and give quality predictions
- Reduces the variance within the data set
The disadvantages of implementing Stratified Cross Validation is that it’s hard to implement, mostly because I was using Java, which means I have to write around 200 lines to code to split the data, create the model and test using each and every fold. It is also time consuming.
Hold Out Method of Cross Validation:
In hold out method of cross validation, a sample of the data set is left out for testing purposes and the rest of the data used for training. This is similar to the normal test-split, which uses 70-30. Advantage of using this is easier in implementation but can give high variance in the results. Mostly used when the data set is small and evenly balanced.