Commit 7e0a8872 authored by sjromuel's avatar sjromuel
Browse files

d

parent 1f317f5e
......@@ -64,30 +64,15 @@ class mrt_unet(BaseNetwork):
X_img_list = []
GT_img_list = []
# thresh_img_list = []
if transfer:
for elem in full_list:
if elem.startswith('P' + str(transpats[0]).zfill(2)) or elem.startswith('P' + str(transpats[1]).zfill(2)):
if elem.endswith(self.gt_type+".gipl.npy"):
GT_img_list.append(elem[0:4]+"segmr"+elem[-9:])
elif elem.endswith("T1.gipl.npy"):
X_img_list.append(elem)
else:
if elem.endswith("T1.gipl.npy"):
X_img_list.append(elem)
if self.gt_type == "thresh":
X_img_list.append(elem)
X_img_list.append(elem)
elif elem.endswith(self.gt_type+".gipl.npy"):
GT_img_list.append(elem)
else:
for elem in full_list:
if elem.endswith("T1.gipl.npy"):
for elem in full_list:
if elem.endswith("T1.gipl.npy"):
X_img_list.append(elem)
if self.gt_type == "thresh":
X_img_list.append(elem)
if self.gt_type == "thresh":
X_img_list.append(elem)
X_img_list.append(elem)
elif elem.endswith(self.gt_type+".gipl.npy"):
GT_img_list.append(elem)
X_img_list.append(elem)
elif elem.endswith(self.gt_type+".gipl.npy"):
GT_img_list.append(elem)
list.sort(X_img_list)
list.sort(GT_img_list)
......@@ -116,6 +101,7 @@ class mrt_unet(BaseNetwork):
augmented_image = image
return augmented_image
#def main(argv):
def main():
############### Parameters ################
......
......@@ -440,7 +440,7 @@ class BaseNetwork:
self.log.write('Loading Training Data ...'+"\r")
### initializing weights ###
weights = init_weights(manual_seed=np.sum(self.test_patients)*2)
weights = init_weights(manual_seed=np.sum(self.test_patients))
filter_multiplier = int(self.newSize[0] / 128)
......
......@@ -42,13 +42,12 @@ def dice_loss(pred, gt, axis=(1, 2, 3)):
b = tf.reduce_sum(gt[:,:,:,0]) + tf.reduce_sum(pred[:,:,:,0])
loss1 = 1 - (a/(b+0.00001))
a = 2 * tf.reduce_sum(gt[:, :, :, 1] * pred[:, :, :, 1])
'''a = 2 * tf.reduce_sum(gt[:, :, :, 1] * pred[:, :, :, 1])
b = tf.reduce_sum(gt[:, :, :, 1]) + tf.reduce_sum(pred[:, :, :, 1])
loss2 = 1 - (a / (b + 0.00001))
loss = (loss1+loss2)/2
return loss
loss = (loss1+loss2)/2'''
return loss1
# weights -> filter/kernel
def Unet(x, weights, filter_multiplier, training=True):
......@@ -116,9 +115,9 @@ def Unet(x, weights, filter_multiplier, training=True):
def train_unet(model, inputs, gt, weights, optimizer, filter_multiplier):
with tf.GradientTape() as tape:
pred = model(inputs, weights, filter_multiplier, training=True)
'''fig = plt.figure()
fig = plt.figure()
fig.add_subplot(2, 3, 1)
plt.imshow(inputs[0, :, :, 0], cmap=plt.cm.bone)
plt.imshow(inputs[1, :, :, 0], cmap=plt.cm.bone)
plt.title('Input')
fig.add_subplot(2, 3, 2)
plt.imshow(pred[1,:,:,0], cmap=plt.cm.bone)
......@@ -126,7 +125,7 @@ def train_unet(model, inputs, gt, weights, optimizer, filter_multiplier):
fig.add_subplot(2, 3, 3)
plt.imshow(gt[1,:,:,0], cmap=plt.cm.bone)
plt.title('True')
plt.show()'''
plt.show()
current_loss = dice_loss(pred, gt, axis=(1, 2, 3))
grads = tape.gradient(current_loss, weights)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment