#include #include #include #include #include #include #include std::random_device rd; std::mt19937 rng(rd()); class KMeans { private: std::vector< std::pair > points; std::vector< std::pair > centroids; std::vector clusters; int cluster_tot; /* Total number of clusters */ int points_tot; const int max_iter = 10 * 1000; /* Max number of iteration */ const double threshold = 1e-3; /* Stopping criterion */ double compute_score(); void reassign(); void recompute(); public: KMeans(int bound, int cluster_tot, int points_tot); int numPoints() { return points_tot; } int getCluster(int i) { return clusters[i]; } std::vector getClusters() { return clusters; } std::pair getPoint(int i) { return points[i]; } void run(); }; KMeans::KMeans(int bound, int cluster_tot_, int points_tot_) { std::uniform_int_distribution uni(0, bound); cluster_tot = cluster_tot_; points_tot = points_tot_; points = std::vector< std::pair >(points_tot); clusters = std::vector(points_tot); centroids = std::vector< std::pair >(cluster_tot); for (int i = 0; i < points_tot; i++) { points[i] = std::make_pair(uni(rng), uni(rng)); clusters[i] = i % cluster_tot; } for (int k = 0; k < cluster_tot; k++) { centroids[k] = std::make_pair(uni(rng), uni(rng)); } } double euclid(std::pair p1, std::pair p2) { return sqrt(pow(p1.first - p2.first, 2.0) + pow(p1.second - p2.second, 2.0)); } double KMeans::compute_score() { double score = 0.0; for (int i = 0; i < points_tot; i++) { //std::cout << "LOL" << std::endl; std::pair coord = points[i]; //std::cout << "LOL2 " << clusters[i] << std::endl; std::pair cent = centroids[clusters[i]]; double d = euclid(coord, cent); score += d; } return score; } void KMeans::run() { double current_score = 0.0; int iter = 0; while (true) { //std::cout << "ITER " << iter << std::endl; double sc = compute_score(); //std::cout << "New score " << sc << std::endl; if (std::abs(sc - current_score) < threshold || iter >= max_iter) { std::cout << iter << " iterations" << std::endl; break; } current_score = sc; reassign(); recompute(); iter++; } } /* Reassign each point to its nearest centroid */ void KMeans::reassign() { for (int i = 0; i < points_tot; i++) { double dmin = 1e25; int kmin = 0; //std::cout << "LAULE " << i << " / " << points_tot << std::endl; for (int k = 0; k < cluster_tot; k++) { double d = euclid(points[i], centroids[k]); if (d < dmin) { kmin = k; dmin = d; } } clusters[i] = kmin; } } void KMeans::recompute() { std::vector clusters_size = std::vector(cluster_tot, 0); std::vector clusters_x = std::vector(cluster_tot, 0); std::vector clusters_y = std::vector(cluster_tot, 0); for (int i = 0; i < points_tot; i++) { clusters_size[clusters[i]]++; clusters_x[clusters[i]] += points[i].first; clusters_y[clusters[i]] += points[i].second; } for (int k = 0; k < cluster_tot; k++) { if (!clusters_size[k]) { continue; } centroids[k] = std::make_pair( (int) (clusters_x[k] / clusters_size[k]), (int) (clusters_y[k] / clusters_size[k]) ); } } int main(int argc, char **argv) { const int WINDOW_SIZE = 1500; /*if (SDL_Init(SDL_INIT_EVERYTHING) != 0) { std::cerr << "SDL_Init error: " << SDL_GetError() << std::endl; return 1; } SDL_Window *window; SDL_Renderer *renderer; SDL_Event event; SDL_CreateWindowAndRenderer(WINDOW_SIZE, WINDOW_SIZE, 0, &window, &renderer); while (false) { SDL_SetRenderDrawColor(renderer, 10, 10, 10, 0); SDL_RenderClear(renderer); for (int i = 0; i < kmeans.numPoints(); i++) { int c = kmeans.getCluster(i); std::pair coord = kmeans.getPoint(i); SDL_SetRenderDrawColor(renderer, c, 255-c, c, 255); SDL_RenderDrawPoint(renderer, coord.first, coord.second); } SDL_RenderPresent(renderer); if (SDL_PollEvent(&event) && event.type == SDL_MOUSEBUTTONDOWN) break; }*/ KMeans kmeans = KMeans(WINDOW_SIZE, 200, 10000); std::cout << "BEGIN" << std::endl; std::clock_t start = std::clock(); double duration; kmeans.run(); duration = ( std::clock() - start ) / (double) CLOCKS_PER_SEC; std::cout << " # DURATION: " << duration << "s" << std::endl; std::cout << "END" << std::endl; //SDL_DestroyRenderer(renderer); //SDL_DestroyWindow(window); // SDL_Quit(); return 0; }