Skip to content
GitLab
Menu
Projects
Groups
Snippets
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
sjromuel
Masterarbeit
Commits
f66f6ab0
Commit
f66f6ab0
authored
Nov 24, 2020
by
sjromuel
Browse files
d
parent
404926dc
Changes
1
Hide whitespace changes
Inline
Side-by-side
RobinNet/rerunModels.py
0 → 100644
View file @
f66f6ab0
import
numpy
as
np
import
tensorflow
as
tf
import
matplotlib.pyplot
as
plt
import
tkinter
as
tk
import
os
import
matplotlib
from
tkinter
import
filedialog
from
skimage
import
transform
import
SimpleITK
as
sitk
import
argparse
#import os
#import pydot
#from graphviz import Digraph
#import shutil
#from tensorflow.keras import layers, models
#from utils.dataLoader import *
from
utils.other_functions
import
*
from
nets.Unet
import
*
def
main
():
models
=
[
"12"
,
"34"
,
"56"
,
"78"
,
"910"
,
"1112"
,
"1314"
,
"1516"
]
folder_path
=
"finalResults/complete_seg/ae_class_cv_seg/"
#modeltype = "Cluster_"
modeltype
=
"Class_"
#modeltype = "Cluster_class_"
#modeltype = "Unet_"
##################### U-Net #####################
for
fold
in
models
:
file_path
=
folder_path
+
"TPs"
+
fold
+
modeltype
### read out files ###
weights
=
np
.
load
(
file_path
+
"model.npy"
,
allow_pickle
=
True
)
[
test_patients
,
val_patients
,
number_patients
,
img_path
,
shrink_data
,
newSize
,
lr
,
batch_size
,
num_epochs
,
e
,
augment
,
save_path
,
gt_type
,
filter_multiplier
]
=
np
.
load
(
file_path
+
"params.npy"
,
allow_pickle
=
True
)
# autoencoder_model__e100_switchclass1024_nohiddenclusternet needs to comment out gt_type and val_patients
print
(
'Training Parameters:'
)
print
(
'-----------------'
)
print
(
'Number of Patients: '
,
number_patients
)
print
(
'Number of epochs: '
,
num_epochs
)
print
(
'Test Patient number: '
,
test_patients
)
print
(
'Image Size: '
,
newSize
)
print
(
'Filter Multiplier: '
,
filter_multiplier
)
print
(
'Data Augmentation: '
,
augment
)
print
(
'Learning rate: '
,
lr
)
print
(
'Image Path: '
,
img_path
)
print
(
'Save Path: '
,
save_path
)
print
(
'GT Type:'
,
gt_type
)
### Load test patient
if
gt_type
==
"thresh"
or
gt_type
==
"ctthresh_gt"
:
img_path
=
"data/npy_thresh/"
else
:
img_path
=
"data/npy/"
full_list
=
os
.
listdir
(
img_path
)
seg_list
=
os
.
listdir
(
"data/npy/"
)
X_img_list
=
[]
GT_img_list
=
[]
ytrue_img_list
=
[]
# thresh_img_list = []
if
"mr"
in
save_path
:
for
elem
in
full_list
:
if
elem
.
endswith
(
"T1.gipl.npy"
)
and
(
elem
.
startswith
(
'P'
+
str
(
test_patients
[
0
]).
zfill
(
2
))
or
elem
.
startswith
(
'P'
+
str
(
test_patients
[
1
]).
zfill
(
2
))):
X_img_list
.
append
(
elem
)
if
gt_type
==
"ctthresh_gt"
:
X_img_list
.
append
(
elem
)
X_img_list
.
append
(
elem
)
elif
elem
.
endswith
(
gt_type
+
".gipl.npy"
)
and
(
elem
.
startswith
(
'P'
+
str
(
test_patients
[
0
]).
zfill
(
2
))
or
elem
.
startswith
(
'P'
+
str
(
test_patients
[
1
]).
zfill
(
2
))):
GT_img_list
.
append
(
elem
)
for
elem
in
seg_list
:
if
elem
.
endswith
(
"segmr.gipl.npy"
)
and
(
elem
.
startswith
(
'P'
+
str
(
test_patients
[
0
]).
zfill
(
2
))
or
elem
.
startswith
(
'P'
+
str
(
test_patients
[
1
]).
zfill
(
2
))):
ytrue_img_list
.
append
(
elem
)
if
gt_type
==
"ctthresh_gt"
:
ytrue_img_list
.
append
(
elem
)
ytrue_img_list
.
append
(
elem
)
else
:
for
elem
in
full_list
:
if
elem
.
endswith
(
"ct.gipl.npy"
)
and
(
elem
.
startswith
(
'P'
+
str
(
test_patients
[
0
]).
zfill
(
2
))
or
elem
.
startswith
(
'P'
+
str
(
test_patients
[
1
]).
zfill
(
2
))):
X_img_list
.
append
(
elem
)
if
gt_type
==
"thresh"
:
X_img_list
.
append
(
elem
)
X_img_list
.
append
(
elem
)
elif
elem
.
endswith
(
gt_type
+
".gipl.npy"
)
and
(
elem
.
startswith
(
'P'
+
str
(
test_patients
[
0
]).
zfill
(
2
))
or
elem
.
startswith
(
'P'
+
str
(
test_patients
[
1
]).
zfill
(
2
))):
GT_img_list
.
append
(
elem
)
for
elem
in
seg_list
:
if
elem
.
endswith
(
"seg.gipl.npy"
)
and
(
elem
.
startswith
(
'P'
+
str
(
test_patients
[
0
]).
zfill
(
2
))
or
elem
.
startswith
(
'P'
+
str
(
test_patients
[
1
]).
zfill
(
2
))):
ytrue_img_list
.
append
(
elem
)
if
gt_type
==
"thresh"
:
ytrue_img_list
.
append
(
elem
)
ytrue_img_list
.
append
(
elem
)
list
.
sort
(
X_img_list
)
list
.
sort
(
GT_img_list
)
list
.
sort
(
ytrue_img_list
)
print
(
"Input Image List"
,
X_img_list
)
print
(
"GT Image List"
,
GT_img_list
)
print
(
"True Segmentation Image List"
,
ytrue_img_list
)
for
j
in
range
(
2
):
if
gt_type
==
"thresh"
or
gt_type
==
"ctthresh_gt"
:
X_img_npys
=
np
.
load
(
img_path
+
X_img_list
[
j
*
3
])
GT_img_npys
=
np
.
load
(
img_path
+
GT_img_list
[
j
*
3
])
ytrue_img_npys
=
np
.
load
(
img_path
+
ytrue_img_list
[
j
*
3
])
print
(
GT_img_list
[
j
*
3
])
print
(
GT_img_list
[
j
*
3
+
1
])
print
(
GT_img_list
[
j
*
3
+
2
])
X_img_npys
=
np
.
append
(
X_img_npys
,
np
.
load
(
img_path
+
X_img_list
[
j
*
3
+
1
]),
axis
=
0
)
GT_img_npys
=
np
.
append
(
GT_img_npys
,
np
.
load
(
img_path
+
GT_img_list
[
j
*
3
+
1
]),
axis
=
0
)
ytrue_img_npys
=
np
.
append
(
ytrue_img_npys
,
np
.
load
(
img_path
+
ytrue_img_list
[
j
*
3
+
1
]),
axis
=
0
)
X_img_npys
=
np
.
append
(
X_img_npys
,
np
.
load
(
img_path
+
X_img_list
[
j
*
3
+
2
]),
axis
=
0
)
GT_img_npys
=
np
.
append
(
GT_img_npys
,
np
.
load
(
img_path
+
GT_img_list
[
j
*
3
+
2
]),
axis
=
0
)
ytrue_img_npys
=
np
.
append
(
ytrue_img_npys
,
np
.
load
(
img_path
+
ytrue_img_list
[
j
*
3
+
2
]),
axis
=
0
)
print
(
"Input shape: "
,
np
.
shape
(
X_img_npys
))
print
(
"GT shape: "
,
np
.
shape
(
GT_img_npys
))
print
(
"True Segm shape: "
,
np
.
shape
(
ytrue_img_npys
))
else
:
X_img_npys
=
np
.
load
(
img_path
+
X_img_list
[
j
])
GT_img_npys
=
np
.
load
(
img_path
+
GT_img_list
[
j
])
ytrue_img_npys
=
np
.
load
(
img_path
+
ytrue_img_list
[
j
])
print
(
"Input shape: "
,
np
.
shape
(
X_img_npys
))
print
(
"GT shape: "
,
np
.
shape
(
GT_img_npys
))
print
(
"True Segm shape: "
,
np
.
shape
(
ytrue_img_npys
))
X_img_npys
=
transform
.
resize
(
X_img_npys
,
(
X_img_npys
.
shape
[
0
],
newSize
[
0
],
newSize
[
1
]),
order
=
0
,
preserve_range
=
True
,
mode
=
'constant'
,
anti_aliasing
=
False
,
anti_aliasing_sigma
=
None
)
GT_img_npys
=
transform
.
resize
(
GT_img_npys
,
(
GT_img_npys
.
shape
[
0
],
newSize
[
0
],
newSize
[
1
]),
order
=
0
,
preserve_range
=
True
,
mode
=
'constant'
,
anti_aliasing
=
False
,
anti_aliasing_sigma
=
None
)
ytrue_img_npys
=
transform
.
resize
(
ytrue_img_npys
,
(
ytrue_img_npys
.
shape
[
0
],
newSize
[
0
],
newSize
[
1
]),
order
=
0
,
preserve_range
=
True
,
mode
=
'constant'
,
anti_aliasing
=
False
,
anti_aliasing_sigma
=
None
)
X_test
=
np
.
reshape
(
X_img_npys
,
(
X_img_npys
.
shape
[
0
],
X_img_npys
.
shape
[
1
],
X_img_npys
.
shape
[
2
],
1
))
GT_test
=
np
.
reshape
(
GT_img_npys
,
(
GT_img_npys
.
shape
[
0
],
GT_img_npys
.
shape
[
1
],
GT_img_npys
.
shape
[
2
],
1
))
ytrue
=
np
.
reshape
(
ytrue_img_npys
,
(
ytrue_img_npys
.
shape
[
0
],
ytrue_img_npys
.
shape
[
1
],
ytrue_img_npys
.
shape
[
2
],
1
))
test_dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
((
X_test
,
ytrue
))
test_dataset
=
test_dataset
.
batch
(
batch_size
=
1
)
print
(
test_patients
)
TP_num
=
test_patients
[
j
]
###################################################################################
detailed_images
=
False
npys3d
=
True
###################################################################################
test_loss
=
[]
test_loss_hdd
=
[]
test_loss_hdd2
=
[]
y_pred3d
=
[]
for
features
in
test_dataset
:
image
,
y_true
=
features
y_true
=
onehotencode
(
y_true
)
y_pred
=
Unet
(
image
,
weights
,
filter_multiplier
,
training
=
False
)
loss
=
dice_loss
(
y_pred
,
y_true
)
loss
=
tf
.
make_ndarray
(
tf
.
make_tensor_proto
(
loss
))
test_loss
.
append
(
loss
)
#print(test_loss)
try
:
y_true_np
=
np
.
squeeze
(
y_true
[
0
,
:,
:,
0
].
numpy
()
>
0.5
)
y_true_np
=
y_true_np
.
astype
(
np
.
float_
)
pred_np
=
np
.
squeeze
(
y_pred
[
0
,
:,
:,
0
].
numpy
()
>
0.5
)
pred_np
=
pred_np
.
astype
(
np
.
float_
)
hausdorff_distance_filter
=
sitk
.
HausdorffDistanceImageFilter
()
hausdorff_distance_filter
.
Execute
(
sitk
.
GetImageFromArray
(
y_true_np
),
sitk
.
GetImageFromArray
(
pred_np
))
test_loss_hdd
.
append
(
hausdorff_distance_filter
.
GetHausdorffDistance
())
hausdorff_distance_filter2
=
sitk
.
HausdorffDistanceImageFilter
()
hausdorff_distance_filter2
.
Execute
(
sitk
.
GetImageFromArray
(
pred_np
),
sitk
.
GetImageFromArray
(
y_true_np
))
test_loss_hdd2
.
append
(
hausdorff_distance_filter2
.
GetHausdorffDistance
())
except
:
pass
#plt.show()
#print(test_loss)
print
(
"TestLoss Mean for P"
,
test_patients
[
j
],
": "
,
np
.
mean
(
test_loss
))
#print(test_loss_hdd)
print
(
"Hausdorff-Distance for P"
,
test_patients
[
j
],
":"
,
np
.
mean
(
test_loss_hdd
))
print
(
"Hausdorff-Distance2 for P"
,
test_patients
[
j
],
":"
,
np
.
mean
(
test_loss_hdd2
))
#####
if
__name__
==
"__main__"
:
main
()
\ No newline at end of file
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment