Commit b80c8d80 authored by sjjsmuel's avatar sjjsmuel

reworked grad_cam

parent d6b7378d
import json
from pathlib import Path
class AnnotationLocationLoader:
_annotation_file = None
_img_path = None
_annotated_images = set()
_available_annotations = set()
_available_images = None
_data = {}
def __init__(self, annotation_file='input/caries_dataset_annotation.json', images_base_folder='input/test_data/'):
self._annotation_file = annotation_file
self._img_path = images_base_folder
# get the names of the images witch are available as files
self._available_images = self._get_names_from_available_images()
self._load_annotations()
def _get_names_from_available_images(self):
names_from_available_images = []
for path in [path for path in Path(self._img_path).iterdir() if path.is_dir()]:
names_from_available_images.extend([filename.name for filename in path.iterdir() if filename.is_file() and not filename.name.startswith('.')])
return names_from_available_images
def _load_annotations(self):
with open(self._annotation_file) as file:
json_data = json.load(file)
for picture in json_data:
picture_filename = picture['External ID']
self._data[picture_filename] = []
if not picture_filename in self._available_images:
#print('File ”{}” not found.'.format(picture_filename))
continue
# Skip the 'Skip' entries in the file
if not type(picture['Label']) == dict:
continue
for annotation_type in picture['Label'].keys():
self._available_annotations.add(annotation_type)
self._annotated_images.add(picture_filename)
for box in picture['Label'][annotation_type]:
x_all = []
y_all = []
for point in box['geometry']:
x_all.append(point['x'])
y_all.append(point['y'])
box_coord = [(min(x_all), min(y_all)), (max(x_all), max(y_all))]
self._data[picture_filename].append((annotation_type.lower(), box_coord))
self._annotated_images = list(self._annotated_images)
self._available_annotations = list(self._available_annotations)
def get_all_types_of_annotations(self):
"""
:return: list of all the types of annotations witch appeared at least once in the annotation_file
"""
return self._available_annotations
def get_all_annotated_images(self):
"""
:return: list of the names of all images witch have at least one annotation
"""
return self._annotated_images
def is_annotated(self, image_name):
"""
Should check weather for the given filename an annotation exists
:param image_name: complete name of the file including the filetype as a string
:return: boolean weather there is an annotation for the image
"""
return image_name in self._annotated_images
def get_annotations(self, image_name, filter=None):
"""
Returns a list of annotations for the given image_name
e.g. [ ('caries', [(x1,y1), (x2,y2)]), _more_entries_ ]
:param filter: a list of strings representing the types of annotations the user wants to derive
"""
if self.is_annotated(image_name):
if filter and len(filter)>0:
filter = [category.lower() for category in filter]
return [annotation for annotation in self._data[image_name] if annotation[0] in filter]
return self._data[image_name]
else:
return []
\ No newline at end of file
This source diff could not be displayed because it is too large. You can view the blob instead.
import PIL
from pathlib import Path
from helpers import metrics
import imutils
from PIL import Image
from helpers.AnnotationLocationLoader import AnnotationLocationLoader
from grad_cam.gradcam import GradCAM
from network_helpers.Resnet152 import Resnet152
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import load_img, img_to_array
import numpy as np
......@@ -19,8 +16,6 @@ out = 'out/'
#model_file = 'out/checkpoints/2020.04.10.13.22.27/model.0016-0.471.hdf5'
model_file= 'out/checkpoints/2020.04.16.10.12.48/model.0005-0.498.hdf5'
resnet_file = 'input/resnet152v2_weights_tf_dim_ordering_tf_kernels_notop.h5'
class_index_map = {'caries': 0, 'no_caries': 1}
index_class_map = {}
for element in class_index_map:
......@@ -34,6 +29,15 @@ model = load_model(model_file)
_, network_input_width, network_input_height, channels = model.input.shape
network_input_size = (network_input_width , network_input_height)
n_classes = model.output.shape[1]
print('[INFO] Model input dimensions are ({}, {}, {}).'.format(network_input_width, network_input_height, channels))
# loading the location annotations
print('[INFO] loading annotations...')
annotation_loader = AnnotationLocationLoader()
# clean up predicted boxes from last run
with open(out + "/predictions.txt", "w") as predictions_file:
predictions_file.write("")
# iterate over all (both) folders of classes
for path in [path for path in Path(img_raw_path).iterdir() if path.is_dir()]:
......@@ -47,7 +51,7 @@ for path in [path for path in Path(img_raw_path).iterdir() if path.is_dir()]:
print('[INFO] Staring to process folder \'{}\' with index {}'.format(path.name, class_index))
# get the filenames of all files in the folder
# get the filenames of all test-images in the current folder
filenames = [item for item in path.glob('*') if item.name != '.DS_Store']
#filenames = filenames[:3] # ------ simplification for testing | remove afterwards -------------------
......@@ -55,7 +59,7 @@ for path in [path for path in Path(img_raw_path).iterdir() if path.is_dir()]:
for img in filenames:
#print("Starting image {}".format(img.name))
# load original image
orig = cv2.imread(str(img))
orig = cv2.imread(str(img), cv2.IMREAD_IGNORE_ORIENTATION | cv2.IMREAD_COLOR)
orig_image_size = (orig.shape[1], orig.shape[0])
# load image for processing
......@@ -78,18 +82,62 @@ for path in [path for path in Path(img_raw_path).iterdir() if path.is_dir()]:
heatmap = cam.compute_heatmap(image)
# resize the resulting heatmap to the original input image dimensions
# and then overlay heatmap on top of the image
heatmap = cv2.resize(heatmap, (orig.shape[1], orig.shape[0]))
# resolve bounding boxes from heatmap
# https://stackoverflow.com/a/58421765
# Grayscale then Otsu's threshold
#grayscale_heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2GRAY)
thresh = cv2.threshold(heatmap, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1]
# overlay heatmap on top of the image
(heatmap, output) = cam.overlay_heatmap(heatmap, orig, alpha=0.5)
# Find contours and draw bounding boxes of prediction
cnts = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cnts = cnts[0] if len(cnts) == 2 else cnts[1]
boxes = []
for c in cnts:
x, y, w, h = cv2.boundingRect(c)
p1 = (x,y)
p2 = (x + w, y + h)
# add new box to list of all boxes for the image (tl: top left corner; br: bottom right corner)
boxes.append({'tl': p1, 'br': p2})
cv2.rectangle(heatmap, p1, p2, (255, 255, 255), 3)
cv2.rectangle(output, p1, p2, (255, 255, 255), 3)
# draw the predicted label on the output image
cv2.rectangle(output, (0, 0), (1300, 150), (0, 0, 0), -1)
cv2.putText(output, label, (10, 120), cv2.FONT_HERSHEY_SIMPLEX, 4, (255, 255, 255), 2)
# write boxes to file
prepared_boxes = [('caries', box['tl'], box['br']) for box in boxes]
with open(out + "/predictions.txt", "a") as predictions_file:
predictions_file.write('{}; {}\n'.format(img.name, str(prepared_boxes)))
# draw the annotated labels to the original picture
annotations = annotation_loader.get_annotations(img.name, ['Caries'])
for annotation in annotations:
cv2.rectangle(orig, annotation[1][0], annotation[1][1], (0, 255, 0), 3)
cv2.rectangle(output, annotation[1][0], annotation[1][1], (0, 255, 0), 3)
#Calculation of metric
#p, r, tp, fp, fn = metrics.calculate_precision_recall(annotations, boxes)
#print("{}: Precision: {}; Recall: {}; TP: {}; FP: {}; FN: {}".format(img.name, p, r, tp, fp, fn))
# display the original image and resulting heatmap and output image
# to our screen
output = np.hstack([orig, heatmap, output])
output = imutils.resize(output, height=700)
out_file = str(out_path / f'{img.name[:-4]}.png')
cv2.imwrite(out_file, output)
out_path_pictures = out_path / 'pictures'
if not out_path_pictures.exists():
out_path_pictures.mkdir()
out_path_heatmap = out_path / 'heatmap'
if not out_path_heatmap.exists():
out_path_heatmap.mkdir()
out_file_picture = str(out_path_pictures / f'{img.name[:-4]}.png')
out_file_heatmap = str(out_path_heatmap / f'{img.name[:-4]}.png')
cv2.imwrite(out_file_picture, output)
cv2.imwrite(out_file_heatmap, heatmap)
print()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment