SemiDiscreteOT 1.0
Semi-Discrete Optimal Transport Library
Loading...
Searching...
No Matches
dkm_utils.hpp
Go to the documentation of this file.
1#pragma once
2
3#include "dkm.hpp"
4
5#include <algorithm>
6#include <array>
7#include <tuple>
8#include <vector>
9#include <fstream>
10#include <iterator>
11#include <regex>
12
13namespace dkm {
14
15namespace details {
16
17// Split a line on commas, making it simple to pull out the values we need
18inline std::vector<std::string> split_commas(const std::string& line) {
19 std::vector<std::string> split;
20 std::regex reg(",");
21 std::copy(std::sregex_token_iterator(line.begin(), line.end(), reg, -1),
22 std::sregex_token_iterator(),
23 std::back_inserter(split));
24 return split;
25}
26
27}
28
38template <typename T, size_t N>
39std::vector<T> dist_to_center(const std::vector<std::array<T, N>>& points, const std::array<T, N>& center) {
40 std::vector<T> result(points.size());
41 std::transform(points.begin(), points.end(), result.begin(), [&center](const std::array<T, N>& p) {
42 return details::distance(p, center);
43 });
44 return result;
45}
46
47
56template <typename T, size_t N>
57T sum_dist(const std::vector<std::array<T, N>>& points, const std::array<T, N>& center) {
58 std::vector<T> distances = dist_to_center(points, center);
59 return std::accumulate(distances.begin(), distances.end(), T());
60}
61
62
73template <typename T, size_t N>
74std::vector<std::array<T, N>> get_cluster(
75 const std::vector<std::array<T, N>>& points, const std::vector<uint32_t>& labels, const uint32_t label) {
76 assert(points.size() == labels.size() && "Points and labels have different sizes");
77 // construct the cluster
78 std::vector<std::array<T, N>> cluster;
79 for (size_t point_index = 0; point_index < points.size(); ++point_index) {
80 if (labels[point_index] == label) {
81 cluster.push_back(points[point_index]);
82 }
83 }
84 return cluster;
85}
86
87
98template <typename T, size_t N>
99T means_inertia(const std::vector<std::array<T, N>>& points,
100 const std::tuple<std::vector<std::array<T, N>>, std::vector<uint32_t>>& means,
101 uint32_t k) {
102 std::vector<std::array<T, N>> centroids;
103 std::vector<uint32_t> labels;
104 std::tie(centroids, labels) = means;
105
106 T inertia{T()};
107 for (uint32_t i = 0; i < k; ++i) {
108 auto cluster = get_cluster(points, labels, i);
109 inertia += sum_dist(cluster, centroids[i]);
110 }
111 return inertia;
112}
113
114
125template <typename T, size_t N>
126std::tuple<std::vector<std::array<T, N>>, std::vector<uint32_t>> get_best_means(
127 const std::vector<std::array<T, N>>& points, uint32_t k, uint32_t n_init = 10) {
128 auto best_means = kmeans_lloyd(points, k);
129 auto best_inertia = means_inertia(points, best_means, k);
130
131 for (uint32_t i = 0; i < n_init - 1; ++i) {
132 auto curr_means = kmeans_lloyd(points, k);
133 auto curr_inertia = means_inertia(points, curr_means, k);
134 if (curr_inertia < best_inertia) {
135 best_inertia = curr_inertia;
136 best_means = curr_means;
137 }
138 }
139 // copy and return
140 return best_means;
141}
142
149template <typename T, size_t N>
150size_t predict(const std::vector<std::array<T, N>>& centroids, const std::array<T, N>& query) {
151 T min = details::distance(centroids[0], query);
152 size_t index = 0;
153 for(size_t i = 1; i < centroids.size(); i++) {
154 auto dist = details::distance(centroids[i], query);
155 if (dist < min) {
156 min = dist;
157 index = i;
158 }
159 }
160 return index;
161}
162
168template <typename T, size_t N>
169std::vector<std::array<T, N>> load_csv(const std::string& path) {
170 std::ifstream file(path);
171 std::vector<std::array<T, N>> data;
172 for (auto it = std::istream_iterator<std::string>(file); it != std::istream_iterator<std::string>(); ++it) {
173 auto split = details::split_commas(*it);
174 assert(split.size() == N); // number of values must match rows in file
175 std::array<T, N> row;
176 std::transform(split.begin(), split.end(), row.begin(), [](const std::string& in) -> T {
177 return static_cast<T>(std::stod(in));
178 });
179 data.push_back(row);
180 }
181 return data;
182}
183
184} // namespace dkm
std::vector< std::string > split_commas(const std::string &line)
Definition dkm_utils.hpp:18
T distance(const std::array< T, N > &point_a, const std::array< T, N > &point_b)
Definition dkm.hpp:42
Definition dkm.hpp:20
std::tuple< std::vector< std::array< T, N > >, std::vector< uint32_t > > get_best_means(const std::vector< std::array< T, N > > &points, uint32_t k, uint32_t n_init=10)
std::vector< std::array< T, N > > get_cluster(const std::vector< std::array< T, N > > &points, const std::vector< uint32_t > &labels, const uint32_t label)
Definition dkm_utils.hpp:74
T sum_dist(const std::vector< std::array< T, N > > &points, const std::array< T, N > &center)
Definition dkm_utils.hpp:57
std::vector< T > dist_to_center(const std::vector< std::array< T, N > > &points, const std::array< T, N > &center)
Definition dkm_utils.hpp:39
T means_inertia(const std::vector< std::array< T, N > > &points, const std::tuple< std::vector< std::array< T, N > >, std::vector< uint32_t > > &means, uint32_t k)
Definition dkm_utils.hpp:99
std::vector< std::array< T, N > > load_csv(const std::string &path)
size_t predict(const std::vector< std::array< T, N > > &centroids, const std::array< T, N > &query)
std::tuple< std::vector< std::array< T, N > >, std::vector< uint32_t > > kmeans_lloyd(const std::vector< std::array< T, N > > &data, const clustering_parameters< T > &parameters)
Definition dkm.hpp:271