Multiclass & Multilabel Classification with XGBoost
XGBoost is already very well known for its performances in various Kaggle competitions and how it has good competition with deep learning algorithms in terms of accuracies and scores.
Although XGBoost is among many solutions in machine learning problems, one could find it less trivial to implement its booster for multiclass or multilabel classification as it’s not directly implemented to the Python API XGBClassifier. With that in mind, I’ll try to mitigate some case studies within this article.
To use XGBoost main module for a multiclass classification problem, it is needed to change the value of two parameters: objective
and num_class
. Let’s see it in practice with the wine dataset.
We’ll start by reading the data from the Scikit Learn dataset API.
Then add columns names to get a proper looking dataset.
Split data for the train and test sets and get our dmatrices.
If you’re not familiar with DMatrix, it is a data interface commonly used as input in XGBoost models, it works well with Pandas dataframes, Numpy arrays, SciPy arrays, CSV files, etc.
Time to set our XGBoost parameters to perform multiclass predictions!
The parameters above mean (from the docs):
max_depth: Maximum depth of a tree. Increasing this value will make the model more complex and more likely to overfit. 0 is only accepted in
lossguided
growing policy when tree_method is set ashist
and it indicates no limit on depth. Beware that XGBoost aggressively consumes memory when training a deep tree.objective:
multi:softmax
: set XGBoost to do multiclass classification using the softmax objective, you also need to set num_class(number of classes)
and num_class
that isn’t featured in more depth in XGBoost’s docs but it means the number of classes you ought to predict (in our case 3).
Now it’s time to train our model and see how it goes.
The result achieved throughout this step-by-step is:
precision recall f1-score support 0 0.94 1.00 0.97 15
1 1.00 0.94 0.97 18
2 1.00 1.00 1.00 12 micro avg 0.98 0.98 0.98 45
macro avg 0.98 0.98 0.98 45
weighted avg 0.98 0.98 0.98 45
And the confusion matrix for that.
which produces the following:
The notebook with all the source code presented above and also another multiclass example using the Anuran Calls (MFCCs) Data Set is saved on my GitHub repo.
Great! Now we have an XGBoost classifier able to predict multiple classes. There is yet another way to accomplish this and it is not the same technique, we can use One vs Rest approach to achieve a multiclass classifier by combining a bunch of binary classifiers, let’s see that with more depth.
OneVsRest
From the SkLearn docs:
One-vs-the-rest (OvR) multiclass/multilabel strategy
Also known as one-vs-all, this strategy consists in fitting one classifier per class. For each classifier, the class is fitted against all the other classes. In addition to its computational efficiency (only n_classes classifiers are needed), one advantage of this approach is its interpretability. Since each class is represented by one and one classifier only, it is possible to gain knowledge about the class by inspecting its corresponding classifier. This is the most commonly used strategy for multiclass classification and is a fair default choice.
This strategy can also be used for multilabel learning, where a classifier is used to predict multiple labels for instance, by fitting on a 2-d matrix in which cell [i, j] is 1 if sample i has label j and 0 otherwise.
In the multilabel learning literature, OvR is also known as the binary relevance method.
Great, so basically we’re going to be fitting one classifier per class, let’s see this different approach. I’ll use a meta code to illustrate how one could use XGBoost for multilabel, this may vary a little from each specific application, but the main idea:
This way you’re going to get multilabel predictions from your classifier, note that MultiLabelBinarizer is very important in order to make the classifier correctly predict between classes. If label transformation is not properly applied, the model could end up not being fitted correctly and thus predicting the same label over and over, which is pretty common to happen if you get any parameter wrong.
Now to validate your model you should be looking at some different metrics. Confusion Matrix is not gonna tell you a lot for multilabel datasets, but metrics like Zero One Loss and Hamming Loss could come handy for you.