Commit d9754e15 authored by sjjsmuel's avatar sjjsmuel

resnet changes; option to disable cam training

parent e3c955f5
......@@ -33,9 +33,14 @@ class Resnet50(NetworkBase):
if shouldSave:'input/resnet_50_base_model.h5')
#Set the base_network untrainable, except the last Conv Block
for layer in base_model.layers[:-36]:
layer.trainable = False
x = GlobalAveragePooling2D()(base_model.output)
x = Dense(self.NUM_CLASSES)(x)
x = Dropout(0.4)(x)
x = Dense(128)(x)
x = Dropout(0.2)(x)
out = Dense(self.NUM_CLASSES, activation='softmax', name='prediction')(x)
if cam:
......@@ -33,6 +33,7 @@ parser.add_option("--patience", type="int", dest="early_stopping_patience", help
parser.add_option("--split_size", type="float", dest="split_size", help="Proportion of validation examples out of all test examples. Set as value between 0 and 1.", default=0.5)
parser.add_option("--width", type="int", dest="width", default=224)
parser.add_option("--height", type="int", dest="height", default=224)
parser.add_option("--no_cam_training", action="store_false", dest="cam_training", default=True)
# parser.add_option("--output_weight_path", dest="output_weight_path", help="Output path for weights.", default='./model_frcnn.hdf5')
# parser.add_option("--input_weight_path", dest="input_weight_path", help="Input path for weights. If not specified, will try to load default weights provided by keras.")
......@@ -111,7 +112,7 @@ test_dataset = test_loader.load_dataset()
# Create Network
network = Resnet50(n_classes, img_width, img_height, channels, base_model_file)
model = network.get_model(cam=False)
model = network.get_model(cam=options.cam_training)
#compile the model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment