Is My Model Trained Well? Model Training Analysis

Machine learning is a combination of models which if applied in a correct setting produces unexpected results. On the contrary, if not selected wisely may result in wrong predictions or outputs. So, the cycle of developing highly accurate machine learning models does not stop with training and fitting the dataset to the selected models. We have to analyze a trained model on different parameters to verify its validity of predictions. There are multiple parameters graphs available in python which can be used to understand a trained model. Python has a special library Scikit-plot which has packages and functions which can be used to create graphs for understanding the performance of a trained model in a few lines of code. In this blog, we are going to discuss some of the most commonly used functions from this library.


Every data scientist knows that visualization is one of the important components of the data science process and it should not be ignored. However, while training a model most of the data scientists are more concerned with the model accuracy, hyperparameter tuning, or getting the desired outputs. In this setup, a library like Scikit-plot helps a lot by providing packages to create informative graphs using a few lines of code. Lets, see how to use this library for creating important machine learning graphs.


The latest version of Scikit-plot can be downloaded using PyPi by the following command.

pip install scikitplot
view raw hosted with ❤ by GitHub

Pre-requisite to Scikit-plot

Before creating graphs using the Scikit-plot, required libraries need to be imported. Scikit-plot has been created using matplotlib so we need to import that also. Below code imports most of the libraries we required while using Scikit-plot in our experiment.

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import scikitplot as skplt
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_predict
from sklearn.datasets import load_breast_cancer
from sklearn.datasets import load_iris
from sklearn.datasets import load_digits
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB
from sklearn.svm import LinearSVC
view raw hosted with ❤ by GitHub

Confusion Matrix

One of the most important outputs after training a model is its confusion matrix which can be used to interpret multiple results like True Positive Rate, False Positive Rate, Precision, Recall, or F1 measure. Normally we can output a simple matrix, however, Scikit-plot provides a beautiful heatmap of the confusion matrix. This confusion heatmap will be able to represent the high values in the confusion matrix with dark colors so that we can not accidentally ignore an important output.

Let’s apply RandomForestClassifier on digits dataset available in sklearn.datasets then make predictions using cross_val_predict with default parameters, below is the code.

X, y= load_digits(return_X_y=True)
random_forest_clf = RandomForestClassifier(n_estimators=5, max_depth=5, random_state=1)
predictions = cross_val_predict(random_forest_clf, X, y)
view raw hosted with ❤ by GitHub

Now let’s plot the confusion matrix using the Scikit-plot, below is the code and the output image. It shows darker cells on the diagonal of the confusion matrix which is desired as most of the predictions are similar to the actual class. Further details related to the confusion matrix can be studied from various blogs or textbooks an example is This Blog on

plt.rcParams['figure.figsize'] = 10,10
skplt.metrics.plot_confusion_matrix(y, predictions, normalize=True)
plt.savefig("confusion_matrix.png", dpi=300)
view raw hosted with ❤ by GitHub

ROC Curve

This is a curve between the True positive rate and the false-positive rate. This graph is very important when the dataset has a class imbalance problem, however, it should be plotted with normal datasets also. Basically, the area under the ROC curve is the main parameter to observe, the higher the area better the model. ROC curve is generated for each output class of the labeled dataset. Scikit-plot plots a ROC curve with micro-average and macro average along with all output classes with a single line of code written below. The ROC curve generated from this code is also below. Details about the ROC curve can be studied from This Blog.

plt.rcParams['figure.figsize'] = 10,6
# For ROC we need to find prediction proabilities
plt.legend(bbox_to_anchor=(1, 1), loc=2)
plt.savefig("roc_curve.png", dpi=300)
view raw hosted with ❤ by GitHub

Precision-Recall Curve

The precision-recall curve is used to test the accuracy of your model for predicting the positive classes. Scikit-plot provides a function to create the precision-recall curve in a single line of code. More details about the precision-recall curve can be studied from this blog. The code of Scikit-plot and its output is below.

plt.rcParams['figure.figsize'] = 10,6
skplt.metrics.plot_precision_recall(y_test, y_prob)
plt.legend(bbox_to_anchor=(1, 1), loc=2)
plt.savefig('p_r.png', dpi=300)
view raw hosted with ❤ by GitHub

Model Calibration (Calibration Curve)

Whenever a model is giving us an accuracy value, it does not reflect the level of confidence the model has while predicting the labeled outputs. In a mission-critical applications model confidence in predicting an output class should also be considered for comparing multiple machine learning models. For example, if a trained model predicts that a person does not have a medical condition (say Corona) with the confidence of only .60. In this case, this person should go for further tests as .60 is not a good level of confidence for predicting the negative class. Calibration curves can help to visualize the mean predicted value of classes along with fractions of positives. The below code of Scikit-plt provides function to generate the calibration curve for a synthetic generated dataset for four different machine learning algorithms and presents their calibration values along with the perfectly calibrated curve. Details about model calibrations can be further studied from This Blog.

from sklearn import datasets
# Generate a random classification dataset with 20 features
X, y = datasets.make_classification(n_samples=100000, n_features=20, n_informative=7, n_redundant=10, random_state=42)
#Generate Training and Testing spilt
X_train, X_test, y_train, y_test=train_test_split(X,y, test_size=.50, random_state=42)
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB
from sklearn.svm import LinearSVC
rf= RandomForestClassifier()
names=['Random Forest','Logistic Regression','GaussianNB','svm']
skplt.metrics.plot_calibration_curve(y_test, proba_list, names)
plt.savefig('calibration_curve.png', dpi=300)
view raw hosted with ❤ by GitHub

Learning Curve

A learning curve is a curve that is used to compare the training and testing accuracy of a model on the training and validation dataset. It is plotted based on the number of training examples used for the calculation of these values. The Scikit-plot provides a function that can help you to plot this curve, you can find the code below and its output. This graph can also let you know about the Bias-Variance Tradeoff.

skplt.estimators.plot_learning_curve(rf, X, y)
view raw hosted with ❤ by GitHub

Feature Importance

Reducing the number of features in a dataset is one of the primary tasks of data preprocessing. It can contribute a lot towards the success of a machine learning model. PCA and many other techniques are used to create new features which may be a combination of one or more features. However, we can compute the feature importance value of every feature after the training of the model, which can also reflect the contribution of every feature toward the output decision. The Scikit-plot library also has a function to find the feature importance value of each feature and also represent its standard deviation. The below code uses the iris dataset available in the sklearn package and finds the importance of its four features for classifying instances into different classes.

from sklearn.datasets import load_iris
skplt.estimators.plot_feature_importances(rf, feature_names=['petal length', 'petal width','sepal length', 'sepal width'])
plt.savefig('feature_importance.png', dpi=300)
view raw hosted with ❤ by GitHub


Comparison of multiple models and different versions of a model should be compared on multiple parameters so to find which works best for a given dataset. It is also true that a graph is always a better representation as compared to numbers while presenting these results to bosses. A python package named Scikit-plot is available which can help machine learning practitioners to generate these graphs with minimum lines of code.

In this blog, we have seen some of the important functions related to classification activity and their respective code and outputs. You can use these codes after training your model for analysis of training results.

Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s