Commit 98f45f9a authored by sjromuel's avatar sjromuel
Browse files

d

parent 407c6741
......@@ -20,7 +20,7 @@ def main():
root = tk.Tk()
root.withdraw()
file_path = filedialog.askopenfilename(initialdir="finalResults/complete_seg/ae_class_cv_seg/")
file_path = filedialog.askopenfilename(initialdir="saves/")
file_path = file_path[:-9]
print(file_path)
......@@ -61,9 +61,9 @@ def main():
### Load test patient
if gt_type == "thresh":
img_path = "../data/npy_thresh/"
img_path = "data/npy_thresh/"
else:
img_path = "../data/npy/"
img_path = "data/npy/"
full_list = os.listdir(img_path)
seg_list = os.listdir("../data/npy/")
X_img_list = []
......@@ -71,20 +71,36 @@ def main():
ytrue_img_list = []
# thresh_img_list = []
for elem in full_list:
if elem.endswith("ct.gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
X_img_list.append(elem)
if gt_type == "thresh":
if "mr" in file_path:
for elem in full_list:
if elem.endswith("T1.gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
X_img_list.append(elem)
X_img_list.append(elem)
elif elem.endswith(gt_type+".gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
GT_img_list.append(elem)
for elem in seg_list:
if elem.endswith("seg.gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
ytrue_img_list.append(elem)
if gt_type == "thresh":
if gt_type == "ctthresh_gt":
X_img_list.append(elem)
X_img_list.append(elem)
elif elem.endswith(gt_type+".gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
GT_img_list.append(elem)
for elem in seg_list:
if elem.endswith("segmr.gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
ytrue_img_list.append(elem)
if gt_type == "ctthresh_gt":
ytrue_img_list.append(elem)
ytrue_img_list.append(elem)
else:
for elem in full_list:
if elem.endswith("ct.gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
X_img_list.append(elem)
if gt_type == "thresh":
X_img_list.append(elem)
X_img_list.append(elem)
elif elem.endswith(gt_type+".gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
GT_img_list.append(elem)
for elem in seg_list:
if elem.endswith("seg.gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
ytrue_img_list.append(elem)
if gt_type == "thresh":
ytrue_img_list.append(elem)
ytrue_img_list.append(elem)
list.sort(X_img_list)
list.sort(GT_img_list)
list.sort(ytrue_img_list)
......@@ -128,7 +144,7 @@ def main():
# print(tf.shape(y_pred), tf.shape(y_true))
#y_true = onehotencode(tf.reshape(y_true, (1, 512, 512, 1)), autoencoder=True)
#y_pred = tf.reshape(y_pred, (1, 512, 512, 2))
'''fig = plt.figure()
fig = plt.figure()
fig.add_subplot(1, 3, 1)
plt.title("Prediction")
plt.imshow(y_pred[0, :, :, 0], cmap=plt.cm.bone)
......@@ -137,7 +153,7 @@ def main():
plt.imshow(y_true[0, :, :, 0], cmap=plt.cm.bone)
fig.add_subplot(1, 3, 3)
plt.title("Train-Segmentation (Fake GT)")
plt.imshow(GT_test[counter, :, :, 0], cmap=plt.cm.bone)'''
plt.imshow(GT_test[counter, :, :, 0], cmap=plt.cm.bone)
counter = counter+1
loss = dice_loss(y_pred, y_true)
......@@ -155,7 +171,7 @@ def main():
test_loss_hdd.append(hausdorff_distance_filter.GetHausdorffDistance())
except:
pass
#plt.show()
# plt.show()
#print(test_loss)
print("TestLoss Mean for P", test_patients[j], ": ", np.mean(test_loss))
......
......@@ -21,7 +21,7 @@ def main():
root = tk.Tk()
root.withdraw()
file_path = filedialog.askopenfilename(initialdir="finalResults/complete_thresh/")
file_path = filedialog.askopenfilename(initialdir="saves/mr_unet_cv_ctthresh/")
file_path = file_path[:-9]
print(file_path)
......@@ -62,9 +62,9 @@ def main():
### Load test patient
if gt_type == "thresh":
img_path = "../data/npy_thresh/"
img_path = "data/npy_thresh/"
else:
img_path = "../data/npy/"
img_path = "data/npy/"
full_list = os.listdir(img_path)
seg_list = os.listdir("../data/npy/")
......@@ -263,6 +263,7 @@ def main():
if npys3d:
np.save(img_savepath + "P" + str(TP_num).zfill(2) + "_pred3D", y_pred3d)
np.save("data/npy/" + "P" + str(TP_num).zfill(2) + "_mr_ctunetpred.gipl", y_pred3d)
#plt.show()
#print(test_loss)
print("TestLoss Mean for P", test_patients[j], ": ", np.mean(test_loss))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment