Commit 78bac83b authored by sjjsmuel's avatar sjjsmuel

revert to old order for 'cache'

parent 6934c408
......@@ -246,7 +246,7 @@ class DataLoader(object):
# Load images
print('Loading images of dataset into memory.')
self.dataset_1 = self.dataset_1.map(self.decode_img, num_parallel_calls=self.NR_THREADS)
self.dataset_1 = self.dataset_1.repeat(self.dataset_1_repeat_factor)
self.dataset_1 = self.dataset_1.cache().repeat(self.dataset_1_repeat_factor)
if self.AUGMENT:
for f in self.AUGMENTATIONS:
......@@ -255,16 +255,16 @@ class DataLoader(object):
self.dataset_1 = self.dataset_1.map(lambda x,y: clip(x,y))
if self.dataset_2:
self.dataset_2 = self.dataset_2.map(self.decode_img, num_parallel_calls=self.NR_THREADS)
self.dataset_2 = self.dataset_2.repeat(self.dataset_2_repeat_factor)
self.dataset_2 = self.dataset_2.cache().repeat(self.dataset_2_repeat_factor)
for f in self.AUGMENTATIONS:
self.dataset_2 = self.dataset_2.map(lambda x, y: tf.cond(tf.random.uniform([], 0, 1) > 0.1, lambda: f(x, y, (self.IMG_WIDTH, self.IMG_HEIGHT)),
lambda: (x, y)), num_parallel_calls=self.NR_THREADS)
self.dataset_2 = self.dataset_2.map(lambda x, y: clip(x, y))
self.dataset_1 = self.dataset_1.cache().shuffle(self.dataset_1_size).batch(self.batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
self.dataset_1 = self.dataset_1.shuffle(self.dataset_1_size).batch(self.batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
if self.dataset_2:
self.dataset_2 = self.dataset_2.cache().shuffle(self.dataset_2_size).batch(self.batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
self.dataset_2 = self.dataset_2.shuffle(self.dataset_2_size).batch(self.batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
return self.dataset_1, self.dataset_2
else:
return self.dataset_1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment