Commit 276b29e0 authored by sjromuel's avatar sjromuel
Browse files

d

parent 6eca8d37
...@@ -115,7 +115,7 @@ def Unet(x, weights, filter_multiplier, training=True): ...@@ -115,7 +115,7 @@ def Unet(x, weights, filter_multiplier, training=True):
def train_unet(model, inputs, gt, weights, optimizer, filter_multiplier): def train_unet(model, inputs, gt, weights, optimizer, filter_multiplier):
with tf.GradientTape() as tape: with tf.GradientTape() as tape:
pred = model(inputs, weights, filter_multiplier, training=True) pred = model(inputs, weights, filter_multiplier, training=True)
fig = plt.figure() '''fig = plt.figure()
fig.add_subplot(2, 3, 1) fig.add_subplot(2, 3, 1)
plt.imshow(inputs[1, :, :, 0], cmap=plt.cm.bone) plt.imshow(inputs[1, :, :, 0], cmap=plt.cm.bone)
plt.title('Input') plt.title('Input')
...@@ -125,7 +125,7 @@ def train_unet(model, inputs, gt, weights, optimizer, filter_multiplier): ...@@ -125,7 +125,7 @@ def train_unet(model, inputs, gt, weights, optimizer, filter_multiplier):
fig.add_subplot(2, 3, 3) fig.add_subplot(2, 3, 3)
plt.imshow(gt[1,:,:,0], cmap=plt.cm.bone) plt.imshow(gt[1,:,:,0], cmap=plt.cm.bone)
plt.title('True') plt.title('True')
plt.show() plt.show()'''
current_loss = dice_loss(pred, gt, axis=(1, 2, 3)) current_loss = dice_loss(pred, gt, axis=(1, 2, 3))
grads = tape.gradient(current_loss, weights) grads = tape.gradient(current_loss, weights)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment