Commit 137b996a authored by sjromuel's avatar sjromuel
Browse files

d

parent 9e7186eb
......@@ -21,7 +21,7 @@ def main():
root = tk.Tk()
root.withdraw()
file_path = filedialog.askopenfilename(initialdir="saves/mr_unet_cv_ctthresh")
file_path = filedialog.askopenfilename(initialdir="saves/unet_cv_thresh")
file_path = file_path[:-9]
print(file_path)
......@@ -61,10 +61,8 @@ def main():
print('GT Type:', gt_type)
### Load test patient
if gt_type == "thresh" or gt_type == "ctthresh_gt":
img_path = "data/npy_thresh/"
else:
img_path = "data/npy/"
img_path = "data/npy_thresh/"
full_list = os.listdir(img_path)
seg_list = os.listdir("data/npy/")
......@@ -72,16 +70,15 @@ def main():
GT_img_list = []
ytrue_img_list = []
# thresh_img_list = []
if "mr" in save_path:
for elem in full_list:
if elem.endswith("ct.gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
X_img_list.append(elem)
for elem in full_list:
if elem.endswith("ct.gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
X_img_list.append(elem)
elif elem.endswith("250_thresh.gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
GT_img_list.append(elem)
for elem in seg_list:
if elem.endswith("seg.gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
ytrue_img_list.append(elem)
elif elem.endswith("250_thresh.gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
GT_img_list.append(elem)
for elem in seg_list:
if elem.endswith("seg.gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
ytrue_img_list.append(elem)
list.sort(X_img_list)
......@@ -98,6 +95,7 @@ def main():
os.makedirs(img_savepath)
for j in range(2):
print(img_path + X_img_list[j])
X_img_npys = np.load(img_path + X_img_list[j])
GT_img_npys = np.load(img_path + GT_img_list[j])
ytrue_img_npys = np.load(img_path + ytrue_img_list[j])
......@@ -181,6 +179,8 @@ def main():
if npys3d:
y_pred3d = np.flip(y_pred3d)
filename = "P" + str(TP_num).zfill(2)
y_pred3d = cut_mr_ctunetpred(filename, y_pred3d)
#np.save(img_savepath + "P" + str(TP_num).zfill(2) + "_pred3D", y_pred3d)
np.save("data/npy_thresh/" + "P" + str(TP_num).zfill(2) + "_mr_ctunetpred.gipl", y_pred3d)
print("ypred shape: ", np.shape(y_pred3d))
......
......@@ -95,6 +95,48 @@ def save_val_prediction(pred, save_path, val_patient):
pred = np.append(npyload, pred, axis=0)
np.save(str(valsavepath)+str(val_patient), pred)
def cut_mr_ctunetpred(filename, patient_array):
if filename.startswith('P01'):
corr_patient_array = patient_array[2:27, :, :]
elif filename.startswith('P02'):
corr_patient_array = patient_array[5:, :, :]
elif filename.startswith('P03'):
corr_patient_array = patient_array[10:29, :, :]
elif filename.startswith('P04'):
corr_patient_array = patient_array[9:, :, :]
elif filename.startswith('P05'):
corr_patient_array = patient_array[5:31, :, :]
elif filename.startswith('P06'):
corr_patient_array = patient_array[7:, :, :]
elif filename.startswith('P07'):
corr_patient_array = patient_array[2:, :, :]
elif filename.startswith('P08'):
corr_patient_array = patient_array[3:, :, :]
elif filename.startswith('P09'):
corr_patient_array = patient_array[1:, :, :]
#elif filename.startswith('P11'):
# corr_patient_array = patient_array[:, :, :]
elif filename.startswith('P12'):
corr_patient_array = patient_array[:36, :, :]
elif filename.startswith('P17'):
corr_patient_array = patient_array[6:, :, :]
else:
corr_patient_array = patient_array
corr_patient_array = np.flip(corr_patient_array)
return corr_patient_array
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment