The problem is simple. How to find the best threshold from an ROC and PR curve that maximise a certain binary classification metric?
To make it clearer, let’s take the approach that is commonly used in scikit-learn. Scikit-learn provides a module called
metrics that consists of some evaluation metrics, such as
precision_recall_curve, and so forth.
To get the ROC and PR curve, we can apply
precision_recall_curve respectively. Both methods require two main parameters, such as the true label and the prediction probability. Take a look at the following code snippet.
from sklearn.metrics import roc_curve, precision_recall_curve fpr, tpr, thresholds = roc_curve(true_label, pred_proba) precision, recall, thresholds = precision_recall_curve(true_label, pred_proba)
As you can see, both methods return the parameters needed for the curve creation. Based on the return values, how would one determine the best threshold that maximize a certain evaluation metric? In this post, I’m going to use accuracy as the metric. I think the approach could be adjusted in accordance with the formula of other metrics.
The solution should be simple.
The following shows one of the threshold finding approaches for ROC curve.
def get_metric_and_best_threshold_from_roc_curve( tpr, fpr, thresholds, num_pos_class, num_neg_class ): tp = tpr * num_pos_class tn = (1 - fpr) * num_neg_class acc = (tp + tn) / (num_pos_class + num_neg_class) best_threshold = thresholds[np.argmax(acc)] return np.amax(acc), best_threshold
Below are the steps used to achieve the above formula. Starts from the below basic formulas.
TPR = recall = TP / TP + FN FPR = FP / FP + TN Accuracy = TP + TN / TP + TN + FP + FN
The denominator of the accuracy should be the total number of data. So it’s simply
num_pos_class + num_neg_class.
Now, we only need to find the value of
TPR formula, we can easily find
TP = TPR * (TP + FN). Since
TP + FN equals to the number of positive class, then the formula becomes
TP = TPR * num_pos_class.
Finding the value of
TN is more challenging. We might apply the following approach:
TN = (FP - (FPR * FP)) / FPR
However, the above approach requires us to know the value of
FP which is pretty troublesome to calculate.
Alternatively, we could use its neighbour called
TNR which equals to
TNR = TN / (FP + TN). In such a case, we can calculate
TNR * (FP + TN) or
TNR * num_neg_class. To make the formula complete, just replace
1 - FPR.
Knowing all the required parameters, we can now compute the accuracy by
accuracy = (TP + TN) / (num_pos_class + num_neg_class).
One thing to note here is that all the variables are in the form of numpy array. Please check the documentation for more details.
To find the best threshold that maximises accuracy, we just need to find the index of the maximum accuracy, then use that index to locate the corresponding threshold.
One curve down. Next, the precision-recall curve.
PR curve consists of precision and recall as its parameters. Below is the code I used to search for the best threshold.
def get_metric_and_best_threshold_from_pr_curve( precision, recall, thresholds, num_pos_class, num_neg_class ): tp = recall * num_pos_class fp = (tp / precision) - tp tn = num_neg_class - fp acc = (tp + tn) / (num_pos_class + num_neg_class) best_threshold = thresholds[np.argmax(acc)] return np.amax(acc), best_threshold
Only needs a few step to reach the final formula. Here’s how it’s built.
Using the following basic formulas,
Precision = TP / TP + FP Recall = TP / TP + FN Accuracy = TP + TN / TP + TN + FP + TN
TP is easy. We can just use the recall’s formula, which results in
TP = recall * (TP + FN) or
TP = recall * num_pos_class.
TN, we need to apply a bit tricks since the above precision and recall formulas don’t include
TN basically equals to
num_neg_class - FP. In this case, we just need to find
FP. It should become easier since
FP can be calculated by leveraging the precision formula. So it becomes
FP = (TP / precision) - TP.
We can then use the calculated variables to find the accuracy. The best threshold is generated using the previous approach.