Commit 1286f41b authored by Jonas Müller's avatar Jonas Müller

try add validation loss

parent ca550eae
from __future__ import division
import copy
import random
import pprint
import sys
......@@ -16,6 +18,30 @@ from keras_frcnn import losses as losses
import keras_frcnn.roi_helpers as roi_helpers
from keras.utils import generic_utils
def get_roi_samples(pos_samples, neg_samples):
if C.num_rois > 1:
if len(pos_samples) < C.num_rois // 2:
selected_pos_samples = pos_samples.tolist()
else:
selected_pos_samples = np.random.choice(pos_samples, C.num_rois // 2, replace=False).tolist()
try:
selected_neg_samples = np.random.choice(neg_samples, C.num_rois - len(selected_pos_samples),
replace=False).tolist()
except:
selected_neg_samples = np.random.choice(neg_samples, C.num_rois - len(selected_pos_samples),
replace=True).tolist()
sel_samples = selected_pos_samples + selected_neg_samples
else:
# in the extreme case where num_rois = 1, we pick a random pos or neg sample
selected_pos_samples = pos_samples.tolist()
selected_neg_samples = neg_samples.tolist()
if np.random.randint(0, 2):
sel_samples = random.choice(neg_samples)
else:
sel_samples = random.choice(pos_samples)
return sel_samples
sys.setrecursionlimit(40000)
parser = OptionParser()
......@@ -163,6 +189,7 @@ num_epochs = int(options.num_epochs)
iter_num = 0
losses = np.zeros((epoch_length, 5))
losses_val = np.zeros((len(val_imgs),5))
rpn_accuracy_rpn_monitor = []
rpn_accuracy_for_epoch = []
start_time = time.time()
......@@ -178,7 +205,7 @@ for epoch_num in range(num_epochs):
progbar = None
if not options.remote:
progbar = generic_utils.Progbar(epoch_length)
progbar = generic_utils.Progbar(epoch_length)
print('Epoch {}/{}'.format(epoch_num + 1, num_epochs))
while True:
......@@ -222,25 +249,7 @@ for epoch_num in range(num_epochs):
rpn_accuracy_rpn_monitor.append(len(pos_samples))
rpn_accuracy_for_epoch.append((len(pos_samples)))
if C.num_rois > 1:
if len(pos_samples) < C.num_rois//2:
selected_pos_samples = pos_samples.tolist()
else:
selected_pos_samples = np.random.choice(pos_samples, C.num_rois//2, replace=False).tolist()
try:
selected_neg_samples = np.random.choice(neg_samples, C.num_rois - len(selected_pos_samples), replace=False).tolist()
except:
selected_neg_samples = np.random.choice(neg_samples, C.num_rois - len(selected_pos_samples), replace=True).tolist()
sel_samples = selected_pos_samples + selected_neg_samples
else:
# in the extreme case where num_rois = 1, we pick a random pos or neg sample
selected_pos_samples = pos_samples.tolist()
selected_neg_samples = neg_samples.tolist()
if np.random.randint(0, 2):
sel_samples = random.choice(neg_samples)
else:
sel_samples = random.choice(pos_samples)
sel_samples = get_roi_samples(pos_samples, neg_samples)
loss_class = model_classifier.train_on_batch([X, X2[:, sel_samples, :]], [Y1[:, sel_samples, :], Y2[:, sel_samples, :]])
......@@ -256,13 +265,36 @@ for epoch_num in range(num_epochs):
('detector_cls', losses[iter_num, 2]), ('detector_regr', losses[iter_num, 3])])
iter_num += 1
# End of epoch
if iter_num == epoch_length:
loss_rpn_cls = np.mean(losses[:, 0])
loss_rpn_regr = np.mean(losses[:, 1])
loss_class_cls = np.mean(losses[:, 2])
loss_class_regr = np.mean(losses[:, 3])
class_acc = np.mean(losses[:, 4])
for idx, data in enumerate(data_gen_val):
X = copy.deepcopy(data[0])
Y = copy.deepcopy(data[1])
img_data = copy.deepcopy(data[2])
loss_rpn = model_rpn.evaluate(X, Y, verbose=0)
P_rpn = model_rpn.predict_on_batch(X)
R = roi_helpers.rpn_to_roi(P_rpn[0], P_rpn[1], C, K.image_dim_ordering(), use_regr=True, overlap_thresh=0.7, max_boxes=300)
# note: calc_iou converts from (x1,y1,x2,y2) to (x,y,w,h) format
X2, Y1, Y2, IouS = roi_helpers.calc_iou(R, img_data, C, class_mapping)
#loss_class = model_classifier.test_on_batch([X, X2], [Y1, Y2])
loss_class = model_classifier.evaluate([X, X2[:, :, :]], [Y1[:, :, :], Y2[:, :, :]], verbose=0)
losses_val[idx, 0] = loss_rpn[1]
losses_val[idx, 1] = loss_rpn[2]
losses_val[idx, 2] = loss_class[1]
losses_val[idx, 3] = loss_class[2]
losses_val[idx, 4] = loss_class[3]
loss_rpn_cls = np.mean(losses_val[:, 0])
loss_rpn_regr = np.mean(losses_val[:, 1])
loss_class_cls = np.mean(losses_val[:, 2])
loss_class_regr = np.mean(losses_val[:, 3])
class_acc = np.mean(losses_val[:, 4])
mean_overlapping_bboxes = float(sum(rpn_accuracy_for_epoch)) / len(rpn_accuracy_for_epoch)
rpn_accuracy_for_epoch = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment