Commit fb1d9052 authored by sjjsmuel's avatar sjjsmuel
Browse files

CAM based loss

New training and loss calculating the activation in areas which are not the mouth
parent cb760532
...@@ -6,6 +6,7 @@ input/test_data/ ...@@ -6,6 +6,7 @@ input/test_data/
input/test_data_mini/ input/test_data_mini/
input/training_data_small/ input/training_data_small/
input/training_data_mini/ input/training_data_mini/
input/mouth_annotations
#output #output
out/ out/
......
import json import json
from pathlib import Path from pathlib import Path, PosixPath
class AnnotationLocationLoader: class AnnotationLocationLoader:
_annotation_file = None _annotation_file = None
_img_path = None _img_path = None
_mouth_annotations_folder = None
_annotated_images = set() _annotated_images = set()
_available_annotations = set() _available_annotations = set()
_available_images = None _available_images = None
_data = {} _data = {}
def __init__(self, annotation_file='input/caries_dataset_annotation.json', images_base_folder='input/test_data/'): def __init__(self, annotation_file='input/caries_dataset_annotation.json', images_base_folder=Path('input/test_data/'), mouth_annotations_folder =Path('input/mouth_annotations/')):
self._annotation_file = annotation_file self._annotation_file = annotation_file
if not type(images_base_folder) == PosixPath:
images_base_folder = Path(images_base_folder)
self._img_path = images_base_folder self._img_path = images_base_folder
if not type(mouth_annotations_folder) == PosixPath:
mouth_annotations_folder = Path(mouth_annotations_folder)
self._mouth_annotations_folder = mouth_annotations_folder
# get the names of the images witch are available as files # get the names of the images witch are available as files
self._available_images = self._get_names_from_available_images() self._available_images = self._get_names_from_available_images()
self._load_annotations() self._load_annotations()
self._load_mouth_annotation_additon()
self._annotated_images = list(self._annotated_images)
self._available_annotations = list(self._available_annotations)
def _get_names_from_available_images(self): def _get_names_from_available_images(self):
names_from_available_images = [] names_from_available_images = []
for path in [path for path in Path(self._img_path).iterdir() if path.is_dir()]: for path in [path for path in self._img_path.iterdir() if path.is_dir()]:
names_from_available_images.extend([filename.name for filename in path.iterdir() if filename.is_file() and not filename.name.startswith('.')]) names_from_available_images.extend([filename.name for filename in path.iterdir() if filename.is_file() and not filename.name.startswith('.')])
return names_from_available_images return names_from_available_images
...@@ -33,12 +45,14 @@ class AnnotationLocationLoader: ...@@ -33,12 +45,14 @@ class AnnotationLocationLoader:
for picture in json_data: for picture in json_data:
picture_filename = picture['External ID'] picture_filename = picture['External ID']
self._data[picture_filename] = []
if not picture_filename in self._available_images: if not picture_filename in self._available_images:
#print('File ”{}” not found.'.format(picture_filename)) #print('File ”{}” not found.'.format(picture_filename))
continue continue
if not picture_filename in self._data.keys():
self._data[picture_filename] = []
# Skip the 'Skip' entries in the file # Skip the 'Skip' entries in the file
if not type(picture['Label']) == dict: if not type(picture['Label']) == dict:
continue continue
...@@ -54,8 +68,33 @@ class AnnotationLocationLoader: ...@@ -54,8 +68,33 @@ class AnnotationLocationLoader:
y_all.append(point['y']) y_all.append(point['y'])
box_coord = [(min(x_all), min(y_all)), (max(x_all), max(y_all))] box_coord = [(min(x_all), min(y_all)), (max(x_all), max(y_all))]
self._data[picture_filename].append((annotation_type.lower(), box_coord)) self._data[picture_filename].append((annotation_type.lower(), box_coord))
self._annotated_images = list(self._annotated_images)
self._available_annotations = list(self._available_annotations)
def _load_mouth_annotation_additon(self):
annotation_files = [file for file in self._mouth_annotations_folder.iterdir() if file.is_file() and not file.name.startswith('.')]
counter_number_of_annotated_but_missing_files = 0
for annotation_file in annotation_files:
with open(annotation_file) as file:
mouth_annotation = json.load(file)
picture_filename = mouth_annotation['asset']['name']
if not picture_filename in self._available_images:
#print('File ”{}” not found.'.format(picture_filename))
counter_number_of_annotated_but_missing_files += 1
continue
if not picture_filename in self._data.keys():
self._data[picture_filename] = []
self._annotated_images.add(picture_filename)
bb = mouth_annotation['regions'][0]['boundingBox']
top_left = (round(bb['left']), round(bb['top']))
bottom_right = (round(bb['left']+bb['width']), round(bb['top'] + bb['height']))
img_width = mouth_annotation['asset']['size']['width']
img_height = mouth_annotation['asset']['size']['height']
self._data[picture_filename].append(('mouth', [(img_width, img_height), top_left, bottom_right]))
print('[INFO] {} mouth annotations were skipped during loading. This was done due to missing corresponding files in the assigned folder.'.format(counter_number_of_annotated_but_missing_files))
def get_all_types_of_annotations(self): def get_all_types_of_annotations(self):
...@@ -83,8 +122,11 @@ class AnnotationLocationLoader: ...@@ -83,8 +122,11 @@ class AnnotationLocationLoader:
""" """
Returns a list of annotations for the given image_name Returns a list of annotations for the given image_name
e.g. [ ('caries', [(x1,y1), (x2,y2)]), _more_entries_ ] e.g. [ ('caries', [(x1,y1), (x2,y2)]), _more_entries_ ]
coords in form: [(top left), (bottom right)]
:param filter: a list of strings representing the types of annotations the user wants to derive :param filter: a list of strings representing the types of annotations the user wants to derive
""" """
if not filter:
filter = self._available_annotations # return anything but the annotation for the mouth
if self.is_annotated(image_name): if self.is_annotated(image_name):
if filter and len(filter)>0: if filter and len(filter)>0:
......
from tensorflow.keras import Model
from tensorflow import GradientTape, cast, reduce_mean, reduce_sum, multiply, newaxis, reshape, transpose, squeeze
from tensorflow.image import resize
from tensorflow.math import multiply, reduce_min, reduce_max, divide, add, l2_normalize
from tensorflow.linalg import matmul
import tensorflow as tf
from tensorflow.keras.losses import categorical_crossentropy
from tensorflow.keras.metrics import Accuracy, Mean
'''
Based on default example from Keras Docs.
https://keras.io/guides/customizing_what_happens_in_fit/
'''
def custom_loss(y, y_pred, cam_loss):
def sub_loss(y, y_pred):
loss = categorical_crossentropy(y, y_pred) + cam_loss
return loss
return sub_loss(y, y_pred)
metric_tracker = Accuracy()
loss_tracker = Mean(name='loss')
class CAMModel(Model):
class_index = 0 # hardcoded class caries
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
x, y = data
img = x['img']
mouth_filter = x['mouth']
with GradientTape() as tape:
y_pred, convOutputs = self(img, training=True) # Forward pass
# Compute the loss for the class_indes
class_loss = y_pred[:, self.class_index]
#compute CAM grads
cam_gradients = tape.gradient(class_loss, convOutputs)
# compute the guided gradients
castConvOutputs = cast(convOutputs > 0, "float32")
castGrads = cast(cam_gradients > 0, "float32")
guidedGrads = castConvOutputs * castGrads * cam_gradients
#save the shape of the convolution to reshape later
conv_shape = convOutputs.shape[1:]
# compute the average of the gradient values, and using them as weights
weights = reduce_mean(guidedGrads, axis=(1, 2))
#flaten out the batch to the filter count dimension
convOutputs = transpose(convOutputs, [0,3, 1,2])
weights = reshape(weights, [-1,])
convOutputs = reshape(convOutputs, [-1,conv_shape[0], conv_shape[1]])
convOutputs = transpose(convOutputs, [1,2,0])
cam = multiply(weights, convOutputs)
#rebatch
cam = reshape(cam, [conv_shape[0],conv_shape[1],conv_shape[2], -1])
cam = reduce_sum(cam, axis=-2)
cam = transpose(cam, [2,0,1])
#ad axis for using the tf.image.resize function
cam = cam[..., newaxis]
heatmap = resize(cam, [img.shape[2], img.shape[1]])
#remove now unnecessary axis
heatmap = squeeze(heatmap)
#spread the values between 0 and 1
numer = heatmap - reduce_min(heatmap)
denom = reduce_max(heatmap) - reduce_min(heatmap)
if not denom <= 0:
heatmap = divide(numer, denom)
heatmap = multiply(heatmap, mouth_filter)
loss_addition = reduce_mean(heatmap)
with GradientTape() as tape:
y_pred, conv_out = self(img, training=True) # Forward pass
# Compute the loss value
loss = custom_loss(y, y_pred, loss_addition)
# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
loss_tracker.update_state(loss)
metric_tracker.update_state(y, y_pred)
return {'loss': loss_tracker.result(), 'accuracy': metric_tracker.result()}
def test_step(self, data):
# Unpack the data
x, y = data
x = x['img']
# Compute predictions and skip the convolution output
y_pred, _ = self(x, training=False)
#calculate the loss
loss = categorical_crossentropy(y, y_pred)
# Updates the metrics tracking the loss
loss_tracker.update_state(loss)
# Update the metrics.
metric_tracker.update_state(y, y_pred)
# Return a dict mapping metric names to current value.
return {'loss': loss_tracker.result(), 'acc': metric_tracker.result()}
\ No newline at end of file
...@@ -5,11 +5,6 @@ import numpy as np ...@@ -5,11 +5,6 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
def _create_dataset(filenames_list, labels_list):
# Creating constants for dataset
filenames = tf.constant(filenames_list)
labels = tf.constant(labels_list)
return tf.data.Dataset.from_tensor_slices((filenames, labels))
''' '''
Augmentation functions Augmentation functions
...@@ -53,7 +48,7 @@ def zoom(x: tf.Tensor, label, size): ...@@ -53,7 +48,7 @@ def zoom(x: tf.Tensor, label, size):
class DataLoader(object): class DataLoader(object):
def __init__(self, data_path, batch_size, img_width, img_height, channels, should_size_dataset1 = 0, should_size_dataset2=0, split_size=0.0, augment=False): def __init__(self, data_path, batch_size, img_width, img_height, channels, should_size_dataset1 = 0, should_size_dataset2=0, split_size=0.0, augment=False, annotation=None):
self.data_path = data_path self.data_path = data_path
self.batch_size = batch_size self.batch_size = batch_size
self.split_size = split_size self.split_size = split_size
...@@ -64,6 +59,7 @@ class DataLoader(object): ...@@ -64,6 +59,7 @@ class DataLoader(object):
self.CHANNELS = channels self.CHANNELS = channels
self.AUGMENT = augment self.AUGMENT = augment
self.AUGMENTATIONS = [flip, color, zoom] self.AUGMENTATIONS = [flip, color, zoom]
self.annotation = annotation
self.classes = [item.name for item in data_path.glob('*') if item.name != '.DS_Store'] self.classes = [item.name for item in data_path.glob('*') if item.name != '.DS_Store']
self.n_classes = len(self.classes) self.n_classes = len(self.classes)
...@@ -79,20 +75,26 @@ class DataLoader(object): ...@@ -79,20 +75,26 @@ class DataLoader(object):
self.dataset_1_repeat_factor = 1 self.dataset_1_repeat_factor = 1
self.dataset_2_repeat_factor = 1 self.dataset_2_repeat_factor = 1
def _create_dataset(self, filenames_list, mouths_list, labels_list):
# Creating constants for dataset
filenames = tf.constant(filenames_list)
mouths = tf.constant(mouths_list, dtype='float32')
labels = tf.constant(labels_list)
return tf.data.Dataset.from_tensor_slices(({'img': filenames, 'mouth': mouths}, labels))
def decode_img(self, img):
def decode_img(self, input, one_hot_encodings):
#input = (self.decode_img(img), input[1])
filepath = input['img']
# load the raw data from the file as a string
img = tf.io.read_file(filepath)
# convert the compressed string to a 3D uint8 tensor # convert the compressed string to a 3D uint8 tensor
img = tf.image.decode_jpeg(img, channels=self.CHANNELS) img = tf.image.decode_jpeg(img, channels=self.CHANNELS)
# Use `convert_image_dtype` to convert to floats in the [0,1] range. # Use `convert_image_dtype` to convert to floats in the [0,1] range.
img = tf.image.convert_image_dtype(img, tf.float32) img = tf.image.convert_image_dtype(img, tf.float32)
# resize the image to the desired size. # resize the image to the desired size.
return tf.image.resize(img, [self.IMG_WIDTH, self.IMG_HEIGHT]) input['img'] = tf.image.resize(img, [self.IMG_WIDTH, self.IMG_HEIGHT])
return input, one_hot_encodings
def process_path(self, file_path, one_hot_encodings):
# load the raw data from the file as a string
img = tf.io.read_file(file_path)
img = self.decode_img(img)
return img, one_hot_encodings
''' '''
Count the actual number of element for each class in each dataset. Count the actual number of element for each class in each dataset.
...@@ -112,8 +114,9 @@ class DataLoader(object): ...@@ -112,8 +114,9 @@ class DataLoader(object):
def _load_data_as_filename_lists(self, data_path): def _load_data_as_filename_lists(self, data_path):
filenames = [] filenames = []
mouth_masks = []
one_hot_encodings = [] one_hot_encodings = []
filenames_of_class = []
# Length of list including all files from previous directories (directory == one class of training data) # Length of list including all files from previous directories (directory == one class of training data)
last_length = 0 last_length = 0
for class_index, classname in enumerate(self.classes): for class_index, classname in enumerate(self.classes):
...@@ -122,15 +125,39 @@ class DataLoader(object): ...@@ -122,15 +125,39 @@ class DataLoader(object):
filenames.extend(filenames_of_class) filenames.extend(filenames_of_class)
# Add the Filenames to the list of all filenames # Add the Filenames to the list of all filenames
one_hot_encodings.extend([self._get_one_hot_for_label(classname)] * len(filenames_of_class)) one_hot_encodings.extend([self._get_one_hot_for_label(classname)] * len(filenames_of_class))
return filenames, one_hot_encodings
def _filename_and_labels_split(self, filenames_list, labels_list): if self.annotation:
for file_full_path in filenames:
filename = file_full_path.split('/')[-1]
annotations = self.annotation.get_annotations(filename, ['mouth'])
mask = np.ones(shape=(self.IMG_WIDTH, self.IMG_HEIGHT))
if len(annotations) > 0:
annotations = annotations[0][1] # if there is an annotation of the mouth present, extract the corners
width_orig, height_orig = annotations[0]
width_scalar = self.IMG_WIDTH / width_orig
height_scalar = self.IMG_HEIGHT / height_orig
x1, y1 = annotations[1]
x2, y2 = annotations[2]
x1 = round(x1 * width_scalar)
x2 = round(x2 * width_scalar)
y1 = round(y1 * height_scalar)
y2 = round(y2 * height_scalar)
mask[y1:y2+1,x1:x2+1] = 0
mouth_masks.append(mask)
else:
mouth_masks = [np.ones(shape=(self.IMG_WIDTH, self.IMG_HEIGHT))] * len(filenames) # create empty lists if no additional annotation for the mouth is provided
return filenames, mouth_masks, one_hot_encodings
def _filename_and_labels_split(self, filenames_list, mouths_list, labels_list):
splitting_element = math.ceil(len(filenames_list) * self.split_size) splitting_element = math.ceil(len(filenames_list) * self.split_size)
filenames2_list = filenames_list[splitting_element:] filenames2_list = filenames_list[splitting_element:]
filenames_list = filenames_list[:splitting_element] filenames_list = filenames_list[:splitting_element]
mouths2_list = mouths_list[splitting_element:]
mouths_list = mouths_list[:splitting_element]
labels2_list = labels_list[splitting_element:] labels2_list = labels_list[splitting_element:]
labels_list = labels_list[:splitting_element] labels_list = labels_list[:splitting_element]
return filenames_list, labels_list, filenames2_list, labels2_list return filenames_list, mouths_list, labels_list, filenames2_list, mouths2_list, labels2_list
def _get_weights(self, dataset_class_elem_count, dataset_size): def _get_weights(self, dataset_class_elem_count, dataset_size):
weights = {k: 0 for k in range(self.n_classes)} weights = {k: 0 for k in range(self.n_classes)}
...@@ -147,7 +174,7 @@ class DataLoader(object): ...@@ -147,7 +174,7 @@ class DataLoader(object):
def load_dataset(self): def load_dataset(self):
print('Loading dataset information') print('Loading dataset information')
# Load filenames and labels as two corresponding lists of strings # Load filenames and labels as two corresponding lists of strings
filenames_list, labels_list = self._load_data_as_filename_lists(self.data_path) filenames_list, mouths_list, labels_list = self._load_data_as_filename_lists(self.data_path)
# Shuffle the lists # Shuffle the lists
dataset_list = list(zip(filenames_list, labels_list)) dataset_list = list(zip(filenames_list, labels_list))
...@@ -157,8 +184,8 @@ class DataLoader(object): ...@@ -157,8 +184,8 @@ class DataLoader(object):
# Will skipp the splitting if it would result in a full copy of dataset in dataset_1 or dataset_2 # Will skipp the splitting if it would result in a full copy of dataset in dataset_1 or dataset_2
if self.split_size not in [0, 1]: if self.split_size not in [0, 1]:
filenames_list, labels_list, filenames2_list, labels2_list = self._filename_and_labels_split(filenames_list, labels_list) filenames_list, mouths_list, labels_list, filenames2_list, mouths2_list, labels2_list = self._filename_and_labels_split(filenames_list, labels_list)
self.dataset_2 = _create_dataset(filenames2_list, labels2_list) self.dataset_2 = self._create_dataset(filenames2_list, mouths2_list, labels2_list)
self.dataset_2 = self.dataset_2.shuffle(buffer_size=(len(filenames2_list))) self.dataset_2 = self.dataset_2.shuffle(buffer_size=(len(filenames2_list)))
self.dataset_2_size = len(filenames2_list) self.dataset_2_size = len(filenames2_list)
if self.dataset_2_size_min_count != 0: if self.dataset_2_size_min_count != 0:
...@@ -166,7 +193,7 @@ class DataLoader(object): ...@@ -166,7 +193,7 @@ class DataLoader(object):
self.dataset_2_class_elem_count = self.count_class_elements(labels2_list) self.dataset_2_class_elem_count = self.count_class_elements(labels2_list)
# Creating actual TF Dataset 1 # Creating actual TF Dataset 1
self.dataset_1 = _create_dataset(filenames_list, labels_list) self.dataset_1 = self._create_dataset(filenames_list, mouths_list, labels_list)
self.dataset_1 = self.dataset_1.shuffle(buffer_size=(len(filenames_list))) self.dataset_1 = self.dataset_1.shuffle(buffer_size=(len(filenames_list)))
self.dataset_1_size = len(filenames_list) self.dataset_1_size = len(filenames_list)
if self.dataset_1_size_min_count != 0: if self.dataset_1_size_min_count != 0:
...@@ -176,7 +203,7 @@ class DataLoader(object): ...@@ -176,7 +203,7 @@ class DataLoader(object):
# Load images # Load images
print('Loading images of dataset into memory.') print('Loading images of dataset into memory.')
self.dataset_1 = self.dataset_1.map(self.process_path, num_parallel_calls=self.NR_THREADS) self.dataset_1 = self.dataset_1.map(self.decode_img, num_parallel_calls=self.NR_THREADS)
self.dataset_1 = self.dataset_1.cache() self.dataset_1 = self.dataset_1.cache()
if self.AUGMENT: if self.AUGMENT:
for f in self.AUGMENTATIONS: for f in self.AUGMENTATIONS:
...@@ -184,7 +211,7 @@ class DataLoader(object): ...@@ -184,7 +211,7 @@ class DataLoader(object):
num_parallel_calls=self.NR_THREADS) num_parallel_calls=self.NR_THREADS)
self.dataset_1 = self.dataset_1.map(lambda x,y: (tf.clip_by_value(x, 0, 1),y)) self.dataset_1 = self.dataset_1.map(lambda x,y: (tf.clip_by_value(x, 0, 1),y))
if self.dataset_2: if self.dataset_2:
self.dataset_2 = self.dataset_2.map(self.process_path, num_parallel_calls=self.NR_THREADS) self.dataset_2 = self.dataset_2.map(self.decode_img, num_parallel_calls=self.NR_THREADS)
self.dataset_2 = self.dataset_2.cache() self.dataset_2 = self.dataset_2.cache()
for f in self.AUGMENTATIONS: for f in self.AUGMENTATIONS:
self.dataset_2 = self.dataset_2.map(lambda x, y: tf.cond(tf.random.uniform([], 0, 1) > 0.1, lambda: f(x, y, (self.IMG_WIDTH, self.IMG_HEIGHT)), self.dataset_2 = self.dataset_2.map(lambda x, y: tf.cond(tf.random.uniform([], 0, 1) > 0.1, lambda: f(x, y, (self.IMG_WIDTH, self.IMG_HEIGHT)),
......
...@@ -36,11 +36,11 @@ class Resnet50(NetworkBase): ...@@ -36,11 +36,11 @@ class Resnet50(NetworkBase):
x = GlobalAveragePooling2D()(base_model.output) x = GlobalAveragePooling2D()(base_model.output)
x = Dropout(0.4)(x) x = Dropout(0.4)(x)
x = Dense(128, activity_regularizer=regularizers.l2(0.005))(x) x = Dense(128)(x)
x = Dropout(0.2)(x) x = Dropout(0.2)(x)
out = Dense(self.NUM_CLASSES, activation='softmax', activity_regularizer=regularizers.l2(0.005), name='probs')(x) out = Dense(self.NUM_CLASSES, activation='softmax', name='prediction')(x)
model = Model(base_model.input, out) model = CAMModel(inputs=[input_tensor], outputs=[out, base_model.layers[-1].output])
for layer in model.layers[154:]: for layer in model.layers[154:]:
layer.trainable = True layer.trainable = True
......
...@@ -2,12 +2,14 @@ from datetime import datetime ...@@ -2,12 +2,14 @@ from datetime import datetime
from optparse import OptionParser from optparse import OptionParser
from PIL import ImageFile from PIL import ImageFile
from helpers.AnnotationLocationLoader import AnnotationLocationLoader
from network.Resnet50 import Resnet50 from network.Resnet50 import Resnet50
from network.VGG_16 import VGG_16 from network.VGG_16 import VGG_16
from helpers.DataLoader import DataLoader from helpers.DataLoader import DataLoader
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard, EarlyStopping
from tensorflow.keras.optimizers import RMSprop, SGD from tensorflow.keras.optimizers import RMSprop, SGD
from tensorflow.keras.models import load_model from tensorflow import version as tfv
import pathlib import pathlib
...@@ -19,6 +21,7 @@ parser = OptionParser() ...@@ -19,6 +21,7 @@ parser = OptionParser()
parser.add_option("-p", "--path_train", dest="train_path", help="Path to training input.", default="./input/train_data") parser.add_option("-p", "--path_train", dest="train_path", help="Path to training input.", default="./input/train_data")
parser.add_option("-t", "--path_test", dest="test_path", help="Path to test input.", default="./input/test_data") parser.add_option("-t", "--path_test", dest="test_path", help="Path to test input.", default="./input/test_data")
parser.add_option("-v", "--path_validation", dest="validation_path", help="Path to validation input.", default="./input/validation_data") parser.add_option("-v", "--path_validation", dest="validation_path", help="Path to validation input.", default="./input/validation_data")
parser.add_option("--mouth_annotation", dest="mouth_annotation_path", help="Path to folder containing the mouth annotation files.", default='./input/mouth_annotations/')
parser.add_option("--train_size", type="int", dest="train_size", default=200) parser.add_option("--train_size", type="int", dest="train_size", default=200)
parser.add_option("--validation_size", type="int", dest="validation_size", default=200) parser.add_option("--validation_size", type="int", dest="validation_size", default=200)
parser.add_option("-o", "--path_output", dest="output_path", help="Path to base folder for output data.", default='./out') parser.add_option("-o", "--path_output", dest="output_path", help="Path to base folder for output data.", default='./out')
...@@ -39,10 +42,15 @@ if not options.train_path: # if folder name is not given ...@@ -39,10 +42,15 @@ if not options.train_path: # if folder name is not given
parser.error('Error: path to training input must be specified. Pass --path_train to command line') parser.error('Error: path to training input must be specified. Pass --path_train to command line')
if not options.test_path: # if folder name is not given if not options.test_path: # if folder name is not given
parser.error('Error: path to test input must be specified. Pass --path_test to command line') parser.error('Error: path to test input must be specified. Pass --path_test to command line')
if not options.mouth_annotation_path: # if folder name is not given
parser.error('Error: path to mouth annotations must be specified. Pass --mouth_annotation to command line')
def get_curr_time(): def get_curr_time():
return datetime.now().strftime("%Y.%m.%d.%H.%M.%S") return datetime.now().strftime("%Y.%m.%d.%H.%M.%S")
print('Tensorflow Version {}'.format(tfv.VERSION))
## Arguments and Settings ## Arguments and Settings
train_dir = pathlib.Path(options.train_path) train_dir = pathlib.Path(options.train_path)
...@@ -64,6 +72,8 @@ batch_size = options.batch_size ...@@ -64,6 +72,8 @@ batch_size = options.batch_size
min_size_train_dataset = options.train_size min_size_train_dataset = options.train_size
min_size_validation_dataset = options.validation_size min_size_validation_dataset = options.validation_size
annot_loader = AnnotationLocationLoader(images_base_folder=train_dir, mouth_annotations_folder=options.mouth_annotation_path)
# Load the dataset into TF Datasets # Load the dataset into TF Datasets
# Training Data # Training Data
train_loader = DataLoader(data_path=train_dir, train_loader = DataLoader(data_path=train_dir,
...@@ -72,7 +82,9 @@ train_loader = DataLoader(data_path=train_dir, ...@@ -72,7 +82,9 @@ train_loader = DataLoader(data_path=train_dir,
img_width=img_width, img_width=img_width,
img_height=img_height, img_height=img_height,
channels=channels, channels=channels,
augment=True # TODO reverse augment to True after testing
augment=False,
annotation=annot_loader
)