What is text classification?

Text classification is a common task in natural language processing (NLP) that involves assigning a label to a piece of text based on its content. For example, you might want to classify news articles into different categories, such as politics, sports, or entertainment. Text classification can also be used for sentiment analysis, spam detection, topic modeling, and more.

In a previous post I showed you how to classify texts using TextBlob Naive Bayes classifier and how easy and accurate it is. In this post, I will show you how to use scikit-learn, a popular Python library for machine learning, to perform text classification using four different algorithms: support vector machines (SVM), random forest, multinomial naive bayes (MNB) and K-nearest neighbor (KNN). I will use the 20 Newsgroups dataset, which contains about 20,000 newsgroup posts from 20 different topics. I will also compare the performance of the four algorithms and discuss their advantages and disadvantages.

The best text classification algorithm

Before goin into the code, let’s see which algorithm is the best to choose for text classification. Choosing the best algorithm for text classification can be a challenging task, as different algorithms may perform differently depending on the characteristics of the dataset, the size of the data, the complexity of the classification task, and other factors.

Here are some general guidelines that can help you choose the best algorithm for your text classification task:

  1. Start with a simple baseline model: It is always a good idea to start with a simple baseline model, such as a Naive Bayes classifier or a linear SVM classifier, to establish a baseline performance level. This can help you determine whether more complex models are needed.
  2. Experiment with different algorithms: Try different algorithms, such as SVM, Naive Bayes, random forest, or K-nearest neighbor to see which ones perform better on your data.
  3. Consider the size and complexity of the data: Some algorithms may perform better on small datasets with simple features, while others may perform better on large datasets with complex features. For example, as you’ll see in this post, linear SVM and Naive Bayes classifiers are known to work well on small and medium-sized datasets with simple features, while neural networks are often used for large datasets with complex features.
  4. Look at the performance metrics: Evaluate the performance of the models using appropriate metrics such as accuracy, precision, recall, and F1 score. Choose the algorithm that gives the best performance on the metrics that are most important for your task.

In general, it is a good practice to try different algorithms and compare their performance to choose the best one for your text classification task. Here I have chosen SVM, Naive Bayes, and Random Forest to see which works best for my dataset.

SVM, Naive bayes, random forest, k-nearest neighbor to classify text

Before moving on to the code, let’s see how these different algorithms work:

Support Vector Machines

Support Vector Machines or SVM is a linear classifier that tries to find the best hyperplane that separates the data points into different classes. we can think of SVM as trying to find the best line or plane that separates the different classes in a dataset. For example, if we have a dataset of images of dogs and cats, SVM tries to find the best line that separates the dog images from the cat images. Once the line is found, SVM can then use it to predict the class of new images.

Multinomial Naive Bayes

MNB is a probabilistic classifier that applies Bayes’ theorem and assumes that the features are independent given the class. Naive Bayes works by looking at how often certain characteristics (such as size and color) appear in each category (apples and oranges). It then uses this information to make a guess about which category a new fruit belongs to.

Random forest

Random forest is an ensemble method that builds multiple decision trees and combines their predictions. Random forest creates many different decision trees by randomly selecting a subset of the data and a subset of the characteristics (such as size and color) to use in each tree. Each tree makes a prediction about which category the data belongs to. Then, the algorithm combines the predictions of all the trees to make a final prediction. This final prediction is often more accurate than the prediction of any single decision tree, because it takes into account the predictions of many different trees.

K-nearest neighbor or KNN

The KNN algorithm works by finding the K closest data points in the training set to the new data point and assigning the class or value of the new data point based on the majority vote or the average of the K nearest neighbors.

Now, let’s code!

First let’s install the required libraries:

pip install scikit-learn

Before we start, we need to import some libraries and load the dataset. We will use the fetch_20newsgroups function from scikit-learn to download the data and split it into training and testing sets. We will also use the TfidfVectorizer class to transform the raw text into numerical features that represent the term frequency-inverse document frequency (TF-IDF) of each word in each document. TF-IDF is a common way to measure how important a word is in a document relative to the whole corpus.

# Import libraries

from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.svm import LinearSVC
from sklearn.naive_bayes import MultinomialNB
from sklearn.ensemble import RandomForestClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

# Load the dataset and split into train and test sets

categories = ['alt.atheism', 'comp.graphics', 'rec.sport.baseball', 'sci.med']
newsgroups = fetch_20newsgroups(subset='all', categories=categories, shuffle=True, random_state=42)
X_train = newsgroups.data[:3000]
y_train = newsgroups.target[:3000]
X_test = newsgroups.data[3000:]
y_test = newsgroups.target[3000:]

# Transform the text into TF-IDF features

vectorizer = TfidfVectorizer(stop_words='english', max_features=1000)
X_train_tfidf = vectorizer.fit_transform(X_train)
X_test_tfidf = vectorizer.transform(X_test)

Now we are ready to train and evaluate our models. We will use four different algorithms: SVM, MNB, random forest, and k-nearest neighbor.

The code below shows how to do this:

# Train and evaluate SVM

svm = LinearSVC(random_state=42)
svm.fit(X_train_tfidf, y_train)
y_pred_svm = svm.predict(X_test_tfidf)
acc_svm = accuracy_score(y_test, y_pred_svm)
print(f'SVM accuracy: {acc_svm:.2f}')

# Train and evaluate MNB

mnb = MultinomialNB()
mnb.fit(X_train_tfidf, y_train)
y_pred_mnb = mnb.predict(X_test_tfidf)
acc_mnb = accuracy_score(y_test, y_pred_mnb)
print(f'MNB accuracy: {acc_mnb:.2f}')

# Train and evaluate random forest

rf = RandomForestClassifier(random_state=42)
rf.fit(X_train_tfidf, y_train)
y_pred_rf = rf.predict(X_test_tfidf)
acc_rf = accuracy_score(y_test, y_pred_rf)
print(f'Random forest accuracy: {acc_rf:.2f}')

# Train and evaluate KNN

knn = KNeighborsClassifier(n_neighbors=6)
knn.fit(X_train_tfidf, y_train)
y_pred_knn = knn.predict(X_test_tfidf)
acc_knn = accuracy_score(y_test, y_pred_knn)
print(f'KNN accuracy: {acc_knn:.2f}')

Comparing The Accuracy

The output of the code is:

SVM accuracy: 0.94
MNB accuracy: 0.93
Random forest accuracy: 0.92
KNN accuracy: 0.89


We can see that SVM has the highest accuracy among the three algorithms, followed by MNB and random forest and k-nearest neighbor (I adjusted the neighbors to 6). This suggests that for our small-medium size text dataset, SVM works better. However, accuracy is not the only metric we should look at when evaluating a classifier. We should also consider other aspects such as precision, recall, f1-score, and confusion matrix.

Similar Posts