from tensorflow.keras import regularizers from helpers.CAMModel import CAMModel from network.NetworkBase import NetworkBase from tensorflow.keras.models import Model from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D, Dropout, Flatten, AveragePooling2D from tensorflow.keras.applications.resnet_v2 import ResNet50V2 import pathlib class Resnet50(NetworkBase): def __init__(self, number_of_classes, img_width=224, img_height=224, channels=3, weights_path=None): super().__init__(img_width=img_width, img_height=img_height, channels=channels) self.NUM_CLASSES = number_of_classes self.WEIGHTS_PATH = weights_path def get_model(self): weights = 'imagenet' shouldSave = True if self.WEIGHTS_PATH: if pathlib.Path(self.WEIGHTS_PATH).exists(): weights = self.WEIGHTS_PATH shouldSave = False else: print('Given weigths-file for base ResNet not found. (Missing: {})'.format(self.WEIGHTS_PATH)) input_tensor = Input(shape=(self.IMG_WIDTH, self.IMG_HEIGHT, self.CHANNELS)) base_model = ResNet50V2(weights=weights, input_tensor=input_tensor, include_top=False) if shouldSave: base_model.save('input/resnet_50_base_model.h5') for layer in base_model.layers: layer.trainable = False x = GlobalAveragePooling2D()(base_model.output) x = Dropout(0.4)(x) x = Dense(128)(x) x = Dropout(0.2)(x) out = Dense(self.NUM_CLASSES, activation='softmax', name='prediction')(x) model = CAMModel(inputs=[input_tensor], outputs=[out, base_model.layers[-1].output]) for layer in model.layers[154:]: layer.trainable = True return model