DataLoader.py 12.5 KB
Newer Older
1
import glob
2 3
import math
import random
sjjsmuel's avatar
sjjsmuel committed
4
import numpy as np
5 6
import tensorflow as tf

7 8


sjjsmuel's avatar
sjjsmuel committed
9 10 11
'''
    Augmentation functions    
'''
sjjsmuel's avatar
sjjsmuel committed
12
def flip(x, label, size):
sjjsmuel's avatar
sjjsmuel committed
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
    img = x['img']
    mask = x['mouth']
    mask = tf.expand_dims(mask, -1)

    random_flr = tf.random.uniform(shape=[], minval=0, maxval=2, dtype=tf.int32)
    img = tf.cond(random_flr == 0, lambda: img, lambda: tf.image.flip_left_right(img))
    mask = tf.cond(random_flr == 0, lambda: mask, lambda: tf.image.flip_left_right(mask))

    random_fud = tf.random.uniform(shape=[], minval=0, maxval=2, dtype=tf.int32)
    img = tf.cond(random_fud == 0, lambda: img, lambda: tf.image.flip_up_down(img))
    mask = tf.cond(random_fud == 0, lambda: mask, lambda: tf.image.flip_up_down(mask))
    mask = tf.squeeze(mask)

    return {'img': img, 'mouth': mask}, label

def color(x, label, size):
    img = x['img']
    mask = x['mouth']
    img = tf.image.random_hue(img, 0.1)
    img = tf.image.random_saturation(img, 0.5, 1.7)
    img = tf.image.random_brightness(img, 0.2)
    img = tf.image.random_contrast(img, 0.5, 1.5)
    return {'img': img, 'mouth': mask}, label

def rotate(x, label, size):
    img = x['img']
    mask = x['mouth']
    mask = tf.expand_dims(mask, -1)
sjjsmuel's avatar
sjjsmuel committed
41 42
    # rotate either 0 , 90, 180 or 270 degrees (0 or 1 or 2 or 3 times 90 degrees)
    random_value = tf.random.uniform(shape=[], minval=0, maxval=4, dtype=tf.int32)
sjjsmuel's avatar
sjjsmuel committed
43 44 45 46 47 48
    img = tf.image.rot90(img, random_value)
    mask = tf.image.rot90(mask, random_value)
    mask = tf.squeeze(mask)
    return {'img': img, 'mouth': mask}, label

def zoom(x, label, size):
sjjsmuel's avatar
sjjsmuel committed
49
    # Generate 20 crop settings, ranging from a 1% to 30% crop.
sjjsmuel's avatar
sjjsmuel committed
50
    scales = list(np.arange(0.6, 1.0, 0.03))
sjjsmuel's avatar
sjjsmuel committed
51 52 53 54 55 56 57
    boxes = np.zeros((len(scales), 4))

    for i, scale in enumerate(scales):
        x1 = y1 = 0.5 - (0.5 * scale)
        x2 = y2 = 0.5 + (0.5 * scale)
        boxes[i] = [x1, y1, x2, y2]

sjjsmuel's avatar
sjjsmuel committed
58 59 60 61
    def random_crop(dataset_dict):
        img = x['img']
        mask = x['mouth']
        mask = tf.expand_dims(mask, -1)
sjjsmuel's avatar
sjjsmuel committed
62
        # Create different crops for an image
sjjsmuel's avatar
sjjsmuel committed
63 64 65 66 67 68 69 70
        crops_image = tf.image.crop_and_resize([img], boxes=boxes, box_indices=np.zeros(len(scales)), crop_size=size)
        crops_mask = tf.image.crop_and_resize([mask], boxes=boxes, box_indices=np.zeros(len(scales)), crop_size=size)
        # Select a random crop
        random_index = tf.random.uniform(shape=[], minval=0, maxval=len(scales), dtype=tf.int32)
        img = crops_image[random_index]
        mask = crops_mask[random_index]
        mask = tf.squeeze(mask)
        return {'img':img, 'mouth': mask}
sjjsmuel's avatar
sjjsmuel committed
71 72 73

    choice = tf.random.uniform(shape=[], minval=0., maxval=1., dtype=tf.float32)

sjjsmuel's avatar
sjjsmuel committed
74 75
    # Only apply cropping 80% of the time
    return tf.cond(choice < 0.8, lambda: x, lambda: random_crop(x)), label
76

sjjsmuel's avatar
sjjsmuel committed
77 78 79 80 81 82 83 84 85 86 87 88 89 90
'''
    Clipping helper
'''
def clip(x, y):
    img = x['img']
    mask = x['mouth']
    img = tf.clip_by_value(img, 0, 1)
    return {'img': img, 'mouth': mask}, y


'''
    Data Loader Class 
    provides main functions for providing the data in the correct for for the training and testing
'''
91 92
class DataLoader(object):

sjjsmuel's avatar
sjjsmuel committed
93
    def __init__(self, data_path, batch_size, img_width, img_height, channels, should_size_dataset1 = 0, should_size_dataset2=0, split_size=0.0, augment=False, annotation=None, shuffle=True):
94 95
        self.data_path = data_path
        self.batch_size = batch_size
96
        self.split_size = split_size
97 98 99 100 101
        # -1 means detecting automatically how many threads are possible
        self.NR_THREADS = -1
        self.IMG_WIDTH = img_width
        self.IMG_HEIGHT = img_height
        self.CHANNELS = channels
sjjsmuel's avatar
sjjsmuel committed
102
        self.AUGMENT = augment
sjjsmuel's avatar
sjjsmuel committed
103
        self.AUGMENTATIONS = [flip, color, zoom, rotate]
sjjsmuel's avatar
sjjsmuel committed
104
        self.annotation = annotation
sjjsmuel's avatar
sjjsmuel committed
105
        self.shuffle = shuffle
106 107 108 109

        self.classes = [item.name for item in data_path.glob('*') if item.name != '.DS_Store']
        self.n_classes = len(self.classes)

110 111 112 113
        self.dataset_1 = None
        self.dataset_2 = None
        self.dataset_1_size = -1
        self.dataset_2_size = -1
114 115
        self.dataset_1_class_elem_count = {}
        self.dataset_2_class_elem_count = {}
116 117 118 119 120
        self.dataset_1_size_min_count = should_size_dataset1
        self.dataset_2_size_min_count = should_size_dataset2
        self.dataset_1_repeat_factor = 1
        self.dataset_2_repeat_factor = 1

sjjsmuel's avatar
sjjsmuel committed
121 122 123 124 125 126
    def _create_dataset(self, filenames_list, mouths_list, labels_list):
        # Creating constants for dataset
        filenames = tf.constant(filenames_list)
        mouths = tf.constant(mouths_list, dtype='float32')
        labels = tf.constant(labels_list)
        return tf.data.Dataset.from_tensor_slices(({'img': filenames, 'mouth': mouths}, labels))
127

sjjsmuel's avatar
sjjsmuel committed
128 129 130 131 132 133

    def decode_img(self, input, one_hot_encodings):
        #input = (self.decode_img(img), input[1])
        filepath = input['img']
        # load the raw data from the file as a string
        img = tf.io.read_file(filepath)
134 135 136 137 138
        # convert the compressed string to a 3D uint8 tensor
        img = tf.image.decode_jpeg(img, channels=self.CHANNELS)
        # Use `convert_image_dtype` to convert to floats in the [0,1] range.
        img = tf.image.convert_image_dtype(img, tf.float32)
        # resize the image to the desired size.
sjjsmuel's avatar
sjjsmuel committed
139 140
        input['img'] = tf.image.resize(img, [self.IMG_WIDTH, self.IMG_HEIGHT])
        return input, one_hot_encodings
141 142 143 144 145 146 147 148 149 150

    '''
        Count the actual number of element for each class in each dataset.
        This is the base information to calculate the right weights to fix the imbalance of the dataset.
    '''
    def count_class_elements(self, one_hot_encodings):
        class_counts = {k: 0 for k in range(self.n_classes)}
        for one_hot_encoding in one_hot_encodings:
            class_counts[one_hot_encoding.index(1)] += 1
        return class_counts
151

sjjsmuel's avatar
sjjsmuel committed
152 153
    def _get_one_hot_for_label(self, classname):
        index = self.classes.index(classname)
154
        one_hot = [0] * self.n_classes
sjjsmuel's avatar
sjjsmuel committed
155 156 157
        one_hot[index] = 1
        return one_hot

158 159
    def _load_data_as_filename_lists(self, data_path):
        filenames = []
sjjsmuel's avatar
sjjsmuel committed
160
        mouth_masks = []
161
        one_hot_encodings = []
sjjsmuel's avatar
sjjsmuel committed
162

163 164
        # Length of list including all files from previous directories (directory == one class of training data)
        last_length = 0
165 166 167 168 169 170 171
        for class_index, classname in enumerate(self.classes):
            filenames_of_class = [item for item in glob.glob(str(data_path) + '/' + classname + '/*')]
            # Add the Filenames to the list of all filenames
            filenames.extend(filenames_of_class)
            # Add the Filenames to the list of all filenames
            one_hot_encodings.extend([self._get_one_hot_for_label(classname)] * len(filenames_of_class))

sjjsmuel's avatar
sjjsmuel committed
172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195
        if self.annotation:
            for file_full_path in filenames:
                filename = file_full_path.split('/')[-1]
                annotations = self.annotation.get_annotations(filename, ['mouth'])
                mask = np.ones(shape=(self.IMG_WIDTH, self.IMG_HEIGHT))
                if len(annotations) > 0:
                    annotations = annotations[0][1] # if there is an annotation of the mouth present, extract the corners
                    width_orig, height_orig = annotations[0]
                    width_scalar = self.IMG_WIDTH / width_orig
                    height_scalar = self.IMG_HEIGHT / height_orig
                    x1, y1 = annotations[1]
                    x2, y2 = annotations[2]
                    x1 = round(x1 * width_scalar)
                    x2 = round(x2 * width_scalar)
                    y1 = round(y1 * height_scalar)
                    y2 = round(y2 * height_scalar)

                    mask[y1:y2+1,x1:x2+1] = 0
                mouth_masks.append(mask)
        else:
            mouth_masks = [np.ones(shape=(self.IMG_WIDTH, self.IMG_HEIGHT))] * len(filenames) # create empty lists if no additional annotation for the mouth is provided
        return filenames, mouth_masks, one_hot_encodings

    def _filename_and_labels_split(self, filenames_list, mouths_list, labels_list):
196 197 198
        splitting_element = math.ceil(len(filenames_list) * self.split_size)
        filenames2_list = filenames_list[splitting_element:]
        filenames_list = filenames_list[:splitting_element]
sjjsmuel's avatar
sjjsmuel committed
199 200
        mouths2_list = mouths_list[splitting_element:]
        mouths_list = mouths_list[:splitting_element]
201 202
        labels2_list = labels_list[splitting_element:]
        labels_list = labels_list[:splitting_element]
sjjsmuel's avatar
sjjsmuel committed
203
        return filenames_list, mouths_list, labels_list, filenames2_list, mouths2_list, labels2_list
204 205 206 207 208 209 210 211 212 213 214 215 216

    def _get_weights(self, dataset_class_elem_count, dataset_size):
        weights = {k: 0 for k in range(self.n_classes)}
        for key, value in  dataset_class_elem_count.items():
            weights[key]  = dataset_size / (self.n_classes * value)
        return weights

    def get_weights_dataset1(self):
        return self._get_weights(self.dataset_1_class_elem_count, self.dataset_1_size)

    def get_weights_dataset2(self):
        return self._get_weights(self.dataset_2_class_elem_count, self.dataset_2_size)

217 218 219
    def load_dataset(self):
        print('Loading dataset information')
        # Load filenames and labels as two corresponding lists of strings
sjjsmuel's avatar
sjjsmuel committed
220
        filenames_list, mouths_list, labels_list = self._load_data_as_filename_lists(self.data_path)
221 222

        # Shuffle the lists
sjjsmuel's avatar
sjjsmuel committed
223 224 225 226
        if self.shuffle:
            dataset_list = list(zip(filenames_list, mouths_list, labels_list))
            random.shuffle(dataset_list)
            filenames_list, mouths_list, labels_list = zip(*dataset_list)
227 228 229 230


        # Will skipp the splitting if it would result in a full copy of dataset in dataset_1 or dataset_2
        if self.split_size not in [0, 1]:
sjjsmuel's avatar
sjjsmuel committed
231 232
            filenames_list, mouths_list, labels_list, filenames2_list, mouths2_list, labels2_list = self._filename_and_labels_split(filenames_list, labels_list)
            self.dataset_2 = self._create_dataset(filenames2_list, mouths2_list, labels2_list)
sjjsmuel's avatar
sjjsmuel committed
233 234
            if self.shuffle:
                self.dataset_2 = self.dataset_2.shuffle(buffer_size=(len(filenames2_list)))
235
            self.dataset_2_size = len(filenames2_list)
236 237
            if self.dataset_2_size_min_count != 0:
                self.dataset_2_repeat_factor = math.ceil(self.dataset_2_size_min_count / self.dataset_2_size)
238 239 240
            self.dataset_2_class_elem_count = self.count_class_elements(labels2_list)

        # Creating actual TF Dataset 1
sjjsmuel's avatar
sjjsmuel committed
241
        self.dataset_1 = self._create_dataset(filenames_list, mouths_list, labels_list)
sjjsmuel's avatar
sjjsmuel committed
242 243
        if self.shuffle:
            self.dataset_1 = self.dataset_1.shuffle(buffer_size=(len(filenames_list)))
244
        self.dataset_1_size = len(filenames_list)
245 246
        if self.dataset_1_size_min_count != 0:
            self.dataset_1_repeat_factor = math.ceil(self.dataset_1_size_min_count / self.dataset_1_size)
247

248
        self.dataset_1_class_elem_count = self.count_class_elements(labels_list)
249 250 251

        # Load images
        print('Loading images of dataset into memory.')
sjjsmuel's avatar
sjjsmuel committed
252
        self.dataset_1 = self.dataset_1.map(self.decode_img, num_parallel_calls=self.NR_THREADS)
sjjsmuel's avatar
sjjsmuel committed
253
        self.dataset_1 = self.dataset_1.cache().repeat(self.dataset_1_repeat_factor)
sjjsmuel's avatar
sjjsmuel committed
254

sjjsmuel's avatar
sjjsmuel committed
255 256
        if self.AUGMENT:
            for f in self.AUGMENTATIONS:
sjjsmuel's avatar
sjjsmuel committed
257
                self.dataset_1 = self.dataset_1.map(lambda x,y: tf.cond(tf.random.uniform([], 0, 1) > 0.1, lambda: f(x,y, (self.IMG_WIDTH, self.IMG_HEIGHT)), lambda: (x,y)),
sjjsmuel's avatar
sjjsmuel committed
258
                                      num_parallel_calls=self.NR_THREADS)
sjjsmuel's avatar
sjjsmuel committed
259
            self.dataset_1 = self.dataset_1.map(lambda x,y: clip(x,y))
sjjsmuel's avatar
sjjsmuel committed
260
            if self.dataset_2:
sjjsmuel's avatar
sjjsmuel committed
261
                self.dataset_2 = self.dataset_2.map(self.decode_img, num_parallel_calls=self.NR_THREADS)
sjjsmuel's avatar
sjjsmuel committed
262
                self.dataset_2 = self.dataset_2.cache().repeat(self.dataset_2_repeat_factor)
sjjsmuel's avatar
sjjsmuel committed
263
                for f in self.AUGMENTATIONS:
sjjsmuel's avatar
sjjsmuel committed
264
                    self.dataset_2 = self.dataset_2.map(lambda x, y: tf.cond(tf.random.uniform([], 0, 1) > 0.1, lambda: f(x, y, (self.IMG_WIDTH, self.IMG_HEIGHT)),
sjjsmuel's avatar
sjjsmuel committed
265
                                                                             lambda: (x, y)), num_parallel_calls=self.NR_THREADS)
sjjsmuel's avatar
sjjsmuel committed
266
                self.dataset_2 = self.dataset_2.map(lambda x, y: clip(x, y))
sjjsmuel's avatar
sjjsmuel committed
267

sjjsmuel's avatar
sjjsmuel committed
268 269 270
        if self.shuffle:
            self.dataset_1 = self.dataset_1.shuffle(self.dataset_1_size)
        self.dataset_1 = self.dataset_1.batch(self.batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
sjjsmuel's avatar
sjjsmuel committed
271

272
        if self.dataset_2:
sjjsmuel's avatar
sjjsmuel committed
273 274 275
            if self.shuffle:
                self.dataset_2 = self.dataset_2.shuffle(self.dataset_2_size)
            self.dataset_2 = self.dataset_2.batch(self.batch_size, drop_remainder=True).prefetch(tf.data.experimental.AUTOTUNE)
276 277 278
            return self.dataset_1, self.dataset_2
        else:
            return self.dataset_1