Commit da94ca5a authored by sjjsmuel's avatar sjjsmuel

more augmentation

parent 24423894
......@@ -20,10 +20,10 @@ def flip(x, label, size):
return x, label
def color(x: tf.Tensor, label, size):
x = tf.image.random_hue(x, 0.08)
x = tf.image.random_saturation(x, 0.6, 1.6)
x = tf.image.random_brightness(x, 0.05)
x = tf.image.random_contrast(x, 0.7, 1.3)
x = tf.image.random_hue(x, 0.1)
x = tf.image.random_saturation(x, 0.5, 1.7)
x = tf.image.random_brightness(x, 0.2)
x = tf.image.random_contrast(x, 0.5, 1.5)
return x, label
def rotate(x: tf.Tensor, label, size):
......@@ -31,7 +31,7 @@ def rotate(x: tf.Tensor, label, size):
def zoom(x: tf.Tensor, label, size):
# Generate 20 crop settings, ranging from a 1% to 30% crop.
scales = list(np.arange(0.7, 1.0, 0.01))
scales = list(np.arange(0.6, 1.0, 0.03))
boxes = np.zeros((len(scales), 4))
for i, scale in enumerate(scales):
......@@ -48,8 +48,8 @@ def zoom(x: tf.Tensor, label, size):
choice = tf.random.uniform(shape=[], minval=0., maxval=1., dtype=tf.float32)
# Only apply cropping 70% of the time
return tf.cond(choice < 0.7, lambda: x, lambda: random_crop(x)), label
# Only apply cropping 80% of the time
return tf.cond(choice < 0.8, lambda: x, lambda: random_crop(x)), label
class DataLoader(object):
......@@ -63,7 +63,7 @@ class DataLoader(object):
self.IMG_HEIGHT = img_height
self.CHANNELS = channels
self.AUGMENT = augment
self.AUGMENTATIONS = [flip, color, zoom, rotate]
self.AUGMENTATIONS = [flip, color, zoom]
self.classes = [item.name for item in data_path.glob('*') if item.name != '.DS_Store']
self.n_classes = len(self.classes)
......@@ -180,14 +180,14 @@ class DataLoader(object):
self.dataset_1 = self.dataset_1.cache()
if self.AUGMENT:
for f in self.AUGMENTATIONS:
self.dataset_1 = self.dataset_1.map(lambda x,y: tf.cond(tf.random.uniform([], 0, 1) > 0.6, lambda: f(x,y, (self.IMG_WIDTH, self.IMG_HEIGHT)), lambda: (x,y)),
self.dataset_1 = self.dataset_1.map(lambda x,y: tf.cond(tf.random.uniform([], 0, 1) > 0.1, lambda: f(x,y, (self.IMG_WIDTH, self.IMG_HEIGHT)), lambda: (x,y)),
num_parallel_calls=self.NR_THREADS)
self.dataset_1 = self.dataset_1.map(lambda x,y: (tf.clip_by_value(x, 0, 1),y))
if self.dataset_2:
self.dataset_2 = self.dataset_2.map(self.process_path, num_parallel_calls=self.NR_THREADS)
self.dataset_2 = self.dataset_2.cache()
for f in self.AUGMENTATIONS:
self.dataset_2 = self.dataset_2.map(lambda x, y: tf.cond(tf.random.uniform([], 0, 1) > 0.6, lambda: f(x, y, (self.IMG_WIDTH, self.IMG_HEIGHT)),
self.dataset_2 = self.dataset_2.map(lambda x, y: tf.cond(tf.random.uniform([], 0, 1) > 0.1, lambda: f(x, y, (self.IMG_WIDTH, self.IMG_HEIGHT)),
lambda: (x, y)), num_parallel_calls=self.NR_THREADS)
self.dataset_2 = self.dataset_2.map(lambda x, y: (tf.clip_by_value(x, 0, 1), y))
......
......@@ -2,7 +2,7 @@ from tensorflow.keras import regularizers
from network.NetworkBase import NetworkBase
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D, Dropout, Flatten, AveragePooling2D
from tensorflow.keras.layers import Input, Dense, GlobalAveragePooling2D, Dropout, Flatten, AveragePooling2D, GaussianNoise
from tensorflow.keras.applications.vgg16 import VGG16
import pathlib
......@@ -25,19 +25,20 @@ class VGG_16(NetworkBase):
print('Given weigths-file for base ResNet not found. (Missing: {})'.format(self.WEIGHTS_PATH))
input_tensor = Input(shape=(self.IMG_WIDTH, self.IMG_HEIGHT, self.CHANNELS))
input_tensor = GaussianNoise(stddev=0.01)(input_tensor)
base_model = VGG16(weights=weights, input_tensor=input_tensor, include_top=False)
if shouldSave:
base_model.save('input/vgg_base_model.h5')
x = GlobalAveragePooling2D()(base_model.output)
x = Dense(256, activation='relu')(x)
x = Flatten()(base_model.output)
x = Dense(1024, activation='relu')(x)
x = Dense(256, activation='relu')(x)
out = Dense(self.NUM_CLASSES, activation='softmax', name='probs')(x)
model = Model(base_model.input, out)
for layer in model.layers[:15]:
for layer in model.layers[:16]:
layer.trainable = False
return model
import pathlib
import numpy as np
import matplotlib.pylab as plt
from classifier.DataLoader import DataLoader
from helpers.DataLoader import DataLoader
'''
Tool to visually inspect the (augmented) output of the DataLoader
train_dir = pathlib.Path('data/training_data_one')
This class should be run in a local environment and not in a docker container.
The purpose is to assure the plt.show will actually show the created images.
'''
class_index_map = {'caries': 0, 'no_caries': 1}
index_class_map = {}
for element in class_index_map:
index_class_map[class_index_map[element]] = element
train_dir = pathlib.Path('input/training_data_mini')
batch_size = 5
img_width = 1200
img_height = 1200
img_width = 500
img_height = 500
channels = 3
batch_count_dataset1 = 5
batch_count_dataset1 = 9
train_loader = DataLoader(data_path=train_dir,
batch_size=batch_size,
batch_count_dataset1=batch_count_dataset1,
should_size_dataset1=batch_count_dataset1,
img_width=img_width,
img_height=img_height,
channels=channels,
......@@ -34,14 +46,17 @@ def plot_images(dataset, n_images, samples_per_image):
plt.imshow(output)
plt.show()
#plot_images(train_dataset, 1, dataset1_size)
count = 0
for (img, label) in train_dataset:
for batch_img in img:
for batch_img, batch_label in train_dataset.as_numpy_iterator():
for i in range(len(batch_label)):
img = batch_img[i]
label = batch_label[i]
count += 1
print(batch_img.shape)
plt.imshow(batch_img)
print(img.shape)
label = index_class_map[np.argmax(label)]
plt.title(label)
plt.imshow(img)
plt.show()
print("new batch")
print("{} image(s) processed".format(count))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment