Commit ec31a4c5 authored by sjjsmuel's avatar sjjsmuel

refined training

parent b7f6ecb4
# Data dir
input/evaluation_data/
input/training_data/
input/training_data_one/
input/test_data/
input/test_data_mini/
input/training_data_small/
......
from network_helpers.NetworkBase import NetworkBase
from network.NetworkBase import NetworkBase
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D, Dropout, Flatten, AveragePooling2D
from tensorflow.keras.applications.resnet_v2 import ResNet152V2
......
from tensorflow.keras import regularizers
from network.NetworkBase import NetworkBase
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D, Dropout, Flatten, AveragePooling2D
......@@ -32,12 +34,10 @@ class Resnet50(NetworkBase):
for layer in base_model.layers:
layer.trainable = False
#x = GlobalAveragePooling2D()(base_model.output)
x = Flatten()(base_model.output)
x = Dropout(0.3)(x)
x = Dense(128)(x)
x = Dropout(0.3)(x)
out = Dense(self.NUM_CLASSES, activation='softmax', name='probs')(x)
x = GlobalAveragePooling2D()(base_model.output)
x = Dropout(0.2)(x)
x = Dense(64, activation='relu', activity_regularizer=regularizers.l2(0.01))(x)
out = Dense(self.NUM_CLASSES, activation='softmax', activity_regularizer=regularizers.l2(0.01), name='probs')(x)
model = Model(base_model.input, out)
......
......@@ -16,10 +16,11 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True
parser = OptionParser()
parser.add_option("-p", "--path_train", dest="train_path", help="Path to training input.", default="./input/train_data")
parser.add_option("-t", "--path_test", dest="test_path", help="Path to test input.", default="./input/test_data")
parser.add_option("-v", "--path_validation", dest="validation_path", help="Path to validation input.", default="./input/validation_data")
parser.add_option("--train_size", type="int", dest="train_size", default=200)
parser.add_option("--validation_size", type="int", dest="validation_size", default=200)
parser.add_option("-o", "--path_output", dest="output_path", help="Path to base folder for output data.", default='./out')
parser.add_option("-t", "--path_test", dest="test_path", help="Path to test input.", default="./input/test_data")
parser.add_option("--base_network_file", dest="base_net_file", help="Optional link to local file of Resnet 152 V2 for TF without top.")
parser.add_option("--num_epochs", type="int", dest="num_epochs", help="Number of epochs.", default=100)
parser.add_option("--num_epochs_pre_train", type="int", dest="num_epochs_pre_train", help="Number of epochs for the basic training of the FC layers.", default=50)
......@@ -45,6 +46,7 @@ def get_curr_time():
train_dir = pathlib.Path(options.train_path)
test_dir = pathlib.Path(options.test_path)
validation_dir = pathlib.Path(options.validation_path)
time = get_curr_time()
checkpoint_path = pathlib.Path(options.output_path +'/checkpoints/{}'.format(time))
if not checkpoint_path.exists():
......@@ -73,17 +75,26 @@ train_loader = DataLoader(data_path=train_dir,
)
train_dataset = train_loader.load_dataset()
# Validation Data
validation_loader = DataLoader(data_path=validation_dir,
batch_size=options.batch_size,
should_size_dataset1=min_size_validation_dataset,
img_width=img_width,
img_height=img_height,
channels=channels,
#split_size=options.split_size,
augment=True
)
validation_dataset = validation_loader.load_dataset()
# Test Data
test_loader = DataLoader(data_path=test_dir,
batch_size=options.batch_size,
should_size_dataset1=min_size_validation_dataset,
img_width=img_width,
img_height=img_height,
channels=channels,
split_size=options.split_size,
augment=True
)
validation_dataset, test_dataset = test_loader.load_dataset()
test_loader = DataLoader(data_path=test_dir,
batch_size=options.batch_size,
img_width=img_width,
img_height=img_height,
channels=channels,
)
test_dataset = test_loader.load_dataset()
# Create Network
......@@ -124,17 +135,18 @@ print('\nHistory dict:', history.history)
'''
model = load_model(str(checkpoint_path) + '/best_pre_train.hdf5')
for layer in model.layers[-39:]:
for layer in model.layers[154:]:
layer.trainable = True
model.compile(optimizer=SGD(lr=1e-4, momentum=0.9), loss='categorical_crossentropy', metrics=['accuracy'])
# Reload Training Data to shuffle and augment in a different way then before
train_dataset = train_loader.load_dataset()
validation_dataset = validation_loader.load_dataset()
callbacks = [
ModelCheckpoint(
filepath= str(checkpoint_path) + '/model.{epoch:04d}-{val_loss:.3f}.hdf5',
filepath= str(checkpoint_path) + '/model.{epoch:03d}-{val_loss:.3f}.hdf5',
# Path where to save the model
# The two parameters below mean that we will overwrite
# the current checkpoint if and only if
......@@ -143,7 +155,8 @@ callbacks = [
monitor='val_loss',
verbose=1),
TensorBoard(options.output_path +'/logs/{}'.format(time)),
EarlyStopping(monitor='val_loss', patience=options.early_stopping_patience)
# Restoring the best weights loads the best model appeared during the training so that it will evaluate the best and not the last model.
EarlyStopping(monitor='val_loss', patience=options.early_stopping_patience, restore_best_weights=True)
]
print('# Fit model on training data')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment