Commit 2b1af9fe authored by sjromuel's avatar sjromuel
Browse files

f

parent 6850bf01
......@@ -20,11 +20,11 @@ spec:
nvidia.com/gpu: 1
volumeMounts:
- mountPath: home/pv/
name: p17crossval
name: p17head
volumes:
- name: p17crossval
- name: p17head
persistentVolumeClaim:
claimName: p17crossval
claimName: p17head
restartPolicy: Never
# kubectl cp sjromuel-ma/p16pv-deployment-657fb656c5-nljbv:/home/pv/saves/ saves/
......
......@@ -20,11 +20,11 @@ spec:
nvidia.com/gpu: 1
volumeMounts:
- mountPath: home/pv/
name: p17crossval
name: p17head
volumes:
- name: p17crossval
- name: p17head
persistentVolumeClaim:
claimName: p17crossval
claimName: p17head
restartPolicy: Never
# kubectl cp sjromuel-ma/p16pv-deployment-657fb656c5-nljbv:/home/pv/saves/ saves/
......
......@@ -20,11 +20,11 @@ spec:
nvidia.com/gpu: 1
volumeMounts:
- mountPath: home/pv/
name: p17crossval
name: p17head
volumes:
- name: p17crossval
- name: p17head
persistentVolumeClaim:
claimName: p17crossval
claimName: p17head
restartPolicy: Never
# kubectl cp sjromuel-ma/p16pv-deployment-657fb656c5-nljbv:/home/pv/saves/ saves/
......
......@@ -194,9 +194,14 @@ def main():
#specificmodels = [6]
if cross_val:
log = open("logs" + modelname + ".txt", "w+")
log.write(modelname + "\r")
log.write("Start Cross Validation Training \r")
log.close()
print("Start Cross Validation Training")
#for validation_round in specificmodels:
for validation_round in range(number_patients//2):
log = open("logs" + modelname + ".txt", "a+")
test_patients = (2*validation_round+1, 2*validation_round+2)
vallist= list(range(1, number_patients+1)) + list(range(1, number_patients+1))
val_patients = (vallist[2*validation_round+2], vallist[2*validation_round+3],vallist[2*validation_round+4],vallist[2*validation_round+5])
......@@ -205,6 +210,11 @@ def main():
print('Model round', validation_round+1, 'of', number_patients//2)
print("Train Model with Test Patients ", test_patients, 'and Validation Patients', val_patients)
print("--" * 50)
log.write("--" * 50 + "\r")
log.write('Model round' + str(validation_round + 1) + 'of' + str(number_patients // 2) + "\r")
log.write("Train Model with Test Patients " + str(test_patients) + "\r" + 'and Validation Patients' + str(val_patients) + "\r")
log.write("--" * 50 + "\r")
log.close()
my_nn = mrt_unet(test_patients=test_patients,
val_patients=val_patients,
number_patients=number_patients,
......
......@@ -67,26 +67,44 @@ def main():
img_path = "../data/npy/"
full_list = os.listdir(img_path)
seg_list = os.listdir("../data/npy/")
X_img_list = []
GT_img_list = []
ytrue_img_list = []
# thresh_img_list = []
for elem in full_list:
if elem.endswith("ct.gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
X_img_list.append(elem)
if gt_type == "thresh":
X_img_list.append(elem)
if "mr" in save_path:
for elem in full_list:
if elem.endswith("T1.gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
X_img_list.append(elem)
elif elem.endswith(gt_type+".gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
GT_img_list.append(elem)
for elem in seg_list:
if elem.endswith("seg.gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
ytrue_img_list.append(elem)
if gt_type == "thresh":
if gt_type == "thresh":
X_img_list.append(elem)
X_img_list.append(elem)
elif elem.endswith(gt_type + ".gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
GT_img_list.append(elem)
for elem in seg_list:
if elem.endswith("segmr.gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
ytrue_img_list.append(elem)
if gt_type == "thresh":
ytrue_img_list.append(elem)
ytrue_img_list.append(elem)
else:
for elem in full_list:
if elem.endswith("ct.gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
X_img_list.append(elem)
if gt_type == "thresh":
X_img_list.append(elem)
X_img_list.append(elem)
elif elem.endswith(gt_type+".gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
GT_img_list.append(elem)
for elem in seg_list:
if elem.endswith("seg.gipl.npy") and (elem.startswith('P' + str(test_patients[0]).zfill(2)) or elem.startswith('P' + str(test_patients[1]).zfill(2))):
ytrue_img_list.append(elem)
if gt_type == "thresh":
ytrue_img_list.append(elem)
ytrue_img_list.append(elem)
list.sort(X_img_list)
list.sort(GT_img_list)
list.sort(ytrue_img_list)
......@@ -171,7 +189,7 @@ def main():
#y_pred = tf.reshape(y_pred, (1, 512, 512, 2))
fig = plt.figure()
fig.add_subplot(1, 4, 1)
plt.title("ct_img")
plt.title("Input_img")
plt.axis('off')
plt.imshow(image[0,:,:,0], cmap=plt.cm.bone)
fig.add_subplot(1, 4, 2)
......@@ -197,22 +215,24 @@ def main():
loss = tf.make_ndarray(tf.make_tensor_proto(loss))
test_loss.append(loss)
#print(test_loss)
y_true_np = np.squeeze(y_true[0, :, :, 0].numpy() > 0)
y_true_np = y_true_np.astype(np.float_)
pred_np = np.squeeze(y_pred[0, :, :, 0].numpy() > 0)
pred_np = pred_np.astype(np.float_)
hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter()
hausdorff_distance_filter.Execute(sitk.GetImageFromArray(y_true_np), sitk.GetImageFromArray(pred_np))
test_loss_hdd.append(hausdorff_distance_filter.GetHausdorffDistance())
hausdorff_distance_filter2 = sitk.HausdorffDistanceImageFilter()
hausdorff_distance_filter2.Execute(sitk.GetImageFromArray(pred_np), sitk.GetImageFromArray(y_true_np))
test_loss_hdd2.append(hausdorff_distance_filter2.GetHausdorffDistance())
try:
y_true_np = np.squeeze(y_true[0, :, :, 0].numpy() > 0)
y_true_np = y_true_np.astype(np.float_)
pred_np = np.squeeze(y_pred[0, :, :, 0].numpy() > 0)
pred_np = pred_np.astype(np.float_)
hausdorff_distance_filter = sitk.HausdorffDistanceImageFilter()
hausdorff_distance_filter.Execute(sitk.GetImageFromArray(y_true_np), sitk.GetImageFromArray(pred_np))
test_loss_hdd.append(hausdorff_distance_filter.GetHausdorffDistance())
hausdorff_distance_filter2 = sitk.HausdorffDistanceImageFilter()
hausdorff_distance_filter2.Execute(sitk.GetImageFromArray(pred_np), sitk.GetImageFromArray(y_true_np))
test_loss_hdd2.append(hausdorff_distance_filter2.GetHausdorffDistance())
except:
pass
# save images:
matplotlib.image.imsave(img_savepath + "P" + str(TP_num).zfill(2) +"_" + str(counter).zfill(2)+"_ctimg.png", image[0,:,:,0],
matplotlib.image.imsave(img_savepath + "P" + str(TP_num).zfill(2) +"_" + str(counter).zfill(2)+"_Inputimg.png", image[0,:,:,0],
cmap=plt.cm.bone)
matplotlib.image.imsave(img_savepath + "P" + str(TP_num).zfill(2) +"_" + str(counter).zfill(2) + "_ytrue.png",
y_true[0, :, :, 0],
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment