Resnet50.py 1.76 KB
Newer Older
sjjsmuel's avatar
sjjsmuel committed
1 2
from tensorflow.keras import regularizers

sjjsmuel's avatar
sjjsmuel committed
3
from helpers.CAMModel import CAMModel
sjjsmuel's avatar
sjjsmuel committed
4
from network.NetworkBase import NetworkBase
sjjsmuel's avatar
sjjsmuel committed
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D, Dropout, Flatten, AveragePooling2D
from tensorflow.keras.applications.resnet_v2 import ResNet50V2

import pathlib

class Resnet50(NetworkBase):
    def __init__(self, number_of_classes,  img_width=224, img_height=224, channels=3, weights_path=None):
        super().__init__(img_width=img_width, img_height=img_height, channels=channels)
        self.NUM_CLASSES = number_of_classes
        self.WEIGHTS_PATH = weights_path

    def get_model(self):
        weights = 'imagenet'
        shouldSave = True

        if self.WEIGHTS_PATH:
            if pathlib.Path(self.WEIGHTS_PATH).exists():
                weights = self.WEIGHTS_PATH
                shouldSave = False
            else:
                print('Given weigths-file for base ResNet not found. (Missing: {})'.format(self.WEIGHTS_PATH))

        input_tensor = Input(shape=(self.IMG_WIDTH, self.IMG_HEIGHT, self.CHANNELS))

        base_model = ResNet50V2(weights=weights, input_tensor=input_tensor, include_top=False)

        if shouldSave:
sjjsmuel's avatar
sjjsmuel committed
33
            base_model.save('input/resnet_50_base_model.h5')
sjjsmuel's avatar
sjjsmuel committed
34

sjjsmuel's avatar
sjjsmuel committed
35 36 37
        for layer in base_model.layers:
            layer.trainable = False

sjjsmuel's avatar
sjjsmuel committed
38
        x = GlobalAveragePooling2D()(base_model.output)
sjjsmuel's avatar
sjjsmuel committed
39
        x = Dropout(0.4)(x)
sjjsmuel's avatar
sjjsmuel committed
40
        x = Dense(128)(x)
sjjsmuel's avatar
sjjsmuel committed
41
        x = Dropout(0.2)(x)
sjjsmuel's avatar
sjjsmuel committed
42
        out = Dense(self.NUM_CLASSES, activation='softmax',  name='prediction')(x)
sjjsmuel's avatar
sjjsmuel committed
43

sjjsmuel's avatar
sjjsmuel committed
44
        model = CAMModel(inputs=[input_tensor], outputs=[out, base_model.layers[-1].output])
sjjsmuel's avatar
sjjsmuel committed
45

sjjsmuel's avatar
try vgg  
sjjsmuel committed
46 47 48
        for layer in model.layers[154:]:
            layer.trainable = True

sjjsmuel's avatar
sjjsmuel committed
49
        return model