import numpy as np
import matplotlib.pyplot as plt


def input_main_menu():
    print('') 
    print('What would you like to do?')
    print('Type 1, 2, or 3 followed by enter:')
    print('1. Compute and display BMI from weight and height.')
    print('2. Compute and display a normal BMI range from height.')
    print('3. Quit')
    i = 0
    while i==0:
        s = input('Choice: ')
        try:
            i = int(s)
        except:
            pass
        if not i in [1,2,3]:
            print('You typed ' + s)
            print('You should type 1, 2 or 3 followed by enter.')
            i = 0
    return(i)


def input_measure(measure_type, units, m_min, m_max):
    print('') 
    m = 0
    while m==0:
        s = input('Type ' + measure_type + ' followed by enter: ') 
        try:
            m = float(s)
        except:
            pass
        if m<m_min or m>m_max:
            print('You typed ' + s)
            print('Type a valid ' + measure_type + ' in ' + units +'.')
            m = 0
    return(m)


def make_plot(height, weight=None):
    h = np.linspace(1.3, 2.2)
    plt.plot(h, 30*h**2, 'r')
    plt.plot(h, 25*h**2, 'm')
    plt.plot(h, 18.5*h**2, 'g')
    if weight is None:
        wmin = 18.5*height**2
        wmax = 25*height**2
        plt.plot([height, height], [wmin, wmax], 'k')
        d = f'Normal range [{wmin:.2f} {wmax:.2f}]'
    else:
        plt.plot(height, weight, 'ko')
        d = f'BMI {weight/(height**2):.2f}'
    plt.legend(['Obese to overweight', 'Overweight to normal',
                'Normal to underweight', d])
    plt.xlabel('Height')
    plt.ylabel('Weight')
    plt.title('BMI plot')
    plt.show()
    

# main program
print('*************************')
print('Welcome to BMI calculator')
choice = 0
while choice<3:
    choice = input_main_menu()
    if choice==1:
        h = input_measure('height', 'meters', 1, 2.5)
        w = input_measure('weight', 'kilograms', 40, 300)
        make_plot(h, w)
    elif choice==2:
        h = input_measure('height', 'meters', 1, 2.5)
        make_plot(h)        
print('Bye!')