Skip to content
Snippets Groups Projects
Commit 6055426d authored by Jakob Ambsdorf's avatar Jakob Ambsdorf
Browse files

add support for arguments, make CUDA optional

parent 0f3ecae7
Branches
No related tags found
No related merge requests found
......@@ -3,7 +3,9 @@ test.py: Modified version of the original example.py file.
This file runs classification on the examples images.
'''
import argparse
import glob
import matplotlib.pyplot as plt
from matplotlib.pyplot import imread
import numpy as np
......@@ -25,20 +27,7 @@ which provides a theano+lasagne implementation.
'''
print(disclaimer)
# Configuration
network_name = 'SN64' # 'SN16', 'SN32' pr 'SN64'
display_images = False # Whether or not to show the images during inference
GPU_NR = 0 # Choose the device number of your GPU
# If you provide the original lasagne parameter file, it will be converted to a
# pytorch state_dict and saved as *.pth.
# In this repository, the converted parameters are already provided.
weights = True
# weights = ('/local/ball4916/dphil/SonoNet/SonoNet-weights/SonoNet{}.npz'
# .format(network_name[2:]))
# Other parameters
# Configuration parameters
crop_range = [(115, 734), (81, 874)] # [(top, bottom), (left, right)]
input_size = [224, 288]
image_path = './example_images/*.tiff'
......@@ -94,22 +83,36 @@ def prepare_inputs():
return input_list
def main():
def main(args):
# convert weights parameter for sononet
if args.weights == 'default':
weights = True
elif args.weights == 'none':
weights = False
elif args.weights == 'random':
weights = 0
else:
weights = args.weights
print('Loading network')
net = sononet.SonoNet(network_name, weights=weights)
net = sononet.SonoNet(args.network_name, weights=weights)
net.eval()
if args.cuda:
print('Moving to GPU:')
torch.cuda.device(GPU_NR)
torch.cuda.device(args.GPU_NR)
print(torch.cuda.get_device_name(torch.cuda.current_device()))
net.cuda()
print("\nPredictions using {}:".format(network_name))
print("\nPredictions using {}:".format(args.network_name))
input_list = prepare_inputs()
for image, file_name in zip(input_list, glob.glob(image_path)):
# Run inference
if args.cuda:
x = Variable(torch.from_numpy(image).cuda())
else:
x = Variable(torch.from_numpy(image))
outputs = net(x)
confidence, prediction = torch.max(outputs.data, 1)
......@@ -119,10 +122,22 @@ def main():
.format(label_names[prediction[0]],
confidence[0], true_label))
if display_images:
if args.display_images:
plt.imshow(np.squeeze(image), cmap='gray')
plt.show()
if __name__ == '__main__':
main()
# argparse
parser = argparse.ArgumentParser(description='SonoNet')
parser.add_argument('--cuda', action='store_true', default=False)
parser.add_argument('--network_name', type=str, default='SN64', help='SN16, SN32 or SN64')
parser.add_argument('--display_images', type=bool, default=False, help='Whether or not to show the images during inference')
parser.add_argument('--GPU_NR', type=int, default=0, help='Choose the device number of your GPU')
parser.add_argument('--weights', type=str, default='default', help='Select weight initialization. \
"default": Load weights from default *.pth weight file. "none": No weights are initialized. \
"random": Standard random weight initialization. Or: Pass the path to your own weight file')
args = parser.parse_args()
main(args)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment