from tensorflow.keras import Model
from tensorflow import GradientTape, cast, reduce_mean, reduce_sum, multiply, newaxis, reshape, transpose, squeeze
from tensorflow.image import resize
from tensorflow.math import multiply, reduce_min, reduce_max, divide, add, l2_normalize
import tensorflow as tf
from tensorflow.keras.losses import categorical_crossentropy
from tensorflow.keras.metrics import CategoricalAccuracy, Mean
'''
Based on default example from Keras Docs.
https://keras.io/guides/customizing_what_happens_in_fit/
'''
def custom_loss(y, y_pred, cam_loss):
def sub_loss(y, y_pred):
loss = categorical_crossentropy(y, y_pred) + cam_loss
return loss
return sub_loss(y, y_pred)
metric_tracker = CategoricalAccuracy()
loss_tracker = Mean(name='loss')
class CAMModel(Model):
class_index_dict = {'caries': 0, 'no_caries': 1} # hardcoded classes
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
x, y = data
img = x['img']
mouth_filter = x['mouth']
with GradientTape(persistent=True) as tape:
y_pred, conv_outputs = self(img, training=True) # Forward pass
# Compute the loss for the class_indes
loss_caries = y_pred[:, self.class_index_dict['caries']]
loss_no_caries = y_pred[:, self.class_index_dict['no_caries']]
#compute CAM grads
cam_gradients_caries = tape.gradient(loss_caries, conv_outputs)
cam_gradients_no_caries = tape.gradient(loss_no_caries, conv_outputs)
del tape
# compute the guided gradients
cast_conv_outputs = cast(conv_outputs > 0, "float32")
cast_grads_caries = cast(cam_gradients_caries > 0, "float32")
cast_grads_no_caries = cast(cam_gradients_no_caries > 0, "float32")
guided_grads_caries = cast_conv_outputs * cast_grads_caries * cam_gradients_caries
guided_grads_no_caries = cast_conv_outputs * cast_grads_no_caries * cam_gradients_no_caries
#save the shape of the convolution to reshape later
conv_shape = conv_outputs.shape[1:]
# compute the average of the gradient values, and using them as weights
weights_caries = reduce_mean(guided_grads_caries, axis=(1, 2))
weights_no_caries = reduce_mean(guided_grads_no_caries, axis=(1, 2))
#flaten out the batch to the filter count dimension
conv_outputs = transpose(conv_outputs, [0,3, 1,2])
conv_outputs = reshape(conv_outputs, [-1,conv_shape[0], conv_shape[1]])
conv_outputs = transpose(conv_outputs, [1,2,0])
weights_caries = reshape(weights_caries, [-1, ])
weights_no_caries = reshape(weights_no_caries, [-1, ])
cam_caries = multiply(weights_caries, conv_outputs)
cam_no_caries = multiply(weights_no_caries, conv_outputs)
#rebatch
cam_caries = reshape(cam_caries, [conv_shape[0],conv_shape[1],conv_shape[2], -1])
cam_no_caries = reshape(cam_no_caries, [conv_shape[0],conv_shape[1],conv_shape[2], -1])
cam_caries = reduce_sum(cam_caries, axis=-2)
cam_no_caries = reduce_sum(cam_no_caries, axis=-2)
cam_caries = transpose(cam_caries, [2,0,1])
cam_no_caries = transpose(cam_no_caries, [2,0,1])
#ad axis for using the tf.image.resize function
cam_caries = cam_caries[..., newaxis]
cam_no_caries = cam_no_caries[..., newaxis]
heatmap_caries = resize(cam_caries, [img.shape[2], img.shape[1]])
heatmap_no_caries = resize(cam_no_caries, [img.shape[2], img.shape[1]])
#remove now unnecessary axis
heatmap_caries = squeeze(heatmap_caries)
heatmap_no_caries = squeeze(heatmap_no_caries)
#spread the values between 0 and 1 for caries
numer = heatmap_caries - reduce_min(heatmap_caries)
denom = reduce_max(heatmap_caries) - reduce_min(heatmap_caries)
if not denom <= 0:
heatmap_caries = divide(numer, denom)
# spread the values between 0 and 1 for no_caries
numer = heatmap_no_caries - reduce_min(heatmap_no_caries)
denom = reduce_max(heatmap_no_caries) - reduce_min(heatmap_no_caries)
if not denom <= 0:
heatmap_no_caries = divide(numer, denom)
heatmap_caries = multiply(heatmap_caries, mouth_filter)
heatmap_no_caries = multiply(heatmap_no_caries, mouth_filter)
loss_addition_caries = reduce_mean(heatmap_caries)
loss_addition_no_caries = reduce_mean(heatmap_no_caries)
loss_cam = tf.divide(tf.add(loss_addition_caries, loss_addition_no_caries), 2)
with GradientTape() as tape:
y_pred, conv_out = self(img, training=True) # Forward pass
# Compute the loss value
loss = custom_loss(y, y_pred, loss_cam)
# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
loss_tracker.update_state(loss)
metric_tracker.update_state(y, y_pred)
return {'loss': loss_tracker.result(), 'accuracy': metric_tracker.result()}
def test_step(self, data):
# Unpack the data
x, y = data
if type(x) == dict:
x = x['img']
# Compute predictions and skip the convolution output
y_pred, _ = self(x, training=False)
#calculate the loss
loss = categorical_crossentropy(y, y_pred)
# Updates the metrics tracking the loss
loss_tracker.update_state(loss)
# Update the metrics.
metric_tracker.update_state(y, y_pred)
# Return a dict mapping metric names to current value.
return {'loss': loss_tracker.result(), 'accuracy': metric_tracker.result()}