PredictionLocationLoader.py 3.43 KB
Newer Older
sjjsmuel's avatar
sjjsmuel committed
1
import ast
sjjsmuel's avatar
sjjsmuel committed
2
from pathlib import Path, PosixPath
sjjsmuel's avatar
sjjsmuel committed
3 4 5 6 7 8


class PredictionLocationLoader:

    def __init__(self, prediction_file='out/predictions.txt', images_base_folder='input/test_data/'):
        self._prediction_file = prediction_file
sjjsmuel's avatar
sjjsmuel committed
9 10 11 12 13 14 15 16
        self._img_path = None
        self._annotated_images = set()
        self._available_annotations = set()
        self._available_images = None
        self._data = {}

        if not type(images_base_folder) == PosixPath:
            images_base_folder = Path(images_base_folder)
sjjsmuel's avatar
sjjsmuel committed
17 18 19 20 21 22 23
        self._img_path = images_base_folder

        # get the names of the images witch are available as files
        self._available_images = self._get_names_from_available_images()
        self._load_predictions()

    def _get_names_from_available_images(self):
sjjsmuel's avatar
sjjsmuel committed
24 25 26
        """
        Finds the names for all pictures available in the image base folder.
        """
sjjsmuel's avatar
sjjsmuel committed
27 28 29 30 31 32 33
        names_from_available_images = []
        for path in [path for path in Path(self._img_path).iterdir() if path.is_dir()]:
            names_from_available_images.extend([filename.name for filename in path.iterdir() if filename.is_file() and not filename.name.startswith('.')])
        return names_from_available_images


    def _load_predictions(self):
sjjsmuel's avatar
sjjsmuel committed
34 35 36
        """
        Loads the predictions from a network based on the prediction file.
        """
sjjsmuel's avatar
sjjsmuel committed
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
        with open(self._prediction_file) as file:
            for line in file.readlines():
                filename,  predictions = line.split(';')
                filename = filename.strip()
                predictions = predictions.strip()
                predictions = ast.literal_eval(predictions)
                self._data[filename] = predictions

                # prepare meta-structures
                for pred in predictions:
                    prediction_tpye = pred[0].upper()
                    self._available_annotations.add(prediction_tpye)
                self._annotated_images.add(filename)
        self._annotated_images = list(self._annotated_images)
        self._available_annotations = list(self._available_annotations)

    def get_all_types_of_annotations(self):
        """
        :return: list of all the types of annotations witch appeared at least once in the annotation_file
        """
        return self._available_annotations

    def get_all_annotated_images(self):
        """
        :return: list of the names of all images witch have at least one annotation
        """
        return self._annotated_images

    def is_annotated(self, image_name):
        """
        Should check weather for the given filename an annotation exists
        :param image_name: complete name of the file including the filetype as a string
        :return: boolean weather there is an annotation for the image
        """
        return image_name in self._annotated_images


    def get_annotations(self, image_name, filter=None):
        """
            Returns a list of annotations for the given image_name
            e.g. [ ('caries', [(x1,y1), (x2,y2)]),  _more_entries_ ]
            :param filter: a list of strings representing the types of annotations the user wants to derive
        """

        if self.is_annotated(image_name):
            if filter and len(filter)>0:
                filter = [category.lower() for category in filter]
                return [annotation for annotation in self._data[image_name] if annotation[0] in filter]
            return self._data[image_name]
        else:
            return []