Commit d8a97133 authored by sjjsmuel's avatar sjjsmuel

add cache and sizing for augmentation

parent e3bc6ed6
# Data dir
data/training_data/
data/test_data/
data/test_data_mini/
data/training_data_small/
data/training_data_mini/
......
......@@ -53,7 +53,7 @@ def zoom(x: tf.Tensor, label, size):
class DataLoader(object):
def __init__(self, data_path, batch_size, img_width, img_height, channels, batch_count_dataset1, batch_count_dataset2=0, split_size=0.0, augment=False):
def __init__(self, data_path, batch_size, img_width, img_height, channels, should_size_dataset1 = 0, should_size_dataset2=0, split_size=0.0, augment=False):
self.data_path = data_path
self.batch_size = batch_size
self.split_size = split_size
......@@ -74,8 +74,11 @@ class DataLoader(object):
self.dataset_2_size = -1
self.dataset_1_class_elem_count = {}
self.dataset_2_class_elem_count = {}
self.dataset_1_batch_count = batch_count_dataset1
self.dataset_2_batch_count = batch_count_dataset2
self.dataset_1_size_min_count = should_size_dataset1
self.dataset_2_size_min_count = should_size_dataset2
self.dataset_1_repeat_factor = 1
self.dataset_2_repeat_factor = 1
def decode_img(self, img):
# convert the compressed string to a 3D uint8 tensor
......@@ -158,18 +161,23 @@ class DataLoader(object):
self.dataset_2 = _create_dataset(filenames2_list, labels2_list)
self.dataset_2 = self.dataset_2.shuffle(buffer_size=(len(filenames2_list)))
self.dataset_2_size = len(filenames2_list)
if self.dataset_2_size_min_count != 0:
self.dataset_2_repeat_factor = math.ceil(self.dataset_2_size_min_count / self.dataset_2_size)
self.dataset_2_class_elem_count = self.count_class_elements(labels2_list)
# Creating actual TF Dataset 1
self.dataset_1 = _create_dataset(filenames_list, labels_list)
self.dataset_1 = self.dataset_1.shuffle(buffer_size=(len(filenames_list)))
self.dataset_1_size = len(filenames_list)
if self.dataset_1_size_min_count != 0:
self.dataset_1_repeat_factor = math.ceil(self.dataset_1_size_min_count / self.dataset_1_size)
self.dataset_1_class_elem_count = self.count_class_elements(labels_list)
# Load images
print('Loading images of dataset into memory.')
self.dataset_1 = self.dataset_1.map(self.process_path, num_parallel_calls=self.NR_THREADS)
self.dataset_1 = self.dataset_1.cache()
if self.AUGMENT:
for f in self.AUGMENTATIONS:
self.dataset_1 = self.dataset_1.map(lambda x,y: tf.cond(tf.random.uniform([], 0, 1) > 0.6, lambda: f(x,y, (self.IMG_WIDTH, self.IMG_HEIGHT)), lambda: (x,y)),
......@@ -177,15 +185,16 @@ class DataLoader(object):
self.dataset_1 = self.dataset_1.map(lambda x,y: (tf.clip_by_value(x, 0, 1),y))
if self.dataset_2:
self.dataset_2 = self.dataset_2.map(self.process_path, num_parallel_calls=self.NR_THREADS)
self.dataset_2 = self.dataset_2.cache()
for f in self.AUGMENTATIONS:
self.dataset_2 = self.dataset_2.map(lambda x, y: tf.cond(tf.random.uniform([], 0, 1) > 0.6, lambda: f(x, y, (self.IMG_WIDTH, self.IMG_HEIGHT)),
lambda: (x, y)), num_parallel_calls=self.NR_THREADS)
self.dataset_2 = self.dataset_2.map(lambda x, y: (tf.clip_by_value(x, 0, 1), y))
self.dataset_1 = self.dataset_1.repeat(self.batch_size*self.dataset_1_batch_count).batch(self.batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
self.dataset_1 = self.dataset_1.repeat(self.dataset_1_repeat_factor).batch(self.batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
if self.dataset_2:
self.dataset_2 = self.dataset_2.repeat(self.batch_size*self.dataset_2_batch_count).batch(self.batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
self.dataset_2 = self.dataset_2.repeat(self.dataset_2_repeat_factor).batch(self.batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
return self.dataset_1, self.dataset_2
else:
return self.dataset_1
......@@ -16,6 +16,8 @@ ImageFile.LOAD_TRUNCATED_IMAGES = True
parser = OptionParser()
parser.add_option("-p", "--path_train", dest="train_path", help="Path to training data.", default="./data/train_data")
parser.add_option("--train_size", type="int", dest="train_size", default=200)
parser.add_option("--validation_size", type="int", dest="validation_size", default=200)
parser.add_option("-t", "--path_test", dest="test_path", help="Path to test data.", default="./data/test_data")
parser.add_option("-o", "--path_output", dest="output_path", help="Path to base folder for output data.", default='./out')
parser.add_option("--base_network_file", dest="base_net_file", help="Optional link to local file of Resnet 152 V2 for TF without top.")
......@@ -54,19 +56,15 @@ img_height = options.height
channels = 3
n_classes = 2
batch_size = options.batch_size
batch_count_train_dataset = 2000 // batch_size
batch_count_validation_dataset = 500 // batch_size
batch_count_test_dataset = 500 // batch_size
batch_count_train_dataset = 200 // batch_size
batch_count_validation_dataset = 50 // batch_size
batch_count_test_dataset = 50 // batch_size
min_size_train_dataset = options.train_size
min_size_validation_dataset = options.validation_size
# Load the dataset into TF Datasets
# Training Data
train_loader = DataLoader(data_path=train_dir,
batch_size=options.batch_size,
batch_count_dataset1=batch_count_train_dataset,
should_size_dataset1=min_size_train_dataset,
img_width=img_width,
img_height=img_height,
channels=channels,
......@@ -75,15 +73,14 @@ train_loader = DataLoader(data_path=train_dir,
train_dataset = train_loader.load_dataset()
# Test Data
test_loader = DataLoader( data_path=test_dir,
batch_size=options.batch_size,
batch_count_dataset1=batch_count_validation_dataset,
batch_count_dataset2=batch_count_test_dataset,
img_width=img_width,
img_height=img_height,
channels=channels,
split_size=options.split_size,
augment=True
test_loader = DataLoader(data_path=test_dir,
batch_size=options.batch_size,
should_size_dataset1=min_size_validation_dataset,
img_width=img_width,
img_height=img_height,
channels=channels,
split_size=options.split_size,
augment=True
)
validation_dataset, test_dataset = test_loader.load_dataset()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment