Commit 287a9b5e authored by sjjsmuel's avatar sjjsmuel


parent e580ed83
from classifier.NetworkBase import NetworkBase
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D, Dropout, Flatten, AveragePooling2D
from tensorflow.keras.applications.resnet_v2 import ResNet50V2
import pathlib
class Resnet50(NetworkBase):
def __init__(self, number_of_classes, img_width=224, img_height=224, channels=3, weights_path=None):
super().__init__(img_width=img_width, img_height=img_height, channels=channels)
self.NUM_CLASSES = number_of_classes
self.WEIGHTS_PATH = weights_path
def get_model(self):
weights = 'imagenet'
shouldSave = True
if pathlib.Path(self.WEIGHTS_PATH).exists():
weights = self.WEIGHTS_PATH
shouldSave = False
print('Given weigths-file for base ResNet not found. (Missing: {})'.format(self.WEIGHTS_PATH))
input_tensor = Input(shape=(self.IMG_WIDTH, self.IMG_HEIGHT, self.CHANNELS))
base_model = ResNet50V2(weights=weights, input_tensor=input_tensor, include_top=False)
if shouldSave:'data/resnet_50_base_model.h5')
x = Flatten()(base_model.output)
x = Dropout(0.4)(x)
x = Dense(128, activation='relu')(x)
x = Dropout(0.4)(x)
out = Dense(self.NUM_CLASSES, activation='softmax', name='probs')(x)
model = Model(base_model.input, out)
return model
......@@ -4,6 +4,7 @@ from optparse import OptionParser
from PIL import ImageFile
from keras_preprocessing.image import ImageDataGenerator
from classifier.Resnet152 import Resnet152
from classifier.Resnet50 import Resnet50
from classifier.DataLoader import DataLoader
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
......@@ -86,22 +87,13 @@ validation_dataset, test_dataset = test_loader.load_dataset()
# Create Network
network = Resnet152(n_classes, img_width, img_height, channels, resnet_file)
network = Resnet50(n_classes, img_width, img_height, channels, resnet_file)
model = network.get_model()
#compile the model
#model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
model.compile(optimizer=Adam(lr=0.000001), loss='categorical_crossentropy', metrics=['accuracy'])
# Was ist mir Dropout einfügen?!
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dropout(0.7)(x)
predictions = Dense(num_classes, activation= 'softmax')(x)
model = Model(inputs = base_model.input, outputs = predictions)
# Print Network summary
# model.summary()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment