generate_tf_records.py 5.65 KB
Newer Older
sjjsmuel's avatar
sjjsmuel committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
from PIL import Image
import glob
import helpers.dataset_util as dataset_util
import hashlib
import io
from helpers.AnnotationLocationLoader import AnnotationLocationLoader
from pathlib import Path
import tensorflow as tf

DEBUG = False


def create_tf_example(filename, annotations, base_folder, labelmap):
    search_string = str(base_folder) + '/**/' + filename
    #there should only exist one file that matches
    full_image_path = glob.glob(search_string, recursive=True)[0]

    #pos examples without localisation information should be skipped
    if len(annotations) == 0 and Path(full_image_path).parent.name == 'caries':
        return None

    image = Image.open(full_image_path)
    width, height = image.size
    with tf.io.gfile.GFile(full_image_path, 'rb') as fid:
        encoded_jpg = fid.read()
    encoded_jpg_io = io.BytesIO(encoded_jpg)
    image = Image.open(encoded_jpg_io)
    if image.format != 'JPEG':
        raise ValueError('Image format not JPEG')
    image_format = image.format
    key = hashlib.sha256(encoded_jpg).hexdigest()
    #image_format = b'jpg'
    xmins = []
    xmaxs = []
    ymins = []
    ymaxs = []
    classes_text = []
    classes = []

    if DEBUG: print('size:', width, height)

    for annotation in annotations:
        annotated_class = annotation[0].upper()
        box = annotation[1]
        
        xmins.append(box[0][0] / width)
        xmaxs.append(box[1][0] / width)
        ymins.append(box[0][1] / height)
        ymaxs.append(box[1][1] / height)

        classes_text.append(annotated_class.encode('utf8'))
        classes.append(labelmap[annotated_class]) # Class as Number

    tf_example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': dataset_util.int64_feature(height),
        'image/width': dataset_util.int64_feature(width),
        #'image/filename': dataset_util.bytes_feature(str.encode(filename)),
        'image/filename': dataset_util.bytes_feature(filename.encode('utf8')),
        #'image/source_id': dataset_util.bytes_feature(str.encode(filename)),
        'image/source_id': dataset_util.bytes_feature(filename.encode('utf8')),
        'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
        'image/encoded': dataset_util.bytes_feature(encoded_jpg),
        'image/format': dataset_util.bytes_feature(image_format.encode('utf8')),
        'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
        'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
        'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
        'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
        'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
        'image/object/class/label': dataset_util.int64_list_feature(classes),
    }))
    return tf_example


def _show_inside_dataset_information(dataset):
    print('Number of annotated images', len(dataset.get_all_annotated_images()))
    counter = 0
    for image in dataset.get_all_annotated_images():    
        counter += len(dataset.get_annotations(image, classes))
    print('Number of annotated regions', counter)

def _create_label_map(classes, out_path):
    classes.sort()
    classes.append('BACKGROUND')
    label_map = {}     

    out = open(out_path, "w")
    for i, item in enumerate(classes):
        label_map[item] = i
        out.write('item {\n  id: '+str(i+1)+'\n  name: \''+item+'\'\n}\n')
        if i < len(classes)-1:
            out.write('\n')
    out.close()
    print('Successfully created label map.')
    return label_map

if __name__ == '__main__':
    """
        This script generated the TF Records corresponding to the dataset. 
        For this the Image orientation needs to be updated. The orientation flag was removed, because it was ignored during annotation.
        This is NOT done for the images delivered with this project. 
        The outputted version as well as the input dataset could be provided by the chair 'Intelligente Systeme' at the University Duisburg-Essen.  
        
        exiftool -Orientation=1 -n -overwrite_original -r <PATH TO THE DATASET FOLDER>
    """

    base_folder_data_origin = Path('../input')
    base_folder_output = Path('../input/tf_records')
    label_mal_destination = base_folder_output / 'caries_label_map.pbtxt'
    annotation_file = base_folder_data_origin / "caries_dataset_annotation.json"
    training_data_folder = base_folder_data_origin / "training_data"
    test_data_folder = base_folder_data_origin / "test_data"
    evaluation_data_folder = base_folder_data_origin / "evaluation_data"

    if not base_folder_output.exists():
        base_folder_output.mkdir(parents=True)

    sets = [training_data_folder, test_data_folder, evaluation_data_folder]

    classes = ['CARIES']

    labelmap = _create_label_map(classes, label_mal_destination)


    for kind in sets:
        dataset = AnnotationLocationLoader(annotation_file=annotation_file, images_base_folder=kind, mouth_annotations_folder=None)
        if DEBUG:
            print(kind.name)
            _show_inside_dataset_information(dataset)
            print()

        record_file = base_folder_output / str(kind.name + '.tfrecord')

        writer = tf.io.TFRecordWriter(str(record_file))    

        for image in dataset.get_all_available_images():
            annotations = dataset.get_annotations(image, classes)
            tf_example = create_tf_example(image, annotations, kind, labelmap)
            # only write to file if there could be created an example
            if tf_example: 
                writer.write(tf_example.SerializeToString())
        writer.close()
        print('Successfully created the TFRecords for {}'.format(kind.name))
        del dataset