Commit 3f1de1da authored by Robin's avatar Robin
Browse files

d

parent fccb1833
......@@ -179,7 +179,7 @@ def main():
#number_patients = number_patients * 5
img_path = "data/npy_thresh/"
specificmodels = [6, 7]
specificmodels = [2]
if cross_val:
log = open("logs" + modelname + ".txt", "w+")
log.write(modelname + "\r")
......
......@@ -442,7 +442,7 @@ class BaseNetwork:
self.log.write('Loading Training Data ...'+"\r")
### initializing weights ###
weights = init_weights(manual_seed=np.sum(self.test_patients))
weights = init_weights(manual_seed=np.sum(self.test_patients)*10)
filter_multiplier = int(self.newSize[0] / 128)
......
......@@ -24,6 +24,7 @@ def main():
#modeltype = "Class_"
modeltype = "Cluster_class_"
#modeltype = "Unet_"
all3dhdds = []
##################### U-Net #####################
for fold in models:
......@@ -179,15 +180,16 @@ def main():
test_loss_hdd = []
test_loss_hdd2 = []
y_pred3d = []
y_true3d = []
for features in test_dataset:
image, y_true = features
y_true = onehotencode(y_true)
y_pred = Unet(image, weights, filter_multiplier, training=False)
if y_pred3d == []:
'''if y_pred3d == []:
y_pred3d = y_pred[:,:,:,0].numpy()
else:
y_pred3d = np.append(y_pred[:,:,:,0].numpy(), y_pred3d, axis=0)
y_pred3d = np.append(y_pred[:,:,:,0].numpy(), y_pred3d, axis=0)'''
loss = dice_loss(y_pred, y_true)
loss = tf.make_ndarray(tf.make_tensor_proto(loss))
......@@ -205,6 +207,15 @@ def main():
hausdorff_distance_filter2 = sitk.HausdorffDistanceImageFilter()
hausdorff_distance_filter2.Execute(sitk.GetImageFromArray(pred_np), sitk.GetImageFromArray(y_true_np))
test_loss_hdd2.append(hausdorff_distance_filter2.GetHausdorffDistance())
if y_pred3d ==[]:
y_pred3d = pred_np
y_true3d = y_true_np
else:
y_pred3d = np.append(pred_np, y_pred3d, axis=0)
y_true3d = np.append(y_true_np, y_true3d, axis=0)
except:
pass
......@@ -223,18 +234,23 @@ def main():
print(tf.shape(ytrue))
y_true3d = np.squeeze(ytrue[:,:,:,0].numpy() > 0.5)
y_true3d = y_true.astype(np.float_)
y_true3d = np.float_(y_true)
print(np.shape(y_true3d))
print(type(y_true3d))
print(type(y_pred3d))
pred3d = np.squeeze(y_pred3d[:,:,:].numpy() > 0.5)
pred3d = pred3d.astype(np.float_)
pred3d = np.float(pred3d)
print(np.shape(pred3d))
print(type(pred3d))
print(type(pred3d))'''
hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter()
hausdorff_distance_filter.Execute(sitk.GetImageFromArray(y_true3d), sitk.GetImageFromArray(pred3d))
hausdorff_distance_filter.Execute(sitk.GetImageFromArray(y_true3d), sitk.GetImageFromArray(y_pred3d))
hdd3d = hausdorff_distance_filter.GetHausdorffDistance()
print("3D HDD for patient", test_patients[j], ":", hdd3d)'''
print("3D HDD for patient", test_patients[j], ":", hdd3d)
all3dhdds.append(hdd3d)
print("All 3d-HDDs: ", all3dhdds)
print(folder_path)
print(modeltype)
#####
if __name__ == "__main__":
......
......@@ -274,9 +274,8 @@ def main():
matplotlib.image.imsave(img_savepath + "P" + str(TP_num).zfill(2) + "_" + str(counter).zfill(2) + "_unetpred"+ str(loss)[0:6] +".png",
y_pred[0, :, :, 0], cmap=plt.cm.bone)
if npys3d:
#if npys3d:
#np.save(img_savepath + "P" + str(TP_num).zfill(2) + "_pred3D", y_pred3d)
print("ypred shape: ", np.shape(y_pred3d))
#plt.show()
#print(test_loss)
print("TestLoss Mean for P", test_patients[j], ": ", np.mean(test_loss))
......
......@@ -180,7 +180,7 @@ def main():
if npys3d:
filename = "P" + str(TP_num).zfill(2)
y_pred3d = cut_mr_ctunetpred(filename, y_pred3d)
ythresh3d = np.flip(cut_mr_ctunetpred(filename, GT_test))
ythresh3d = cut_mr_ctunetpred(filename, np.flip(GT_test))
#np.save(img_savepath + "P" + str(TP_num).zfill(2) + "_pred3D", y_pred3d)
print(np.shape(ythresh3d[:,:,:,0]))
np.save("data/npy_thresh/" + "P" + str(TP_num).zfill(2) + "_mr_ctunetpred.gipl", ythresh3d[:,:,:,0])
......
......@@ -9,16 +9,17 @@ import matplotlib.pyplot as plt
from scipy import ndimage
from datetime import datetime
pat = "06"
inputimg = np.load("data/npy_thresh/P01_mr_T1.gipl.npy")
gt = np.load("data/npy_thresh/P01_mr_ctunetpred.gipl.npy")
ytrue = np.load("data/npy_thresh/P01_segmr.gipl.npy")
inputimg = np.load("data/npy_thresh/P"+pat+"_mr_T1.gipl.npy")
gt = np.load("data/npy_thresh/P"+pat+"_mr_ctunetpred.gipl.npy")
ytrue = np.load("data/npy_thresh/P"+pat+"_segmr.gipl.npy")
print(np.shape(inputimg))
print(np.shape(gt))
print(np.shape(ytrue))
for i in range(25):
for i in range(np.shape(inputimg)[0]):
fig = plt.figure()
fig.add_subplot(2, 3, 1)
plt.imshow(inputimg[i, :, :], cmap=plt.cm.bone)
......
......@@ -368,4 +368,4 @@ def slicenumber_to_class(self, x):
if (x[1] >= 27 and x[1] < 29): sliceclass = 10
else: sliceclass = 0
#print('Sliceclass:', sliceclass)
return sliceclass
\ No newline at end of file
return sliceclass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment