Commit 90da0f19 authored by sjjsmuel's avatar sjjsmuel
Browse files

Revert "remove background class from CAM Training"

This reverts commit e46e7c1f.
parent e46e7c1f
......@@ -32,7 +32,7 @@ class CAMModel(Model):
img = x['img']
mouth_filter = x['mouth']
with GradientTape() as tape:
with GradientTape(persistent=True) as tape:
y_pred, conv_outputs = self(img, training=True) # Forward pass
# Compute the loss for the class_indes
loss_caries = y_pred[:, self.class_index_dict['caries']]
......@@ -40,36 +40,48 @@ class CAMModel(Model):
#compute CAM grads
cam_gradients_caries = tape.gradient(loss_caries, conv_outputs)
cam_gradients_no_caries = tape.gradient(loss_no_caries, conv_outputs)
del tape
# compute the guided gradients
cast_conv_outputs = cast(conv_outputs > 0, "float32")
cast_grads_caries = cast(cam_gradients_caries > 0, "float32")
cast_grads_no_caries = cast(cam_gradients_no_caries > 0, "float32")
guided_grads_caries = cast_conv_outputs * cast_grads_caries * cam_gradients_caries
guided_grads_no_caries = cast_conv_outputs * cast_grads_no_caries * cam_gradients_no_caries
#save the shape of the convolution to reshape later
conv_shape = conv_outputs.shape[1:]
# compute the average of the gradient values, and using them as weights
weights_caries = reduce_mean(guided_grads_caries, axis=(1, 2))
weights_no_caries = reduce_mean(guided_grads_no_caries, axis=(1, 2))
#flaten out the batch to the filter count dimension
conv_outputs = transpose(conv_outputs, [0,3, 1,2])
conv_outputs = reshape(conv_outputs, [-1,conv_shape[0], conv_shape[1]])
conv_outputs = transpose(conv_outputs, [1,2,0])
weights_caries = reshape(weights_caries, [-1, ])
weights_no_caries = reshape(weights_no_caries, [-1, ])
cam_caries = multiply(weights_caries, conv_outputs)
cam_no_caries = multiply(weights_no_caries, conv_outputs)
#rebatch
cam_caries = reshape(cam_caries, [conv_shape[0],conv_shape[1],conv_shape[2], -1])
cam_no_caries = reshape(cam_no_caries, [conv_shape[0],conv_shape[1],conv_shape[2], -1])
cam_caries = reduce_sum(cam_caries, axis=-2)
cam_no_caries = reduce_sum(cam_no_caries, axis=-2)
cam_caries = transpose(cam_caries, [2,0,1])
cam_no_caries = transpose(cam_no_caries, [2,0,1])
#ad axis for using the tf.image.resize function
cam_caries = cam_caries[..., newaxis]
cam_no_caries = cam_no_caries[..., newaxis]
heatmap_caries = resize(cam_caries, [img.shape[2], img.shape[1]])
heatmap_no_caries = resize(cam_no_caries, [img.shape[2], img.shape[1]])
#remove now unnecessary axis
heatmap_caries = squeeze(heatmap_caries)
heatmap_no_caries = squeeze(heatmap_no_caries)
#spread the values between 0 and 1 for caries
numer = heatmap_caries - reduce_min(heatmap_caries)
......@@ -77,10 +89,18 @@ class CAMModel(Model):
if not denom <= 0:
heatmap_caries = divide(numer, denom)
# spread the values between 0 and 1 for no_caries
numer = heatmap_no_caries - reduce_min(heatmap_no_caries)
denom = reduce_max(heatmap_no_caries) - reduce_min(heatmap_no_caries)
if not denom <= 0:
heatmap_no_caries = divide(numer, denom)
heatmap_caries = multiply(heatmap_caries, mouth_filter)
heatmap_no_caries = multiply(heatmap_no_caries, mouth_filter)
loss_addition_caries = reduce_mean(heatmap_caries)
loss_cam = loss_addition_caries
loss_addition_no_caries = reduce_mean(heatmap_no_caries)
loss_cam = tf.divide(tf.add(loss_addition_caries, loss_addition_no_caries), 2)
with GradientTape() as tape:
y_pred, conv_out = self(img, training=True) # Forward pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment