import numpy as np import tensorflow as tf import matplotlib.pyplot as plt import tkinter as tk import os import matplotlib from tkinter import filedialog from skimage import transform import SimpleITK as sitk import argparse #import os #import pydot #from graphviz import Digraph #import shutil #from tensorflow.keras import layers, models #from utils.dataLoader import * from utils.other_functions import * from nets.Unet import * def main(): models = ["12", "34", "56", "78", "910", "1112", "1314", "1516"] folder_path = "finalResults/complete_segmr/mr_unet_cv_unetpred/" #modeltype = "Cluster_" #modeltype = "Class_" #modeltype = "Cluster_class_" modeltype = "Unet_" all3dhdds = [] ##################### U-Net ##################### for fold in models: file_path = folder_path+"TPs"+fold+modeltype ### read out files ### weights = np.load(file_path+"model.npy", allow_pickle=True) [test_patients, val_patients, number_patients, img_path, shrink_data, newSize, lr, batch_size, num_epochs, e, augment, save_path, gt_type, filter_multiplier] = np.load(file_path+"params.npy", allow_pickle=True) # autoencoder_model__e100_switchclass1024_nohiddenclusternet needs to comment out gt_type and val_patients print('Training Parameters:') print('-----------------') print('Number of Patients: ', number_patients) print('Number of epochs: ', num_epochs) print('Test Patient number: ', test_patients) print('Image Size: ', newSize) print('Filter Multiplier: ', filter_multiplier) print('Data Augmentation: ', augment) print('Learning rate: ', lr) print('Image Path: ', img_path) print('Save Path: ', save_path) print('GT Type:', gt_type) ### Load test patient if gt_type == "thresh" or gt_type == "ctthresh_gt" or gt_type == "_mr_ctunetpred": img_path = "data/npy_thresh/" else: img_path = "data/npy/" full_list = os.listdir(img_path) seg_list = os.listdir("data/npy/") X_img_list = [] GT_img_list = [] ytrue_img_list = [] # thresh_img_list = [] if "mr" in save_path: for elem in full_list: if elem.endswith("T1.gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))): X_img_list.append(elem) if gt_type == "ctthresh_gt": X_img_list.append(elem) X_img_list.append(elem) elif elem.endswith(gt_type + ".gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))): GT_img_list.append(elem) for elem in seg_list: if elem.endswith("segmr.gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))): ytrue_img_list.append(elem) if gt_type == "ctthresh_gt": ytrue_img_list.append(elem) ytrue_img_list.append(elem) else: for elem in full_list: if elem.endswith("ct.gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))): X_img_list.append(elem) if gt_type == "thresh": X_img_list.append(elem) X_img_list.append(elem) elif elem.endswith(gt_type+".gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))): GT_img_list.append(elem) for elem in seg_list: if elem.endswith("seg.gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))): ytrue_img_list.append(elem) if gt_type == "thresh": ytrue_img_list.append(elem) ytrue_img_list.append(elem) #if gt_type =="_mr_ctunetpred": #X_img_list.remove("P17_mr_T1.gipl.npy") list.sort(X_img_list) list.sort(GT_img_list) list.sort(ytrue_img_list) print("Input Image List", X_img_list) print("GT Image List", GT_img_list) print("True Segmentation Image List", ytrue_img_list) for j in range(2): if gt_type == "thresh" or gt_type == "ctthresh_gt": X_img_npys = np.load(img_path + X_img_list[j*3]) GT_img_npys = np.load(img_path + GT_img_list[j*3]) ytrue_img_npys = np.load(img_path + ytrue_img_list[j*3]) print(GT_img_list[j*3]) print(GT_img_list[j * 3+1]) print(GT_img_list[j * 3+2]) X_img_npys = np.append(X_img_npys, np.load(img_path + X_img_list[j * 3+1]), axis=0) GT_img_npys = np.append(GT_img_npys, np.load(img_path + GT_img_list[j * 3+1]), axis=0) ytrue_img_npys = np.append(ytrue_img_npys, np.load(img_path + ytrue_img_list[j * 3+1]), axis=0) X_img_npys = np.append(X_img_npys, np.load(img_path + X_img_list[j * 3+2]), axis=0) GT_img_npys = np.append(GT_img_npys, np.load(img_path + GT_img_list[j * 3+2]), axis=0) ytrue_img_npys = np.append(ytrue_img_npys, np.load(img_path + ytrue_img_list[j * 3+2]), axis=0) print("Input shape: ", np.shape(X_img_npys)) print("GT shape: ", np.shape(GT_img_npys)) print("True Segm shape: ", np.shape(ytrue_img_npys)) else: X_img_npys = np.load(img_path + X_img_list[j]) GT_img_npys = np.load(img_path + GT_img_list[j]) ytrue_img_npys = np.load(img_path + ytrue_img_list[j]) print("Input shape: ", np.shape(X_img_npys)) print("GT shape: ", np.shape(GT_img_npys)) print("True Segm shape: ", np.shape(ytrue_img_npys)) X_img_npys = transform.resize(X_img_npys, (X_img_npys.shape[0], newSize[0], newSize[1]), order=0, preserve_range=True, mode='constant', anti_aliasing=False, anti_aliasing_sigma=None) GT_img_npys = transform.resize(GT_img_npys, (GT_img_npys.shape[0], newSize[0], newSize[1]), order=0, preserve_range=True, mode='constant', anti_aliasing=False, anti_aliasing_sigma=None) ytrue_img_npys = transform.resize(ytrue_img_npys, (ytrue_img_npys.shape[0], newSize[0], newSize[1]), order=0, preserve_range=True, mode='constant', anti_aliasing=False, anti_aliasing_sigma=None) X_test = np.reshape(X_img_npys, (X_img_npys.shape[0], X_img_npys.shape[1], X_img_npys.shape[2], 1)) GT_test = np.reshape(GT_img_npys, (GT_img_npys.shape[0], GT_img_npys.shape[1], GT_img_npys.shape[2], 1)) ytrue = np.reshape(ytrue_img_npys, (ytrue_img_npys.shape[0], ytrue_img_npys.shape[1], ytrue_img_npys.shape[2], 1)) test_dataset = tf.data.Dataset.from_tensor_slices((X_test, ytrue)) test_dataset = test_dataset.batch(batch_size=1) print(test_patients) TP_num = test_patients[j] ################################################################################### detailed_images = False npys3d = True ################################################################################### test_loss = [] test_loss_hdd = [] test_loss_hdd2 = [] y_pred3d = [] y_true3d = [] for features in test_dataset: image, y_true = features y_true = onehotencode(y_true) y_pred = Unet(image, weights, filter_multiplier, training=False) '''if y_pred3d == []: y_pred3d = y_pred[:,:,:,0].numpy() else: y_pred3d = np.append(y_pred[:,:,:,0].numpy(), y_pred3d, axis=0)''' loss = dice_loss(y_pred, y_true) loss = tf.make_ndarray(tf.make_tensor_proto(loss)) test_loss.append(loss) #print(test_loss) try: y_true_np = np.squeeze(y_true[0, :, :, 0].numpy() > 0.5) y_true_np = y_true_np.astype(np.float_) pred_np = np.squeeze(y_pred[0, :, :, 0].numpy() > 0.5) pred_np = pred_np.astype(np.float_) hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter() hausdorff_distance_filter.Execute(sitk.GetImageFromArray(y_true_np), sitk.GetImageFromArray(pred_np)) test_loss_hdd.append(hausdorff_distance_filter.GetHausdorffDistance()) hausdorff_distance_filter2 = sitk.HausdorffDistanceImageFilter() hausdorff_distance_filter2.Execute(sitk.GetImageFromArray(pred_np), sitk.GetImageFromArray(y_true_np)) test_loss_hdd2.append(hausdorff_distance_filter2.GetHausdorffDistance()) '''if y_pred3d ==[]: y_pred3d = pred_np y_true3d = y_true_np else: y_pred3d = np.append(pred_np, y_pred3d, axis=0) y_true3d = np.append(y_true_np, y_true3d, axis=0)''' if y_pred3d == []: y_pred3d = np.reshape(pred_np, newshape=(1, 512, 512)) y_true3d = np.reshape(y_true_np, newshape=(1, 512, 512)) else: y_pred3d = np.append(np.reshape(pred_np, newshape=(1, 512, 512)), y_pred3d, axis=0) y_true3d = np.append(np.reshape(y_true_np, newshape=(1, 512, 512)), y_true3d, axis=0) except: pass #plt.show() #print(test_loss) print("TestLoss Mean for P", test_patients[j], ": ", np.mean(test_loss)) #print(test_loss_hdd) print("Hausdorff-Distance for P", test_patients[j],":", np.mean(test_loss_hdd)) print("Hausdorff-Distance2 for P", test_patients[j], ":", np.mean(test_loss_hdd2)) #### Hausdorff 3D: '''ytrue = tf.convert_to_tensor(ytrue) y_pred3d = tf.convert_to_tensor(y_pred3d) print(tf.shape(ytrue)) y_true3d = np.squeeze(ytrue[:,:,:,0].numpy() > 0.5) y_true3d = np.float_(y_true) print(np.shape(y_true3d)) print(type(y_true3d)) print(type(y_pred3d)) pred3d = np.squeeze(y_pred3d[:,:,:].numpy() > 0.5) pred3d = np.float(pred3d) print(np.shape(pred3d)) print(type(pred3d))''' try: hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter() hausdorff_distance_filter.Execute(sitk.GetImageFromArray(y_true3d), sitk.GetImageFromArray(y_pred3d)) hdd3d = hausdorff_distance_filter.GetHausdorffDistance() print("3D HDD for patient", test_patients[j], ":", hdd3d) all3dhdds.append(hdd3d) except: pass print("All 3d-HDDs: ", all3dhdds) print(folder_path) print(modeltype) ##### if __name__ == "__main__": main()