Precision and Recall for non-binary classificationBy Eric Antoine Scuccimarra
I am working on classifying mammography scans with a TensorFlow ConvNet. The scans are classified into five classes:
- Benign Calcification
- Malignant Calcification
- Benign Mass
- Malignant Mass
I was unsure of how I wanted to classify the scans so I created the model in such a way that it would work for any combination of classes. I initially started training with binary classification - normal or abnormal, with the goal of then expanding the number of classes once I had a model that made decent predictions on the binary case.
For the binary prediction I used precision, recall and a pr curve as metrics. When I expanded to multiple classes obviously those metrics no longer worked. As far as precision and recall I don't really care what type of abnormal the scan is - I just care that it is abnormal at all. And I wanted to have the same metrics to compare for all my models so I had to figure out a way to do precision and recall for all versions of the model.
The solution I came to was to "squash" my multi-class labels and predictions down into binary labels and predictions and feed those into the p/r metrics. I set up the classes so that 0 was always normal, so I can do the squashing as follows:
zero = tf.constant(0, dtype=tf.int64) collapsed_predictions = tf.greater(predictions, zero) collapsed_labels = tf.greater(y, zero)
Collapsed_predictions and collapsed_labels will then contain True if the prediction or label is NOT 0 and False if it is. Then I can feed these into my precision and recall metrics:
recall, rec_op = tf.metrics.recall(labels=collapsed_labels, predictions=collapsed_predictions) precision, prec_op = tf.metrics.precision(labels=collapsed_labels, predictions=collapsed_predictions)
I also created a pr curve metric to see how the thresholds would affect the predictions. First I convert the logits to probabilities via a softmax and then feed that into a pr_curve_streaming_op as the predictions. In order to make this work with multi-class classification I squash the probabilities down to the probability that the item is NOT normal. Since my labels are created such that normal is always 0, the probability that it is not normal is just 1 - the probability that it is:
probabilities = tf.nn.softmax(logits, name="probabilities")
_, update_op = summary_lib.pr_curve_streaming_op(name='pr_curve', predictions=(1 - probabilities[:, 0]), labels=collapsed_labels, updates_collections=tf.GraphKeys.UPDATE_OPS, num_thresholds=20)