From e3c955f5dc8b81dc0760596fcf351920fcd9908d Mon Sep 17 00:00:00 2001 From: sjjsmuel Date: Mon, 6 Jul 2020 10:49:28 +0200 Subject: [PATCH] DS Shuffle optional --- helpers/DataLoader.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/helpers/DataLoader.py b/helpers/DataLoader.py index c4183c2..8c883f7 100644 --- a/helpers/DataLoader.py +++ b/helpers/DataLoader.py @@ -90,7 +90,7 @@ def clip(x, y): ''' class DataLoader(object): - def __init__(self, data_path, batch_size, img_width, img_height, channels, should_size_dataset1 = 0, should_size_dataset2=0, split_size=0.0, augment=False, annotation=None): + def __init__(self, data_path, batch_size, img_width, img_height, channels, should_size_dataset1 = 0, should_size_dataset2=0, split_size=0.0, augment=False, annotation=None, shuffle=True): self.data_path = data_path self.batch_size = batch_size self.split_size = split_size @@ -102,6 +102,7 @@ class DataLoader(object): self.AUGMENT = augment self.AUGMENTATIONS = [flip, color, zoom, rotate] self.annotation = annotation + self.shuffle = shuffle self.classes = [item.name for item in data_path.glob('*') if item.name != '.DS_Store'] self.n_classes = len(self.classes) @@ -219,16 +220,18 @@ class DataLoader(object): filenames_list, mouths_list, labels_list = self._load_data_as_filename_lists(self.data_path) # Shuffle the lists - dataset_list = list(zip(filenames_list, mouths_list, labels_list)) - random.shuffle(dataset_list) - filenames_list, mouths_list, labels_list = zip(*dataset_list) + if self.shuffle: + dataset_list = list(zip(filenames_list, mouths_list, labels_list)) + random.shuffle(dataset_list) + filenames_list, mouths_list, labels_list = zip(*dataset_list) # Will skipp the splitting if it would result in a full copy of dataset in dataset_1 or dataset_2 if self.split_size not in [0, 1]: filenames_list, mouths_list, labels_list, filenames2_list, mouths2_list, labels2_list = self._filename_and_labels_split(filenames_list, labels_list) self.dataset_2 = self._create_dataset(filenames2_list, mouths2_list, labels2_list) - self.dataset_2 = self.dataset_2.shuffle(buffer_size=(len(filenames2_list))) + if self.shuffle: + self.dataset_2 = self.dataset_2.shuffle(buffer_size=(len(filenames2_list))) self.dataset_2_size = len(filenames2_list) if self.dataset_2_size_min_count != 0: self.dataset_2_repeat_factor = math.ceil(self.dataset_2_size_min_count / self.dataset_2_size) @@ -236,7 +239,8 @@ class DataLoader(object): # Creating actual TF Dataset 1 self.dataset_1 = self._create_dataset(filenames_list, mouths_list, labels_list) - self.dataset_1 = self.dataset_1.shuffle(buffer_size=(len(filenames_list))) + if self.shuffle: + self.dataset_1 = self.dataset_1.shuffle(buffer_size=(len(filenames_list))) self.dataset_1_size = len(filenames_list) if self.dataset_1_size_min_count != 0: self.dataset_1_repeat_factor = math.ceil(self.dataset_1_size_min_count / self.dataset_1_size) @@ -261,10 +265,14 @@ class DataLoader(object): lambda: (x, y)), num_parallel_calls=self.NR_THREADS) self.dataset_2 = self.dataset_2.map(lambda x, y: clip(x, y)) - self.dataset_1 = self.dataset_1.shuffle(self.dataset_1_size).batch(self.batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE) + if self.shuffle: + self.dataset_1 = self.dataset_1.shuffle(self.dataset_1_size) + self.dataset_1 = self.dataset_1.batch(self.batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE) if self.dataset_2: - self.dataset_2 = self.dataset_2.shuffle(self.dataset_2_size).batch(self.batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE) + if self.shuffle: + self.dataset_2 = self.dataset_2.shuffle(self.dataset_2_size) + self.dataset_2 = self.dataset_2.batch(self.batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE) return self.dataset_1, self.dataset_2 else: return self.dataset_1 -- GitLab