Commit 908319e8 authored by Gijs Hendriksen's avatar Gijs Hendriksen

Include training loop for different models

parent 3e041ac5
.ipynb_checkpoints/
**/__pycache__/
.idea/
*.nii.gz
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -19,8 +19,6 @@ ipywidgets = "^7.5.1"
scikit-image = "^0.17.2"
tensorflow = "^2.2.0"
tensorflow-gpu = "^2.2.0"
segmentation-models = "^1.0.1"
category_encoders = "^2.2.2"
[tool.poetry.dev-dependencies]
......
This diff is collapsed.
......@@ -67,7 +67,7 @@ class Augmentation(ABC):
else:
raise ValueError('%s is not a valid applies_to value.' % applies_to)
@tf.function
# @tf.function
def __call__(self, x_batch, y_batch):
return self.apply_with_chance(x_batch, y_batch)
......
......@@ -11,30 +11,40 @@ def border_mode_string_to_func(border_mode, fill=None):
return border_zero, None
elif border_mode == 'constant' or border_mode == 'fill':
if fill == [0, 0, 0]:
return border_zero, None
return border_constant, tf.convert_to_tensor(fill)
if fill is None:
raise ValueError('Fill cannot not be None when using border mode %s.' % border_mode)
return border_constant, tf.convert_to_tensor(fill, dtype=tf.float32)
raise ValueError('Unknown border mode: %s.' % border_mode)
def border_reflect(indices, image_size):
def reflect(indices, image_size):
return -tf.abs(image_size - tf.abs(indices) - 1) + image_size - 1
def border_repeat(indices, image_size):
def border_reflect(indices_x, indices_y, image_size_x, image_size_y):
return reflect(indices_x, image_size_x), reflect(indices_y, image_size_y)
def repeat(indices, image_size):
return tf.clip_by_value(indices, 0, image_size - 1)
def border_constant(indices, image_size):
cond = tf.math.logical_and(tf.math.less(indices, image_size), tf.math.greater_equal(indices, 0))
cond = tf.math.reduce_all(cond, axis=2, keepdims=True)
def border_repeat(indices_x, indices_y, image_size_x, image_size_y):
return repeat(indices_x, image_size_x), repeat(indices_y, image_size_y)
def border_constant(indices_x, indices_y, image_size_x, image_size_y):
cond_x = tf.math.logical_and(tf.math.less(indices_x, image_size_x), tf.math.greater_equal(indices_x, 0))
cond_y = tf.math.logical_and(tf.math.less(indices_y, image_size_y), tf.math.greater_equal(indices_y, 0))
return tf.where(cond, indices, 0)
cond = tf.math.reduce_all(tf.concat([cond_x, cond_y], axis=2), axis=2, keepdims=True)
return tf.where(cond, indices_x, 0), tf.where(cond, indices_y, 0)
def border_zero(indices, image_size):
cond = tf.math.logical_and(tf.math.less(indices, image_size), tf.math.greater_equal(indices, 0))
cond = tf.math.reduce_all(cond, axis=2, keepdims=True)
def border_zero(indices_x, indices_y, image_size_x, image_size_y):
cond_x = tf.math.logical_and(tf.math.less(indices_x, image_size_x), tf.math.greater_equal(indices_x, 0))
cond_y = tf.math.logical_and(tf.math.less(indices_y, image_size_y), tf.math.greater_equal(indices_y, 0))
cond = tf.math.reduce_all(tf.concat([cond_x, cond_y], axis=2), axis=2, keepdims=True)
return tf.where(cond, indices, -image_size * image_size)
return tf.where(cond, indices_x, -image_size * image_size), tf.where(cond, indices_y, -image_size * image_size)
def guarantee_rank_one(tensor):
if tf.rank(tensor) == 0:
......@@ -153,10 +163,9 @@ def shift_rotate_shear_zoom(batch, x_shifts=None, y_shifts=None, angles=None, x_
# Apply border function
indices = tf.cast(tf.math.round(indices), tf.int32)
indices_x, indices_y = indices[:, :, :1], indices[:, :, 1:]
indices_x = border_func(indices_x, image_size_x)
indices_y = border_func(indices_y, image_size_y)
indices_x, indices_y = border_func(indices_x, indices_y, image_size_x, image_size_y)
indices = tf.concat([indices_x, indices_y], axis=2)
# Calculate flat indices
indices = tf.reduce_sum(indices * [1, image_size_x], axis=2)
indices = tf.reshape(indices + image_total_size * tf.expand_dims(tf.range(batch_size), axis=1), (total_size,))
......@@ -167,7 +176,7 @@ def shift_rotate_shear_zoom(batch, x_shifts=None, y_shifts=None, angles=None, x_
filler = tf.tile([fill], [batch_size, 1])
filler -= batch[:, 0]
filler = tf.expand_dims(filler, axis=1)
filler = tf.concat([filler, tf.zeros((batch_size, image_total_size - 1, channel_size), tf.int32)], axis=1)
filler = tf.concat([filler, tf.zeros((batch_size, image_total_size - 1, channel_size), tf.float32)], axis=1)
batch += filler
# Apply affines to the batch
......
......@@ -65,12 +65,13 @@ class HSV_Hue(Augmentation):
self.max_delta = max_delta
def apply_x(self, x_batch, y_batch):
batch_size = tf.shape(x_batch)[0]
x_shape = tf.shape(x_batch)
target_shape = tf.concat([x_shape[:1], tf.ones(tf.size(x_shape) - 1, dtype=tf.int32)], axis=0)
deltas = tf.random.uniform((batch_size, 1, 1, 1), -self.max_delta, self.max_delta, dtype=tf.float32)
zeros = tf.zeros((batch_size, 1, 1, 2))
deltas = tf.random.uniform(target_shape, -self.max_delta, self.max_delta, dtype=tf.float32)
zeros = tf.zeros(target_shape, dtype=tf.float32)
deltas = tf.concat([deltas, zeros], axis=3)
deltas = tf.concat([deltas, zeros, zeros], axis=-1)
x_batch = tf.clip_by_value(tf.math.add(x_batch, deltas), 0, 1)
......@@ -82,12 +83,13 @@ class HSV_Saturation(Augmentation):
self.max_delta = max_delta
def apply_x(self, x_batch, y_batch):
batch_size = tf.shape(x_batch)[0]
x_shape = tf.shape(x_batch)
target_shape = tf.concat([x_shape[:1], tf.ones(tf.size(x_shape) - 1, dtype=tf.int32)], axis=0)
deltas = tf.random.uniform((batch_size, 1, 1, 1), -self.max_delta, self.max_delta, dtype=tf.float32)
zeros = tf.zeros((batch_size, 1, 1, 1))
deltas = tf.random.uniform(target_shape, -self.max_delta, self.max_delta, dtype=tf.float32)
zeros = tf.zeros(target_shape, dtype=tf.float32)
deltas = tf.concat([zeros, deltas, zeros], axis=3)
deltas = tf.concat([zeros, deltas, zeros], axis=-1)
x_batch = tf.clip_by_value(tf.math.add(x_batch, deltas), 0, 1)
......@@ -99,12 +101,13 @@ class HSV_Value(Augmentation):
self.max_delta = max_delta
def apply_x(self, x_batch, y_batch):
batch_size = tf.shape(x_batch)[0]
x_shape = tf.shape(x_batch)
target_shape = tf.concat([x_shape[:1], tf.ones(tf.size(x_shape) - 1, dtype=tf.int32)], axis=0)
deltas = tf.random.uniform((batch_size, 1, 1, 1), -self.max_delta, self.max_delta, dtype=tf.float32)
zeros = tf.zeros((batch_size, 1, 1, 2))
deltas = tf.random.uniform(target_shape, -self.max_delta, self.max_delta, dtype=tf.float32)
zeros = tf.zeros(target_shape, dtype=tf.float32)
deltas = tf.concat([zeros, deltas], axis=3)
deltas = tf.concat([zeros, zeros, deltas], axis=-1)
x_batch = tf.clip_by_value(tf.math.add(x_batch, deltas), 0, 1)
......@@ -116,12 +119,14 @@ class HSV_Hue_Saturation(Augmentation):
self.max_delta = max_delta
def apply_x(self, x_batch, y_batch):
batch_size = tf.shape(x_batch)[0]
x_shape = tf.shape(x_batch)
target_shape = tf.concat([x_shape[:1], tf.ones(tf.size(x_shape) - 1, dtype=tf.int32)], axis=0)
deltas = tf.random.uniform((batch_size, 1, 1, 2), -self.max_delta, self.max_delta, dtype=tf.float32)
zeros = tf.zeros((batch_size, 1, 1, 1))
deltas_a = tf.random.uniform(target_shape, -self.max_delta, self.max_delta, dtype=tf.float32)
deltas_b = tf.random.uniform(target_shape, -self.max_delta, self.max_delta, dtype=tf.float32)
zeros = tf.zeros(target_shape, dtype=tf.float32)
deltas = tf.concat([deltas, zeros], axis=3)
deltas = tf.concat([deltas_a, deltas_b, zeros], axis=-1)
x_batch = tf.clip_by_value(tf.math.add(x_batch, deltas), 0, 1)
......@@ -133,12 +138,14 @@ class HSV_Saturation_Value(Augmentation):
self.max_delta = max_delta
def apply_x(self, x_batch, y_batch):
batch_size = tf.shape(x_batch)[0]
x_shape = tf.shape(x_batch)
target_shape = tf.concat([x_shape[:1], tf.ones(tf.size(x_shape) - 1, dtype=tf.int32)], axis=0)
deltas = tf.random.uniform((batch_size, 1, 1, 2), -self.max_delta, self.max_delta, dtype=tf.float32)
zeros = tf.zeros((batch_size, 1, 1, 1))
deltas_a = tf.random.uniform(target_shape, -self.max_delta, self.max_delta, dtype=tf.float32)
deltas_b = tf.random.uniform(target_shape, -self.max_delta, self.max_delta, dtype=tf.float32)
zeros = tf.zeros(target_shape, dtype=tf.float32)
deltas = tf.concat([zeros, deltas], axis=3)
deltas = tf.concat([zeros, deltas_a, deltas_b], axis=-1)
x_batch = tf.clip_by_value(tf.math.add(x_batch, deltas), 0, 1)
......@@ -150,8 +157,15 @@ class HSV_Hue_Saturation_Value(Augmentation):
self.max_delta = max_delta
def apply_x(self, x_batch, y_batch):
deltas = tf.random.uniform((tf.shape(x_batch)[0], 1, 1, 3), -self.max_delta, self.max_delta, dtype=tf.float32)
x_shape = tf.shape(x_batch)
target_shape = tf.concat([x_shape[:1], tf.ones(tf.size(x_shape) - 1, dtype=tf.int32)], axis=0)
deltas_a = tf.random.uniform(target_shape, -self.max_delta, self.max_delta, dtype=tf.float32)
deltas_b = tf.random.uniform(target_shape, -self.max_delta, self.max_delta, dtype=tf.float32)
deltas_c = tf.random.uniform(target_shape, -self.max_delta, self.max_delta, dtype=tf.float32)
deltas = tf.concat([deltas_a, deltas_b, deltas_c], axis=-1)
x_batch = tf.clip_by_value(tf.math.add(x_batch, deltas), 0, 1)
return x_batch, y_batch
return x_batch, y_batch
\ No newline at end of file
......@@ -18,8 +18,9 @@ class Keras(tf.keras.callbacks.Callback):
self.best_epoch = 0
self.train_losses = []
self.validation_losses = []
self.metric_values = []
self.best_metric_value = None
self.best_validation_metric_value = None
self.on_predict_batch_end = self.on_non_train_batch_end
self.on_predict_begin = self.on_non_train_begin
......@@ -43,9 +44,13 @@ class Keras(tf.keras.callbacks.Callback):
self.epoch_bar = tqdm(total=self.params.get('epochs'), unit='epoch')
if self.trainer.data.has_val and self.best_metric_value is None:
self.best_metric_value = self.trainer.validate()
self.metric_values.append(self.best_metric_value)
if self.trainer.data.has_val and self.best_validation_metric_value is None:
loss, metric_results = self.trainer.validate()
self.validation_losses.append(loss)
self.best_validation_metric_value = metric_results[0]
self.metric_values.append(metric_results)
def on_train_end(self, logs=None):
self.epoch_bar.close()
......@@ -63,11 +68,12 @@ class Keras(tf.keras.callbacks.Callback):
self.train_epochs += 1
if self.trainer.data.has_val:
metric_value = self.trainer.validate()
self.metric_values.append(metric_value)
loss, metric_results = self.trainer.validate()
self.validation_losses.append(loss)
self.metric_values.append(metric_results)
if metric_value > self.best_metric_value:
self.best_metric_value = metric_value
if metric_results[0] > self.best_validation_metric_value:
self.best_validation_metric_value = metric_results[0]
self.plot(offset=0)
......@@ -89,26 +95,36 @@ class Keras(tf.keras.callbacks.Callback):
if self.trainer.data.has_val:
plt.subplot(1, 2, 1)
plt.plot(xs[:len(self.train_losses)], self.train_losses, label='Batch loss', alpha=0.6)
# Plot training loss
plt.plot(xs[:len(self.train_losses)], self.train_losses, label='Training loss', alpha=0.6)
plt.title('Loss')
plt.xlabel('Epoch')
plt.xlim(0, self.train_epochs + offset)
if self.trainer.data.has_val:
mv_max = np.max(self.metric_values)
mv_max_epoch = np.argmax(self.metric_values)
# Plot validation loss
plt.plot(np.arange(self.train_epochs + 1), self.validation_losses, label='Validation loss')
plt.legend()
# Plot validation metrics
metric_values_np = np.array(self.metric_values, dtype=np.float32)
mv_max = np.max(metric_values_np[:, 0])
mv_max_epoch = np.argmax(metric_values_np[:, 0])
text_x = -80 if mv_max_epoch > 0 else mv_max_epoch + 10
text_y = 10 if mv_max < 0.7 else -20
# self.ax_right.redraw_in_frame()
plt.subplot(1, 2, 2)
plt.plot(np.arange(self.train_epochs + 1), self.metric_values, '-o', label=self.trainer.metric.label, linewidth=1)
for i, metric in enumerate(self.trainer.all_metrics):
plt.plot(np.arange(self.train_epochs + 1), metric_values_np[:, i], '-o', label=metric.label, linewidth=1)
plt.annotate('Best = %.4f' % mv_max, xy=(mv_max_epoch, mv_max), xytext=(text_x, text_y), textcoords='offset pixels', arrowprops=dict(arrowstyle='fancy'))
plt.title('Validation Score')
plt.xlabel('Epoch')
plt.ylabel(self.trainer.metric.label)
plt.xlim(0, self.train_epochs + offset)
plt.ylim(0, 1)
plt.legend()
self.plot_handle.update(fig)
plt.close()
......
......@@ -5,18 +5,25 @@ from .trainer import Trainer
from .classification_backbones import CLASSIFICATION_BACKBONES
class BackboneTrainer(Trainer):
def __init__(self, *args, backbone='efficientnet-b0', metric='accuracy', loss='categorical_crossentropy',
def __init__(self, *args, backbone='efficientnet-b0', validation_metric='accuracy', additional_metrics=[], loss='categorical_crossentropy',
weights=None, dropout_rate='default', learning_rate='default', **kwargs):
self.backbone = CLASSIFICATION_BACKBONES[backbone.lower()] if isinstance(backbone, str) else backbone
self.data = args[0] if isinstance(*args, wfs.data.Container) else wfs.data.auto(*args, **kwargs)
self.metric = metric if isinstance(metric, wfs.metric.Metric) else wfs.metric.Metric(metric)
self.validation_metric = validation_metric if isinstance(validation_metric, wfs.metric.Metric) else wfs.metric.Metric(validation_metric)
self.additional_metrics = []
for additional_metric in additional_metrics:
self.additional_metrics.append(additional_metric if isinstance(additional_metric, wfs.metric.Metric) else wfs.metric.Metric(additional_metric))
self.all_metrics = [self.validation_metric] + self.additional_metrics
self._init_model(weights, dropout_rate, **kwargs)
self.optimizer = self.backbone.build_optimizer(learning_rate, self.data.train, **kwargs)
self.model.compile(self.optimizer, loss=loss)
self.loss = self.model.loss
self.logger = wfs.logger.Keras(self, **kwargs)
def train(self, epochs=1):
......@@ -27,13 +34,21 @@ class BackboneTrainer(Trainer):
y_pred = self.predict(self.data.val)
y_true = self.data.y_decode(self.data.val.y)
return self.metric(y_true, y_pred)
metric_results = [self.validation_metric(y_true, y_pred)]
for additional_metric in self.additional_metrics:
metric_results.append(additional_metric(y_true, y_pred))
return self.loss(y_true, y_pred), metric_results
def evaluate_test(self):
y_pred = self.predict(self.data.test)
y_true = self.data.y_decode(self.data.test.y)
return self.metric(y_true, y_pred)
metric_results = [self.validation_metric(y_true, y_pred)]
for additional_metric in self.additional_metrics:
metric_results.append(additional_metric(y_true, y_pred))
return self.loss(y_true, y_pred), metric_results
def predict(self, data):
predictions = self.model.predict(data.out, verbose=0, callbacks=[self.logger.attach(data)])
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment