CAMModel.py 5.76 KB
Newer Older
sjjsmuel's avatar
sjjsmuel committed
1 2 3 4 5 6
from tensorflow.keras import Model
from tensorflow import GradientTape, cast, reduce_mean, reduce_sum, multiply, newaxis, reshape, transpose, squeeze
from tensorflow.image import resize
from tensorflow.math import multiply, reduce_min, reduce_max, divide, add, l2_normalize
import tensorflow as tf
from tensorflow.keras.losses import categorical_crossentropy
sjjsmuel's avatar
sjjsmuel committed
7
from tensorflow.keras.metrics import CategoricalAccuracy, Mean
sjjsmuel's avatar
sjjsmuel committed
8 9 10 11 12 13 14 15 16 17 18 19 20

'''
    Based on default example from Keras Docs.
    https://keras.io/guides/customizing_what_happens_in_fit/
'''

def custom_loss(y, y_pred, cam_loss):
    def sub_loss(y, y_pred):
        loss = categorical_crossentropy(y, y_pred) + cam_loss
        return loss

    return sub_loss(y, y_pred)

sjjsmuel's avatar
sjjsmuel committed
21
metric_tracker = CategoricalAccuracy()
sjjsmuel's avatar
sjjsmuel committed
22 23 24
loss_tracker = Mean(name='loss')

class CAMModel(Model):
sjjsmuel's avatar
sjjsmuel committed
25
    class_index_dict = {'caries': 0, 'no_caries': 1} # hardcoded classes
sjjsmuel's avatar
sjjsmuel committed
26 27 28 29 30 31 32 33

    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data
        img = x['img']
        mouth_filter = x['mouth']

34
        with GradientTape(persistent=True) as tape:
sjjsmuel's avatar
sjjsmuel committed
35
            y_pred, conv_outputs = self(img, training=True)  # Forward pass
sjjsmuel's avatar
sjjsmuel committed
36
            # Compute the loss for the class_indes
sjjsmuel's avatar
sjjsmuel committed
37 38
            loss_caries = y_pred[:, self.class_index_dict['caries']]
            loss_no_caries = y_pred[:, self.class_index_dict['no_caries']]
sjjsmuel's avatar
sjjsmuel committed
39 40

        #compute CAM grads
sjjsmuel's avatar
sjjsmuel committed
41
        cam_gradients_caries = tape.gradient(loss_caries, conv_outputs)
42 43
        cam_gradients_no_caries = tape.gradient(loss_no_caries, conv_outputs)
        del tape
sjjsmuel's avatar
sjjsmuel committed
44 45

        # compute the guided gradients
sjjsmuel's avatar
sjjsmuel committed
46 47
        cast_conv_outputs = cast(conv_outputs > 0, "float32")
        cast_grads_caries = cast(cam_gradients_caries > 0, "float32")
48
        cast_grads_no_caries = cast(cam_gradients_no_caries > 0, "float32")
sjjsmuel's avatar
sjjsmuel committed
49
        guided_grads_caries = cast_conv_outputs * cast_grads_caries * cam_gradients_caries
50
        guided_grads_no_caries = cast_conv_outputs * cast_grads_no_caries * cam_gradients_no_caries
sjjsmuel's avatar
sjjsmuel committed
51 52

        #save the shape of the convolution to reshape later
sjjsmuel's avatar
sjjsmuel committed
53
        conv_shape = conv_outputs.shape[1:]
sjjsmuel's avatar
sjjsmuel committed
54 55

        # compute the average of the gradient values, and using them as weights
sjjsmuel's avatar
sjjsmuel committed
56
        weights_caries = reduce_mean(guided_grads_caries, axis=(1, 2))
57
        weights_no_caries = reduce_mean(guided_grads_no_caries, axis=(1, 2))
sjjsmuel's avatar
sjjsmuel committed
58 59

        #flaten out the batch to the filter count dimension
sjjsmuel's avatar
sjjsmuel committed
60 61 62 63
        conv_outputs = transpose(conv_outputs, [0,3, 1,2])
        conv_outputs = reshape(conv_outputs, [-1,conv_shape[0], conv_shape[1]])
        conv_outputs = transpose(conv_outputs, [1,2,0])
        weights_caries = reshape(weights_caries, [-1, ])
64
        weights_no_caries = reshape(weights_no_caries, [-1, ])
sjjsmuel's avatar
sjjsmuel committed
65
        cam_caries = multiply(weights_caries, conv_outputs)
66
        cam_no_caries = multiply(weights_no_caries, conv_outputs)
sjjsmuel's avatar
sjjsmuel committed
67

sjjsmuel's avatar
sjjsmuel committed
68
        #rebatch
sjjsmuel's avatar
sjjsmuel committed
69
        cam_caries = reshape(cam_caries, [conv_shape[0],conv_shape[1],conv_shape[2], -1])
70
        cam_no_caries = reshape(cam_no_caries, [conv_shape[0],conv_shape[1],conv_shape[2], -1])
sjjsmuel's avatar
sjjsmuel committed
71
        cam_caries = reduce_sum(cam_caries, axis=-2)
72
        cam_no_caries = reduce_sum(cam_no_caries, axis=-2)
sjjsmuel's avatar
sjjsmuel committed
73
        cam_caries = transpose(cam_caries, [2,0,1])
74
        cam_no_caries = transpose(cam_no_caries, [2,0,1])
sjjsmuel's avatar
sjjsmuel committed
75 76

        #ad axis for using the tf.image.resize function
sjjsmuel's avatar
sjjsmuel committed
77
        cam_caries = cam_caries[..., newaxis]
78
        cam_no_caries = cam_no_caries[..., newaxis]
sjjsmuel's avatar
sjjsmuel committed
79
        heatmap_caries = resize(cam_caries, [img.shape[2], img.shape[1]])
80
        heatmap_no_caries = resize(cam_no_caries, [img.shape[2], img.shape[1]])
sjjsmuel's avatar
sjjsmuel committed
81
        #remove now unnecessary axis
sjjsmuel's avatar
sjjsmuel committed
82
        heatmap_caries = squeeze(heatmap_caries)
83
        heatmap_no_caries = squeeze(heatmap_no_caries)
sjjsmuel's avatar
sjjsmuel committed
84 85 86 87

        #spread the values between 0 and 1 for caries
        numer = heatmap_caries - reduce_min(heatmap_caries)
        denom = reduce_max(heatmap_caries) - reduce_min(heatmap_caries)
sjjsmuel's avatar
sjjsmuel committed
88
        if not denom <= 0:
sjjsmuel's avatar
sjjsmuel committed
89
            heatmap_caries = divide(numer, denom)
sjjsmuel's avatar
sjjsmuel committed
90

91 92 93 94 95 96
        # spread the values between 0 and 1 for no_caries
        numer = heatmap_no_caries - reduce_min(heatmap_no_caries)
        denom = reduce_max(heatmap_no_caries) - reduce_min(heatmap_no_caries)
        if not denom <= 0:
            heatmap_no_caries = divide(numer, denom)

sjjsmuel's avatar
sjjsmuel committed
97
        heatmap_caries = multiply(heatmap_caries, mouth_filter)
98
        heatmap_no_caries = multiply(heatmap_no_caries, mouth_filter)
sjjsmuel's avatar
sjjsmuel committed
99

sjjsmuel's avatar
sjjsmuel committed
100
        loss_addition_caries = reduce_mean(heatmap_caries)
101 102
        loss_addition_no_caries = reduce_mean(heatmap_no_caries)
        loss_cam = tf.divide(tf.add(loss_addition_caries, loss_addition_no_caries), 2)
sjjsmuel's avatar
sjjsmuel committed
103 104 105 106

        with GradientTape() as tape:
            y_pred, conv_out = self(img, training=True)  # Forward pass
            # Compute the loss value
sjjsmuel's avatar
sjjsmuel committed
107
            loss = custom_loss(y, y_pred, loss_cam)
sjjsmuel's avatar
sjjsmuel committed
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        loss_tracker.update_state(loss)
        metric_tracker.update_state(y, y_pred)
        return {'loss': loss_tracker.result(), 'accuracy': metric_tracker.result()}

    def test_step(self, data):
        # Unpack the data
        x, y = data
sjjsmuel's avatar
sjjsmuel committed
123 124
        if type(x) == dict:
            x = x['img']
sjjsmuel's avatar
sjjsmuel committed
125 126 127 128 129 130 131 132 133
        # Compute predictions and skip the convolution output
        y_pred, _ = self(x, training=False)
        #calculate the loss
        loss = categorical_crossentropy(y, y_pred)
        # Updates the metrics tracking the loss
        loss_tracker.update_state(loss)
        # Update the metrics.
        metric_tracker.update_state(y, y_pred)
        # Return a dict mapping metric names to current value.
sjjsmuel's avatar
sjjsmuel committed
134
        return {'loss': loss_tracker.result(), 'accuracy': metric_tracker.result()}