Skip to content
Snippets Groups Projects
Commit bdc28bf3 authored by s183897's avatar s183897 :ice_skate:
Browse files

Kmeans version

parent d9f19f47
No related branches found
No related tags found
No related merge requests found
import pygame
import os
import time
import numpy as np
import scipy
from matplotlib import pyplot
from mpl_toolkits.mplot3d import Axes3D
from random import randint
from sklearn.cluster import KMeans
# init
pygame.init()
# resolution
width = 800
height = 600
# make the pygame window
display = pygame.display.set_mode((width, height))
# set framerate
clock = pygame.time.Clock()
clock.tick(60)
# keyboard pictures and predictor
leftPic = pygame.image.load('left.png')
rightPic = pygame.image.load('right.png')
downPic = pygame.image.load('down.png')
upPic = pygame.image.load('up.png')
predPic = pygame.image.load('predictor.png')
iconPic = pygame.image.load('icon.png')
# set window icon
pygame.display.set_icon(iconPic)
# title of app
pygame.display.set_caption('Color Classification')
# set fonts
smallfont = pygame.font.SysFont('comicsansms', 25)
mediumfont = pygame.font.SysFont('comicsansms', 50)
largefont = pygame.font.SysFont('comicsansms', 80)
# matrixes of color, index 0 should be deleted
points = np.array([[1000000,1000000,1000000]])
white = np.array([[1000000,1000000,1000000]])
black = np.array([[1000000,1000000,1000000]])
# write on screen
def text(msg,color,font,x,y):
screen_text = font.render(msg, True, color)
display.blit(screen_text, [x, y])
# K-means cluster
def km(X):
kmeans = KMeans(n_clusters=1)
kmeans.fit(X)
c = kmeans.cluster_centers_
return c
# outputs a matrix to a .txt file
def matrixToFile(matrix, fileName):
f = open(fileName, "w")
for n in matrix:
for i in range(len(n)):
if i == len(n) - 1:
f.write(str(n[i]))
else:
f.write(str(n[i]) + " ")
f.write("\n")
f.close()
# program loop
while True:
# random color
color = np.array([randint(0,255), randint(0,255), randint(0,255)])
# color background
display.fill((123, 123, 123))
# draw color boxes
pygame.draw.rect(display, color, [width / 3 - 200 / 2, 300, 200, 200])
pygame.draw.rect(display, color, [width / 1.5 - 200 / 2, 300, 200, 200])
pygame.draw.rect(display, (37,37,37), [width / 3 - 220 / 2 - 4, 75, 220, 40])
pygame.draw.rect(display, (37,37,37), [width / 1.5 - 245 / 2 - 4, 75, 245, 40])
# blit pics of arrows
display.blit(leftPic, (width / 3 - 63 / 2, 512))
display.blit(rightPic, (width / 1.5 - 63 / 2, 512))
display.blit(downPic, ((width / 3 - 63 / 2), 125))
display.blit(upPic, ((width / 1.5 - 63 / 2), 125))
# text
text('text', (255,255,255), largefont, width / 3 - 200 / 2 + 20, 330)
text('text', (0,0,0), largefont, width / 1.5 - 200 / 2 + 20, 330)
text('save training data', (255,255,255), smallfont, width / 3 - 216 / 2, 75)
text('import training data', (255,255,255), smallfont, width / 1.5 - 241 / 2, 75)
text('Color value: ',(255,255,255), smallfont, 5, 200)
text(str(color), (255,255,255), smallfont, 5, 230)
text('press tab to show data', (255,255,255), smallfont, width - 275, 10)
# make prediction, display prediction above color box
whitecluster = km(white)
blackcluster = km(black)
y = np.reshape(color, (-1, 3))
if len(points) > 1:
if scipy.spatial.distance.euclidean(y, whitecluster) < scipy.spatial.distance.euclidean(y, blackcluster):
display.blit(predPic, (width / 3 - 80 / 2, 200))
elif scipy.spatial.distance.euclidean(y, whitecluster) > scipy.spatial.distance.euclidean(y, blackcluster):
display.blit(predPic, (width / 1.5 - 80 / 2, 200))
else:
display.blit(predPic, (width / 1.5 - 80 / 2, 200))
display.blit(predPic, (width / 3 - 80 / 2, 200))
# update display
pygame.display.update()
loop = True
while loop:
# check for input to exit
for event in pygame.event.get():
if event.type == pygame.QUIT:
pygame.quit()
quit()
# check if white or black is best color and append to corresponding list
if event.type == pygame.KEYDOWN:
if event.key == pygame.K_LEFT:
white = np.vstack((white, color))
points = np.vstack((points, color))
loop = False
if event.key == pygame.K_RIGHT:
black = np.vstack((black, color))
points = np.vstack((points, color))
loop = False
# check if button presses are down or up, if they are, they save or import data
if event.key == pygame.K_DOWN:
# check if previous files exist
try:
os.remove("points.txt")
os.remove("white.txt")
os.remove("black.txt")
except:
pass
# make a file for the points matrix
matrixToFile(points, "points.txt")
# make a file for the white matrix
matrixToFile(white, "white.txt")
# make a file for the black matrix
matrixToFile(black, "black.txt")
text("Data successfully saved to files", (255,255,255), smallfont, 10, 10)
pygame.display.update()
if event.key == pygame.K_UP:
try:
points = np.loadtxt("points.txt")
white = np.loadtxt("white.txt")
black = np.loadtxt("black.txt")
text("Data successfully imported", (255,255,255), smallfont, 10, 10)
pygame.display.update()
except:
text("Data failed to be imported", (255,255,255), mediumfont, 10, 10)
pygame.display.update()
if event.key == pygame.K_TAB:
fig = pyplot.figure()
ax = Axes3D(fig)
ax.set_facecolor((1.0, 0.47, 0.42))
for n in range(1,len(white)):
ax.scatter(white[n,0],white[n,1],white[n,2], color = "white")
for n in range(1,len(black)):
ax.scatter(black[n,0],black[n,1],black[n,2], color = "black")
ax.scatter(color[0], color[1], color[2], color = "blue")
pyplot.show()
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment