Commit a7b7cf67 authored by sjjsmuel's avatar sjjsmuel

refinement for the learning

parent b80c8d80
from network_helpers.NetworkBase import NetworkBase
from network.NetworkBase import NetworkBase
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D, Dropout, Flatten, AveragePooling2D
from tensorflow.keras.applications.resnet_v2 import ResNet50V2
......@@ -29,13 +29,16 @@ class Resnet50(NetworkBase):
if shouldSave:
base_model.save('input/resnet_50_base_model.h5')
for layer in base_model.layers:
layer.trainable = False
x = GlobalAveragePooling2D()(base_model.output)
x = Dropout(0.5)(x)
out = Dense(self.NUM_CLASSES, activation='softmax', name='probs')(x)
model = Model(base_model.input, out)
for layer in model.layers[:-39]:
layer.trainable = False
#for layer in model.layers[:-39]:
# layer.trainable = False
return model
......@@ -2,12 +2,10 @@ from datetime import datetime
from optparse import OptionParser
from PIL import ImageFile
from keras_preprocessing.image import ImageDataGenerator
from network_helpers.Resnet152 import Resnet152
from network_helpers.Resnet50 import Resnet50
from network_helpers.DataLoader import DataLoader
from tensorflow.keras.optimizers import Adam
from network.Resnet50 import Resnet50
from helpers.DataLoader import DataLoader
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
from tensorflow.keras.optimizers import RMSprop, SGD
import pathlib
......@@ -23,6 +21,7 @@ parser.add_option("-o", "--path_output", dest="output_path", help="Path to base
parser.add_option("-t", "--path_test", dest="test_path", help="Path to test input.", default="./input/test_data")
parser.add_option("--base_network_file", dest="base_net_file", help="Optional link to local file of Resnet 152 V2 for TF without top.")
parser.add_option("--num_epochs", type="int", dest="num_epochs", help="Number of epochs.", default=100)
parser.add_option("--num_epochs_pre_train", type="int", dest="num_epochs_pre_train", help="Number of epochs for the basic training of the FC layers.", default=50)
parser.add_option("--batch_size", type="int", dest="batch_size", help="Size of batches.", default=10)
parser.add_option("--patience", type="int", dest="early_stopping_patience", help="How many epochs without improvement should the training going on?", default=5)
parser.add_option("--split_size", type="float", dest="split_size", help="Proportion of validation examples out of all test examples. Set as value between 0 and 1.", default=0.5)
......@@ -92,12 +91,35 @@ model = network.get_model()
#compile the model
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
#model.compile(optimizer=Adam(lr=0.000001), loss='categorical_crossentropy', metrics=['accuracy'])
model.compile(optimizer=RMSprop(), loss='categorical_crossentropy', metrics=['accuracy'])
# Print Network summary
# model.summary()
'''
Pre-Train FC Layers
'''
callbacks_prefit = [
TensorBoard(options.output_path +'/logs/{}_prefit'.format(time)),
]
history = model.fit(train_dataset,
epochs=options.num_epochs_pre_train,
validation_data=validation_dataset,
callbacks=callbacks_prefit,
verbose=2
)
#model.save_weights('last_pre_train_model.h5')
print('\nHistory dict:', history.history)
for layer in model.layers[-39:]:
layer.trainable = True
model.compile(optimizer=SGD(lr=1e-4, momentum=0.9), loss='categorical_crossentropy', metrics=['accuracy'])
# Reload Training Data to shuffle and augment in a different way then before
train_dataset = train_loader.load_dataset()
callbacks = [
ModelCheckpoint(
filepath= str(checkpoint_path) + '/model.{epoch:04d}-{val_loss:.3f}.hdf5',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment