Press enter to see results or esc to cancel.

Predicting customer churn with Python: Logistic regression, decision trees and random forests

Customer churn is when a company’s customers stop doing business with that company. Businesses are very keen on measuring churn because keeping an existing customer is far less expensive than acquiring a new customer. New business involves working leads through a sales funnel, using marketing and sales budgets to gain additional customers. Existing customers will often have a higher volume of service consumption and can generate additional customer referrals.

Customer retention can be achieved with good customer service and products. But the most effective way for a company to prevent attrition of customers is to truly know them. The vast volumes of data collected about customers can be used to build churn prediction models. Knowing who is most likely to defect means that a company can prioritise focused marketing efforts on that subset of their customer base.

Preventing customer churn is critically important to the telecommunications sector, as the barriers to entry for switching services are so low. In this post, we will examine customer data from IBM Sample Data Sets with the aim of building and comparing several customer churn prediction models. You can download the dataset here.

Data preparation

What does the data look like?

Each row gives details for all of the 7,034 individual customers, e.g. the length of their tenure, internet service type, contract type, monthly charges. The target for prediction is the ‘Churn’ column, indicating whether or not the customer cancelled their service.

Pandas didn’t detect all of the values in the ‘TotalCharges’ column to be float64 type, so we probably have some non-numeric data in the column.

Looks like the blank ‘TotalCharges’ values were for customers with 0 months tenure. We will adjust these values to $0.

Next we’ll convert the categorical values into numeric values, so our ML algorithms can process the data. We will also remove the columns not used in the predictive model.

Correlations between customer data features and customer churn

To decide which features of the data to include in our predictive churn model, we’ll examine the correlation between churn and each customer feature.

Avoiding multicollinearity

Total charges and monthly charges are highly correlated. We try to avoid strongly correlated explanatory variables in regression models. Correlation of explanatory variables is known as multicollinearity, and perfect multicollinearity occurs when the correlation between two independent variables is equal to 1 or -1.

To more intuitively understand why multicollinearity is a problem for estimating regression coefficients, this post provides the following explanation:

Two people are pushing a boulder up a hill. You want to know how hard each of them is pushing. Suppose that you watch them push together for ten minutes and the boulder moves 10 feet. Did the first guy do all the work and the second just fake it? Or vice versa? Or 50-50? Since both forces are working at the exact same time, you can’t separate the strength of either one separately. All that you can say is that their combined force is 1 foot per minute.

Now imagine that the first guy pushes for a minute himself, then nine minutes with the second guy, and a final minute is just the second guy pushing. Now you can use estimates of forces in the first and last minutes to figure out each person’s force separately. Even though they are still largely working at the same time, the fact that there is a bit of difference lets you get estimates of the force for each.

If you saw each man pushing independently for a full ten minutes, that would give you more precise estimates of the forces than if there is a large overlap in the forces.

To avoid unstable estimates of coeffiecients in our models, we will drop the ‘TotalCharges’ variable, as it is highly correlated to both ‘Tenure’ and ‘MonthlyCharges’.

Predictive modelling

We will consider several different models to predict customer churn. To ensure we are not over-fitting to our data, we will split the 7,043 customer records into a training and test set, with the test set being 25% of the total records.

Logistic regression

Logistic regression is one of the more basic classification algorithms in a data scientist’s toolkit. It is used to predict a category or group based on an observation. Logistic regression is usually used for binary classification (1 or 0, win or lose, true or false). The output of logistic regression is a probability, which will always be a value between 0 and 1. While the output value does not give a classification directly, we can choose a cutoff value so that inputs with with probability greater than the cutoff belong to one class, and those with less than the cutoff belong to the other.

For example, if the classifier predicts a probability of customer attrition being 70%, and our cutoff value is 50%, then we predict that the customer will churn. Similarly, if the model outputs a 30% chance of attrition for a customer, then we predict that the customer won’t churn. See here for a more detailed look into the mathematics and assumptions of logistic regression.

We got 81% classification accuracy from our logistic regression classifier. But the precision and recall for predictions in the positive class (churn) are relatively low, which suggests our data set may be imbalanced.

How to handle imbalanced classes

It is also important to look at the distribution of how many customers churn. If 95% of customers don’t churn, we can achieve 95% accuracy by building a model that simply predicts that all customers won’t churn. But this isn’t a very useful model, because it will never tell us when a customer will churn, which is what we are really interested in.

The class for churn is only around 25% of the total population of samples. There is a real risk that a model trained on this data may only make too many predictions in favour of the majority class. There are a number of techniques for handling imbalanced classes:

Up-sampling the minority class

To balance the data set, we can randomly duplicate observations from the minority class. This is also known as re sampling with replacement:

Now that we have a 1:1 ratio for our classes, let’s train another logistic regression model:

The overall accuracy of the model has decreased, but the precision and recall scores for predicting a churn have improved. There are a number of other ways to deal with imbalanced classes, including:

Down-sampling the majority class

Similar to the above method, we reduce the number of samples in the majority class to be equal to the number of samples in the minority class.

Using a different performance metric

Area Under ROC Curve (AUROC) represents the likelihood of a model distinguishing observations between two classes. In very simple terms, AUROC gives a single measure of how a model’s true positive rate and false positive rate change with different threshold values. The closer a model’s AUROC score is to 1, the better it is. To calculate AUROC, we need the predicted class probabilities:

Interestingly, the AUROC scores are very similar between the two models. Both are above 0.5 however, suggesting that both models have the ability to distiguish between observations from each class.

Tree-based algorithms

Using tree-based algorithms such as decision trees or random forests can result in good models for unbalanced datasets. If the minority class exists in one area of the feature space, a tree will be able to separate the class into a single node. For example, if 99% of customers who stream movies tend to churn, then a tree-based algorithm will likely pick this up. We will look at the results for two of these algorithms in the next section.

Decision Trees

A decision tree is a supervised learning method that makes a prediction by learning simple decision rules from the explanatory variables. Decision trees have the following advantages:

  • Trees can be visualised, which makes them easy to interpret
  • They can handle numerical and categorical data
  • We can easily validate the model using statistical tests

The downsides to decision trees:

  • Decision trees are very prone to overfitting the training data, and often do not generalise well
  • Small variations in the training data can cause a completely different tree to be generated
  • Decision tree learning algorithms are based on heuristic algorithms like the greedy algorithm, which make locally optimal decisions at each node. These algorithms cannot guarantee a globally optimal decision tree

Despite their downsides, decision trees can be a good starting point for developing predictive models that generalise better, like random forests.

Click on the image to see the detail in the nodes of the tree

Trimming the tree

We have set the maximum depth of the tree to 4 in the above example. The other variable controlling the size of the tree is ‘min_samples_leaf’, which specifies the minimum number of samples required to split an internal node. The default depth and minimum samples per leaf are set to unlimited, which leads to fully grown and unpruned trees, like this:

The fully unpruned tree was actually too large to render on a single PDF page

An unpruned tree is effectively trying to sort every training example ‘perfectly’ into its own leaf. We will get very good ‘accuracy’ when testing against the training set, but it is likely that the model is over fitted. Let’s see what kind of accuracy each of the trees get on test and training sets:

Exactly as we suspected! The unpruned tree gets a perfect score on the training set, but a relatively lower score (73%) on the test set. Our pruned tree is less accurate on the training set, but performs better when presented with the out-of-sample test data.

Random forests

Random forests are an ensemble learning method, where the results from multiple decision trees are combined to make a final prediction. For example, a random forest may be made up of 10 decision trees, 7 of which make a prediction for ‘churn’ and 3 of which make a prediction for ‘no churn’. The final prediction for the forest will be ‘churn’.

Tree ensembles have become very popular due to their impressive performance on many real world problems.

Looks like we get similar performance to our pruned decision tree with a random forest. The next step would be to run several more rounds of cross validation using different training and testing sets to measure the performance of each of the models. We would then average the results from all rounds of cross validation to estimate the accuracy of each machine learning model.


Predicting customer churn with machine learning presents many interesting challenges. Building the best predictive model means having a good understanding of the underlying data. Different models can be implemented and tested relatively quickly using the Python sklearn package.