Understanding XGBoost in five minutes
In this first post we will go over the theory and a basic worked example of XGBoost. The next post will look at preparing data for XGBoost models, visualising the trees and plotting feature importance.
What is XGBoost?
XGBoost is an algorithm that has shown high performance in regression, classification and ranking problems in data science competitions and industry. The name stands for eXtreme Gradient Boosting. While XGBoost is widely used, some non-technical stakeholders consider it something of a data science ‘black box’. The models are not quite as easy to interpret as a decision tree, or a set of manually coded rules. The aim of this post is to demystify XGBoost, and give a mostly ‘maths-free’ explanation of how it works.
XGBoost models are generally made up of gradient boosted decision trees. To better understand the what this means, let's go over some of the background:
Decision Trees
A decision tree is a simple rule based predictor made up of branching true and false statements. For example, 'is it raining today? If not, how windy is it? If less than 10 km/h, then I will go for a walk'. Decision trees are built by finding the best split for the entire dataset, (e.g. raining or not raining), then for each quarter, and so on until splitting the data no longer adds predictive value.
The main advantage of decision trees is how easy they are to explain. The downsides are that they very prone to overfitting the training data, and often don’t generalise well. The model may work well when it’s trained, but fail completely when given new data.
Combining multiple decision trees
Ensemble methods work by combining multiple decision trees to generate a consensus prediction. The aim is to create a strong classifier by combining several 'weak' classifiers. A weak classifier is one that has some ability to predict the target variable, but not enough to use as a model on its own.
Bagging and boosting are two commonly used ensemble methods.
Bagging takes samples of data (with replacement) from a population and generates a weak learner decision tree for each. Each weak learner is trained independently, so it is possible to train them in parallel. The predictions from each of the decision trees are averaged to generate a final prediction. A random forest is made up of several decision trees and is one example of the bagging technique.
Boosting is a method that generates a model, then calculates the predictive error for each data point. A weak learner is then trained to ‘correct‘ the error of the model. The ‘corrected’ value is then compared to the actual value of the data point. Another weak learner is added to further correct the error. The process is repeated until and more learners are added until the model reaches a desired level of predictive performance.
A simple analogy for bagging would be taking a room full of people and asking each one to predict how many beans are in a jar. We would average each of the ‘weak’ predictions to form a consensus single ‘strong’ prediction.
Boosting is slightly harder to put into a simple analogy, so we’ll walk through an example.
A worked example of Boosting
Consider a dataset where years of experience is an input variable, and salary (£ ,000) is the variable we are trying to predict. We want to build a model that uses years of experience to predict salary.
Years | Salary (£ ,000) |
---|---|
2 | 40 |
5 | 49 |
7 | 75 |
12 | 60 |
13 | 64 |
24 | 110 |
26 | 100 |
Before we start training, we need to define a loss function. The loss function is gives us the difference between the correct values for salary and values predicted by our model. To get a good model, we want to minimise the values coming out of the loss function.
For this example, we’ll choose the squared loss, which is calculated by squaring the difference between the predicted value and the correct value. The squared loss is also known as the mean squared error.
We will initialise our boosting model with a function that minimises the squared loss - simply the mean of the salaries in this case. We will also calculate the error of our function for each instance:
x (Years) | y (Salary (£ ,000)) | F0 | y - F0 |
---|---|---|---|
2 | 40 | 71 | -31 |
5 | 49 | 71 | -22 |
7 | 75 | 71 | 4 |
12 | 60 | 71 | -11 |
13 | 64 | 71 | -7 |
24 | 110 | 71 | 39 |
26 | 100 | 71 | 29 |
The third column F₀ is our initial function - the average salary. Our (dumb) F₀ model is predicting that everyone makes £71,000 per year.
The fourth column y - F₀ is the error of our initial function. We now use this error to to create our first weak learner regression tree h₁(x). The new tree is not trying to predict the salary; instead it will try and predict the error at each data point of our dumb F₀ model (y - F₀) using the data we have on years of experience (x).
The new regression tree h₁(x) is a regression tree stump with a decision node at the root and two predictor leaves. The leaves predict the average of the target values held by that leaf. The split on x is the value which minimises the variance between values in the leaves. Let’s add the values to our table:
x (Years) | y (Salary (£ ,000)) | F₀ |
y - F₀ |
h₁(x) |
F₁ |
---|---|---|---|---|---|
2 | 40 | 71 | -31 | -13.4 | 57.6 |
5 | 49 | 71 | -22 | -13.4 | 57.6 |
7 | 75 | 71 | 4 | -13.4 | 57.6 |
12 | 60 | 71 | -11 | -13.4 | 57.6 |
13 | 64 | 71 | -7 | -13.4 | 57.6 |
24 | 110 | 71 | 39 | 34 | 105 |
26 | 100 | 71 | 29 | 34 | 105 |
F₁ is calculated by adding F₀ and h₁(x). F₁ is the new predictor for y ‘boosted’ by the weak error predictor h₁(x). F₁ looks a bit smarter than our initial model - it predicts that people with 24+ years of experience make more than those with less experience.
The key idea behind boosting is to insert more models to correct the errors of the previous model.
We can repeat the boosting process another 2 times to see if the performance increases:
x (Years) | y (Salary (£ ,000)) | F₀ |
y - F₀ |
h₁(x) |
F₁ | y - F₁ |
h₂(x) | F₂ |
y - F₂ |
h₃(x) | F₃ |
---|---|---|---|---|---|---|---|---|---|---|---|
2 | 40 | 71 | -31 | -13.4 | 57.6 | -17.6 | -13.1 | 44.5 | -4.5 | 1.7 | 46.2 |
5 | 49 | 71 | -22 | -13.4 | 57.6 | -8.6 | -13.1 | 44.5 | 4.5 | 1.7 | 46.2 |
7 | 75 | 71 | 4 | -13.4 | 57.6 | 17.4 | 5.2 | 62.8 | 12.2 | 1.7 | 64.5 |
12 | 60 | 71 | -11 | -13.4 | 57.6 | 2.4 | 5.2 | 62.8 | -2.8 | 1.7 | 64.5 |
13 | 64 | 71 | -7 | -13.4 | 57.6 | 6.4 | 5.2 | 62.8 | 1.2 | 1.7 | 64.5 |
24 | 110 | 71 | 39 | 34 | 105 | 5 | 5.2 | 110.2 | -0.2 | 1.7 | 111.9 |
26 | 100 | 71 | 29 | 34 | 105 | -5 | 5.2 | 110.2 | -10.2 | -10.2 | 100 |
The predictions in the F₃ column look a lot closer to our true salary values. We’ll plot the mean squared error for each training iteration to see if the predictive performance of the model has improved:
Impressively, adding simple regression tree stumps result in a significant reduction in prediction error.
The ‘gradient’ part of gradient boosted trees
In the boosting example above, our loss function was the mean squared error (MSE). The MSE gave us a measure of how ‘good’ our model was at predicting salary based on years of experience. To improve the model, we wanted to make the MSE as small as possible. Calculating the mean of the values in each of the tree’s leaves minimised the MSE.
Gradient boosting generalises the idea to loss functions other than squared error. Things start to get a bit more complicated and mathematical here, but the general ideas are the same as our simple example. For our salary predictor, the errors y - Fₘ(x) are the negative gradients of the squared error loss function. Gradient boosting lets you plug in any differentiable loss function and train the trees on the gradient of that loss function. Essentially, we are still building trees to predict and correct the errors of our previously constructed model. For those interested, the Wikipedia page on gradient descent gives an excellent analogy.
If you’ve followed along so far, well done, you now understand the underlying theory and ideas for the XGBoost algorithm! In the next post, we’ll look at using Python to prepare data for XGBoost models, visualise the trees and show which features are most predictive in the model.