Commit a480def9 authored by Robin's avatar Robin
Browse files

d

parent 404926dc
......@@ -124,7 +124,7 @@ def main():
continue_training = False
e = 0
gt_type = "_mr_ctunetpred"
modelname = "mr_unet_cv_ctthresh"
modelname = "mr_unet_cv_unetpred"
###############################
......
......@@ -442,7 +442,7 @@ class BaseNetwork:
self.log.write('Loading Training Data ...'+"\r")
### initializing weights ###
weights = init_weights(manual_seed=np.sum(self.test_patients)*7)
weights = init_weights(manual_seed=np.sum(self.test_patients))
filter_multiplier = int(self.newSize[0] / 128)
......
......@@ -21,7 +21,7 @@ def main():
root = tk.Tk()
root.withdraw()
file_path = filedialog.askopenfilename(initialdir="finalResults/complete_segmr/mr_unet_cv_ctthresh")
file_path = filedialog.askopenfilename(initialdir="saves/mr_unet_cv_unetpred")
file_path = file_path[:-9]
print(file_path)
......@@ -61,7 +61,7 @@ def main():
print('GT Type:', gt_type)
### Load test patient
if gt_type == "thresh" or gt_type == "ctthresh_gt":
if gt_type == "thresh" or gt_type == "ctthresh_gt" or gt_type == "_mr_ctunetpred":
img_path = "data/npy_thresh/"
else:
img_path = "data/npy/"
......@@ -88,6 +88,7 @@ def main():
if gt_type == "ctthresh_gt":
ytrue_img_list.append(elem)
ytrue_img_list.append(elem)
else:
for elem in full_list:
......@@ -105,6 +106,8 @@ def main():
if gt_type == "thresh":
ytrue_img_list.append(elem)
ytrue_img_list.append(elem)
#if gt_type =="_mr_ctunetpred":
# X_img_list.remove("P17_mr_T1.gipl.npy")
list.sort(X_img_list)
list.sort(GT_img_list)
list.sort(ytrue_img_list)
......@@ -442,4 +445,4 @@ def show_loss_plot(slice, loss_list):
if __name__ == "__main__":
main()
\ No newline at end of file
main()
......@@ -180,8 +180,10 @@ def main():
if npys3d:
filename = "P" + str(TP_num).zfill(2)
y_pred3d = cut_mr_ctunetpred(filename, y_pred3d)
ythresh3d = np.flip(cut_mr_ctunetpred(filename, GT_test))
#np.save(img_savepath + "P" + str(TP_num).zfill(2) + "_pred3D", y_pred3d)
np.save("data/npy_thresh/" + "P" + str(TP_num).zfill(2) + "_mr_ctunetpred.gipl", y_pred3d)
print(np.shape(ythresh3d[:,:,:,0]))
np.save("data/npy_thresh/" + "P" + str(TP_num).zfill(2) + "_mr_ctunetpred.gipl", ythresh3d[:,:,:,0])
print("ypred shape: ", np.shape(y_pred3d))
#plt.show()
#print(test_loss)
......@@ -351,4 +353,4 @@ def show_loss_plot(slice, loss_list):
if __name__ == "__main__":
main()
\ No newline at end of file
main()
......@@ -10,12 +10,32 @@ from scipy import ndimage
from datetime import datetime
img = np.load("data/npy_thresh/P01_250_thresh.gipl.npy")
print(np.shape(img))
inputimg = np.load("data/npy_thresh/P01_mr_T1.gipl.npy")
gt = np.load("data/npy_thresh/P01_mr_ctunetpred.gipl.npy")
ytrue = np.load("data/npy_thresh/P01_segmr.gipl.npy")
print(np.shape(inputimg))
print(np.shape(gt))
print(np.shape(ytrue))
for i in range(25):
fig = plt.figure()
fig.add_subplot(2, 3, 1)
plt.imshow(inputimg[i, :, :], cmap=plt.cm.bone)
plt.title('Input')
fig.add_subplot(2, 3, 2)
plt.imshow(gt[i,:,:], cmap=plt.cm.bone)
plt.title('gt')
fig.add_subplot(2, 3, 3)
plt.imshow(ytrue[i,:,:], cmap=plt.cm.bone)
plt.title('True')
plt.show()
'''print(np.shape(img))
img = np.flip(img)
#for i in range(np.shape(img)[0]):
plt.imshow(np.flipud(img[1,:, :]), cmap=plt.cm.bone)
plt.show()
plt.show()'''
'''mr_img = np.load("data/npy/P03_mr_T1.gipl.npy")
ct_img = np.load("data/npy/P03_mr_ctseg_gt.gipl.npy")
......@@ -82,4 +102,4 @@ val_pat_images = np.load(file_path)
for slice in val_pat_images:
plt.imshow(slice[:,:,0], cmap=plt.cm.bone)
plt.show()
'''
\ No newline at end of file
'''
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment