Commit 85f31fcb authored by sjjsmuel's avatar sjjsmuel

change loss and augmentation

add discrimination for no_caries activation outside of the are of the mouth; add rotation back to the augmentation
parent 450a4cbf
......@@ -23,7 +23,7 @@ metric_tracker = CategoricalAccuracy()
loss_tracker = Mean(name='loss')
class CAMModel(Model):
class_index = 0 # hardcoded class caries
class_index_dict = {'caries': 0, 'no_caries': 1} # hardcoded classes
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
......@@ -32,56 +32,80 @@ class CAMModel(Model):
img = x['img']
mouth_filter = x['mouth']
with GradientTape() as tape:
y_pred, convOutputs = self(img, training=True) # Forward pass
with GradientTape(persistent=True) as tape:
y_pred, conv_outputs = self(img, training=True) # Forward pass
# Compute the loss for the class_indes
class_loss = y_pred[:, self.class_index]
loss_caries = y_pred[:, self.class_index_dict['caries']]
loss_no_caries = y_pred[:, self.class_index_dict['no_caries']]
#compute CAM grads
cam_gradients = tape.gradient(class_loss, convOutputs)
cam_gradients_caries = tape.gradient(loss_caries, conv_outputs)
cam_gradients_no_caries = tape.gradient(loss_no_caries, conv_outputs)
del tape
# compute the guided gradients
castConvOutputs = cast(convOutputs > 0, "float32")
castGrads = cast(cam_gradients > 0, "float32")
guidedGrads = castConvOutputs * castGrads * cam_gradients
cast_conv_outputs = cast(conv_outputs > 0, "float32")
cast_grads_caries = cast(cam_gradients_caries > 0, "float32")
cast_grads_no_caries = cast(cam_gradients_no_caries > 0, "float32")
guided_grads_caries = cast_conv_outputs * cast_grads_caries * cam_gradients_caries
guided_grads_no_caries = cast_conv_outputs * cast_grads_no_caries * cam_gradients_no_caries
#save the shape of the convolution to reshape later
conv_shape = convOutputs.shape[1:]
conv_shape = conv_outputs.shape[1:]
# compute the average of the gradient values, and using them as weights
weights = reduce_mean(guidedGrads, axis=(1, 2))
weights_caries = reduce_mean(guided_grads_caries, axis=(1, 2))
weights_no_caries = reduce_mean(guided_grads_no_caries, axis=(1, 2))
#flaten out the batch to the filter count dimension
convOutputs = transpose(convOutputs, [0,3, 1,2])
weights = reshape(weights, [-1,])
convOutputs = reshape(convOutputs, [-1,conv_shape[0], conv_shape[1]])
convOutputs = transpose(convOutputs, [1,2,0])
cam = multiply(weights, convOutputs)
conv_outputs = transpose(conv_outputs, [0,3, 1,2])
conv_outputs = reshape(conv_outputs, [-1,conv_shape[0], conv_shape[1]])
conv_outputs = transpose(conv_outputs, [1,2,0])
weights_caries = reshape(weights_caries, [-1, ])
weights_no_caries = reshape(weights_no_caries, [-1, ])
cam_caries = multiply(weights_caries, conv_outputs)
cam_no_caries = multiply(weights_no_caries, conv_outputs)
#rebatch
cam = reshape(cam, [conv_shape[0],conv_shape[1],conv_shape[2], -1])
cam = reduce_sum(cam, axis=-2)
cam = transpose(cam, [2,0,1])
cam_caries = reshape(cam_caries, [conv_shape[0],conv_shape[1],conv_shape[2], -1])
cam_no_caries = reshape(cam_no_caries, [conv_shape[0],conv_shape[1],conv_shape[2], -1])
cam_caries = reduce_sum(cam_caries, axis=-2)
cam_no_caries = reduce_sum(cam_no_caries, axis=-2)
cam_caries = transpose(cam_caries, [2,0,1])
cam_no_caries = transpose(cam_no_caries, [2,0,1])
#ad axis for using the tf.image.resize function
cam = cam[..., newaxis]
heatmap = resize(cam, [img.shape[2], img.shape[1]])
cam_caries = cam_caries[..., newaxis]
cam_no_caries = cam_no_caries[..., newaxis]
heatmap_caries = resize(cam_caries, [img.shape[2], img.shape[1]])
heatmap_no_caries = resize(cam_no_caries, [img.shape[2], img.shape[1]])
#remove now unnecessary axis
heatmap = squeeze(heatmap)
#spread the values between 0 and 1
numer = heatmap - reduce_min(heatmap)
denom = reduce_max(heatmap) - reduce_min(heatmap)
heatmap_caries = squeeze(heatmap_caries)
heatmap_no_caries = squeeze(heatmap_no_caries)
#spread the values between 0 and 1 for caries
numer = heatmap_caries - reduce_min(heatmap_caries)
denom = reduce_max(heatmap_caries) - reduce_min(heatmap_caries)
if not denom <= 0:
heatmap = divide(numer, denom)
heatmap_caries = divide(numer, denom)
heatmap = multiply(heatmap, mouth_filter)
# spread the values between 0 and 1 for no_caries
numer = heatmap_no_caries - reduce_min(heatmap_no_caries)
denom = reduce_max(heatmap_no_caries) - reduce_min(heatmap_no_caries)
if not denom <= 0:
heatmap_no_caries = divide(numer, denom)
loss_addition = reduce_mean(heatmap)
heatmap_caries = multiply(heatmap_caries, mouth_filter)
heatmap_no_caries = multiply(heatmap_no_caries, mouth_filter)
loss_addition_caries = reduce_mean(heatmap_caries)
loss_addition_no_caries = reduce_mean(heatmap_no_caries)
loss_cam = tf.divide(tf.add(loss_addition_caries, loss_addition_no_caries), 2)
with GradientTape() as tape:
y_pred, conv_out = self(img, training=True) # Forward pass
# Compute the loss value
loss = custom_loss(y, y_pred, loss_addition)
loss = custom_loss(y, y_pred, loss_cam)
# Compute gradients
trainable_vars = self.trainable_variables
......@@ -97,7 +121,8 @@ class CAMModel(Model):
def test_step(self, data):
# Unpack the data
x, y = data
x = x['img']
if type(x) == dict:
x = x['img']
# Compute predictions and skip the convolution output
y_pred, _ = self(x, training=False)
#calculate the loss
......
......@@ -38,7 +38,8 @@ def rotate(x, label, size):
img = x['img']
mask = x['mouth']
mask = tf.expand_dims(mask, -1)
random_value = tf.random.uniform(shape=[], minval=0, maxval=4, dtype=tf.int32)
# rotate either 0 or 180 degrees (0 times or 2 times 90 degrees)
random_value = tf.multiply(tf.random.uniform(shape=[], minval=0, maxval=2, dtype=tf.int32),2)
img = tf.image.rot90(img, random_value)
mask = tf.image.rot90(mask, random_value)
mask = tf.squeeze(mask)
......@@ -99,7 +100,7 @@ class DataLoader(object):
self.IMG_HEIGHT = img_height
self.CHANNELS = channels
self.AUGMENT = augment
self.AUGMENTATIONS = [flip, color, zoom]
self.AUGMENTATIONS = [flip, color, zoom, rotate]
self.annotation = annotation
self.classes = [item.name for item in data_path.glob('*') if item.name != '.DS_Store']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment