Commit d1791a62 authored by sjromuel's avatar sjromuel
Browse files

d

parent 6a3aba14
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import tkinter as tk
import os
import matplotlib
from tkinter import filedialog
from skimage import transform
import SimpleITK as sitk
import argparse
#import os
#import pydot
#from graphviz import Digraph
#import shutil
#from tensorflow.keras import layers, models
#from utils.dataLoader import *
from utils.other_functions import *
from nets.Unet import *
def main():
folds = ["12", "34", "56", "78", "910", "1112", "1314", "1516"]
folder_path = "finalResults/complete_"
unet_models_paths = ["ctfgt2/unet_cv_ctfgt2/",
"seg/unet_seg/",
"thresh/unet_cv_thresh/",
"segmr/mr_unet_cv_ctsegmr/",
"segmr/mr_unet_cv_ctseg/",
"segmr/mr_unet_cv_ctthresh/",
"segmr/mr_unet_cv_unetpred/"]
cluster_models_paths = ["ctfgt2/ae_noclass_cv_ctfgt2/",
"seg/ae_noclass_seg/",
"thresh/ae_noclass_cv_thresh/"]
class_models_paths = ["ctfgt2/ae_class_cv_ctfgt2/",
"seg/ae_class_seg/",
"thresh/ae_class_cv_thresh/"]
clusterclass_models_paths = ["ctfgt2/ae_class_cv_ctfgt2/",
"seg/ae_class_seg/",
"thresh/ae_class_cv_thresh/"]
modeltypes = ["Unet_", "Cluster_", "Class_", "Cluster_class_"]
for modeltype in modeltypes:
if modeltype == "Unet_": model_paths = unet_models_paths
elif modeltype == "Cluster_": model_paths == cluster_models_paths
elif modeltype == "Class_": model_paths = class_models_paths
elif modeltype == "Cluster_class_": model_paths = clusterclass_models_paths
print("!!!!!!!!!! Lets start the ", modeltype, " models !!!!!!!!!!")
for model_path in model_paths:
for fold in folds:
print("Starting fold ", fold, "in ", model_path)
file_path = folder_path+model_path+"TPs"+fold+modeltype
print(file_path)
### read out files ###
weights = np.load(file_path+"model.npy", allow_pickle=True)
[test_patients,
val_patients,
number_patients,
img_path,
shrink_data,
newSize,
lr,
batch_size,
num_epochs,
e,
augment,
save_path,
gt_type,
filter_multiplier] = np.load(file_path+"params.npy", allow_pickle=True)
# autoencoder_model__e100_switchclass1024_nohiddenclusternet needs to comment out gt_type and val_patients
'''print('Training Parameters:')
print('-----------------')
print('Number of Patients: ', number_patients)
print('Number of epochs: ', num_epochs)
print('Test Patient number: ', test_patients)
print('Image Size: ', newSize)
print('Filter Multiplier: ', filter_multiplier)
print('Data Augmentation: ', augment)
print('Learning rate: ', lr)
print('Image Path: ', img_path)
print('Save Path: ', save_path)
print('GT Type:', gt_type)'''
### Load test patient
if gt_type == "thresh" or gt_type == "ctthresh_gt" or gt_type == "_mr_ctunetpred":
img_path = "data/npy_thresh/"
else:
img_path = "data/npy/"
full_list = os.listdir(img_path)
seg_list = os.listdir("data/npy/")
X_img_list = []
GT_img_list = []
ytrue_img_list = []
# thresh_img_list = []
if "mr" in save_path:
for elem in full_list:
if elem.endswith("T1.gipl.npy") and (
elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith(
'P' + str(test_patients[1]).zfill(2))):
X_img_list.append(elem)
if gt_type == "ctthresh_gt":
X_img_list.append(elem)
X_img_list.append(elem)
elif elem.endswith(gt_type + ".gipl.npy") and (
elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith(
'P' + str(test_patients[1]).zfill(2))):
GT_img_list.append(elem)
for elem in seg_list:
if elem.endswith("segmr.gipl.npy") and (
elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith(
'P' + str(test_patients[1]).zfill(2))):
ytrue_img_list.append(elem)
if gt_type == "ctthresh_gt":
ytrue_img_list.append(elem)
ytrue_img_list.append(elem)
else:
for elem in full_list:
if elem.endswith("ct.gipl.npy") and (
elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith(
'P' + str(test_patients[1]).zfill(2))):
X_img_list.append(elem)
if gt_type == "thresh":
X_img_list.append(elem)
X_img_list.append(elem)
elif elem.endswith(gt_type + ".gipl.npy") and (
elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith(
'P' + str(test_patients[1]).zfill(2))):
GT_img_list.append(elem)
for elem in seg_list:
if elem.endswith("seg.gipl.npy") and (
elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith(
'P' + str(test_patients[1]).zfill(2))):
ytrue_img_list.append(elem)
if gt_type == "thresh":
ytrue_img_list.append(elem)
ytrue_img_list.append(elem)
list.sort(X_img_list)
list.sort(GT_img_list)
list.sort(ytrue_img_list)
'''print("Input Image List", X_img_list)
print("GT Image List", GT_img_list)
print("True Segmentation Image List", ytrue_img_list)'''
for j in range(2):
if gt_type == "thresh" or gt_type == "ctthresh_gt":
X_img_npys = np.load(img_path + X_img_list[j*3])
GT_img_npys = np.load(img_path + GT_img_list[j*3])
ytrue_img_npys = np.load(img_path + ytrue_img_list[j*3])
print(GT_img_list[j*3])
print(GT_img_list[j * 3+1])
print(GT_img_list[j * 3+2])
X_img_npys = np.append(X_img_npys, np.load(img_path + X_img_list[j * 3+1]), axis=0)
GT_img_npys = np.append(GT_img_npys, np.load(img_path + GT_img_list[j * 3+1]), axis=0)
ytrue_img_npys = np.append(ytrue_img_npys, np.load(img_path + ytrue_img_list[j * 3+1]), axis=0)
X_img_npys = np.append(X_img_npys, np.load(img_path + X_img_list[j * 3+2]), axis=0)
GT_img_npys = np.append(GT_img_npys, np.load(img_path + GT_img_list[j * 3+2]), axis=0)
ytrue_img_npys = np.append(ytrue_img_npys, np.load(img_path + ytrue_img_list[j * 3+2]), axis=0)
else:
X_img_npys = np.load(img_path + X_img_list[j])
GT_img_npys = np.load(img_path + GT_img_list[j])
ytrue_img_npys = np.load(img_path + ytrue_img_list[j])
X_img_npys = transform.resize(X_img_npys, (X_img_npys.shape[0], newSize[0], newSize[1]), order=0,
preserve_range=True, mode='constant', anti_aliasing=False,
anti_aliasing_sigma=None)
GT_img_npys = transform.resize(GT_img_npys, (GT_img_npys.shape[0], newSize[0], newSize[1]), order=0,
preserve_range=True, mode='constant', anti_aliasing=False,
anti_aliasing_sigma=None)
ytrue_img_npys = transform.resize(ytrue_img_npys, (ytrue_img_npys.shape[0], newSize[0], newSize[1]),
order=0,
preserve_range=True, mode='constant', anti_aliasing=False,
anti_aliasing_sigma=None)
X_test = np.reshape(X_img_npys, (X_img_npys.shape[0], X_img_npys.shape[1], X_img_npys.shape[2], 1))
GT_test = np.reshape(GT_img_npys, (GT_img_npys.shape[0], GT_img_npys.shape[1], GT_img_npys.shape[2], 1))
ytrue = np.reshape(ytrue_img_npys,
(ytrue_img_npys.shape[0], ytrue_img_npys.shape[1], ytrue_img_npys.shape[2], 1))
test_dataset = tf.data.Dataset.from_tensor_slices((X_test, ytrue))
test_dataset = test_dataset.batch(batch_size=1)
TP_num = test_patients[j]
###################################################################################
detailed_images = False
npys3d = True
###################################################################################
test_loss = []
test_loss_hdd = []
test_loss_hdd2 = []
y_pred3d = []
for features in test_dataset:
image, y_true = features
y_true = onehotencode(y_true)
y_pred = Unet(image, weights, filter_multiplier, training=False)
loss = dice_loss(y_pred, y_true)
loss = tf.make_ndarray(tf.make_tensor_proto(loss))
test_loss.append(loss)
#print(test_loss)
try:
y_true_np = np.squeeze(y_true[0, :, :, 0].numpy() > 0.5)
y_true_np = y_true_np.astype(np.float_)
pred_np = np.squeeze(y_pred[0, :, :, 0].numpy() > 0.5)
pred_np = pred_np.astype(np.float_)
hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter()
hausdorff_distance_filter.Execute(sitk.GetImageFromArray(y_true_np), sitk.GetImageFromArray(pred_np))
test_loss_hdd.append(hausdorff_distance_filter.GetHausdorffDistance())
hausdorff_distance_filter2 = sitk.HausdorffDistanceImageFilter()
hausdorff_distance_filter2.Execute(sitk.GetImageFromArray(pred_np), sitk.GetImageFromArray(y_true_np))
test_loss_hdd2.append(hausdorff_distance_filter2.GetHausdorffDistance())
if y_pred3d == []:
y_pred3d = pred_np
y_true3d = y_true_np
else:
y_pred3d = np.append(pred_np, y_pred3d, axis=0)
y_true3d = np.append(y_true_np, y_true3d, axis=0)
except:
pass
npy_savepath = "finalResults/3dnpys/" + model_path
if not os.path.exists(npy_savepath):
os.makedirs(npy_savepath)
np.save(npy_savepath + "P" + str(TP_num).zfill(2) + "_pred3D", y_pred3d)
#####
if __name__ == "__main__":
main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment