Commit fccb1833 authored by Robin's avatar Robin
Browse files

d

parent 97e31dc9
......@@ -179,7 +179,7 @@ def main():
#number_patients = number_patients * 5
img_path = "data/npy_thresh/"
specificmodels = [0]
specificmodels = [6, 7]
if cross_val:
log = open("logs" + modelname + ".txt", "w+")
log.write(modelname + "\r")
......
......@@ -21,8 +21,8 @@ def main():
models = ["12", "34", "56", "78", "910", "1112", "1314", "1516"]
folder_path = "finalResults/complete_seg/ae_class_cv_seg/"
#modeltype = "Cluster_"
modeltype = "Class_"
#modeltype = "Cluster_class_"
#modeltype = "Class_"
modeltype = "Cluster_class_"
#modeltype = "Unet_"
##################### U-Net #####################
......@@ -183,6 +183,11 @@ def main():
image, y_true = features
y_true = onehotencode(y_true)
y_pred = Unet(image, weights, filter_multiplier, training=False)
if y_pred3d == []:
y_pred3d = y_pred[:,:,:,0].numpy()
else:
y_pred3d = np.append(y_pred[:,:,:,0].numpy(), y_pred3d, axis=0)
loss = dice_loss(y_pred, y_true)
loss = tf.make_ndarray(tf.make_tensor_proto(loss))
......@@ -210,7 +215,27 @@ def main():
#print(test_loss_hdd)
print("Hausdorff-Distance for P", test_patients[j],":", np.mean(test_loss_hdd))
print("Hausdorff-Distance2 for P", test_patients[j], ":", np.mean(test_loss_hdd2))
#### Hausdorff 3D:
'''ytrue = tf.convert_to_tensor(ytrue)
y_pred3d = tf.convert_to_tensor(y_pred3d)
print(tf.shape(ytrue))
y_true3d = np.squeeze(ytrue[:,:,:,0].numpy() > 0.5)
y_true3d = y_true.astype(np.float_)
print(np.shape(y_true3d))
print(type(y_true3d))
print(type(y_pred3d))
pred3d = np.squeeze(y_pred3d[:,:,:].numpy() > 0.5)
pred3d = pred3d.astype(np.float_)
print(np.shape(pred3d))
print(type(pred3d))
hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter()
hausdorff_distance_filter.Execute(sitk.GetImageFromArray(y_true3d), sitk.GetImageFromArray(pred3d))
hdd3d = hausdorff_distance_filter.GetHausdorffDistance()
print("3D HDD for patient", test_patients[j], ":", hdd3d)'''
#####
if __name__ == "__main__":
main()
\ No newline at end of file
main()
......@@ -82,6 +82,7 @@ def main():
elif elem.endswith(gt_type + ".gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
GT_img_list.append(elem)
for elem in seg_list:
if elem.endswith("segmr.gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
ytrue_img_list.append(elem)
......@@ -275,7 +276,6 @@ def main():
if npys3d:
#np.save(img_savepath + "P" + str(TP_num).zfill(2) + "_pred3D", y_pred3d)
np.save("data/npy_thresh/" + "P" + str(TP_num).zfill(2) + "_mr_ctunetpred.gipl", y_pred3d)
print("ypred shape: ", np.shape(y_pred3d))
#plt.show()
#print(test_loss)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment