Commit 410b8d3c authored by sjromuel's avatar sjromuel
Browse files

d

parent ef9d396b
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import tkinter as tk
import os
import matplotlib
from tkinter import filedialog
from skimage import transform
import SimpleITK as sitk
import argparse
#import os
#import pydot
#from graphviz import Digraph
#import shutil
#from tensorflow.keras import layers, models
#from utils.dataLoader import *
from utils.other_functions import *
def main():
root = tk.Tk()
root.withdraw()
file_path = filedialog.askopenfilename(initialdir="finalResults/complete_segmr/mr_unet_cv_ctthresh")
file_path = file_path[:-9]
print(file_path)
##################### U-Net #####################
if 'Unet' in file_path or 'Cluster' in file_path or 'class' in file_path:
from nets.Unet import run_test_patient, show_test_patient_pred, dice_loss, Unet
print("UNet_model selected")
### read out files ###
weights = np.load(file_path+"model.npy", allow_pickle=True)
[test_patients,
val_patients,
number_patients,
img_path,
shrink_data,
newSize,
lr,
batch_size,
num_epochs,
e,
augment,
save_path,
gt_type,
filter_multiplier] = np.load(file_path+"params.npy", allow_pickle=True)
# autoencoder_model__e100_switchclass1024_nohiddenclusternet needs to comment out gt_type and val_patients
print('Training Parameters:')
print('-----------------')
print('Number of Patients: ', number_patients)
print('Number of epochs: ', num_epochs)
print('Test Patient number: ', test_patients)
print('Image Size: ', newSize)
print('Filter Multiplier: ', filter_multiplier)
print('Data Augmentation: ', augment)
print('Learning rate: ', lr)
print('Image Path: ', img_path)
print('Save Path: ', save_path)
print('GT Type:', gt_type)
### Load test patient
if gt_type == "thresh" or gt_type == "ctthresh_gt":
img_path = "data/npy_thresh/"
else:
img_path = "data/npy/"
full_list = os.listdir(img_path)
seg_list = os.listdir("data/npy/")
X_img_list = []
GT_img_list = []
ytrue_img_list = []
# thresh_img_list = []
if "mr" in save_path:
for elem in full_list:
if elem.endswith("T1.gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
X_img_list.append(elem)
elif elem.endswith("250_mr_ctthresh_gt.gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
GT_img_list.append(elem)
for elem in seg_list:
if elem.endswith("segmr.gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
ytrue_img_list.append(elem)
list.sort(X_img_list)
list.sort(GT_img_list)
list.sort(ytrue_img_list)
print("Input Image List", X_img_list)
print("GT Image List", GT_img_list)
print("True Segmentation Image List", ytrue_img_list)
img_savepath = file_path + "imgresults/"
print(img_savepath)
if not os.path.exists(img_savepath):
os.makedirs(img_savepath)
for j in range(2):
X_img_npys = np.load(img_path + X_img_list[j])
GT_img_npys = np.load(img_path + GT_img_list[j])
ytrue_img_npys = np.load(img_path + ytrue_img_list[j])
print("Input shape: ", np.shape(X_img_npys))
print("GT shape: ", np.shape(GT_img_npys))
print("True Segm shape: ", np.shape(ytrue_img_npys))
X_img_npys = transform.resize(X_img_npys, (X_img_npys.shape[0], newSize[0], newSize[1]), order=0,
preserve_range=True, mode='constant', anti_aliasing=False,
anti_aliasing_sigma=None)
GT_img_npys = transform.resize(GT_img_npys, (GT_img_npys.shape[0], newSize[0], newSize[1]), order=0,
preserve_range=True, mode='constant', anti_aliasing=False,
anti_aliasing_sigma=None)
ytrue_img_npys = transform.resize(ytrue_img_npys, (ytrue_img_npys.shape[0], newSize[0], newSize[1]),
order=0,
preserve_range=True, mode='constant', anti_aliasing=False,
anti_aliasing_sigma=None)
X_test = np.reshape(X_img_npys, (X_img_npys.shape[0], X_img_npys.shape[1], X_img_npys.shape[2], 1))
GT_test = np.reshape(GT_img_npys, (GT_img_npys.shape[0], GT_img_npys.shape[1], GT_img_npys.shape[2], 1))
ytrue = np.reshape(ytrue_img_npys,
(ytrue_img_npys.shape[0], ytrue_img_npys.shape[1], ytrue_img_npys.shape[2], 1))
test_dataset = tf.data.Dataset.from_tensor_slices((X_test, ytrue))
test_dataset = test_dataset.batch(batch_size=1)
print(test_patients)
TP_num = test_patients[j]
###################################################################################
detailed_images = False
npys3d = True
###################################################################################
if npys3d:
#test_patient_pred = run_test_patient(test_dataset, weights, filter_multiplier)
np.save(img_savepath + "P" + str(TP_num).zfill(2) +"_Inputimg3D", X_test)
np.save(img_savepath + "P" + str(TP_num).zfill(2) + "_gt3D", GT_test)
np.save(img_savepath + "P" + str(TP_num).zfill(2) + "_true3D", ytrue)
test_loss = []
test_loss_hdd = []
test_loss_hdd2 = []
counter=0
y_pred3d = []
for features in test_dataset:
image, y_true = features
y_true = onehotencode(y_true)
y_pred = Unet(image, weights, filter_multiplier, training=False)
if npys3d:
if y_pred3d == []:
y_pred3d = y_pred[:,:,:,0].numpy()
else:
y_pred3d = np.append(y_pred[:,:,:,0].numpy(), y_pred3d, axis=0)
#print(np.shape(y_pred3d))
# print(tf.shape(y_pred), tf.shape(y_true))
#y_true = onehotencode(tf.reshape(y_true, (1, 512, 512, 1)), autoencoder=True)
#y_pred = tf.reshape(y_pred, (1, 512, 512, 2))
fig = plt.figure()
fig.add_subplot(1, 4, 1)
plt.title("Input_img")
plt.axis('off')
plt.imshow(image[0,:,:,0], cmap=plt.cm.bone)
fig.add_subplot(1, 4, 2)
plt.title("Pred")
plt.axis('off')
plt.imshow(y_pred[0, :, :, 0], cmap=plt.cm.bone)
fig.add_subplot(1, 4, 3)
plt.title("True Seg")
plt.axis('off')
plt.imshow(y_true[0, :, :, 0], cmap=plt.cm.bone)
fig.add_subplot(1, 4, 4)
plt.title("Fake GT")
plt.axis('off')
plt.imshow(GT_test[counter, :, :, 0], cmap=plt.cm.bone)
if not os.path.exists(img_savepath+ "4Imgs_P" + str(TP_num).zfill(2)):
os.makedirs(img_savepath+ "4Imgs_P" + str(TP_num).zfill(2))
fig.savefig(img_savepath + "4Imgs_P" + str(TP_num).zfill(2) +"/" + str(counter+1).zfill(2)+".png")
plt.close(fig)
counter = counter+1
loss = dice_loss(y_pred, y_true)
loss = tf.make_ndarray(tf.make_tensor_proto(loss))
test_loss.append(loss)
#print(test_loss)
try:
y_true_np = np.squeeze(y_true[0, :, :, 0].numpy() > 0.5)
y_true_np = y_true_np.astype(np.float_)
pred_np = np.squeeze(y_pred[0, :, :, 0].numpy() > 0.5)
pred_np = pred_np.astype(np.float_)
hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter()
hausdorff_distance_filter.Execute(sitk.GetImageFromArray(y_true_np), sitk.GetImageFromArray(pred_np))
test_loss_hdd.append(hausdorff_distance_filter.GetHausdorffDistance())
hausdorff_distance_filter2 = sitk.HausdorffDistanceImageFilter()
hausdorff_distance_filter2.Execute(sitk.GetImageFromArray(pred_np), sitk.GetImageFromArray(y_true_np))
test_loss_hdd2.append(hausdorff_distance_filter2.GetHausdorffDistance())
except:
pass
# save images:
if detailed_images:
matplotlib.image.imsave(img_savepath + "P" + str(TP_num).zfill(2) +"_" + str(counter).zfill(2)+"_Inputimg.png", image[0,:,:,0],
cmap=plt.cm.bone)
matplotlib.image.imsave(img_savepath + "P" + str(TP_num).zfill(2) +"_" + str(counter).zfill(2) + "_ytrue.png",
y_true[0, :, :, 0],
cmap=plt.cm.bone)
matplotlib.image.imsave(img_savepath + "P" + str(TP_num).zfill(2) +"_" + str(counter).zfill(2) + "_ytrain.png",
GT_test[counter-1, :, :, 0],
cmap=plt.cm.bone)
if 'Cluster' in file_path:
matplotlib.image.imsave(img_savepath + "P" + str(TP_num).zfill(2) +"_" + str(counter).zfill(2) + "_clusterpred"+ str(loss)[0:6] +".png",
y_pred[0, :, :, 0], cmap=plt.cm.bone)
elif 'class' in file_path:
matplotlib.image.imsave(img_savepath + "P" + str(TP_num).zfill(2) +"_" + str(counter).zfill(2) + "_classpred"+ str(loss)[0:6] +".png",
y_pred[0, :, :, 0], cmap=plt.cm.bone)
else:
matplotlib.image.imsave(img_savepath + "P" + str(TP_num).zfill(2) + "_" + str(counter).zfill(2) + "_unetpred"+ str(loss)[0:6] +".png",
y_pred[0, :, :, 0], cmap=plt.cm.bone)
if npys3d:
#np.save(img_savepath + "P" + str(TP_num).zfill(2) + "_pred3D", y_pred3d)
np.save("data/npy_thresh/" + "P" + str(TP_num).zfill(2) + "_mr_ctunetpred.gipl", y_pred3d)
print("ypred shape: ", np.shape(y_pred3d))
#plt.show()
#print(test_loss)
print("TestLoss Mean for P", test_patients[j], ": ", np.mean(test_loss))
#print(test_loss_hdd)
print("Hausdorff-Distance for P", test_patients[j],":", np.mean(test_loss_hdd))
print("Hausdorff-Distance2 for P", test_patients[j], ":", np.mean(test_loss_hdd2))
#####
'''X_img_npys = np.load(img_path + X_img_list[0])
GT_img_npys = np.load(img_path + GT_img_list[0])
ytrue_img_npys = np.load("../data/npy/" + ytrue_img_list[0])
for i in range(len(X_img_list)-1):
X_img_npys = np.append(X_img_npys, np.load(img_path + X_img_list[i+1]), axis=0)
GT_img_npys = np.append(GT_img_npys, np.load(img_path + GT_img_list[i+1]), axis=0)
ytrue_img_npys = np.append(ytrue_img_npys, np.load("../data/npy/" + ytrue_img_list[i+1]), axis=0)
X_img_npys = transform.resize(X_img_npys, (X_img_npys.shape[0], newSize[0], newSize[1]), order=0,
preserve_range=True, mode='constant', anti_aliasing=False, anti_aliasing_sigma=None)
GT_img_npys = transform.resize(GT_img_npys, (GT_img_npys.shape[0], newSize[0], newSize[1]), order=0,
preserve_range=True, mode='constant', anti_aliasing=False, anti_aliasing_sigma=None)
ytrue_img_npys = transform.resize(ytrue_img_npys, (ytrue_img_npys.shape[0], newSize[0], newSize[1]), order=0,
preserve_range=True, mode='constant', anti_aliasing=False, anti_aliasing_sigma=None)
X_test = np.reshape(X_img_npys,(X_img_npys.shape[0], X_img_npys.shape[1], X_img_npys.shape[2], 1))
GT_test = np.reshape(GT_img_npys, (GT_img_npys.shape[0], GT_img_npys.shape[1], GT_img_npys.shape[2], 1))
ytrue = np.reshape(ytrue_img_npys, (ytrue_img_npys.shape[0], ytrue_img_npys.shape[1], ytrue_img_npys.shape[2], 1))
### recreate test-dataset ###
test_dataset = tf.data.Dataset.from_tensor_slices((X_test, GT_test))
test_dataset = test_dataset.batch(batch_size)
print("x_test shape: " + str(X_test.shape))
### test model and show plots ###
test_patient_pred = run_test_patient(test_dataset, weights, filter_multiplier)
print("test_patient_pred shape: " + str(test_patient_pred.shape))
# show_test_patient_pred(test_patient_pred, test_dataset)
print(GT_test.shape)
### compute loss for each slice ###
test_loss = []
test_loss_hdd = []
for i in range(test_patient_pred.shape[0]):
y_pred = tf.convert_to_tensor(test_patient_pred[i])
y_true = tf.convert_to_tensor(ytrue[i])
#print(tf.shape(y_pred), tf.shape(y_true))
y_true = onehotencode(tf.reshape(y_true, (1, 512, 512, 1)), autoencoder=True)
y_pred = tf.reshape(y_pred, (1, 512, 512, 2))
loss = dice_loss(y_pred, y_true)
loss = tf.make_ndarray(tf.make_tensor_proto(loss))
test_loss.append(loss)
gt_np = np.squeeze(y_true[0, :, :, 0].numpy() > 0)
gt_np = gt_np.astype(np.float_)
pred_np = np.squeeze(y_pred[0, :, :, 0].numpy() > 0)
pred_np = pred_np.astype(np.float_)
hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter()
hausdorff_distance_filter.Execute(sitk.GetImageFromArray(gt_np), sitk.GetImageFromArray(pred_np))
test_loss_hdd.append(hausdorff_distance_filter.GetHausdorffDistance())
print(test_loss)
print("TestLoss Mean: ", np.mean(test_loss))
print(test_loss_hdd)
print("Hausdorff-Distance: ", np.mean(test_loss_hdd))
#show_loss_plot(len(test_loss), test_loss)
print("test_patient_pred shape: " + str(test_patient_pred.shape))'''
#show_test_patient_pred(test_patient_pred, test_dataset)
################## Autoencoder ##################
elif 'AE' in file_path:
# from nets.AE_3conv32 import run_test_patient, show_test_patient_pred, dice_loss
# from nets.AE_3conv128 import run_test_patient, show_test_patient_pred, dice_loss
# from nets.AE_3conv1024 import run_test_patient, show_test_patient_pred, dice_loss
# from nets.AE_4conv2048 import run_test_patient, show_test_patient_pred, dice_loss
from nets.AE_ulike import run_test_patient, show_test_patient_pred, dice_loss
print("Autoencoder_model selected")
### read out files ###
weights = np.load(file_path + "model.npy", allow_pickle=True)
[test_patients,
val_patients,
number_patients,
img_path,
shrink_data,
newSize,
lr,
batch_size,
num_epochs,
e,
augment,
save_path,
gt_type,
filter_multiplier] = np.load(file_path + "params.npy", allow_pickle=True)
print('Training Parameters:')
print('-----------------')
print('Number of Patients: ', number_patients)
print('Number of epochs: ', num_epochs)
print('Test Patient numbers: ', test_patients)
# print('Validation Patient numbers: ', val_patients)
print('Image Size: ', newSize)
print('Data Augmentation: ', augment)
print('Learning rate: ', lr)
print('Image Path: ', img_path)
print('Save Path: ', save_path)
print('GT Type: ', gt_type)
### Load test patient
if gt_type == "thresh":
img_path = "data/npy_thresh/"
else:
img_path = "data/npy/"
full_list = os.listdir(img_path)
X_img_list = []
print(test_patients)
for elem in full_list:
if elem.endswith(gt_type+".gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
X_img_list.append(elem)
if gt_type == "thresh":
X_img_list.append(elem)
X_img_list.append(elem)
print(X_img_list)
test_patients_npys = np.append(np.load(img_path + X_img_list[0]), np.load(img_path + X_img_list[1]), axis=0)
x_test = np.reshape(test_patients_npys, (test_patients_npys.shape[0], test_patients_npys.shape[1], test_patients_npys.shape[2], 1))
### recreate test-dataset ###
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, x_test))
test_dataset = test_dataset.batch(batch_size=1)
### test model and show plots ###
test_patient_pred = run_test_patient(test_dataset, weights)
test_loss = []
for i in range(test_patient_pred.shape[0]):
y_pred = tf.convert_to_tensor(test_patient_pred[i])
y_true = tf.convert_to_tensor(x_test[i])
y_pred = tf.reshape(y_pred, (1, 512, 512, 2))
y_true = onehotencode(tf.reshape(y_true, (1, 512, 512, 1)), autoencoder=True)
loss = dice_loss(y_pred, y_true)
loss = tf.make_ndarray(tf.make_tensor_proto(loss))
test_loss.append(loss)
print('Test losses: ', test_loss)
print('Test loss-mean: ', np.mean(test_loss))
print("test_patient_pred shape: " + str(test_patient_pred.shape))
show_test_patient_pred(test_patient_pred, test_dataset)
def show_loss_plot(slice, loss_list):
x_axis = list(range(1, slice+1, 1))
plt.plot(x_axis, loss_list)
plt.xlabel('Slice Number')
plt.ylabel('Loss')
plt.show()
if __name__ == "__main__":
main()
\ No newline at end of file
......@@ -13,7 +13,7 @@ from datetime import datetime
img = np.load("data/npy_thresh/P01_mr_ctunetpred.gipl.npy")
print(np.shape(img))
plt.imshow(img[5,:, :], cmap=plt.cm.bone)
plt.imshow(img[10,:, :], cmap=plt.cm.bone)
plt.show()
'''mr_img = np.load("data/npy/P03_mr_T1.gipl.npy")
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment