from tensorflow.keras import regularizers from network.NetworkBase import NetworkBase from tensorflow.keras.models import Model from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D, Dropout, Flatten, AveragePooling2D from tensorflow.keras.applications.vgg16 import VGG16 import pathlib class VGG_16(NetworkBase): def __init__(self, number_of_classes, img_width=224, img_height=224, channels=3, weights_path=None): super().__init__(img_width=img_width, img_height=img_height, channels=channels) self.NUM_CLASSES = number_of_classes self.WEIGHTS_PATH = weights_path def get_model(self): weights = 'imagenet' shouldSave = True if self.WEIGHTS_PATH: if pathlib.Path(self.WEIGHTS_PATH).exists(): weights = self.WEIGHTS_PATH shouldSave = False else: print('Given weigths-file for base ResNet not found. (Missing: {})'.format(self.WEIGHTS_PATH)) input_tensor = Input(shape=(self.IMG_WIDTH, self.IMG_HEIGHT, self.CHANNELS)) base_model = VGG16(weights=weights, input_tensor=input_tensor, include_top=False) if shouldSave: base_model.save('input/vgg_base_model.h5') x = GlobalAveragePooling2D()(base_model.output) x = Dense(4096, activation='relu')(x) x = Dense(4096, activation='relu')(x) out = Dense(self.NUM_CLASSES, activation='softmax', name='probs')(x) model = Model(base_model.input, out) print(model.summary()) return model