Commit 919ea390 authored by Gijs Hendriksen's avatar Gijs Hendriksen

Create submission script

parent 43dbf36f
import argparse
import json
import os
import time
import numpy as np
import SimpleITK as sitk
from skimage import restoration
from scipy.ndimage import zoom
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from tensorflow.keras.models import load_model
tf.get_logger().setLevel('ERROR')
IMAGE_SHAPE = (128, 128, 256)
# Constrast-stretching parameters
CUTOFF_LOWER = 0
CUTOFF_UPPER = 2000
# Denoising parameters
DENOISE_METHOD = 'chambolle'
DENOISE_STRENGTH = 0.02
def contrast_stretching(image, p0, pk, q0=None, qk=None):
if q0 is None:
q0 = p0
if qk is None:
qk = pk
return np.clip(q0 + (qk - q0) * (image - p0) / (pk - p0), q0, qk)
def denoise(image, method, strength):
if method == 'chambolle':
return restoration.denoise_tv_chambolle(image, strength)
if method == 'bregman':
return restoration.denoise_tv_bregman(image, 1 / strength)
raise NotImplementedError('Denoise method `%s` is not implemented.')
def load_image(path):
# Load image
file_image = sitk.ReadImage(path)
# Build inverse affine from file
affine = np.linalg.inv(np.array(file_image.GetDirection()).reshape(3, 3))
# Extract axes swaps
swaps = np.argmax(np.abs(affine), axis=0).astype(np.int)
# Extract axes flips and correct them
flips = np.sum(affine.round(), axis=0).astype(np.int)
flips[2] *= -1
# Extract spacings from file and correct them
spacings = np.array(file_image.GetSpacing())[swaps]
# Read image data from file
image = sitk.GetArrayViewFromImage(file_image).astype(np.float32)
# Swap axes to standard orientation (zyx)
image = np.transpose(np.transpose(image, (2, 1, 0)), swaps)
# Flip axes to standard directions
image = image[tuple(slice(None, None, int(f)) for f in flips)]
# Calculate scaling ratios
original_shape = image.shape
scale_ratios_image = np.divide(IMAGE_SHAPE, original_shape)
# Resize image
image = zoom(image, scale_ratios_image)
return image, original_shape, spacings
def preprocess(image):
# Constrast stretching
image = contrast_stretching(image, CUTOFF_LOWER, CUTOFF_UPPER, 0, 1)
# Denoising
image = denoise(image, DENOISE_METHOD, DENOISE_STRENGTH)
# Maximum intensity projections
image_xy = image.max(axis=0).astype(np.float32)
image_xz = image.max(axis=1).astype(np.float32)
# Build slices
slices = np.stack([image_xy, image_xz], axis=-1).astype(np.float32)
# Convert to three channels
slices = np.repeat(np.expand_dims(slices, axis=-1), 3, -1)
# Normalize ImageNet
slices -= [0.485, 0.456, 0.406]
slices /= [0.229, 0.224, 0.225]
# Add empty dimension to fit network
slices = np.expand_dims(slices, 0)
return slices
def predict(model, scan_file, threshold=0.5):
if not os.path.exists(scan_file):
print('[!] File does not exist!')
return
elif not scan_file.endswith('.nii.gz'):
print('[!] File should be in .nii.gz format!')
return
print('[*] Loading image...', end=' ', flush=True)
start = time.time()
image, image_size, image_spacing = load_image(scan_file)
preprocessed = preprocess(image)
end = time.time()
print(f'done in {end - start:.02f}s')
result = model.predict(preprocessed).squeeze()
result[:, 1:] *= image_size * image_spacing
predictions = []
for label, (conf, z, y, x) in enumerate(result):
if conf > threshold:
predictions.append({
'X': x.item(),
'Y': y.item(),
'Z': z.item(),
'label': label + 1,
})
return predictions
def main():
parser = argparse.ArgumentParser(description='Prediction for the VerSe 2019 task')
parser.add_argument('model', help='The fully trained model file')
parser.add_argument('scan', help='The scan images we want to predict', nargs='+')
parser.add_argument('-t', '--threshold', help='Confidence threshold', type=float, default=0.5)
args = parser.parse_args()
if not os.path.exists(args.model):
print(f'[!] The model {args.model} does not exist!')
return
print('[*] Loading model...', end=' ', flush=True)
start = time.time()
model = load_model(args.model, compile=False)
end = time.time()
print(f'done in {end - start:.02f}s')
for scan in args.scan:
print(f'[*] Predicting {scan}')
prediction = predict(model, scan, args.threshold)
if prediction is not None:
destination = scan.replace('.nii.gz', '.json')
print(f'[*] Saving prediction to {destination}')
with open(destination, 'w') as _file:
json.dump(prediction, _file)
if __name__ == '__main__':
main()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment