Skip to content
Snippets Groups Projects
Commit ab8ad738 authored by s183897's avatar s183897 :ice_skate:
Browse files

Used nearest centroid algorithm, renamed accordingly

parent 5001d509
No related branches found
No related tags found
No related merge requests found
import numpy as np
from sklearn.cluster import KMeans
# function that finds cluster of data points using K-means algorithm
def km(X):
# function that finds centroid of pre-labeled cluster of data points
def centroid(X):
kmeans = KMeans(n_clusters=1)
kmeans.fit(X)
c = kmeans.cluster_centers_
......
# -*- coding: utf-8 -*-
# K-means Text Color Chooser: Main
# Nearest Centroid Text Color Chooser: Main
import pygame
import os
import numpy as np
......@@ -7,7 +7,7 @@ import scipy
from matplotlib import pyplot
from mpl_toolkits.mplot3d import Axes3D
from random import randint
from K_means import km
from centroid import centroid
from matrixToFile import matrixToFile
# init
......@@ -36,7 +36,7 @@ iconPic = pygame.image.load('icon.png')
pygame.display.set_icon(iconPic)
# title of app
pygame.display.set_caption('K-means Text Color Chooser 1.1.1')
pygame.display.set_caption('Nearest Centroid Text Color Chooser 1.2')
# set fonts
smallfont = pygame.font.SysFont('comicsansms', 25)
......@@ -86,13 +86,13 @@ while True:
text('press tab to show data', (255,255,255), smallfont, width - 275, 10)
# make prediction, display prediction above color box
whitecluster = km(white)
blackcluster = km(black)
whitecent = centroid(white)
blackcent = centroid(black)
y = np.reshape(color, (-1, 3))
if len(points) > 1:
if scipy.spatial.distance.euclidean(y, whitecluster) < scipy.spatial.distance.euclidean(y, blackcluster):
if scipy.spatial.distance.euclidean(y, whitecent) < scipy.spatial.distance.euclidean(y, blackcent):
display.blit(predPic, (width / 3 - 80 / 2, 200))
elif scipy.spatial.distance.euclidean(y, whitecluster) > scipy.spatial.distance.euclidean(y, blackcluster):
elif scipy.spatial.distance.euclidean(y, whitecent) > scipy.spatial.distance.euclidean(y, blackcent):
display.blit(predPic, (width / 1.5 - 80 / 2, 200))
else:
display.blit(predPic, (width / 1.5 - 80 / 2, 200))
......@@ -161,9 +161,9 @@ while True:
for n in range(1,len(black)):
ax.scatter(black[n,0],black[n,1],black[n,2], color = "black")
ax.scatter(color[0], color[1], color[2], color = "blue")
# display clusters
ax.scatter(whitecluster[:, 0], whitecluster[:, 1], whitecluster[:, 2], marker='*', c='white', s=500)
ax.scatter(blackcluster[:, 0], blackcluster[:, 1], blackcluster[:, 2], marker='*', c='black', s=500)
# display centroids
ax.scatter(whitecent[:, 0], whitecent[:, 1], whitecent[:, 2], marker='*', c='white', s=500)
ax.scatter(blackcent[:, 0], blackcent[:, 1], blackcent[:, 2], marker='*', c='black', s=500)
# axis labels
ax.set_xlabel('R', fontsize=15)
ax.set_ylabel('G', fontsize=15)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment