SemiDiscreteOT 1.0
Semi-Discrete Optimal Transport Library
Loading...
Searching...
No Matches
OptimalTransportPlan.cc
Go to the documentation of this file.
3
4namespace fs = std::filesystem;
5
7
8template <int spacedim>
10 : ParameterAcceptor("OptimalTransportPlan"),
11 distance_function(euclidean_distance<spacedim>)
12{
13 strategy = create_strategy(strategy_name);
14}
15
16template <int spacedim>
18 const std::vector<Point<spacedim>>& points,
19 const std::vector<double>& density)
20{
21 AssertDimension(points.size(), density.size());
22 source_points = points;
23 source_density = density;
24}
25
26template <int spacedim>
28 const std::vector<Point<spacedim>>& points,
29 const std::vector<double>& density)
30{
31 AssertDimension(points.size(), density.size());
32 target_points = points;
33 target_density = density;
34}
35
36template <int spacedim>
38 const Vector<double>& potential,
39 const double regularization_param)
40{
41 AssertDimension(potential.size(), target_points.size());
42 transport_potential = potential;
43 epsilon = regularization_param;
44}
45
46template <int spacedim>
48{
49 truncation_radius = radius;
50}
51
52template <int spacedim>
54{
55 Assert(strategy, ExcMessage("No strategy selected"));
56 Assert(!source_points.empty(), ExcMessage("Source measure not set"));
57 Assert(!target_points.empty(), ExcMessage("Target measure not set"));
58 Assert(transport_potential.size() > 0, ExcMessage("Transport potential not set"));
59
60 strategy->compute_map(distance_function, source_points, source_density,
61 target_points, target_density,
62 transport_potential, epsilon,
63 truncation_radius);
64}
65
66template <int spacedim>
67void OptimalTransportPlan<spacedim>::save_map(const std::string& output_dir) const
68{
69 Assert(strategy, ExcMessage("No strategy selected"));
70 fs::create_directories(output_dir);
71 strategy->save_results(output_dir);
72}
73
74template <int spacedim>
75void OptimalTransportPlan<spacedim>::set_strategy(const std::string& strategy_name)
76{
77 strategy = create_strategy(strategy_name);
78}
79
80template <int spacedim>
82{
83 return {"modal", "barycentric"};
84}
85
86template <int spacedim>
87std::unique_ptr<MapApproximationStrategy<spacedim>>
89{
90 if (name == "modal")
91 return std::make_unique<ModalStrategy<spacedim>>();
92 else if (name == "barycentric")
93 return std::make_unique<BarycentricStrategy<spacedim>>();
94 else
95 throw std::runtime_error("Unknown strategy: " + name);
96}
97
98// Implementation of ModalStrategy
99template <int spacedim>
101 const std::function<double(const Point<spacedim>&, const Point<spacedim>&)> distance_function,
102 const std::vector<Point<spacedim>>& source_points,
103 const std::vector<double>& source_density,
104 const std::vector<Point<spacedim>>& target_points,
105 const std::vector<double>& target_density,
106 const Vector<double>& potential,
107 const double regularization_param,
108 const double truncation_radius)
109{
110 using IndexedPoint = std::pair<Point<spacedim>, std::size_t>;
111 using RTreeParams = boost::geometry::index::rstar<8>;
112 using RTree = boost::geometry::index::rtree<IndexedPoint, RTreeParams>;
113
114 // Build RTree for target points
115 std::vector<IndexedPoint> indexed_points;
116 indexed_points.reserve(target_points.size());
117 for (std::size_t i = 0; i < target_points.size(); ++i) {
118 indexed_points.emplace_back(target_points[i], i);
119 }
120 RTree target_rtree(indexed_points.begin(), indexed_points.end());
121
122 // For each source point, find the target point that maximizes the score
123 this->mapped_points.resize(source_points.size());
124 this->transport_density.resize(source_points.size());
125
126 // Store source points for displacement computation
127 this->source_points = source_points;
128
129 // Determine if we use truncation or consider all points
130 const bool use_truncation = (truncation_radius > 0.0);
131
132 for (std::size_t i = 0; i < source_points.size(); ++i) {
133 const Point<spacedim>& x = source_points[i];
134 double max_score = -std::numeric_limits<double>::infinity();
135 std::size_t best_idx = 0;
136
137 // Set of points to consider (all or truncated)
138 std::vector<IndexedPoint> candidates;
139
140 if (use_truncation) {
141 // Use truncation radius to limit points to consider
142 target_rtree.query(
143 boost::geometry::index::satisfies([&x, &distance_function, truncation_radius](const IndexedPoint& p) {
144 return distance_function(x, p.first) < truncation_radius;
145 }),
146 std::back_inserter(candidates)
147 );
148
149 // If no points within truncation radius, fall back to nearest neighbor
150 if (candidates.empty()) {
151 target_rtree.query(boost::geometry::index::nearest(x, 1), std::back_inserter(candidates));
152 }
153 } else {
154 // Consider all target points
155 candidates.reserve(target_points.size());
156 for (std::size_t j = 0; j < target_points.size(); ++j) {
157 candidates.push_back(indexed_points[j]);
158 }
159 }
160
161 // Compute scores and find maximum
162 for (const auto& candidate : candidates) {
163 const Point<spacedim>& y = candidate.first;
164 const std::size_t j = candidate.second;
165
166 // Compute squared distance
167 double squared_dist = std::pow(distance_function(x, y), 2);
168
169 // Compute score: potential - c(x,y) + regularization_param * log(target_density)
170 double log_term = 0.0;
171 if (target_density[j] > 0) {
172 log_term = regularization_param * std::log(target_density[j]);
173 } else {
174 // If target density is zero or negative, use negative infinity for log term
175 log_term = -std::numeric_limits<double>::infinity();
176 }
177
178 double score = potential[j] - 0.5 * squared_dist + log_term;
179
180 if (score > max_score) {
181 max_score = score;
182 best_idx = j;
183 }
184 }
185
186 this->mapped_points[i] = target_points[best_idx];
187 this->transport_density[i] = source_density[i];
188 }
189}
190
191template <int spacedim>
192void ModalStrategy<spacedim>::save_results(const std::string& output_dir) const
193{
194 // Save text files for backward compatibility
195 Utils::write_vector(this->mapped_points, output_dir + "/mapped_points", "txt");
196 Utils::write_vector(this->transport_density, output_dir + "/transport_density", "txt");
197
198 // 1. Source triangulation with displacement field using utility function
199 Utils::write_points_with_displacement_vtk<spacedim>(
200 this->source_points,
201 this->mapped_points,
202 this->transport_density,
203 output_dir + "/source_triangulation_with_displacement.vtk",
204 "Source triangulation with displacement vectors",
205 "source_density"
206 );
207
208 // 2. Mapped points with density scalars using utility function
209 Utils::write_points_with_density_vtk<spacedim>(
210 this->mapped_points,
211 this->transport_density,
212 output_dir + "/mapped_points_with_density.vtk",
213 "Mapped points with density values",
214 "transport_density"
215 );
216}
217
218// Implementation of BarycentricStrategy
219template <int spacedim>
221 const std::function<double(const Point<spacedim>&, const Point<spacedim>&)> distance_function,
222 const std::vector<Point<spacedim>>& source_points,
223 const std::vector<double>& source_density,
224 const std::vector<Point<spacedim>>& target_points,
225 const std::vector<double>& target_density,
226 const Vector<double>& potential,
227 const double regularization_param,
228 const double truncation_radius)
229{
230 using IndexedPoint = std::pair<Point<spacedim>, std::size_t>;
231 using RTreeParams = boost::geometry::index::rstar<8>; // TODO: set as hyperparameter ?
232 using RTree = boost::geometry::index::rtree<IndexedPoint, RTreeParams>;
233
234 // Build RTree for target points
235 std::vector<IndexedPoint> indexed_points;
236 indexed_points.reserve(target_points.size());
237 for (std::size_t i = 0; i < target_points.size(); ++i) {
238 indexed_points.emplace_back(target_points[i], i);
239 }
240 RTree target_rtree(indexed_points.begin(), indexed_points.end());
241
242 this->mapped_points.resize(source_points.size());
243 this->transport_density.resize(source_points.size());
244
245 // Store source points for displacement computation
246 this->source_points = source_points;
247
248 // Determine if we use truncation or consider all points
249 const bool use_truncation = (truncation_radius > 0.0);
250
251 // For each source point, compute barycentric interpolation
252 for (std::size_t i = 0; i < source_points.size(); ++i) {
253 const Point<spacedim>& x = source_points[i];
254 Point<spacedim> weighted_sum;
255
256 // Determine which target points to consider
257 std::vector<IndexedPoint> candidates;
258
259 if (use_truncation) {
260 // Use truncation radius to limit points to consider
261 target_rtree.query(
262 boost::geometry::index::satisfies([&x, &distance_function, truncation_radius](const IndexedPoint& p) {
263 return distance_function(x, p.first) < truncation_radius;
264 }),
265 std::back_inserter(candidates)
266 );
267
268 // If no points within truncation radius, fall back to nearest neighbor
269 if (candidates.empty()) {
270 target_rtree.query(boost::geometry::index::nearest(x, 1), std::back_inserter(candidates));
271 const Point<spacedim>& nearest = candidates[0].first;
272 this->mapped_points[i] = nearest;
273 this->transport_density[i] = source_density[i];
274 continue;
275 }
276 } else {
277 // Consider all target points
278 candidates.reserve(target_points.size());
279 for (std::size_t j = 0; j < target_points.size(); ++j) {
280 candidates.push_back(indexed_points[j]);
281 }
282 }
283
284 // Log-sum-exp trick for numerical stability
285 // First compute all exponent terms and find maximum
286 std::vector<double> exponent_terms;
287 exponent_terms.reserve(candidates.size());
288 double max_exponent = -std::numeric_limits<double>::infinity();
289
290 for (const auto& candidate : candidates) {
291 const Point<spacedim>& y = candidate.first;
292 const std::size_t j = candidate.second;
293
294 double squared_dist = std::pow(distance_function(x, y), 2);
295 double exponent = (potential[j] - 0.5 * squared_dist) / regularization_param;
296
297 exponent_terms.push_back(exponent);
298 max_exponent = std::max(max_exponent, exponent);
299 }
300
301 // Compute weighted sum using the log-sum-exp trick
302 double sum_exp = 0.0;
303 for (std::size_t k = 0; k < candidates.size(); ++k) {
304 const auto& candidate = candidates[k];
305 const Point<spacedim>& y = candidate.first;
306 const std::size_t j = candidate.second;
307
308 // Subtract max_exponent for numerical stability
309 double stable_exp = std::exp(exponent_terms[k] - max_exponent);
310 double weight = target_density[j] * stable_exp;
311
312 weighted_sum += weight * y;
313 sum_exp += weight;
314 }
315
316 // Normalize the weighted sum
317 if (sum_exp > 0) {
318 this->mapped_points[i] = weighted_sum / sum_exp;
319 } else {
320 // If all weights are zero, map to nearest target point
321 std::vector<IndexedPoint> nearest;
322 target_rtree.query(boost::geometry::index::nearest(x, 1), std::back_inserter(nearest));
323 this->mapped_points[i] = nearest[0].first;
324 }
325
326 this->transport_density[i] = source_density[i];
327 }
328}
329
330template <int spacedim>
331void BarycentricStrategy<spacedim>::save_results(const std::string& output_dir) const
332{
333 // Save text files for backward compatibility
334 Utils::write_vector(this->mapped_points, output_dir + "/mapped_points", "txt");
335 Utils::write_vector(this->transport_density, output_dir + "/transport_density", "txt");
336
337 // 1. Source triangulation with displacement field using utility function
338 Utils::write_points_with_displacement_vtk<spacedim>(
339 this->source_points,
340 this->mapped_points,
341 this->transport_density,
342 output_dir + "/source_triangulation_with_displacement.vtk",
343 "Source triangulation with displacement vectors (barycentric)",
344 "source_density"
345 );
346
347 // 2. Mapped points with density scalars using utility function
348 Utils::write_points_with_density_vtk<spacedim>(
349 this->mapped_points,
350 this->transport_density,
351 output_dir + "/mapped_points_with_density.vtk",
352 "Mapped points with density values (barycentric)",
353 "transport_density"
354 );
355}
356
357// Explicit instantiation
358template class OptimalTransportPlan<2>;
359template class OptimalTransportPlan<3>;
360template class ModalStrategy<2>;
361template class ModalStrategy<3>;
362template class BarycentricStrategy<2>;
363template class BarycentricStrategy<3>;
364
365} // namespace OptimalTransportPlanSpace
double euclidean_distance(const Point< spacedim > a, const Point< spacedim > b)
Computes the Euclidean distance between two points.
Definition Distance.h:24
Barycentric interpolation strategy for map approximation.
void compute_map(const std::function< double(const Point< spacedim > &, const Point< spacedim > &)> distance_function, const std::vector< Point< spacedim > > &source_points, const std::vector< double > &source_density, const std::vector< Point< spacedim > > &target_points, const std::vector< double > &target_density, const Vector< double > &potential, const double regularization_param, const double truncation_radius) override
Computes the transport map.
void save_results(const std::string &output_dir) const override
Saves the results to a file.
Modal strategy for map approximation.
void compute_map(const std::function< double(const Point< spacedim > &, const Point< spacedim > &)> distance_function, const std::vector< Point< spacedim > > &source_points, const std::vector< double > &source_density, const std::vector< Point< spacedim > > &target_points, const std::vector< double > &target_density, const Vector< double > &potential, const double regularization_param, const double truncation_radius) override
Computes the transport map.
void save_results(const std::string &output_dir) const override
Saves the results to a file.
A class for computing and managing optimal transport map approximations.
void compute_map()
Compute the optimal transport map approximation using the current strategy.
OptimalTransportPlan(const std::string &strategy_name="modal")
Constructor taking an optional strategy name.
static std::unique_ptr< MapApproximationStrategy< spacedim > > create_strategy(const std::string &name)
static std::vector< std::string > get_available_strategies()
Get available strategy names.
void set_strategy(const std::string &strategy_name)
Change the approximation strategy.
void set_target_measure(const std::vector< Point< spacedim > > &points, const std::vector< double > &density)
Set the target measure data.
void set_source_measure(const std::vector< Point< spacedim > > &points, const std::vector< double > &density)
Set the source measure data.
void set_truncation_radius(double radius)
Set the truncation radius for map computation. Points outside this radius will not be considered in t...
std::unique_ptr< MapApproximationStrategy< spacedim > > strategy
void set_potential(const Vector< double > &potential, const double regularization_param=0.0)
Set the optimal transport potential.
void save_map(const std::string &output_dir) const
Save the computed transport map to files.
void write_vector(const VectorContainer &points, const std::string &filepath, const std::string &fileMode="txt")
Write a vector container to a file in binary or text format.
Definition utils.h:56