Commit f66f6ab0 authored by sjromuel's avatar sjromuel
Browse files

d

parent 404926dc
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import tkinter as tk
import os
import matplotlib
from tkinter import filedialog
from skimage import transform
import SimpleITK as sitk
import argparse
#import os
#import pydot
#from graphviz import Digraph
#import shutil
#from tensorflow.keras import layers, models
#from utils.dataLoader import *
from utils.other_functions import *
from nets.Unet import *
def main():
models = ["12", "34", "56", "78", "910", "1112", "1314", "1516"]
folder_path = "finalResults/complete_seg/ae_class_cv_seg/"
#modeltype = "Cluster_"
modeltype = "Class_"
#modeltype = "Cluster_class_"
#modeltype = "Unet_"
##################### U-Net #####################
for fold in models:
file_path = folder_path+"TPs"+fold+modeltype
### read out files ###
weights = np.load(file_path+"model.npy", allow_pickle=True)
[test_patients,
val_patients,
number_patients,
img_path,
shrink_data,
newSize,
lr,
batch_size,
num_epochs,
e,
augment,
save_path,
gt_type,
filter_multiplier] = np.load(file_path+"params.npy", allow_pickle=True)
# autoencoder_model__e100_switchclass1024_nohiddenclusternet needs to comment out gt_type and val_patients
print('Training Parameters:')
print('-----------------')
print('Number of Patients: ', number_patients)
print('Number of epochs: ', num_epochs)
print('Test Patient number: ', test_patients)
print('Image Size: ', newSize)
print('Filter Multiplier: ', filter_multiplier)
print('Data Augmentation: ', augment)
print('Learning rate: ', lr)
print('Image Path: ', img_path)
print('Save Path: ', save_path)
print('GT Type:', gt_type)
### Load test patient
if gt_type == "thresh" or gt_type == "ctthresh_gt":
img_path = "data/npy_thresh/"
else:
img_path = "data/npy/"
full_list = os.listdir(img_path)
seg_list = os.listdir("data/npy/")
X_img_list = []
GT_img_list = []
ytrue_img_list = []
# thresh_img_list = []
if "mr" in save_path:
for elem in full_list:
if elem.endswith("T1.gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
X_img_list.append(elem)
if gt_type == "ctthresh_gt":
X_img_list.append(elem)
X_img_list.append(elem)
elif elem.endswith(gt_type + ".gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
GT_img_list.append(elem)
for elem in seg_list:
if elem.endswith("segmr.gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
ytrue_img_list.append(elem)
if gt_type == "ctthresh_gt":
ytrue_img_list.append(elem)
ytrue_img_list.append(elem)
else:
for elem in full_list:
if elem.endswith("ct.gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
X_img_list.append(elem)
if gt_type == "thresh":
X_img_list.append(elem)
X_img_list.append(elem)
elif elem.endswith(gt_type+".gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
GT_img_list.append(elem)
for elem in seg_list:
if elem.endswith("seg.gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
ytrue_img_list.append(elem)
if gt_type == "thresh":
ytrue_img_list.append(elem)
ytrue_img_list.append(elem)
list.sort(X_img_list)
list.sort(GT_img_list)
list.sort(ytrue_img_list)
print("Input Image List", X_img_list)
print("GT Image List", GT_img_list)
print("True Segmentation Image List", ytrue_img_list)
for j in range(2):
if gt_type == "thresh" or gt_type == "ctthresh_gt":
X_img_npys = np.load(img_path + X_img_list[j*3])
GT_img_npys = np.load(img_path + GT_img_list[j*3])
ytrue_img_npys = np.load(img_path + ytrue_img_list[j*3])
print(GT_img_list[j*3])
print(GT_img_list[j * 3+1])
print(GT_img_list[j * 3+2])
X_img_npys = np.append(X_img_npys, np.load(img_path + X_img_list[j * 3+1]), axis=0)
GT_img_npys = np.append(GT_img_npys, np.load(img_path + GT_img_list[j * 3+1]), axis=0)
ytrue_img_npys = np.append(ytrue_img_npys, np.load(img_path + ytrue_img_list[j * 3+1]), axis=0)
X_img_npys = np.append(X_img_npys, np.load(img_path + X_img_list[j * 3+2]), axis=0)
GT_img_npys = np.append(GT_img_npys, np.load(img_path + GT_img_list[j * 3+2]), axis=0)
ytrue_img_npys = np.append(ytrue_img_npys, np.load(img_path + ytrue_img_list[j * 3+2]), axis=0)
print("Input shape: ", np.shape(X_img_npys))
print("GT shape: ", np.shape(GT_img_npys))
print("True Segm shape: ", np.shape(ytrue_img_npys))
else:
X_img_npys = np.load(img_path + X_img_list[j])
GT_img_npys = np.load(img_path + GT_img_list[j])
ytrue_img_npys = np.load(img_path + ytrue_img_list[j])
print("Input shape: ", np.shape(X_img_npys))
print("GT shape: ", np.shape(GT_img_npys))
print("True Segm shape: ", np.shape(ytrue_img_npys))
X_img_npys = transform.resize(X_img_npys, (X_img_npys.shape[0], newSize[0], newSize[1]), order=0,
preserve_range=True, mode='constant', anti_aliasing=False,
anti_aliasing_sigma=None)
GT_img_npys = transform.resize(GT_img_npys, (GT_img_npys.shape[0], newSize[0], newSize[1]), order=0,
preserve_range=True, mode='constant', anti_aliasing=False,
anti_aliasing_sigma=None)
ytrue_img_npys = transform.resize(ytrue_img_npys, (ytrue_img_npys.shape[0], newSize[0], newSize[1]),
order=0,
preserve_range=True, mode='constant', anti_aliasing=False,
anti_aliasing_sigma=None)
X_test = np.reshape(X_img_npys, (X_img_npys.shape[0], X_img_npys.shape[1], X_img_npys.shape[2], 1))
GT_test = np.reshape(GT_img_npys, (GT_img_npys.shape[0], GT_img_npys.shape[1], GT_img_npys.shape[2], 1))
ytrue = np.reshape(ytrue_img_npys,
(ytrue_img_npys.shape[0], ytrue_img_npys.shape[1], ytrue_img_npys.shape[2], 1))
test_dataset = tf.data.Dataset.from_tensor_slices((X_test, ytrue))
test_dataset = test_dataset.batch(batch_size=1)
print(test_patients)
TP_num = test_patients[j]
###################################################################################
detailed_images = False
npys3d = True
###################################################################################
test_loss = []
test_loss_hdd = []
test_loss_hdd2 = []
y_pred3d = []
for features in test_dataset:
image, y_true = features
y_true = onehotencode(y_true)
y_pred = Unet(image, weights, filter_multiplier, training=False)
loss = dice_loss(y_pred, y_true)
loss = tf.make_ndarray(tf.make_tensor_proto(loss))
test_loss.append(loss)
#print(test_loss)
try:
y_true_np = np.squeeze(y_true[0, :, :, 0].numpy() > 0.5)
y_true_np = y_true_np.astype(np.float_)
pred_np = np.squeeze(y_pred[0, :, :, 0].numpy() > 0.5)
pred_np = pred_np.astype(np.float_)
hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter()
hausdorff_distance_filter.Execute(sitk.GetImageFromArray(y_true_np), sitk.GetImageFromArray(pred_np))
test_loss_hdd.append(hausdorff_distance_filter.GetHausdorffDistance())
hausdorff_distance_filter2 = sitk.HausdorffDistanceImageFilter()
hausdorff_distance_filter2.Execute(sitk.GetImageFromArray(pred_np), sitk.GetImageFromArray(y_true_np))
test_loss_hdd2.append(hausdorff_distance_filter2.GetHausdorffDistance())
except:
pass
#plt.show()
#print(test_loss)
print("TestLoss Mean for P", test_patients[j], ": ", np.mean(test_loss))
#print(test_loss_hdd)
print("Hausdorff-Distance for P", test_patients[j],":", np.mean(test_loss_hdd))
print("Hausdorff-Distance2 for P", test_patients[j], ":", np.mean(test_loss_hdd2))
#####
if __name__ == "__main__":
main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment