Commit 5f91dd79 authored by Gijs Hendriksen's avatar Gijs Hendriksen

Add plots to submission script

parent 0f9ec86f
......@@ -2,7 +2,9 @@ import argparse
import json
import os
import time
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import SimpleITK as sitk
from skimage import restoration
......@@ -28,6 +30,69 @@ CUTOFF_UPPER = 2000
DENOISE_METHOD = 'chambolle'
# C1 - L6
ANNOTATIONS = [f'C{i}' for i in range(1, 8)] + [f'Th{i}' for i in range(1, 13)] + [f'L{i}' for i in range(1, 7)]
def plot_prediction(image, predictions, destination, threshold=0.5):
# Style parameters
kwargs_image = dict(, aspect='auto')
kwargs_label_pred = dict(c='springgreen', marker='x', s=12)
kwargs_annotation_pred = dict(c='green', fontsize=12, fontweight='bold')
annotations = []
# Select labels with confidence above the threshold
label_xyz_pred = []
for i, label in enumerate(predictions):
if label[0] > threshold:
label_xyz_pred = np.array(label_xyz_pred)
fig, axes = plt.subplots(2, 1, sharex='all', sharey='all', gridspec_kw=dict(wspace=0))
# Plot X-Y slice of data
axes[0].imshow(image.max(axis=0), **kwargs_image)
for (x, y, txt) in zip(label_xyz_pred[:, 2], label_xyz_pred[:, 1], annotations):
axes[0].scatter(x, y, **kwargs_label_pred)
axes[0].text(x, y - 10, txt, **kwargs_annotation_pred)
axes[0].set_title('X/Y maximum intensity projection')
# Plot X-Z slice of data
axes[1].imshow(image.max(axis=1), **kwargs_image)
for (x, z, txt) in zip(label_xyz_pred[:, 2], label_xyz_pred[:, 0], annotations):
axes[1].scatter(x, z, **kwargs_label_pred)
axes[1].text(x, z - 10, txt, **kwargs_annotation_pred)
axes[1].set_title('X/Z maximum intensity projection')
print(f'[*] Saving prediction plot to {destination}')
def prediction_to_json(prediction, destination, threshold=0.5):
prediction_json = []
for label, (conf, z, y, x) in enumerate(prediction):
if conf > threshold:
'X': x.item(),
'Y': y.item(),
'Z': z.item(),
'label': label + 1,
print(f'[*] Saving prediction to {destination}')
with open(destination, 'w') as _file:
json.dump(prediction_json, _file)
def contrast_stretching(image, p0, pk, q0=None, qk=None):
if q0 is None:
......@@ -75,17 +140,16 @@ def load_image(path):
# Flip axes to standard directions
image = image[tuple(slice(None, None, int(f)) for f in flips)]
return image, spacings
def preprocess(image):
# Calculate scaling ratios
original_shape = image.shape
scale_ratios_image = np.divide(IMAGE_SHAPE, original_shape)
scale_ratios_image = np.divide(IMAGE_SHAPE, image.shape)
# Resize image
image = zoom(image, scale_ratios_image)
return image, original_shape, spacings
def preprocess(image):
# Constrast stretching
image = contrast_stretching(image, CUTOFF_LOWER, CUTOFF_UPPER, 0, 1)
......@@ -112,38 +176,13 @@ def preprocess(image):
return slices
def predict(model, scan_file, threshold=0.5):
if not os.path.exists(scan_file):
print('[!] File does not exist!')
elif not scan_file.endswith('.nii.gz'):
print('[!] File should be in .nii.gz format!')
print('[*] Loading image...', end=' ', flush=True)
start = time.time()
image, image_size, image_spacing = load_image(scan_file)
preprocessed = preprocess(image)
end = time.time()
print(f'done in {end - start:.02f}s')
result = model.predict(preprocessed).squeeze()
result[:, 1:] *= image_size * image_spacing
def predict(model, image):
preprocessed_image = preprocess(image)
predictions = []
prediction = model.predict(preprocessed_image).squeeze()
prediction[:, 1:] *= image.shape
for label, (conf, z, y, x) in enumerate(result):
if conf > threshold:
'X': x.item(),
'Y': y.item(),
'Z': z.item(),
'label': label + 1,
return predictions
return prediction
def main():
......@@ -151,6 +190,7 @@ def main():
parser.add_argument('model', help='The fully trained model file')
parser.add_argument('scan', help='The scan images we want to predict', nargs='+')
parser.add_argument('-t', '--threshold', help='Confidence threshold', type=float, default=0.5)
parser.add_argument('-p', '--plot', help='Whether to also plot the prediction on the image', action='store_true')
args = parser.parse_args()
......@@ -166,18 +206,28 @@ def main():
print(f'done in {end - start:.02f}s')
for scan in args.scan:
print(f'[*] Predicting {scan}')
for scan_file in args.scan:
if not os.path.exists(scan_file):
print(f'[!] File {scan_file} does not exist!')
elif not scan_file.endswith('.nii.gz'):
print(f'[!] File {scan_file} should be in .nii.gz format!')
print(f'[*] Predicting {scan_file}')
image, spacing = load_image(scan_file)
prediction = predict(model, scan, args.threshold)
# plot_prediction(image, np.array([[1, 0, 0, 0]]), scan_file.replace('.nii.gz', '.png'))
# continue
if prediction is not None:
destination = scan.replace('.nii.gz', '.json')
prediction = predict(model, image)
print(f'[*] Saving prediction to {destination}')
if args.plot:
plot_prediction(image, prediction, scan_file.replace('.nii.gz', '.png'), threshold=args.threshold)
with open(destination, 'w') as _file:
json.dump(prediction, _file)
prediction[:, 1:] *= spacing
prediction_to_json(prediction, scan_file.replace('.nii.gz', '.json'), threshold=args.threshold)
if __name__ == '__main__':
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment