Commit 91eca17d authored by sjromuel's avatar sjromuel
Browse files

d

parent a9c7211b
......@@ -33,7 +33,8 @@ class mrt_unet(BaseNetwork):
remote=False,
gt_type="segmr",
e=0,
continue_training=False):
continue_training=False,
modelname="mr_unet_cv_segmr"):
super().__init__( test_patients = test_patients,
val_patients = val_patients,
......@@ -52,7 +53,8 @@ class mrt_unet(BaseNetwork):
remote = remote,
e = e,
gt_type = gt_type,
continue_training = continue_training)
continue_training = continue_training,
modelname = modelname)
def get_data_lists(self, transfer=False): # splits data list in base__img_list, seg_img_list, thresh_img_list
full_list = os.listdir(self.img_path)
......@@ -232,7 +234,8 @@ def main():
remote=remote,
e=e,
continue_training=continue_training,
gt_type=gt_type)
gt_type=gt_type,
modelname=modelname)
my_nn.train_standard_unet()
......@@ -271,7 +274,8 @@ def main():
remote=remote,
e=e,
continue_training=continue_training,
gt_type=gt_type)
gt_type=gt_type,
modelname=modelname)
my_nn.train_standard_unet()
......
......@@ -33,7 +33,8 @@ class mrt_unet(BaseNetwork):
remote=False,
gt_type="segmr",
e=0,
continue_training=False):
continue_training=False,
modelname=mr_unet_cv_ctseg):
super().__init__( test_patients = test_patients,
val_patients = val_patients,
......@@ -52,7 +53,8 @@ class mrt_unet(BaseNetwork):
remote = remote,
e = e,
gt_type = gt_type,
continue_training = continue_training)
continue_training = continue_training,
modelname = modelname)
def get_data_lists(self, transfer=False): # splits data list in base__img_list, seg_img_list, thresh_img_list
full_list = os.listdir(self.img_path)
......@@ -217,7 +219,8 @@ def main():
remote=remote,
e=e,
continue_training=continue_training,
gt_type=gt_type)
gt_type=gt_type,
modelname=modelname)
my_nn.train_standard_unet()
......@@ -256,7 +259,8 @@ def main():
remote=remote,
e=e,
continue_training=continue_training,
gt_type=gt_type)
gt_type=gt_type,
modelname=modelname)
my_nn.train_standard_unet()
......
......@@ -33,7 +33,8 @@ class mrt_unet(BaseNetwork):
remote=False,
gt_type="segmr",
e=0,
continue_training=False):
continue_training=False,
modelname="mr_unet_cv_ctthresh"):
super().__init__( test_patients = test_patients,
val_patients = val_patients,
......@@ -52,7 +53,8 @@ class mrt_unet(BaseNetwork):
remote = remote,
e = e,
gt_type = gt_type,
continue_training = continue_training)
continue_training = continue_training,
modelname = modelname)
def get_data_lists(self, transfer=False): # splits data list in base__img_list, seg_img_list, thresh_img_list
full_list = os.listdir(self.img_path)
......@@ -122,7 +124,7 @@ def main():
continue_training = False
e = 0
gt_type = "ctthresh_gt"
modelname = "mr_unet_cv_ctseg"
modelname = "mr_unet_cv_ctthresh"
###############################
......@@ -217,7 +219,8 @@ def main():
remote=remote,
e=e,
continue_training=continue_training,
gt_type=gt_type)
gt_type=gt_type,
modelname=modelname)
my_nn.train_standard_unet()
......@@ -256,7 +259,8 @@ def main():
remote=remote,
e=e,
continue_training=continue_training,
gt_type=gt_type)
gt_type=gt_type,
modelname=modelname)
my_nn.train_standard_unet()
......
......@@ -105,7 +105,7 @@ class dbscan_clustering(BaseNetwork):
if not pat.startswith('P' + str(self.test_patients[0]).zfill(2)):
if not pat.startswith('P' + str(self.test_patients[1]).zfill(2)):
new_img_list.append(pat)
return img_list
return new_img_list
def loadimg_frompatientslices(self, patientslice):
image = np.load(self.img_path + str(patientslice[0][0]))
......
......@@ -104,7 +104,7 @@ class dbscan_clustering(BaseNetwork):
if not pat.startswith('P' + str(self.test_patients[0]).zfill(2)):
if not pat.startswith('P' + str(self.test_patients[1]).zfill(2)):
new_img_list.append(pat)
return img_list
return new_img_list
def loadimg_frompatientslices(self, patientslice):
image = np.load(self.img_path + str(patientslice[0][0]))
......
......@@ -104,7 +104,7 @@ class dbscan_clustering(BaseNetwork):
if not pat.startswith('P' + str(self.test_patients[0]).zfill(2)):
if not pat.startswith('P' + str(self.test_patients[1]).zfill(2)):
new_img_list.append(pat)
return img_list
return new_img_list
def loadimg_frompatientslices(self, patientslice):
image = np.load(self.img_path + str(patientslice[0][0]))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment