Commit 6a3aba14 authored by Robin's avatar Robin
Browse files

d

parent 3f1de1da
......@@ -179,7 +179,7 @@ def main():
#number_patients = number_patients * 5
img_path = "data/npy_thresh/"
specificmodels = [2]
specificmodels = [7]
if cross_val:
log = open("logs" + modelname + ".txt", "w+")
log.write(modelname + "\r")
......
......@@ -442,7 +442,7 @@ class BaseNetwork:
self.log.write('Loading Training Data ...'+"\r")
### initializing weights ###
weights = init_weights(manual_seed=np.sum(self.test_patients)*10)
weights = init_weights(manual_seed=np.sum(self.test_patients))
filter_multiplier = int(self.newSize[0] / 128)
......
......@@ -135,8 +135,8 @@ def train_unet(model, inputs, gt, weights, optimizer, filter_multiplier):
def init_weights(manual_seed):
### initializing weights ###
#initializer = tf.initializers.glorot_uniform(seed=manual_seed)
initializer = tf.keras.initializers.TruncatedNormal(seed=manual_seed)
initializer = tf.initializers.glorot_uniform(seed=manual_seed)
#initializer = tf.keras.initializers.TruncatedNormal(seed=manual_seed)
shapes = [ # filter_height, filter_width, in_channels, out_channels
# for conv2d_transpose: filter_height, filter_width, out_channels, in_channels
[3, 3, 1, 16],
......
......@@ -19,11 +19,11 @@ from nets.Unet import *
def main():
models = ["12", "34", "56", "78", "910", "1112", "1314", "1516"]
folder_path = "finalResults/complete_seg/ae_class_cv_seg/"
folder_path = "finalResults/complete_segmr/mr_unet_cv_ctthresh/"
#modeltype = "Cluster_"
#modeltype = "Class_"
modeltype = "Cluster_class_"
#modeltype = "Unet_"
#modeltype = "Cluster_class_"
modeltype = "Unet_"
all3dhdds = []
##################### U-Net #####################
......@@ -62,7 +62,7 @@ def main():
print('GT Type:', gt_type)
### Load test patient
if gt_type == "thresh" or gt_type == "ctthresh_gt":
if gt_type == "thresh" or gt_type == "ctthresh_gt" or gt_type == "_mr_ctunetpred":
img_path = "data/npy_thresh/"
else:
img_path = "data/npy/"
......@@ -106,6 +106,8 @@ def main():
if gt_type == "thresh":
ytrue_img_list.append(elem)
ytrue_img_list.append(elem)
#if gt_type =="_mr_ctunetpred":
#X_img_list.remove("P17_mr_T1.gipl.npy")
list.sort(X_img_list)
list.sort(GT_img_list)
list.sort(ytrue_img_list)
......@@ -242,11 +244,14 @@ def main():
pred3d = np.float(pred3d)
print(np.shape(pred3d))
print(type(pred3d))'''
hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter()
hausdorff_distance_filter.Execute(sitk.GetImageFromArray(y_true3d), sitk.GetImageFromArray(y_pred3d))
hdd3d = hausdorff_distance_filter.GetHausdorffDistance()
print("3D HDD for patient", test_patients[j], ":", hdd3d)
all3dhdds.append(hdd3d)
try:
hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter()
hausdorff_distance_filter.Execute(sitk.GetImageFromArray(y_true3d), sitk.GetImageFromArray(y_pred3d))
hdd3d = hausdorff_distance_filter.GetHausdorffDistance()
print("3D HDD for patient", test_patients[j], ":", hdd3d)
all3dhdds.append(hdd3d)
except:
pass
print("All 3d-HDDs: ", all3dhdds)
print(folder_path)
print(modeltype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment