from tensorflow.keras import Model from tensorflow import GradientTape, cast, reduce_mean, reduce_sum, multiply, newaxis, reshape, transpose, squeeze from tensorflow.keras.losses import categorical_crossentropy from tensorflow.keras.metrics import CategoricalAccuracy, Mean ''' Based on default example from Keras Docs. https://keras.io/guides/customizing_what_happens_in_fit/ ''' metric_tracker = CategoricalAccuracy() loss_tracker = Mean(name='loss') class NoCAMModel(Model): class_index_dict = {'caries': 0, 'no_caries': 1} # hardcoded classes def train_step(self, data): # Unpack the data. Its structure depends on your model and # on what you pass to `fit()`. x, y = data img = x['img'] with GradientTape() as tape: y_pred, conv_out = self(img, training=True) # Forward pass # Compute the loss value loss = categorical_crossentropy(y, y_pred) # Compute gradients trainable_vars = self.trainable_variables gradients = tape.gradient(loss, trainable_vars) # Update weights self.optimizer.apply_gradients(zip(gradients, trainable_vars)) loss_tracker.update_state(loss) metric_tracker.update_state(y, y_pred) return {'loss': loss_tracker.result(), 'accuracy': metric_tracker.result()} def test_step(self, data): # Unpack the data x, y = data if type(x) == dict: x = x['img'] # Compute predictions and skip the convolution output y_pred, _ = self(x, training=False) #calculate the loss loss = categorical_crossentropy(y, y_pred) # Updates the metrics tracking the loss loss_tracker.update_state(loss) # Update the metrics. metric_tracker.update_state(y, y_pred) # Return a dict mapping metric names to current value. return {'loss': loss_tracker.result(), 'accuracy': metric_tracker.result()}