Commit bfa5a3d6 authored by sjromuel's avatar sjromuel
Browse files

d

parent 12fe920b
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import argparse
#from utils.dataLoader import *
from utils.other_functions import *
from nets.Unet import *
from nets.BaseNetwork import BaseNetwork
from skimage import transform
import datetime
import os
import random
from scipy import ndimage
class mrt_unet(BaseNetwork):
def __init__(self,
test_patients=(1, 2),
val_patients= (3, 4, 5, 6),
number_patients=4,
img_path="../data/npy/",
shrink_data=True,
newSize=(512, 512),
lr=1e-4,
optimizer=tf.optimizers.Adam(learning_rate=1e-4),
batch_size=1,
num_epochs=4,
augment=False,
verbose=True,
save_weights=True,
save_path="saves/mrt_unet_",
remote=False,
gt_type="segmr",
e=0,
continue_training=False):
super().__init__( test_patients = test_patients,
val_patients = val_patients,
number_patients = number_patients,
img_path = img_path,
shrink_data = shrink_data,
newSize = newSize,
lr = lr,
optimizer = optimizer,
batch_size = batch_size,
num_epochs = num_epochs,
augment = augment,
verbose = verbose,
save_weights = save_weights,
save_path = save_path,
remote = remote,
e = e,
gt_type = gt_type,
continue_training = continue_training)
def get_data_lists(self, transfer=False): # splits data list in base__img_list, seg_img_list, thresh_img_list
full_list = os.listdir(self.img_path)
vallist = list(range(1, self.number_patients + 1)) + list(range(1, self.number_patients + 1))
transpats = [vallist[self.val_patients[-1]], vallist[self.val_patients[-1]+1]]
#print("transfer patienten: ", transpats)
X_img_list = []
GT_img_list = []
# thresh_img_list = []
if transfer:
for elem in full_list:
if elem.startswith('P' + str(transpats[0]).zfill(2)) or elem.startswith('P' + str(transpats[1]).zfill(2)):
if elem.endswith(self.gt_type+".gipl.npy"):
GT_img_list.append(elem[0:4]+"segmr"+elem[-9:])
elif elem.endswith("T1.gipl.npy"):
X_img_list.append(elem)
else:
if elem.endswith("T1.gipl.npy"):
X_img_list.append(elem)
if self.gt_type == "thresh":
X_img_list.append(elem)
X_img_list.append(elem)
elif elem.endswith(self.gt_type+".gipl.npy"):
GT_img_list.append(elem)
else:
for elem in full_list:
if elem.endswith("T1.gipl.npy"):
X_img_list.append(elem)
if self.gt_type == "thresh":
X_img_list.append(elem)
X_img_list.append(elem)
elif elem.endswith(self.gt_type+".gipl.npy"):
GT_img_list.append(elem)
list.sort(X_img_list)
list.sort(GT_img_list)
#print(X_img_list)
#print(GT_img_list)
return X_img_list, GT_img_list
def augment_slice(self, image, image_name, augmentation_type='normal'):
if augmentation_type=='fliplr':
augmented_image = np.fliplr(image)
elif augmentation_type == 'flipud':
augmented_image = np.flipud(image)
elif augmentation_type == 'rotate':
angle = random.randint(-5, 5)
if 'segmr' in image_name or 'thresh' in image_name or 'ctfgt' in image_name:
augmented_image = ndimage.rotate(image, angle, reshape=False, mode = 'nearest', order=0)
elif 'T1' in image_name:
augmented_image = ndimage.rotate(image, angle, reshape=False)
elif augmentation_type == 'normal':
augmented_image = image
else:
augmented_image = image
return augmented_image
#def main(argv):
def main():
############### Parameters ################
test_patients = (1, 2)
val_patients = (3, 4, 5, 6)
number_patients = 17
img_path = "../data/npy/"
shrink_data = True # needs to be debugged -> only works if true. See autoencoder_main.py
newSize = (512, 512) # only if shrink data = True
lr = 1e-4
optimizer = tf.optimizers.Adam(learning_rate=lr)
batch_size = 5
num_epochs = 50
augment = ['normal', 'fliplr', 'flipud', 'rotate']
verbose = True
save_weights = True
save_path = "saves/"
remote = False
cross_val = True
continue_training = False
e = 0
gt_type = "ctseg_gt"
modelname = "mr_unet_cv_segmr"
###############################
parser = argparse.ArgumentParser()
parser.add_argument('-t', '--test_patients', nargs='+')
parser.add_argument('-v', '--val_patients', nargs='+')
parser.add_argument('-e', '--num_epochs')
parser.add_argument('-p', '--number_patients')
parser.add_argument('-m', '--modelname')
parser.add_argument('-n', '--newSize')
parser.add_argument('-a', '--augment', action='store_true')
parser.add_argument('-r', '--remote', action='store_true')
parser.add_argument('-c', '--continue_training', action='store_true')
parser.add_argument('-g', '--gt_type')
args = parser.parse_args()
if args.test_patients:
test_patients = tuple(args.test_patients)
if args.val_patients:
val_patients = tuple(args.val_patients)
if args.num_epochs:
num_epochs = int(args.num_epochs)
if args.number_patients:
number_patients = int(args.number_patients)
if args.augment:
augment = ['normal', 'fliplr', 'flipud', 'rotate']
if args.remote:
remote = True
if args.continue_training:
continue_training = True
if args.newSize:
newSize = (int(args.newSize), int(args.newSize))
if args.gt_type:
gt_type = args.gt_type
if args.modelname:
modelname = args.modelname
save_path = save_path + modelname + "/"
if remote:
save_path = "/home/pv/saves/" + modelname + "/"
img_path = "/home/pv/data/npy/"
if gt_type == "thresh":
img_path = "/home/pv/data/npy_thresh/"
#number_patients = number_patients * 5
if not os.path.exists("/home/pv/saves/"+modelname):
os.makedirs("/home/pv/saves/"+modelname)
else:
if not os.path.exists("saves/" + modelname):
os.makedirs("saves/" + modelname)
if gt_type == "thresh":
#number_patients = number_patients * 5
img_path = "data/npy_thresh/"
#specificmodels = [6]
if cross_val:
log = open("logs" + modelname + ".txt", "w+")
log.write(modelname + "\r")
log.write("Start Cross Validation Training \r")
log.close()
print("Start Cross Validation Training")
#for validation_round in specificmodels:
for validation_round in range(number_patients//2):
log = open("logs" + modelname + ".txt", "a+")
test_patients = (2*validation_round+1, 2*validation_round+2)
vallist= list(range(1, number_patients+1)) + list(range(1, number_patients+1))
val_patients = (vallist[2*validation_round+2], vallist[2*validation_round+3],vallist[2*validation_round+4],vallist[2*validation_round+5])
if verbose:
print("--" * 50)
print('Model round', validation_round+1, 'of', number_patients//2)
print("Train Model with Test Patients ", test_patients, 'and Validation Patients', val_patients)
print("--" * 50)
log.write("--" * 50 + "\r")
log.write('Model round' + str(validation_round + 1) + 'of' + str(number_patients // 2) + "\r")
log.write("Train Model with Test Patients " + str(test_patients) + "\r" + 'and Validation Patients' + str(val_patients) + "\r")
log.write("--" * 50 + "\r")
log.close()
my_nn = mrt_unet(test_patients=test_patients,
val_patients=val_patients,
number_patients=number_patients,
img_path=img_path,
shrink_data=shrink_data,
newSize=newSize,
lr=lr,
optimizer=optimizer,
batch_size=batch_size,
num_epochs=num_epochs,
augment=augment,
verbose=verbose,
save_weights=save_weights,
save_path=save_path + "TPs" + str(test_patients[0]) + str(test_patients[1]),
remote=remote,
e=e,
continue_training=continue_training,
gt_type=gt_type)
my_nn.train_standard_unet()
else:
if continue_training:
[test_patients,
val_patients,
number_patients,
img_path,
shrink_data,
newSize,
lr,
batch_size ,
num_epochs,
e,
augment,
save_path,
gt_type,
filter_multiplier] = np.load(save_path + "params" + ".npy", allow_pickle=True)
my_nn = BaseNetwork(test_patients=test_patients,
val_patients=val_patients,
number_patients=number_patients,
img_path=img_path,
shrink_data=shrink_data,
newSize=newSize,
lr=lr,
optimizer=optimizer,
batch_size=batch_size,
num_epochs=num_epochs,
augment=augment,
verbose=verbose,
save_weights=save_weights,
save_path=save_path,
remote=remote,
e=e,
continue_training=continue_training,
gt_type=gt_type)
my_nn.train_standard_unet()
if __name__ == "__main__":
main()
#if __name__ == "__main__":
# main(sys.argv[1:])
'''
ToDo:
- auf winPC --> git nicht im venv benutzen
- plot problem in run_trained_model
- data augmentation --> Problem, zu großes dataset dann?
'''
......@@ -502,7 +502,7 @@ def main():
optimizer = tf.optimizers.Adam(learning_rate=1e-4)
newSize = (512, 512)
save_weights = True
tasks = ['encode', 'cluster', 'train'] # encode for loading slices and computing latent vectors
tasks = ['encode', 'cluster'] # encode for loading slices and computing latent vectors
num_epochs = 50
showplots = False
remote = False
......@@ -551,16 +551,16 @@ def main():
if not os.path.exists("saves/" + modelname):
os.makedirs("saves/" + modelname)
specificmodels = [3, 4, 6]
#specificmodels = [4, 5, 6, 7]
#specificmodels = [3, 4, 6]
specificmodels = [0, 1, 2, 3, 4, 5, 6, 7]
if cross_val:
log = open("logs_clusternet" + modelname + ".txt", "w+")
log.write(modelname + "\r")
log.write("Start Cross Validation Training \r")
log.close()
print("Start Cross Validation Training")
#for validation_round in specificmodels:
for validation_round in range(number_patients//2):
for validation_round in specificmodels:
#for validation_round in range(number_patients//2):
#log = open("logs_clusternet" + modelname + ".txt", "a+")
test_patients = (2*validation_round+1, 2*validation_round+2)
vallist= list(range(1, number_patients+1)) + list(range(1, number_patients+1))
......
import numpy as np
gt_path = "data/npy/"
threshs = ["100", "250", "400"]
for patientnumber in range(17):
filename = "P"+str(patientnumber+1).zfill(2)+"_seg.gipl.npy"
print(filename)
patient_array = np.load(gt_path+filename)
if filename.startswith('P01'):
corr_patient_array = patient_array[2:27, :, :]
elif filename.startswith('P02'):
corr_patient_array = patient_array[5:, :, :]
elif filename.startswith('P03'):
corr_patient_array = patient_array[10:29, :, :]
elif filename.startswith('P04'):
corr_patient_array = patient_array[9:, :, :]
elif filename.startswith('P05'):
corr_patient_array = patient_array[5:31, :, :]
elif filename.startswith('P06'):
corr_patient_array = patient_array[7:, :, :]
elif filename.startswith('P07'):
corr_patient_array = patient_array[2:, :, :]
elif filename.startswith('P08'):
corr_patient_array = patient_array[3:, :, :]
elif filename.startswith('P09'):
corr_patient_array = patient_array[1:, :, :]
elif filename.startswith('P11'):
corr_patient_array = patient_array[2:, :, :]
elif filename.startswith('P12'):
corr_patient_array = patient_array[:36, :, :]
elif filename.startswith('P17'):
corr_patient_array = patient_array[6:, :, :]
else:
corr_patient_array = patient_array
np.save(gt_path+"P"+str(patientnumber+1).zfill(2)+"mr_ctseg_gt.gipl", corr_patient_array)
thresh_gt_path = "data/npy_thresh/"
for patientnumber in range(17):
for thresh in threshs:
filename = "P"+str(patientnumber+1).zfill(2)+"_"+thresh+"_thresh.gipl.npy"
print(filename)
patient_array = np.load(thresh_gt_path+filename)
if filename.startswith('P01'):
corr_patient_array = patient_array[2:27, :, :]
elif filename.startswith('P02'):
corr_patient_array = patient_array[5:, :, :]
elif filename.startswith('P03'):
corr_patient_array = patient_array[10:29, :, :]
elif filename.startswith('P04'):
corr_patient_array = patient_array[9:, :, :]
elif filename.startswith('P05'):
corr_patient_array = patient_array[5:31, :, :]
elif filename.startswith('P06'):
corr_patient_array = patient_array[7:, :, :]
elif filename.startswith('P07'):
corr_patient_array = patient_array[2:, :, :]
elif filename.startswith('P08'):
corr_patient_array = patient_array[3:, :, :]
elif filename.startswith('P09'):
corr_patient_array = patient_array[1:, :, :]
elif filename.startswith('P11'):
corr_patient_array = patient_array[2:, :, :]
elif filename.startswith('P12'):
corr_patient_array = patient_array[:36, :, :]
elif filename.startswith('P17'):
corr_patient_array = patient_array[6:, :, :]
else:
corr_patient_array = patient_array
np.save(thresh_gt_path+"P"+str(patientnumber+1).zfill(2)+"mr_ctthresh"+thresh+"_gt.gipl", corr_patient_array)
\ No newline at end of file
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment