Commit 7a5f3b8c authored by sjjsmuel's avatar sjjsmuel

try vgg

parent 6be13c07
......@@ -42,4 +42,7 @@ class Resnet50(NetworkBase):
model = Model(base_model.input, out)
for layer in model.layers[154:]:
layer.trainable = True
return model
from tensorflow.keras import regularizers
from network.NetworkBase import NetworkBase
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D, Dropout, Flatten, AveragePooling2D
from tensorflow.keras.applications.vgg16 import VGG16
import pathlib
class VGG_16(NetworkBase):
def __init__(self, number_of_classes, img_width=224, img_height=224, channels=3, weights_path=None):
super().__init__(img_width=img_width, img_height=img_height, channels=channels)
self.NUM_CLASSES = number_of_classes
self.WEIGHTS_PATH = weights_path
def get_model(self):
weights = 'imagenet'
shouldSave = True
if self.WEIGHTS_PATH:
if pathlib.Path(self.WEIGHTS_PATH).exists():
weights = self.WEIGHTS_PATH
shouldSave = False
else:
print('Given weigths-file for base ResNet not found. (Missing: {})'.format(self.WEIGHTS_PATH))
input_tensor = Input(shape=(self.IMG_WIDTH, self.IMG_HEIGHT, self.CHANNELS))
base_model = VGG16(weights=weights, input_tensor=input_tensor, include_top=False)
if shouldSave:
base_model.save('input/vgg_base_model.h5')
x = Flatten()(base_model.output)
x = Dense(4096, activation='relu')(x)
x = Dense(4096, activation='relu')(x)
out = Dense(self.NUM_CLASSES, activation='softmax', name='probs')(x)
model = Model(base_model.input, out)
print(model.summary())
return model
......@@ -3,6 +3,7 @@ from optparse import OptionParser
from PIL import ImageFile
from network.Resnet50 import Resnet50
from network.VGG_16 import VGG_16
from helpers.DataLoader import DataLoader
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
from tensorflow.keras.optimizers import RMSprop, SGD
......@@ -98,12 +99,9 @@ test_dataset = test_loader.load_dataset()
# Create Network
network = Resnet50(n_classes, img_width, img_height, channels, resnet_file)
network = VGG_16(n_classes, img_width, img_height, channels)
model = network.get_model()
for layer in model.layers[154:]:
layer.trainable = True
#compile the model
model.compile(optimizer=RMSprop(), loss='categorical_crossentropy', metrics=['accuracy'])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment