Commit 9bcf8e78 authored by Jonas Müller's avatar Jonas Müller

save best epoch

parent 1286f41b
......@@ -29,7 +29,7 @@ class Config:
self.img_scaling_factor = 1.0
# number of ROIs at once
self.num_rois = 4
self.num_rois = 8
# stride at the RPN (this depends on the network configuration)
self.rpn_stride = 16
......
import keras_frcnn.roi_helpers as roi_helpers
from keras.utils import generic_utils
from keras import backend as K
import numpy as np
import random
#def get_validation_lossv2(data_gen_val, epoch_length, model_rpn, model_classifier, model_classifier_only, C, class_mapping_inv, class_to_color, writer_tensorboard=None, num_epoch=0):
def get_validation_loss(data_gen_val, epoch_length, model_rpn, model_classifier, C, class_mapping_inv, writer_tensorboard=None, num_epoch=0):
'''
compute loss on validation data. Can also print images on tensorboard with the boxes predicted.
threshold : bbox, rpn, classifier
'''
losses = np.zeros((epoch_length, 5))
rpn_accuracy_rpn_monitor = []
rpn_accuracy_for_epoch = []
threshold = C.threshold
class_mapping = C.class_mapping
iter_num = 0
progbar = generic_utils.Progbar(epoch_length)
print('Validating')
'''try:
load_model_weights(model_classifier, model_classifier_only)
except Exception as e:
print('Exception Validation: iter num {}, {}'.format(iter_num, e))
PrintException()
exit()'''
for epoch_num in range(0, epoch_length):
try:
#print('begin try, iter num : {} , epoch num {}'.format(iter_num, epoch_num))
if len(rpn_accuracy_rpn_monitor) == epoch_length and C.verbose:
mean_overlapping_bboxes = float(sum(rpn_accuracy_rpn_monitor))/len(rpn_accuracy_rpn_monitor)
rpn_accuracy_rpn_monitor = []
print('Average number of overlapping bounding boxes from RPN = {} for {} previous iterations (validation)'.format(mean_overlapping_bboxes, epoch_length))
if mean_overlapping_bboxes == 0:
print('RPN is not producing bounding boxes that overlap the ground truth boxes. Check RPN settings or keep training.')
X, Y, img_data = next(data_gen_val)
loss_rpn = model_rpn.evaluate(X, Y, verbose=0)
P_rpn = model_rpn.predict_on_batch(X)
R = roi_helpers.rpn_to_roi(P_rpn[0], P_rpn[1], C, K.image_dim_ordering(), use_regr=True, overlap_thresh=threshold[1], max_boxes=300)
# note: calc_iou converts from (x1,y1,x2,y2) to (x,y,w,h) format
X2, Y1, Y2, IouS = roi_helpers.calc_iou(R, img_data, C, class_mapping)
if X2 is None:
rpn_accuracy_rpn_monitor.append(0)
rpn_accuracy_for_epoch.append(0)
iter_num +=1
continue
neg_samples = np.where(Y1[0, :, -1] == 1)
pos_samples = np.where(Y1[0, :, -1] == 0)
if len(neg_samples) > 0:
neg_samples = neg_samples[0]
else:
neg_samples = []
if len(pos_samples) > 0:
pos_samples = pos_samples[0]
else:
pos_samples = []
rpn_accuracy_rpn_monitor.append(len(pos_samples))
rpn_accuracy_for_epoch.append((len(pos_samples)))
if C.num_rois > 1:
if len(pos_samples) < C.num_rois//2:
selected_pos_samples = pos_samples.tolist()
else:
selected_pos_samples = np.random.choice(pos_samples, C.num_rois//2, replace=False).tolist()
try:
selected_neg_samples = np.random.choice(neg_samples, C.num_rois - len(selected_pos_samples), replace=False).tolist()
except:
selected_neg_samples = np.random.choice(neg_samples, C.num_rois - len(selected_pos_samples), replace=True).tolist()
sel_samples = selected_pos_samples + selected_neg_samples
else:
# in the extreme case where num_rois = 1, we pick a random pos or neg sample
selected_pos_samples = pos_samples.tolist()
selected_neg_samples = neg_samples.tolist()
if np.random.randint(0, 2):
sel_samples = random.choice(neg_samples)
else:
sel_samples = random.choice(pos_samples)
loss_class = model_classifier.evaluate([X, X2[:, sel_samples, :]], [Y1[:, sel_samples, :], Y2[:, sel_samples, :]], verbose=0)
losses[iter_num, 0] = loss_rpn[1]
losses[iter_num, 1] = loss_rpn[2]
losses[iter_num, 2] = loss_class[1]
losses[iter_num, 3] = loss_class[2]
losses[iter_num, 4] = loss_class[3]
iter_num += 1
progbar.update(iter_num,
[('val_rpn_cls', np.mean(losses[:iter_num, 0])), ('val_rpn_regr', np.mean(losses[:iter_num, 1])),
('val_detector_cls', np.mean(losses[:iter_num, 2])), ('val_detector_regr', np.mean(losses[:iter_num, 3]))]
)
#print('end try before if, iter num : {} , epoch num {}'.format(iter_num, epoch_num))
if iter_num == epoch_length or iter_num == epoch_length-1:
loss_rpn_cls = np.mean(losses[:, 0])
loss_rpn_regr = np.mean(losses[:, 1])
loss_class_cls = np.mean(losses[:, 2])
loss_class_regr = np.mean(losses[:, 3])
class_acc = np.mean(losses[:, 4])
curr_loss = loss_rpn_cls + loss_rpn_regr + loss_class_cls + loss_class_regr
mean_overlapping_bboxes = float(sum(rpn_accuracy_for_epoch)) / len(rpn_accuracy_for_epoch)
'''if iter_num - 1 < C.tensorboard_images:
img, all_dets = predict_on_image(np.copy(X), model_rpn, model_classifier_only, C, class_mapping_inv, class_to_color, bbox_threshold = threshold[0], overlap_thresh_rpn = threshold[1], overlap_thresh_classifier = threshold[2], tensorboard=True)
TensorboardImage(writer_tensorboard, img, iter_num - 1, num_epoch)'''
#print('end try end if, iter num : {} , epoch num {}'.format(iter_num, epoch_num))
except Exception as e:
print('Exception Validation: iter num {}, {}'.format(iter_num, e))
#PrintException()
continue
#print('end validation loss loop, iter num' + str(iter_num))
return {'loss_rpn_cls': loss_rpn_cls, 'loss_rpn_regr': loss_rpn_regr, 'loss_class_cls': loss_class_cls,
'loss_class_regr': loss_class_regr, 'class_acc': class_acc, 'curr_loss': curr_loss,
'mean_overlapping_bboxes': mean_overlapping_bboxes}
......@@ -18,6 +18,9 @@ from keras_frcnn import losses as losses
import keras_frcnn.roi_helpers as roi_helpers
from keras.utils import generic_utils
from keras_frcnn.get_validation_loss import get_validation_loss
def get_roi_samples(pos_samples, neg_samples):
if C.num_rois > 1:
if len(pos_samples) < C.num_rois // 2:
......@@ -63,6 +66,10 @@ parser.add_option("--config_filename", dest="config_filename", help=
default="config.pickle")
parser.add_option("--output_weight_path", dest="output_weight_path", help="Output path for weights.", default='./model_frcnn.hdf5')
parser.add_option("--input_weight_path", dest="input_weight_path", help="Input path for weights. If not specified, will try to load default weights provided by keras.")
parser.add_option("-b", "--bbox_threshold", dest="bbox_threshold", help="bbox_threshold", default=0.8)
parser.add_option("-r", "--overlap_threshold_rpn", dest="overlap_threshold_rpn", help="overlap_threshold_rpn", default=0.7)
parser.add_option("-c", "--overlap_threshold_classifier", dest="overlap_threshold_classifier", help="overlap_thresh_classifier", default=0.5)
parser.add_option("--output_log", dest="out_log", help="Output path for logs.", default='./log.txt')
(options, args) = parser.parse_args()
......@@ -85,6 +92,7 @@ C.rot_90 = bool(options.rot_90)
C.model_path = options.output_weight_path
C.num_rois = int(options.num_rois)
C.threshold = [float(options.bbox_threshold), float(options.overlap_threshold_rpn), float(options.overlap_threshold_classifier)]
if options.network == 'vgg':
C.network = 'vgg'
......@@ -189,18 +197,20 @@ num_epochs = int(options.num_epochs)
iter_num = 0
losses = np.zeros((epoch_length, 5))
losses_val = np.zeros((len(val_imgs),5))
#losses_val = np.zeros((len(val_imgs),5))
rpn_accuracy_rpn_monitor = []
rpn_accuracy_for_epoch = []
start_time = time.time()
best_loss = np.Inf
best_loss_epoch_num = -1
epochs_without_improvement_before_early_termination = 5
class_mapping_inv = {v: k for k, v in class_mapping.items()}
print('Starting training')
vis = True
log = open(options.out_log, "w")
for epoch_num in range(num_epochs):
progbar = None
......@@ -268,35 +278,19 @@ for epoch_num in range(num_epochs):
# End of epoch
if iter_num == epoch_length:
for idx, data in enumerate(data_gen_val):
X = copy.deepcopy(data[0])
Y = copy.deepcopy(data[1])
img_data = copy.deepcopy(data[2])
loss_rpn = model_rpn.evaluate(X, Y, verbose=0)
P_rpn = model_rpn.predict_on_batch(X)
R = roi_helpers.rpn_to_roi(P_rpn[0], P_rpn[1], C, K.image_dim_ordering(), use_regr=True, overlap_thresh=0.7, max_boxes=300)
# note: calc_iou converts from (x1,y1,x2,y2) to (x,y,w,h) format
X2, Y1, Y2, IouS = roi_helpers.calc_iou(R, img_data, C, class_mapping)
#loss_class = model_classifier.test_on_batch([X, X2], [Y1, Y2])
loss_class = model_classifier.evaluate([X, X2[:, :, :]], [Y1[:, :, :], Y2[:, :, :]], verbose=0)
losses_val[idx, 0] = loss_rpn[1]
losses_val[idx, 1] = loss_rpn[2]
losses_val[idx, 2] = loss_class[1]
losses_val[idx, 3] = loss_class[2]
losses_val[idx, 4] = loss_class[3]
loss_rpn_cls = np.mean(losses_val[:, 0])
loss_rpn_regr = np.mean(losses_val[:, 1])
loss_class_cls = np.mean(losses_val[:, 2])
loss_class_regr = np.mean(losses_val[:, 3])
class_acc = np.mean(losses_val[:, 4])
mean_overlapping_bboxes = float(sum(rpn_accuracy_for_epoch)) / len(rpn_accuracy_for_epoch)
val_losses = get_validation_loss(data_gen_val, len(val_imgs),
model_rpn, model_classifier, C,
class_mapping_inv,
num_epoch=epoch_num)
#print(val_losses)
loss_rpn_cls = val_losses.get('loss_rpn_cls')#np.mean(losses[:, 0])
loss_rpn_regr = val_losses.get('loss_rpn_regr')#np.mean(losses[:, 1])
loss_class_cls = val_losses.get('loss_class_cls') #np.mean(losses[:, 2])
loss_class_regr = val_losses.get('loss_class_regr') #np.mean(losses[:, 3])
class_acc = val_losses.get('class_acc')#np.mean(losses[:, 4])
mean_overlapping_bboxes = val_losses.get('mean_overlapping_bboxes')#float(sum(rpn_accuracy_for_epoch)) / len(rpn_accuracy_for_epoch)
rpn_accuracy_for_epoch = []
if C.verbose:
......@@ -308,20 +302,39 @@ for epoch_num in range(num_epochs):
print('Loss Detector regression: {}'.format(loss_class_regr))
print('Elapsed time: {}'.format(time.time() - start_time))
curr_loss = loss_rpn_cls + loss_rpn_regr + loss_class_cls + loss_class_regr
curr_loss = val_losses.get('curr_loss')# loss_rpn_cls + loss_rpn_regr + loss_class_cls + loss_class_regr
iter_num = 0
start_time = time.time()
if curr_loss < best_loss:
out = 'Total loss decreased from {} to {}, saving weights'.format(best_loss, curr_loss)
if C.verbose:
print('Total loss decreased from {} to {}, saving weights'.format(best_loss,curr_loss))
print(out)
log.write("---\nEpoch "+str(epoch_num)+"\n")
log.write(out + "\n")
best_loss = curr_loss
best_loss_epoch_num = epoch_num
model_all.save_weights(C.model_path)
'''else:
if best_loss_epoch_num + epochs_without_improvement_before_early_termination <= epoch_num:
print('Early terminated at epoch {}.'.format(epoch_num))
break
'''
#save last 3 epochs
for i in range(3):
t_minus_i = epoch_num - i
if t_minus_i >= 0:
path = C.model_path[:-5] + '_last_minus_' + str(i) + '.hdf5'
model_all.save_weights(path)
break
except Exception as e:
print('Exception: {}'.format(e))
continue
log.close()
print('Training complete, exiting.')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment