train_classifier.py 6.98 KB
Newer Older
sjjsmuel's avatar
sjjsmuel committed
1
from datetime import datetime
sjjsmuel's avatar
sjjsmuel committed
2
from optparse import OptionParser
sjjsmuel's avatar
sjjsmuel committed
3

sjjsmuel's avatar
sjjsmuel committed
4
from PIL import ImageFile
sjjsmuel's avatar
sjjsmuel committed
5 6

from helpers.AnnotationLocationLoader import AnnotationLocationLoader
sjjsmuel's avatar
sjjsmuel committed
7
from network.Resnet50 import Resnet50
sjjsmuel's avatar
try vgg  
sjjsmuel committed
8
from network.VGG_16 import VGG_16
sjjsmuel's avatar
sjjsmuel committed
9
from helpers.DataLoader import DataLoader
sjjsmuel's avatar
sjjsmuel committed
10
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
sjjsmuel's avatar
sjjsmuel committed
11
from tensorflow.keras.optimizers import RMSprop, SGD
sjjsmuel's avatar
sjjsmuel committed
12
from tensorflow import version as tfv
sjjsmuel's avatar
sjjsmuel committed
13

sjjsmuel's avatar
sjjsmuel committed
14 15
import pathlib

sjjsmuel's avatar
sjjsmuel committed
16 17 18

ImageFile.LOAD_TRUNCATED_IMAGES = True

sjjsmuel's avatar
sjjsmuel committed
19 20
parser = OptionParser()

sjjsmuel's avatar
sjjsmuel committed
21
parser.add_option("-p", "--path_train", dest="train_path", help="Path to training input.", default="./input/train_data")
sjjsmuel's avatar
sjjsmuel committed
22 23
parser.add_option("-t", "--path_test", dest="test_path", help="Path to test input.", default="./input/test_data")
parser.add_option("-v", "--path_validation", dest="validation_path", help="Path to validation input.", default="./input/validation_data")
sjjsmuel's avatar
sjjsmuel committed
24
parser.add_option("--mouth_annotation", dest="mouth_annotation_path", help="Path to folder containing the mouth annotation files.", default='./input/mouth_annotations/')
25 26
parser.add_option("--train_size", type="int", dest="train_size", default=200)
parser.add_option("--validation_size", type="int", dest="validation_size", default=200)
sjjsmuel's avatar
sjjsmuel committed
27
parser.add_option("-o", "--path_output", dest="output_path", help="Path to base folder for output data.", default='./out')
28
parser.add_option("--base_network_file", dest="base_net_file", help="Optional link to local file of Resnet 152 V2 for TF without top.")
sjjsmuel's avatar
sjjsmuel committed
29
parser.add_option("--num_epochs", type="int", dest="num_epochs", help="Number of epochs.", default=100)
sjjsmuel's avatar
sjjsmuel committed
30
parser.add_option("--num_epochs_pre_train", type="int", dest="num_epochs_pre_train", help="Number of epochs for the basic training of the FC layers.", default=50)
31
parser.add_option("--batch_size", type="int", dest="batch_size", help="Size of batches.", default=10)
sjjsmuel's avatar
sjjsmuel committed
32 33
parser.add_option("--patience", type="int", dest="early_stopping_patience", help="How many epochs without improvement should the training going on?", default=5)
parser.add_option("--split_size", type="float", dest="split_size", help="Proportion of validation examples out of all test examples. Set as value between 0 and 1.", default=0.5)
sjjsmuel's avatar
sjjsmuel committed
34 35
parser.add_option("--width", type="int", dest="width", default=224)
parser.add_option("--height", type="int", dest="height", default=224)
sjjsmuel's avatar
sjjsmuel committed
36 37 38 39 40
# parser.add_option("--output_weight_path", dest="output_weight_path", help="Output path for weights.", default='./model_frcnn.hdf5')
# parser.add_option("--input_weight_path", dest="input_weight_path", help="Input path for weights. If not specified, will try to load default weights provided by keras.")

(options, args) = parser.parse_args()

41
if not options.train_path:   # if folder name is not given
sjjsmuel's avatar
sjjsmuel committed
42
    parser.error('Error: path to training input must be specified. Pass --path_train to command line')
43
if not options.test_path:   # if folder name is not given
sjjsmuel's avatar
sjjsmuel committed
44
    parser.error('Error: path to test input must be specified. Pass --path_test to command line')
sjjsmuel's avatar
sjjsmuel committed
45 46
if not options.mouth_annotation_path:   # if folder name is not given
    parser.error('Error: path to mouth annotations must be specified. Pass --mouth_annotation to command line')
sjjsmuel's avatar
sjjsmuel committed
47

sjjsmuel's avatar
sjjsmuel committed
48 49 50
def get_curr_time():
    return datetime.now().strftime("%Y.%m.%d.%H.%M.%S")

sjjsmuel's avatar
sjjsmuel committed
51 52 53

print('Tensorflow Version {}'.format(tfv.VERSION))

54 55
## Arguments and Settings

56 57
train_dir = pathlib.Path(options.train_path)
test_dir = pathlib.Path(options.test_path)
sjjsmuel's avatar
sjjsmuel committed
58
validation_dir = pathlib.Path(options.validation_path)
sjjsmuel's avatar
sjjsmuel committed
59 60
time = get_curr_time()
checkpoint_path = pathlib.Path(options.output_path +'/checkpoints/{}'.format(time))
sjjsmuel's avatar
sjjsmuel committed
61 62
if not checkpoint_path.exists():
    checkpoint_path.mkdir()
sjjsmuel's avatar
rename  
sjjsmuel committed
63
base_model_file = None
64
if options.base_net_file:
sjjsmuel's avatar
rename  
sjjsmuel committed
65
    base_model_file = options.base_net_file
sjjsmuel's avatar
sjjsmuel committed
66 67
img_width = options.width
img_height = options.height
68
channels = 3
sjjsmuel's avatar
sjjsmuel committed
69
n_classes = 2
sjjsmuel's avatar
sjjsmuel committed
70 71
batch_size = options.batch_size

72 73
min_size_train_dataset = options.train_size
min_size_validation_dataset = options.validation_size
sjjsmuel's avatar
sjjsmuel committed
74

sjjsmuel's avatar
sjjsmuel committed
75 76
annot_loader = AnnotationLocationLoader(images_base_folder=train_dir, mouth_annotations_folder=options.mouth_annotation_path)

sjjsmuel's avatar
sjjsmuel committed
77 78 79 80
# Load the dataset into TF Datasets
# Training Data
train_loader = DataLoader(data_path=train_dir,
                          batch_size=options.batch_size,
81
                          should_size_dataset1=min_size_train_dataset,
sjjsmuel's avatar
sjjsmuel committed
82 83 84
                          img_width=img_width,
                          img_height=img_height,
                          channels=channels,
sjjsmuel's avatar
sjjsmuel committed
85 86 87
                          # TODO reverse augment to True after testing
                          augment=False,
                          annotation=annot_loader
sjjsmuel's avatar
sjjsmuel committed
88 89 90
                          )
train_dataset = train_loader.load_dataset()

sjjsmuel's avatar
sjjsmuel committed
91 92 93 94 95 96 97 98
# Validation Data
validation_loader =  DataLoader(data_path=validation_dir,
                                batch_size=options.batch_size,
                                should_size_dataset1=min_size_validation_dataset,
                                img_width=img_width,
                                img_height=img_height,
                                channels=channels,
                                #split_size=options.split_size,
sjjsmuel's avatar
sjjsmuel committed
99 100
                                # TODO reverse augment to True after testing
                                augment=False
sjjsmuel's avatar
sjjsmuel committed
101 102 103
                                )
validation_dataset = validation_loader.load_dataset()

sjjsmuel's avatar
sjjsmuel committed
104
# Test Data
sjjsmuel's avatar
sjjsmuel committed
105 106 107 108 109 110 111
test_loader = DataLoader(data_path=test_dir,
                        batch_size=options.batch_size,
                        img_width=img_width,
                        img_height=img_height,
                        channels=channels,
                        )
test_dataset = test_loader.load_dataset()
sjjsmuel's avatar
sjjsmuel committed
112

113 114

# Create Network
sjjsmuel's avatar
resnet  
sjjsmuel committed
115
network = Resnet50(n_classes, img_width, img_height, channels, base_model_file)
sjjsmuel's avatar
sjjsmuel committed
116
model = network.get_model()
117

sjjsmuel's avatar
sjjsmuel committed
118
#compile the model
sjjsmuel's avatar
sjjsmuel committed
119
model.compile(optimizer=RMSprop())
sjjsmuel's avatar
sjjsmuel committed
120

121 122
# Print Network summary
# model.summary()
sjjsmuel's avatar
sjjsmuel committed
123

sjjsmuel's avatar
sjjsmuel committed
124 125
callbacks = [
    ModelCheckpoint(
sjjsmuel's avatar
sjjsmuel committed
126
        filepath= str(checkpoint_path) + '/model.{epoch:03d}-{val_loss:.3f}.hdf5',
sjjsmuel's avatar
sjjsmuel committed
127 128 129 130 131 132 133
        # Path where to save the model
        # The two parameters below mean that we will overwrite
        # the current checkpoint if and only if
        # the `val_loss` score has improved.
        save_best_only=True,
        monitor='val_loss',
        verbose=1),
sjjsmuel's avatar
sjjsmuel committed
134
    TensorBoard(options.output_path +'/logs/{}'.format(time)),
sjjsmuel's avatar
sjjsmuel committed
135 136
    # Restoring the best weights loads the best model appeared during the training so that it will evaluate the best and not the last model.
    EarlyStopping(monitor='val_loss', patience=options.early_stopping_patience, restore_best_weights=True)
sjjsmuel's avatar
sjjsmuel committed
137
]
138

sjjsmuel's avatar
sjjsmuel committed
139
print('# Fit model on training data')
sjjsmuel's avatar
sjjsmuel committed
140
history = model.fit(train_dataset,
141
                    epochs=options.num_epochs,
sjjsmuel's avatar
sjjsmuel committed
142
                    validation_data=validation_dataset,
sjjsmuel's avatar
sjjsmuel committed
143
                    callbacks=callbacks,
sjjsmuel's avatar
sjjsmuel committed
144
                    verbose=2
145
                    )
sjjsmuel's avatar
sjjsmuel committed
146 147
print('\nHistory dict:', history.history)

sjjsmuel's avatar
sjjsmuel committed
148

sjjsmuel's avatar
sjjsmuel committed
149 150
# Evaluate the model on the test data using `evaluate`
print('\n# Evaluate on test data')
151 152 153
results = model.evaluate(test_dataset
                         #   , batch_size=128
                         )
sjjsmuel's avatar
sjjsmuel committed
154
print('test loss, test acc:', results)