Commit e3c955f5 authored by sjjsmuel's avatar sjjsmuel

DS Shuffle optional

parent bf9c5ff3
......@@ -90,7 +90,7 @@ def clip(x, y):
'''
class DataLoader(object):
def __init__(self, data_path, batch_size, img_width, img_height, channels, should_size_dataset1 = 0, should_size_dataset2=0, split_size=0.0, augment=False, annotation=None):
def __init__(self, data_path, batch_size, img_width, img_height, channels, should_size_dataset1 = 0, should_size_dataset2=0, split_size=0.0, augment=False, annotation=None, shuffle=True):
self.data_path = data_path
self.batch_size = batch_size
self.split_size = split_size
......@@ -102,6 +102,7 @@ class DataLoader(object):
self.AUGMENT = augment
self.AUGMENTATIONS = [flip, color, zoom, rotate]
self.annotation = annotation
self.shuffle = shuffle
self.classes = [item.name for item in data_path.glob('*') if item.name != '.DS_Store']
self.n_classes = len(self.classes)
......@@ -219,16 +220,18 @@ class DataLoader(object):
filenames_list, mouths_list, labels_list = self._load_data_as_filename_lists(self.data_path)
# Shuffle the lists
dataset_list = list(zip(filenames_list, mouths_list, labels_list))
random.shuffle(dataset_list)
filenames_list, mouths_list, labels_list = zip(*dataset_list)
if self.shuffle:
dataset_list = list(zip(filenames_list, mouths_list, labels_list))
random.shuffle(dataset_list)
filenames_list, mouths_list, labels_list = zip(*dataset_list)
# Will skipp the splitting if it would result in a full copy of dataset in dataset_1 or dataset_2
if self.split_size not in [0, 1]:
filenames_list, mouths_list, labels_list, filenames2_list, mouths2_list, labels2_list = self._filename_and_labels_split(filenames_list, labels_list)
self.dataset_2 = self._create_dataset(filenames2_list, mouths2_list, labels2_list)
self.dataset_2 = self.dataset_2.shuffle(buffer_size=(len(filenames2_list)))
if self.shuffle:
self.dataset_2 = self.dataset_2.shuffle(buffer_size=(len(filenames2_list)))
self.dataset_2_size = len(filenames2_list)
if self.dataset_2_size_min_count != 0:
self.dataset_2_repeat_factor = math.ceil(self.dataset_2_size_min_count / self.dataset_2_size)
......@@ -236,7 +239,8 @@ class DataLoader(object):
# Creating actual TF Dataset 1
self.dataset_1 = self._create_dataset(filenames_list, mouths_list, labels_list)
self.dataset_1 = self.dataset_1.shuffle(buffer_size=(len(filenames_list)))
if self.shuffle:
self.dataset_1 = self.dataset_1.shuffle(buffer_size=(len(filenames_list)))
self.dataset_1_size = len(filenames_list)
if self.dataset_1_size_min_count != 0:
self.dataset_1_repeat_factor = math.ceil(self.dataset_1_size_min_count / self.dataset_1_size)
......@@ -261,10 +265,14 @@ class DataLoader(object):
lambda: (x, y)), num_parallel_calls=self.NR_THREADS)
self.dataset_2 = self.dataset_2.map(lambda x, y: clip(x, y))
self.dataset_1 = self.dataset_1.shuffle(self.dataset_1_size).batch(self.batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
if self.shuffle:
self.dataset_1 = self.dataset_1.shuffle(self.dataset_1_size)
self.dataset_1 = self.dataset_1.batch(self.batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
if self.dataset_2:
self.dataset_2 = self.dataset_2.shuffle(self.dataset_2_size).batch(self.batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
if self.shuffle:
self.dataset_2 = self.dataset_2.shuffle(self.dataset_2_size)
self.dataset_2 = self.dataset_2.batch(self.batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
return self.dataset_1, self.dataset_2
else:
return self.dataset_1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment