Multiclass classification in machine learning
This article was originally published at Algorithimia’s website. The company was acquired by DataRobot in 2021. This article may not be entirely up-to-date or refer to products and offerings no longer in existence. Find out more about DataRobot MLOps here.
Outside of regression, multi-class classification is probably the most common machine learning task. Learn how to properly use it in building machine learning models.
What is multiclass classification?
Whether it’s spelled multi-class or multiclass, the science is the same. Multiclass classification is a machine learning classification task that consists of more than two classes, or outputs. For example, using a model to identify animal types in images from an encyclopedia is a multiclass classification example because there are many different animal classifications that each image can be classified as. Multiclass classification also requires that a sample only have one class (ie. an elephant is only an elephant; it is not also a lemur).
Outside of regression, multiclass classification is probably the most common machine learning task. In classification, we are presented with a number of training examples divided into K separate classes, and we build a machine learning model to predict which of those classes some previously unseen data belongs to (ie. the animal types from the previous example). In seeing the training dataset, the model learns patterns specific to each class and uses those patterns to predict the membership of future data.
For instance, images of cats may all follow a pattern of pointed ears and whiskers, helping the model to identify future images of cats as compared to other animals without whiskers or pointed ears.
Multiclass classification use cases
For example, a cybersecurity company might want to be able to monitor a user’s email inbox and classify incoming emails as either potential phishers or not. To do so, it might train a classification model on the email texts and inbound email addresses of previous phishing scams to teach the model to predict which URLs tend to accompany threatening emails.
As another example, a marketing company might serve an online ad and want to predict whether a given customer will click on it. (This is a binary classification problem.)
Multiclass classification algorithm models are just one of the many examples of the importance of machine learning.
How classification machine learning works
Hundreds of models exist for classification. In fact, it’s often possible to take a model that works for regression and make it into a classification model. This is basically how logistic regression works. We model a linear response WX + b to an input and turn it into a probability value between 0 and 1 by feeding that response into a sigmoid function. We then predict that an input belongs to class 0 if the model outputs a probability greater than 0.5 and belongs to class 1 otherwise.
Can SVM do multiclass classification?
Another common model for classification is the support vector machine (SVM). An SVM works by projecting the data into a higher dimensional space and separating it into different classes by using a single (or set of) hyperplanes. A single SVM does binary classification and can differentiate between two classes. In order to differentiate between K classes, one can use (K – 1) SVMs. Each one would predict membership in one of the K classes.
Which model is used for multiclass classification algorithms?
Within the realm of natural language processing and text multiclass classification, the Naive Bayes model is quite popular. Its popularity in large part arises from the fact of how simple it is and how quickly it trains. In the Naive Bayes classifier, we use Bayes’ Theorem to break down the joint probability of membership in a class into a series of conditional probabilities.
The model makes the naive assumption (hence Naive Bayes) that all the input features to the model are mutually independent. While this isn’t true, it’s often a good enough approximation to get the results we want. The probability of class membership then breaks down into a product of probabilities, and we just classify an input X as class k if k maximizes this product. To learn more, check out What is Bayesian machine learning?
Deep learning multiclass classification examples
There also exist plenty of deep learning models for classification. Almost any neural network can be made into a classifier by simply tacking a softmax function onto the last layer. The softmax function creates a probability distribution over K classes, and produces an output vector of length K. Each element of the vector is the probability that the input belongs to the corresponding class. The most likely class is chosen by selecting the index of that vector having the highest probability.
While many neural network architectures can be used, some work better than others. Convolutional Neural Networks (CNNs) typically fare very well on multiclass classification tasks, especially for images and text. A CNN extracts useful features from data, particularly ones that are invariant to scaling, transformation, and rotation. This helps it detect images that may be rotated, shrunken, or off-center, allowing it to achieve higher accuracy in image multiclass classification tasks.
While nearly all typical classification models are supervised, you can think of unsupervised classification as a clustering problem. In this setting, we want to assign data into one of K groups without having labeled examples ahead of time (just as in unsupervised learning). Classic clustering algorithms such as k-means, k-medoids, or hierarchical clustering perform well at this task.