VGG_16.py 1.54 KB
Newer Older
sjjsmuel's avatar
try vgg  
sjjsmuel committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
from tensorflow.keras import regularizers

from network.NetworkBase import NetworkBase
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D, Dropout, Flatten, AveragePooling2D
from tensorflow.keras.applications.vgg16 import VGG16

import pathlib

class VGG_16(NetworkBase):
    def __init__(self, number_of_classes,  img_width=224, img_height=224, channels=3, weights_path=None):
        super().__init__(img_width=img_width, img_height=img_height, channels=channels)
        self.NUM_CLASSES = number_of_classes
        self.WEIGHTS_PATH = weights_path

    def get_model(self):
        weights = 'imagenet'
        shouldSave = True

        if self.WEIGHTS_PATH:
            if pathlib.Path(self.WEIGHTS_PATH).exists():
                weights = self.WEIGHTS_PATH
                shouldSave = False
            else:
                print('Given weigths-file for base ResNet not found. (Missing: {})'.format(self.WEIGHTS_PATH))

        input_tensor = Input(shape=(self.IMG_WIDTH, self.IMG_HEIGHT, self.CHANNELS))

        base_model = VGG16(weights=weights, input_tensor=input_tensor, include_top=False)

        if shouldSave:
            base_model.save('input/vgg_base_model.h5')

sjjsmuel's avatar
sjjsmuel committed
34
        x = GlobalAveragePooling2D()(base_model.output)
sjjsmuel's avatar
try vgg  
sjjsmuel committed
35 36 37 38 39 40 41 42
        x = Dense(4096, activation='relu')(x)
        x = Dense(4096, activation='relu')(x)
        out = Dense(self.NUM_CLASSES, activation='softmax', name='probs')(x)

        model = Model(base_model.input, out)

        print(model.summary())
        return model