Skip to content
Snippets Groups Projects
Commit 7e06a89e authored by s183919's avatar s183919
Browse files
parents f87485b8 c1176491
No related branches found
No related tags found
No related merge requests found
# -*- coding: utf-8 -*-
# Text Color Chooser (Logistic Regression) - Prediction Only: Main
import pygame
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import pyplot
from mpl_toolkits.mplot3d import Axes3D
from random import randint
from strToFile import stringToFile
from random import seed
from sklearn.linear_model import LogisticRegression
# init
pygame.init()
# resolution
width = 800
height = 600
# make the pygame window
display = pygame.display.set_mode((width, height))
# set framerate
clock = pygame.time.Clock()
clock.tick(60)
# keyboard pictures and predictor
leftPic = pygame.image.load('left.png')
rightPic = pygame.image.load('right.png')
upPic = pygame.image.load('up.png')
predPic = pygame.image.load('predictor.png')
iconPic = pygame.image.load('icon.png')
downPic = pygame.image.load('down.png')
# set window icon
pygame.display.set_icon(iconPic)
# title of app
pygame.display.set_caption('Text Color Chooser (Logistic Regression) v. 1.0 - Prediction - Training Data Sample Size: 384')
# set fonts
smallfont = pygame.font.SysFont('comicsansms', 25)
mediumfont = pygame.font.SysFont('comicsansms', 50)
largefont = pygame.font.SysFont('comicsansms', 80)
# matrixes of color, index 0 should be deleted
points = np.reshape(np.array([[0, 0, 0],[255,255,255]]), (-1, 3))
white = np.reshape(np.array([0, 0, 0]), (-1, 3))
black = np.reshape(np.array([255, 255, 255]), (-1, 3))
# initial arrays for correct and total predictions for vstack
corcnt = np.array([1])
incnt = np.array([1])
# ////////////////////////////////////////////////////////////
# write on screen
# ////////////////////////////////////////////////////////////
def text(msg,color,font,x,y):
screen_text = font.render(msg, True, color)
display.blit(screen_text, [x, y])
# program loop
number = 0
while True:
# random color
color = np.array([randint(0,255), randint(0,255), randint(0,255)])
number += 1
# color background
display.fill((123, 123, 123))
# display information
text('Total Predictions: '+str(number-2), (255,255,255), smallfont, width / 1.5 - -135 / 2, 210)
# draw color boxes
pygame.draw.rect(display, color, [width / 3 - 200 / 2, 300, 200, 200])
pygame.draw.rect(display, color, [width / 1.5 - 200 / 2, 300, 200, 200])
pygame.draw.rect(display, (37,37,37), [width / 3 - 220 / 2 - 4, 75, 220, 40])
# blit pics of arrows
display.blit(leftPic, (width / 3 - 63 / 2, 512))
display.blit(rightPic, (width / 1.5 - 63 / 2, 512))
display.blit(downPic, ((width / 3 - 63 / 2), 125))
# text
text('Text', (255,255,255), largefont, width / 3 - 200 / 2 + 20, 330)
text('Text', (0,0,0), largefont, width / 1.5 - 200 / 2 + 20, 330)
text('Save Prediction Data', (255,255,255), smallfont, width / 3 - 216 / 2, 75)
text('RGB Color Value: ',(255,255,255), smallfont, 5, 200)
text(str(color), (255,255,255), smallfont, 5, 230)
text('press tab to show data', (255,255,255), smallfont, width - 275, 10)
# convert RGB values to between 0 and 1
z = color/255
# training data (sample size: 384)
X = [[0.729411764705882, 0.647058823529412, 0.945098039215686], [0.482352941176471, 0.443137254901961, 0.0862745098039216], [0.643137254901961, 0.525490196078431, 0.313725490196078], [0.462745098039216, 0.933333333333333, 0.501960784313725], [0.588235294117647, 0.168627450980392, 0.768627450980392], [0.854901960784314, 0.635294117647059, 0.462745098039216], [0.580392156862745, 0.588235294117647, 0.305882352941176], [0.0666666666666667, 0.717647058823529, 0.0901960784313725], [0.133333333333333, 0.627450980392157, 0.0862745098039216], [0.16078431372549, 0.415686274509804, 0.125490196078431], [0.219607843137255, 0.52156862745098, 0.431372549019608], [0.556862745098039, 0.525490196078431, 0.611764705882353], [0.227450980392157, 0.741176470588235, 0.286274509803922], [0.176470588235294, 0.376470588235294, 0.466666666666667], [0.752941176470588, 0.0705882352941176, 0.541176470588235], [0.572549019607843, 0.227450980392157, 0.588235294117647], [0.862745098039216, 0.549019607843137, 0.654901960784314], [0.945098039215686, 0.490196078431373, 0.274509803921569], [0.749019607843137, 0.556862745098039, 0.780392156862745], [0.654901960784314, 0.937254901960784, 0.796078431372549], [0.890196078431372, 0.525490196078431, 0.101960784313725], [0.768627450980392, 0.670588235294118, 0.2], [0.905882352941176, 0.443137254901961, 0.654901960784314], [0.43921568627451, 0.56078431372549, 0.0509803921568627], [0.235294117647059, 0.552941176470588, 0.717647058823529], [0.105882352941176, 0.858823529411765, 0.788235294117647], [0.341176470588235, 0.101960784313725, 0.0823529411764706], [0.780392156862745, 0.745098039215686, 0.180392156862745], [0.741176470588235, 0.392156862745098, 0.976470588235294], [0.517647058823529, 0.258823529411765, 1], [0.854901960784314, 0.580392156862745, 0.101960784313725], [0.858823529411765, 0.211764705882353, 0.737254901960784], [0.968627450980392, 0.2, 0.419607843137255], [0.698039215686274, 0.380392156862745, 0.517647058823529], [0.905882352941176, 0.0196078431372549, 0.988235294117647], [0.372549019607843, 0.419607843137255, 0.0470588235294118], [0.945098039215686, 0.635294117647059, 0.27843137254902], [0.47843137254902, 0.427450980392157, 0.729411764705882], [0.188235294117647, 0.976470588235294, 0.172549019607843], [0.309803921568627, 0.662745098039216, 0.56078431372549], [0.745098039215686, 0.0196078431372549, 0.215686274509804], [0.00784313725490196, 0.0313725490196078, 0.345098039215686], [0.341176470588235, 0.411764705882353, 0.850980392156863], [0.611764705882353, 0.254901960784314, 0.52156862745098], [0.266666666666667, 0.858823529411765, 0.925490196078431], [0.980392156862745, 0.647058823529412, 0.933333333333333], [0.447058823529412, 0.137254901960784, 0.635294117647059], [0.333333333333333, 0.603921568627451, 0.890196078431372], [0.670588235294118, 0.909803921568627, 0.462745098039216], [0.831372549019608, 0.16078431372549, 0.619607843137255], [0.192156862745098, 0.0784313725490196, 0.231372549019608], [0.713725490196078, 0.529411764705882, 0.976470588235294], [0.203921568627451, 0.749019607843137, 0.980392156862745], [0.87843137254902, 0.368627450980392, 0.356862745098039], [0.913725490196078, 0.952941176470588, 0.498039215686275], [0.0666666666666667, 0.929411764705882, 0.266666666666667], [0.784313725490196, 0.486274509803922, 0.152941176470588], [0.117647058823529, 0.741176470588235, 0.894117647058824], [0.776470588235294, 0.466666666666667, 0.533333333333333], [0.133333333333333, 0.254901960784314, 0.494117647058824], [0.635294117647059, 0.294117647058824, 0.882352941176471], [0.627450980392157, 0.470588235294118, 0.47843137254902], [0.592156862745098, 0.815686274509804, 0.145098039215686], [0.290196078431373, 0.333333333333333, 0.262745098039216], [0.0392156862745098, 0.552941176470588, 0.847058823529412], [0.725490196078431, 0.835294117647059, 0.164705882352941], [0.619607843137255, 0.882352941176471, 0.945098039215686], [0.36078431372549, 0.807843137254902, 0.0823529411764706], [0.498039215686275, 0.941176470588235, 0.509803921568627], [0.572549019607843, 0.733333333333333, 0.72156862745098], [0.0745098039215686, 0.113725490196078, 0.623529411764706], [0.533333333333333, 0.325490196078431, 0.00784313725490196], [0.752941176470588, 0.811764705882353, 0.917647058823529], [0.180392156862745, 0.2, 0.941176470588235], [0.27843137254902, 0.909803921568627, 0.156862745098039], [0.133333333333333, 0.650980392156863, 0.258823529411765], [0.250980392156863, 0.588235294117647, 0.435294117647059], [0.898039215686275, 0.729411764705882, 0.294117647058824], [0.631372549019608, 0.0392156862745098, 0.717647058823529], [0.564705882352941, 0.392156862745098, 0.541176470588235], [0.317647058823529, 0.596078431372549, 0.4], [0.133333333333333, 0.274509803921569, 0.474509803921569], [0.749019607843137, 0.164705882352941, 0.713725490196078], [0.305882352941176, 0.749019607843137, 0.533333333333333], [0.145098039215686, 0.647058823529412, 0.356862745098039], [0.925490196078431, 0.243137254901961, 0.972549019607843], [0.890196078431372, 0.286274509803922, 0.749019607843137], [0.227450980392157, 0.509803921568627, 0.929411764705882], [0.823529411764706, 0.686274509803922, 0.631372549019608], [0.0901960784313725, 0.125490196078431, 0.113725490196078], [0.423529411764706, 0.498039215686275, 0.266666666666667], [0.780392156862745, 0.682352941176471, 0.176470588235294], [0.00392156862745098, 0.682352941176471, 0.607843137254902], [0.725490196078431, 0.556862745098039, 0.215686274509804], [0.443137254901961, 0.0666666666666667, 0.823529411764706], [0.713725490196078, 0.980392156862745, 0.882352941176471], [0.988235294117647, 0.36078431372549, 0.83921568627451], [0.486274509803922, 0.792156862745098, 0.235294117647059], [0.482352941176471, 0.603921568627451, 0.447058823529412], [0.592156862745098, 0.152941176470588, 0.733333333333333]]
# data classes (0 = black, 1 = white)
Y = [0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1]
logreg = LogisticRegression(C=1e5, solver='lbfgs', multi_class='multinomial')
# fit model to training data and classes
logreg.fit(X, Y)
# please deactivate prediction display to prevent placebo effect during experiment
if len(points) > 1:
if logreg.predict(z.reshape(1,-1)) == 1:
display.blit(predPic, (width / 3 - 80 / 2, 200))
elif logreg.predict(z.reshape(1,-1)) == 0:
display.blit(predPic, (width / 1.5 - 80 / 2, 200))
# update display
pygame.display.update()
text('Correct Predictions: '+str(np.sum(corcnt)-2), (255,255,255), smallfont, width / 1.5 - -135 / 2, 230)
text('Ratio: '+str((np.sum(corcnt)-2)/(number-2)), (255,255,255), smallfont, width / 1.5 - -135 / 2, 190)
loop = True
while loop:
# check for input to exit
for event in pygame.event.get():
if event.type == pygame.QUIT:
pygame.quit()
quit()
# if prediction correct, append array with integer val. 1 to array of correct predictions
if event.type == pygame.KEYDOWN:
if (event.key == pygame.K_LEFT and logreg.predict(z.reshape(1,-1)) == 1):
corcnt = np.vstack((np.array([1]), corcnt))
loop = False
elif (event.key == pygame.K_RIGHT and logreg.predict(z.reshape(1,-1)) == 0):
corcnt = np.vstack((np.array([1]), corcnt))
loop = False
elif event.key == pygame.K_LEFT: # if prediction incorrect, append to array of incorrect predictions
incnt = np.vstack((np.array([1]), incnt))
loop = False
elif event.key == pygame.K_RIGHT:
incnt = np.vstack((np.array([1]), incnt))
loop = False
# check if button presses are down or up, if they are, they save or import data
if event.key == pygame.K_DOWN:
# check if previous file exists
try:
os.remove("Prediction Data (Logistic Regression).txt")
except:
pass
# create a file for prediction data
stringToFile('Correct Predictions: '+str(np.sum(corcnt)-2)+'\nTotal Predictions: '+str(number-2)+'\nRatio: '+str((np.sum(corcnt)-2)/(np.sum(incnt)+np.sum(corcnt)-3)), "Prediction Data (Logistic Regression).txt")
text("Data Successfully Saved to File", (255,255,255), smallfont, 10, 10)
pygame.display.update()
if event.key == pygame.K_TAB:
x = np.asarray(X)
y = np.asarray(Y)
fig = pyplot.figure()
ax = Axes3D(fig)
ax.set_facecolor((1.0, 0.47, 0.42))
for n in range(len(x)):
if y[n] == 0:
ax.scatter(x[n,0], x[n,1], x[n,2], c="black")
elif y[n] == 1:
ax.scatter(x[n,0], x[n,1], x[n,2], c="white")
ax.set_xlabel('R', fontsize = 15)
ax.set_ylabel('B', fontsize = 15)
ax.set_zlabel('G', fontsize = 15)
pyplot.show()
......@@ -36,17 +36,17 @@ downPic = pygame.image.load('down.png')
pygame.display.set_icon(iconPic)
# title of app
pygame.display.set_caption('Nearest Centroid Text Color Chooser 2 - Prediction')
pygame.display.set_caption('Nearest Centroid Text Color Chooser 3 - Prediction')
# set fonts
smallfont = pygame.font.SysFont('comicsansms', 25)
mediumfont = pygame.font.SysFont('comicsansms', 50)
largefont = pygame.font.SysFont('comicsansms', 80)
# matrixes of color, index 0 should be deleted
points = np.reshape(np.array([randint(0,255), randint(0,255), randint(0,255)]), (-1, 3))
white = np.reshape(np.array([randint(0,255), randint(0,255), randint(0,255)]), (-1, 3))
black = np.reshape(np.array([randint(0,255), randint(0,255), randint(0,255)]), (-1, 3))
# white text looks best on black background and reversed
points = np.reshape(np.array([[0, 0, 0],[255,255,255]]), (-1, 3))
white = np.reshape(np.array([0, 0, 0]), (-1, 3))
black = np.reshape(np.array([255, 255, 255]), (-1, 3))
# initial arrays for correct and total predictions for vstack
corcnt = np.array([1])
......@@ -98,8 +98,8 @@ while True:
blackcent = centroid(black)
y = np.reshape(color, (-1, 3))
# deactivated prediction display to prevent placebo effect during experiment
"""
# please deactivate prediction display to prevent placebo effect during experiment
if len(points) > 1:
if scipy.spatial.distance.euclidean(y, whitecent) < scipy.spatial.distance.euclidean(y, blackcent):
display.blit(predPic, (width / 3 - 80 / 2, 200))
......@@ -108,7 +108,7 @@ while True:
else:
display.blit(predPic, (width / 1.5 - 80 / 2, 200))
display.blit(predPic, (width / 3 - 80 / 2, 200))
"""
# update display
pygame.display.update()
......
# -*- coding: utf-8 -*-
# Neural Network-powered Text Color Chooser - Prediction Only: Main
import pygame
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import pyplot
from mpl_toolkits.mplot3d import Axes3D
from random import randint
from strToFile import stringToFile
from random import seed
from sklearn.neural_network import MLPClassifier
# init
pygame.init()
# resolution
width = 800
height = 600
# make the pygame window
display = pygame.display.set_mode((width, height))
# set framerate
clock = pygame.time.Clock()
clock.tick(60)
# keyboard pictures and predictor
leftPic = pygame.image.load('left.png')
rightPic = pygame.image.load('right.png')
predPic = pygame.image.load('predictor.png')
iconPic = pygame.image.load('icon.png')
downPic = pygame.image.load('down.png')
# set window icon
pygame.display.set_icon(iconPic)
# title of app
pygame.display.set_caption('Neural Network-powered Text Color Chooser v. 1.2 - Prediction - Training Data Sample Size: 384')
# set fonts
smallfont = pygame.font.SysFont('comicsansms', 25)
mediumfont = pygame.font.SysFont('comicsansms', 50)
largefont = pygame.font.SysFont('comicsansms', 80)
# matrixes of color, index 0 should be deleted
points = np.reshape(np.array([[0, 0, 0],[255,255,255]]), (-1, 3))
white = np.reshape(np.array([0, 0, 0]), (-1, 3))
black = np.reshape(np.array([255, 255, 255]), (-1, 3))
# initial arrays for correct and total predictions for vstack
corcnt = np.array([1])
incnt = np.array([1])
# ////////////////////////////////////////////////////////////
# write on screen
# ////////////////////////////////////////////////////////////
def text(msg,color,font,x,y):
screen_text = font.render(msg, True, color)
display.blit(screen_text, [x, y])
# program loop
number = 0
while True:
# random color
color = np.array([randint(0,255), randint(0,255), randint(0,255)])
number += 1
# color background
display.fill((123, 123, 123))
# display information
text('Total Predictions: '+str(number-2), (255,255,255), smallfont, width / 1.5 - -135 / 2, 210)
# draw color boxes
pygame.draw.rect(display, color, [width / 3 - 200 / 2, 300, 200, 200])
pygame.draw.rect(display, color, [width / 1.5 - 200 / 2, 300, 200, 200])
pygame.draw.rect(display, (37,37,37), [width / 3 - 220 / 2 - 4, 75, 220, 40])
# blit pics of arrows
display.blit(leftPic, (width / 3 - 63 / 2, 512))
display.blit(rightPic, (width / 1.5 - 63 / 2, 512))
display.blit(downPic, ((width / 3 - 63 / 2), 125))
# text
text('Text', (255,255,255), largefont, width / 3 - 200 / 2 + 20, 330)
text('Text', (0,0,0), largefont, width / 1.5 - 200 / 2 + 20, 330)
text('Save Prediction Data', (255,255,255), smallfont, width / 3 - 216 / 2, 75)
text('RGB Color Value: ',(255,255,255), smallfont, 5, 200)
text(str(color), (255,255,255), smallfont, 5, 230)
text('press tab to show data', (255,255,255), smallfont, width - 275, 10)
# convert RGB values to between 0 and 1
z = color/255
# training data (sample size: 384)
X = [[0.729411764705882, 0.647058823529412, 0.945098039215686], [0.482352941176471, 0.443137254901961, 0.0862745098039216], [0.643137254901961, 0.525490196078431, 0.313725490196078], [0.462745098039216, 0.933333333333333, 0.501960784313725], [0.588235294117647, 0.168627450980392, 0.768627450980392], [0.854901960784314, 0.635294117647059, 0.462745098039216], [0.580392156862745, 0.588235294117647, 0.305882352941176], [0.0666666666666667, 0.717647058823529, 0.0901960784313725], [0.133333333333333, 0.627450980392157, 0.0862745098039216], [0.16078431372549, 0.415686274509804, 0.125490196078431], [0.219607843137255, 0.52156862745098, 0.431372549019608], [0.556862745098039, 0.525490196078431, 0.611764705882353], [0.227450980392157, 0.741176470588235, 0.286274509803922], [0.176470588235294, 0.376470588235294, 0.466666666666667], [0.752941176470588, 0.0705882352941176, 0.541176470588235], [0.572549019607843, 0.227450980392157, 0.588235294117647], [0.862745098039216, 0.549019607843137, 0.654901960784314], [0.945098039215686, 0.490196078431373, 0.274509803921569], [0.749019607843137, 0.556862745098039, 0.780392156862745], [0.654901960784314, 0.937254901960784, 0.796078431372549], [0.890196078431372, 0.525490196078431, 0.101960784313725], [0.768627450980392, 0.670588235294118, 0.2], [0.905882352941176, 0.443137254901961, 0.654901960784314], [0.43921568627451, 0.56078431372549, 0.0509803921568627], [0.235294117647059, 0.552941176470588, 0.717647058823529], [0.105882352941176, 0.858823529411765, 0.788235294117647], [0.341176470588235, 0.101960784313725, 0.0823529411764706], [0.780392156862745, 0.745098039215686, 0.180392156862745], [0.741176470588235, 0.392156862745098, 0.976470588235294], [0.517647058823529, 0.258823529411765, 1], [0.854901960784314, 0.580392156862745, 0.101960784313725], [0.858823529411765, 0.211764705882353, 0.737254901960784], [0.968627450980392, 0.2, 0.419607843137255], [0.698039215686274, 0.380392156862745, 0.517647058823529], [0.905882352941176, 0.0196078431372549, 0.988235294117647], [0.372549019607843, 0.419607843137255, 0.0470588235294118], [0.945098039215686, 0.635294117647059, 0.27843137254902], [0.47843137254902, 0.427450980392157, 0.729411764705882], [0.188235294117647, 0.976470588235294, 0.172549019607843], [0.309803921568627, 0.662745098039216, 0.56078431372549], [0.745098039215686, 0.0196078431372549, 0.215686274509804], [0.00784313725490196, 0.0313725490196078, 0.345098039215686], [0.341176470588235, 0.411764705882353, 0.850980392156863], [0.611764705882353, 0.254901960784314, 0.52156862745098], [0.266666666666667, 0.858823529411765, 0.925490196078431], [0.980392156862745, 0.647058823529412, 0.933333333333333], [0.447058823529412, 0.137254901960784, 0.635294117647059], [0.333333333333333, 0.603921568627451, 0.890196078431372], [0.670588235294118, 0.909803921568627, 0.462745098039216], [0.831372549019608, 0.16078431372549, 0.619607843137255], [0.192156862745098, 0.0784313725490196, 0.231372549019608], [0.713725490196078, 0.529411764705882, 0.976470588235294], [0.203921568627451, 0.749019607843137, 0.980392156862745], [0.87843137254902, 0.368627450980392, 0.356862745098039], [0.913725490196078, 0.952941176470588, 0.498039215686275], [0.0666666666666667, 0.929411764705882, 0.266666666666667], [0.784313725490196, 0.486274509803922, 0.152941176470588], [0.117647058823529, 0.741176470588235, 0.894117647058824], [0.776470588235294, 0.466666666666667, 0.533333333333333], [0.133333333333333, 0.254901960784314, 0.494117647058824], [0.635294117647059, 0.294117647058824, 0.882352941176471], [0.627450980392157, 0.470588235294118, 0.47843137254902], [0.592156862745098, 0.815686274509804, 0.145098039215686], [0.290196078431373, 0.333333333333333, 0.262745098039216], [0.0392156862745098, 0.552941176470588, 0.847058823529412], [0.725490196078431, 0.835294117647059, 0.164705882352941], [0.619607843137255, 0.882352941176471, 0.945098039215686], [0.36078431372549, 0.807843137254902, 0.0823529411764706], [0.498039215686275, 0.941176470588235, 0.509803921568627], [0.572549019607843, 0.733333333333333, 0.72156862745098], [0.0745098039215686, 0.113725490196078, 0.623529411764706], [0.533333333333333, 0.325490196078431, 0.00784313725490196], [0.752941176470588, 0.811764705882353, 0.917647058823529], [0.180392156862745, 0.2, 0.941176470588235], [0.27843137254902, 0.909803921568627, 0.156862745098039], [0.133333333333333, 0.650980392156863, 0.258823529411765], [0.250980392156863, 0.588235294117647, 0.435294117647059], [0.898039215686275, 0.729411764705882, 0.294117647058824], [0.631372549019608, 0.0392156862745098, 0.717647058823529], [0.564705882352941, 0.392156862745098, 0.541176470588235], [0.317647058823529, 0.596078431372549, 0.4], [0.133333333333333, 0.274509803921569, 0.474509803921569], [0.749019607843137, 0.164705882352941, 0.713725490196078], [0.305882352941176, 0.749019607843137, 0.533333333333333], [0.145098039215686, 0.647058823529412, 0.356862745098039], [0.925490196078431, 0.243137254901961, 0.972549019607843], [0.890196078431372, 0.286274509803922, 0.749019607843137], [0.227450980392157, 0.509803921568627, 0.929411764705882], [0.823529411764706, 0.686274509803922, 0.631372549019608], [0.0901960784313725, 0.125490196078431, 0.113725490196078], [0.423529411764706, 0.498039215686275, 0.266666666666667], [0.780392156862745, 0.682352941176471, 0.176470588235294], [0.00392156862745098, 0.682352941176471, 0.607843137254902], [0.725490196078431, 0.556862745098039, 0.215686274509804], [0.443137254901961, 0.0666666666666667, 0.823529411764706], [0.713725490196078, 0.980392156862745, 0.882352941176471], [0.988235294117647, 0.36078431372549, 0.83921568627451], [0.486274509803922, 0.792156862745098, 0.235294117647059], [0.482352941176471, 0.603921568627451, 0.447058823529412], [0.592156862745098, 0.152941176470588, 0.733333333333333]]
# data classes (0 = black, 1 = white)
Y = [0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1]
clf = MLPClassifier(solver='lbfgs', alpha=1e-5, hidden_layer_sizes=(5, 2), random_state=1)
# fit model to training data and classes
clf.fit(X, Y)
# please deactivate prediction display to prevent placebo effect during experiment
if len(points) > 1:
if clf.predict(z.reshape(1,-1)) == 1:
display.blit(predPic, (width / 3 - 80 / 2, 200))
elif clf.predict(z.reshape(1,-1)) == 0:
display.blit(predPic, (width / 1.5 - 80 / 2, 200))
# update display
pygame.display.update()
text('Correct Predictions: '+str(np.sum(corcnt)-2), (255,255,255), smallfont, width / 1.5 - -135 / 2, 230)
text('Ratio: '+str((np.sum(corcnt)-2)/(number-2)), (255,255,255), smallfont, width / 1.5 - -135 / 2, 190)
loop = True
while loop:
# check for input to exit
for event in pygame.event.get():
if event.type == pygame.QUIT:
pygame.quit()
quit()
# if prediction correct, append array with integer val. 1 to array of correct predictions
if event.type == pygame.KEYDOWN:
if (event.key == pygame.K_LEFT and clf.predict(z.reshape(1,-1)) == 1):
corcnt = np.vstack((np.array([1]), corcnt))
loop = False
elif (event.key == pygame.K_RIGHT and clf.predict(z.reshape(1,-1)) == 0):
corcnt = np.vstack((np.array([1]), corcnt))
loop = False
elif event.key == pygame.K_LEFT: # if prediction incorrect, append to array of incorrect predictions
incnt = np.vstack((np.array([1]), incnt))
loop = False
elif event.key == pygame.K_RIGHT:
incnt = np.vstack((np.array([1]), incnt))
loop = False
# check if button presses are down or up, if they are, they save or import data
if event.key == pygame.K_DOWN:
# check if previous file exists
try:
os.remove("Prediction Data (Neural Network).txt")
except:
pass
# create a file for prediction data
stringToFile('Correct Predictions: '+str(np.sum(corcnt)-2)+'\nTotal Predictions: '+str(number-2)+'\nRatio: '+str((np.sum(corcnt)-2)/(np.sum(incnt)+np.sum(corcnt)-3)), "Prediction Data (Neural Network).txt")
text("Data Successfully Saved to File", (255,255,255), smallfont, 10, 10)
pygame.display.update()
if event.key == pygame.K_TAB:
x = np.asarray(X)
y = np.asarray(Y)
fig = pyplot.figure()
ax = Axes3D(fig)
ax.set_facecolor((1.0, 0.47, 0.42))
for n in range(len(x)):
if y[n] == 0:
ax.scatter(x[n,0], x[n,1], x[n,2], c="black")
elif y[n] == 1:
ax.scatter(x[n,0], x[n,1], x[n,2], c="white")
ax.set_xlabel('R', fontsize = 15)
ax.set_ylabel('B', fontsize = 15)
ax.set_zlabel('G', fontsize = 15)
pyplot.show()
\ No newline at end of file
......@@ -36,17 +36,17 @@ downPic = pygame.image.load('down.png')
pygame.display.set_icon(iconPic)
# title of app
pygame.display.set_caption('Nearest Neighbor Text Color Chooser 2 - Prediction')
pygame.display.set_caption('Nearest Neighbor Text Color Chooser 2.5 - Prediction')
# set fonts
smallfont = pygame.font.SysFont('comicsansms', 25)
mediumfont = pygame.font.SysFont('comicsansms', 50)
largefont = pygame.font.SysFont('comicsansms', 80)
# matrixes of color, index 0 should be deleted
points = np.array([[1000000,1000000,1000000]])
white = np.array([[1000000,1000000,1000000]])
black = np.array([[1000000,1000000,1000000]])
# white text looks best on black background and reversed
points = np.reshape(np.array([[0, 0, 0],[255,255,255]]), (-1, 3))
white = np.reshape(np.array([0, 0, 0]), (-1, 3))
black = np.reshape(np.array([255, 255, 255]), (-1, 3))
# initial arrays for correct and total predictions for vstack
corcnt = np.array([1])
......@@ -97,8 +97,8 @@ while True:
y = np.reshape(color, (-1, 3))
prediction = nn(points, y)
# deactivated prediction display to prevent placebo effect during experiment
"""
# please deactivate prediction display to prevent placebo effect during experiment
if len(points) > 1:
for n in white:
if np.array_equal(n, prediction):
......@@ -106,7 +106,7 @@ while True:
for n in black:
if np.array_equal(n, prediction):
display.blit(predPic, (width / 1.5 - 80 / 2, 200))
"""
# update display
pygame.display.update()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment