NoCAMModel.py 1.93 KB
Newer Older
sjjsmuel's avatar
fix  
sjjsmuel committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14
from tensorflow.keras import Model
from tensorflow import GradientTape, cast, reduce_mean, reduce_sum, multiply, newaxis, reshape, transpose, squeeze
from tensorflow.keras.losses import categorical_crossentropy
from tensorflow.keras.metrics import CategoricalAccuracy, Mean

'''
    Based on default example from Keras Docs.
    https://keras.io/guides/customizing_what_happens_in_fit/
'''

metric_tracker = CategoricalAccuracy()
loss_tracker = Mean(name='loss')

class NoCAMModel(Model):
sjjsmuel's avatar
sjjsmuel committed
15
    class_index_dict = {'caries': 0, 'no_caries': 1}  # hardcoded classes
sjjsmuel's avatar
fix  
sjjsmuel committed
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54

    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data

        img = x['img']

        with GradientTape() as tape:
            y_pred, conv_out = self(img, training=True)  # Forward pass
            # Compute the loss value
            loss = categorical_crossentropy(y, y_pred)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        loss_tracker.update_state(loss)
        metric_tracker.update_state(y, y_pred)
        return {'loss': loss_tracker.result(), 'accuracy': metric_tracker.result()}

    def test_step(self, data):
        # Unpack the data
        x, y = data
        if type(x) == dict:
            x = x['img']
        # Compute predictions and skip the convolution output
        y_pred, _ = self(x, training=False)
        #calculate the loss
        loss = categorical_crossentropy(y, y_pred)
        # Updates the metrics tracking the loss
        loss_tracker.update_state(loss)
        # Update the metrics.
        metric_tracker.update_state(y, y_pred)
        # Return a dict mapping metric names to current value.
        return {'loss': loss_tracker.result(), 'accuracy': metric_tracker.result()}