SemiDiscreteOT 1.0
Semi-Discrete Optimal Transport Library
Loading...
Searching...
No Matches
SotSolver.cc
Go to the documentation of this file.
2
3template <int dim, int spacedim>
5 : mpi_communicator(comm)
6 , n_mpi_processes(Utilities::MPI::n_mpi_processes(comm))
7 , this_mpi_process(Utilities::MPI::this_mpi_process(comm))
8 , pcout(std::cout, this_mpi_process == 0)
9 , current_distance_threshold(1e-1)
10 , effective_distance_threshold(1e-1)
11 , current_potential(nullptr)
12 , current_epsilon(1.0)
13 , global_functional(0.0)
14 , gradient()
15 , covering_radius(0.0)
16 , min_target_density(0.0)
17 , C_global(0.0)
18{
19 distance_function = [](const Point<spacedim> x, const Point<spacedim> y) { return euclidean_distance<spacedim>(x, y); };
20}
21
22template <int dim, int spacedim>
24 const std::string &distance_name_)
25{
26 distance_name = distance_name_;
27 if (distance_name == "euclidean") {
28 pcout << "Using Euclidean distance function." << std::endl;
29 distance_function = [](const Point<spacedim> x, const Point<spacedim> y) { return euclidean_distance<spacedim>(x, y); };
30 distance_function_gradient = [](const Point<spacedim> x, const Point<spacedim> y) { return euclidean_distance_gradient<spacedim>(x, y); };
31 distance_function_exponential_map = [](const Point<spacedim> x, const Vector<double> v) { return euclidean_distance_exp_map<spacedim>(x, v); };
32
33 } else if (distance_name == "spherical") {
34 pcout << "Using Spherical distance function." << std::endl;
35 distance_function = [](const Point<spacedim> x, const Point<spacedim> y) { return spherical_distance<spacedim>(x, y); };
36 distance_function_gradient = [](const Point<spacedim> x, const Point<spacedim> y) { return spherical_distance_gradient<spacedim>(x, y); };
37 distance_function_exponential_map = [](const Point<spacedim> x, const Vector<double> v) { return spherical_distance_exp_map<spacedim>(x, v); };
38
39 } else {
40 throw std::invalid_argument("Unknown distance function: " + distance_name);
41 }
42}
43
44template <int dim, int spacedim>
46 const DoFHandler<dim, spacedim>& dof_handler,
47 const Mapping<dim, spacedim>& mapping,
48 const FiniteElement<dim, spacedim>& fe,
49 const LinearAlgebra::distributed::Vector<double, MemorySpace::Host>& source_density,
50 const unsigned int quadrature_order)
51{
52 source_measure = SourceMeasure(dof_handler, mapping, fe, source_density, quadrature_order);
53}
54
55template <int dim, int spacedim>
57 const std::vector<Point<spacedim>>& target_points,
58 const Vector<double>& target_density)
59{
60 target_measure = TargetMeasure(target_points, target_density);
61}
62
63template <int dim, int spacedim>
66{
67 // Store parameters for later use
68 current_params = params;
69 current_epsilon = params.epsilon;
70}
71
72template <int dim, int spacedim>
74{
75 if (!source_measure.dof_handler) {
76 pcout << "Error: Source measure not set up" << std::endl;
77 return false;
78 }
79 if (target_measure.points.empty()) {
80 pcout << "Error: Target measure not set up" << std::endl;
81 return false;
82 }
83 if (target_measure.density.size() != target_measure.points.size()) {
84 pcout << "Error: Target density size mismatch" << std::endl;
85 return false;
86 }
87 return true;
88}
89
90template <int dim, int spacedim>
92 Vector<double>& potentials,
93 const SourceMeasure& source,
94 const TargetMeasure& target,
96{
97 // Set up measures
98 source_measure = source;
99 target_measure = target;
100
101 // Call main solve method
102 solve(potentials, params);
103}
104
105template <int dim, int spacedim>
107 Vector<double>& potentials,
109{
110 if (!validate_measures()) {
111 throw std::runtime_error("Invalid measures configuration");
112 }
113
114 // Store parameters
115 current_params = params;
116 current_epsilon = params.epsilon;
117
118 // Set up parallel environment
119 unsigned int n_threads = params.n_threads;
120 if (n_threads == 0) {
121 n_threads = std::max(1U, MultithreadInfo::n_cores() / n_mpi_processes);
122 }
123 MultithreadInfo::set_thread_limit(n_threads);
124
125 pcout << "Parallel Configuration:" << std::endl
126 << " MPI Processes: " << n_mpi_processes
127 << " (Current rank: " << this_mpi_process << ")" << std::endl
128 << " Threads per process: " << n_threads << std::endl
129 << " Total parallel units: " << n_threads * n_mpi_processes << std::endl;
130
131 // Compute and cache the covering radius and minimum target density
132 covering_radius = compute_covering_radius() * 1.1;
133 min_target_density = *std::min_element(target_measure.density.begin(), target_measure.density.end());
134
135 pcout << "Covering radius (R0): " << covering_radius << std::endl
136 << " (Maximum distance from any source cell center to the nearest target point)" << std::endl;
137
138 // Print log-sum-exp status if using small entropy
139 if (params.epsilon < 1e-2) {
140 pcout << "Small entropy detected (ε = " << params.epsilon << ")" << std::endl;
141 pcout << " Log-Sum-Exp trick: " << (params.use_log_sum_exp_trick ? "enabled" : "disabled") << std::endl;
142 if (!params.use_log_sum_exp_trick && params.epsilon < 1e-4) {
143 pcout << " \033[1;33mWARNING: Using very small entropy without Log-Sum-Exp trick may cause numerical instability\033[0m" << std::endl;
144 }
145 }
146
147 // Initialize potentials if needed
148 if (potentials.size() != target_measure.points.size()) {
149 potentials.reinit(target_measure.points.size());
150 }
151
152 // Initialize gradient member variable
153 gradient.reinit(potentials.size());
154
155 // Configure Solver Control based on selected tolerance type
156 bool use_componentwise = (params.solver_control_type == "componentwise");
157 double solver_ctrl_tolerance;
158 if (use_componentwise) {
159 solver_ctrl_tolerance = std::numeric_limits<double>::min();
160 } else {
161 solver_ctrl_tolerance = params.tolerance;
162 }
163
164 solver_control = std::make_unique<VerboseSolverControl>(
165 params.max_iterations,
166 solver_ctrl_tolerance,
167 use_componentwise,
168 pcout
169 );
170
171 if (!params.verbose_output) {
172 solver_control->log_history(false);
173 solver_control->log_result(false);
174 }
175
176 auto* verbose_control = dynamic_cast<VerboseSolverControl*>(solver_control.get());
177 AssertThrow(verbose_control != nullptr, ExcInternalError());
178
179 verbose_control->set_gradient(gradient);
180
181 if (use_componentwise) {
182 verbose_control->set_target_measure(target_measure.density, params.tolerance);
183 }
184
185 try {
186 Timer timer;
187 timer.start();
188
189 // Create and run BFGS solver
190 SolverBFGS<Vector<double>> solver(*solver_control);
191 current_potential = &potentials;
192
193 solver.solve(
194 [this](const Vector<double>& w, Vector<double>& grad) {
195 return this->evaluate_functional(w, grad);
196 },
197 potentials
198 );
199
200 timer.stop();
201
202 pcout << Color::green << Color::bold << "Optimization completed:" << std::endl
203 << " Time taken: " << timer.wall_time() << " seconds" << std::endl
204 << " Iterations: " << solver_control->last_step() << std::endl
205 << " Final function value: " << solver_control->last_value() << Color::reset << std::endl;
206
207 } catch (SolverControl::NoConvergence& exc) {
208 pcout << "Warning: Optimization did not converge" << std::endl
209 << " Iterations: " << exc.last_step << std::endl
210 << " Residual: " << exc.last_residual << std::endl;
211 throw;
212 }
213
214 // Reset solver state
215 current_potential = nullptr;
216}
217
218template <int dim, int spacedim>
220 const Vector<double>& potentials,
221 Vector<double>& gradient_out)
222{
223 // Store current potentials for use in local assembly
224 current_potential = &potentials;
225 current_epsilon = current_params.epsilon;
226
227 // Update distance threshold for target point search
228 compute_distance_threshold();
229
230 // Reset global accumulators
231 global_functional = 0.0;
232 gradient = 0; // Reset class member gradient
233 C_global = 0.0; // Reset C_global
234 double local_process_functional = 0.0;
235 Vector<double> local_process_gradient(target_measure.points.size());
236 double local_process_C_sum = 0.0; // Accumulator for C_sum on this MPI process
237
238 if (current_params.verbose_output) {
239 pcout << "Using distance threshold: " << current_distance_threshold
240 << " (Effective: " << effective_distance_threshold << ")" << std::endl;
241 }
242
243 try {
244 // Determine if we're using simplex elements
245 bool use_simplex = (dynamic_cast<const FE_SimplexP<dim>*>(&*source_measure.fe) != nullptr);
246
247 // Create appropriate quadrature rule
248 std::unique_ptr<Quadrature<dim>> quadrature;
249 if (use_simplex) {
250 quadrature = std::make_unique<QGaussSimplex<dim>>(source_measure.quadrature_order);
251 } else {
252 quadrature = std::make_unique<QGauss<dim>>(source_measure.quadrature_order);
253 }
254
255 // Create scratch and copy data
256 ScratchData scratch_data(*source_measure.fe,
257 *source_measure.mapping,
258 *quadrature);
259 CopyData copy_data(target_measure.points.size());
260
261 // Create filtered iterators for locally owned cells
262 FilteredIterator<typename DoFHandler<dim, spacedim>::active_cell_iterator>
263 begin_filtered(IteratorFilters::LocallyOwnedCell(),
264 source_measure.dof_handler->begin_active()),
265 end_filtered(IteratorFilters::LocallyOwnedCell(),
266 source_measure.dof_handler->end());
267
268 auto function_call = [this](
269 CopyData& copy,
270 const Point<spacedim> &x,
271 const std::vector<std::size_t> &cell_target_indices,
272 const std::vector<double> &exp_terms,
273 const std::vector<double> &target_densities,
274 const double &density_value,
275 const double &JxW,
276 const double &total_sum_exp,
277 const double &max_exponent,
278 const double &current_epsilon)
279 {
280
281 // Calculate functional value based on whether log-sum-exp is used
282 if (current_params.use_log_sum_exp_trick) {
283 copy.functional_value += density_value * current_epsilon *
284 (max_exponent + std::log(total_sum_exp)) * JxW;
285 } else {
286 copy.functional_value += density_value * current_epsilon *
287 std::log(total_sum_exp) * JxW;
288 }
289 const double scale = density_value * JxW / total_sum_exp;
290 copy.local_C_sum += scale; // Add scale to local_C_sum for this cell q-point
291 #pragma omp simd
292 for (size_t i = 0; i < cell_target_indices.size(); ++i) {
293 if (exp_terms[i] > 0.0) {
294 copy.gradient_values[cell_target_indices[i]] += scale * exp_terms[i];
295 }
296 }
297 };
298
299 // Parallel assembly using WorkStream
300 WorkStream::run(
301 begin_filtered,
302 end_filtered,
303 [this, &function_call](const typename DoFHandler<dim, spacedim>::active_cell_iterator& cell,
304 ScratchData& scratch,
305 CopyData& copy) {
306 this->local_assemble(cell, scratch, copy, function_call);
307 },
308 [this, &local_process_functional, &local_process_gradient, &local_process_C_sum](const CopyData& copy) {
309 local_process_functional += copy.functional_value;
310 local_process_gradient += copy.gradient_values;
311 local_process_C_sum += copy.local_C_sum; // Accumulate local_C_sum
312 },
313 scratch_data,
314 copy_data);
315
316 // Synchronize across MPI processes
317 global_functional = Utilities::MPI::sum(local_process_functional, mpi_communicator);
318 gradient = 0; // Reset gradient
319 Utilities::MPI::sum(local_process_gradient, mpi_communicator, gradient);
320 C_global = Utilities::MPI::sum(local_process_C_sum, mpi_communicator); // Sum C_global across processes
321
322 // Add linear term (only on root process to avoid duplication)
323 if (Utilities::MPI::this_mpi_process(mpi_communicator) == 0) {
324 for (unsigned int i = 0; i < target_measure.points.size(); ++i) {
325 global_functional -= potentials[i] * target_measure.density[i];
326 gradient[i] -= target_measure.density[i];
327 }
328 }
329
330 // Broadcast final results to all processes
331 global_functional = Utilities::MPI::broadcast(mpi_communicator, global_functional, 0);
332 gradient = Utilities::MPI::broadcast(mpi_communicator, gradient, 0);
333
334 // Copy result to output gradient
335 gradient_out = gradient;
336
337 if (current_params.verbose_output) {
338 pcout << "Functional evaluation completed:" << std::endl;
339 pcout << " Function value: " << global_functional << std::endl;
340 pcout << " C_global: " << C_global << std::endl;
341
342 // Calculate and print the geometric radius bound for comparison
343 if (current_potential != nullptr && current_potential->size() == target_measure.points.size()) {
344 double geom_radius_bound = compute_geometric_radius_bound(*current_potential, current_epsilon, current_params.tau);
345
346 // Calculate traditional pointwise bound for comparison
347 double max_pot = *std::max_element(current_potential->begin(), current_potential->end());
348 double min_tgt_density = min_target_density > 0.0 ?
349 min_target_density :
350 *std::min_element(target_measure.density.begin(), target_measure.density.end());
351 double sq_threshold = -2.0 * current_epsilon * std::log(current_params.tau/min_tgt_density) + 2.0 * max_pot;
352 double pointwise_bound = std::sqrt(std::max(0.0, sq_threshold));
353
354 // Calculate the new integral radius bound
355 double integral_radius_bound = compute_integral_radius_bound(
356 *current_potential,
357 current_epsilon,
358 current_params.tau,
359 C_global,
360 global_functional
361 );
362
363 pcout << " Current distance threshold: " << current_distance_threshold
364 << " (using: " << current_params.distance_threshold_type << ")" << std::endl
365 << " Pointwise bound (eps_machine=" << current_params.tau << "): " << pointwise_bound << std::endl
366 << " Integral bound (C_global=" << C_global << ", τ=" << current_params.tau << "): " << integral_radius_bound << std::endl
367 << " Geometric bound (τ=" << current_params.tau << "): " << geom_radius_bound
368 << " (ratio to pointwise: " << (pointwise_bound > 1e-9 ? geom_radius_bound/pointwise_bound : 0.0) << ")" << std::endl;
369 }
370 }
371
372 } catch (const std::exception& e) {
373 pcout << "Error in functional evaluation: " << e.what() << std::endl;
374 throw;
375 }
376
377 return global_functional;
378}
379
380template <int dim, int spacedim>
382{
383 if (current_potential == nullptr) {
384 current_distance_threshold = std::numeric_limits<double>::max();
385 effective_distance_threshold = std::numeric_limits<double>::max();
386 return;
387 }
388
389 // Choose distance threshold calculation method based on parameter
390 double computed_threshold = 0.0;
391 std::string used_method_for_log;
392
393 if (current_params.distance_threshold_type == "integral") {
394 computed_threshold = compute_integral_radius_bound(
395 *current_potential,
396 current_epsilon,
397 current_params.tau,
398 C_global,
399 global_functional
400 );
401 computed_threshold = std::max(computed_threshold, covering_radius);
402 used_method_for_log = "integral (C_global based)";
403 } else if (current_params.distance_threshold_type == "geometric") {
404 // Use geometric radius bound (integral approach)
405 computed_threshold = compute_geometric_radius_bound(*current_potential, current_epsilon, current_params.tau);
406 used_method_for_log = "geometric (covering radius based)";
407 } else { // Default to pointwise, or if type is explicitly "pointwise"
408 double max_potential = *std::max_element(current_potential->begin(), current_potential->end());
409 double current_min_target_density = min_target_density > 0.0 ?
410 min_target_density :
411 *std::min_element(target_measure.density.begin(), target_measure.density.end());
412
413 if (current_min_target_density <= 0 || current_params.tau <=0) {
414 computed_threshold = std::numeric_limits<double>::max();
415 } else {
416 double squared_threshold = -2.0 * current_epsilon *
417 std::log(current_params.tau/current_min_target_density) + 2.0 * max_potential;
418 computed_threshold = std::sqrt(std::max(0.0, squared_threshold));
419 }
420 used_method_for_log = "pointwise (tau based)";
421 }
422
423 if (current_params.verbose_output) {
424 pcout << "Computed distance threshold using " << used_method_for_log << ": " << computed_threshold << std::endl;
425 }
426
427 double new_proposed_effective_threshold = computed_threshold * 1.1;
428
429 current_distance_threshold = computed_threshold;
430}
431
432template <int dim, int spacedim>
434{
435 namespace bgi = boost::geometry::index;
436
437 if (!validate_measures()) {
438 throw std::runtime_error("Invalid measures configuration for computing covering radius");
439 }
440
441 double max_min_distance = 0.0;
442
443 // Iterate through all locally owned cells in source domain
444 FilteredIterator<typename DoFHandler<dim, spacedim>::active_cell_iterator>
445 begin_filtered(IteratorFilters::LocallyOwnedCell(),
446 source_measure.dof_handler->begin_active()),
447 end_filtered(IteratorFilters::LocallyOwnedCell(),
448 source_measure.dof_handler->end());
449
450 std::vector<double> local_min_distances;
451
452 // For each cell, find the minimum distance to any target point
453 for (auto cell = begin_filtered; cell != end_filtered; ++cell) {
454 const Point<spacedim>& cell_center = cell->center();
455
456 // Use rtree to find the nearest target point - fixed to match container type
457 std::vector<IndexedPoint> nearest_results;
458 target_measure.rtree.query(bgi::nearest(cell_center, 1), std::back_inserter(nearest_results));
459
460 if (!nearest_results.empty()) {
461 const Point<spacedim>& nearest_point = nearest_results[0].first;
462 const double distance = (nearest_point - cell_center).norm();
463 local_min_distances.push_back(distance);
464 }
465 }
466
467 // Find maximum of all minimum distances
468 if (!local_min_distances.empty()) {
469 max_min_distance = *std::max_element(local_min_distances.begin(), local_min_distances.end());
470 }
471
472 // Synchronize across MPI processes
473 return Utilities::MPI::max(max_min_distance, mpi_communicator);
474}
475
476template <int dim, int spacedim>
478 const Vector<double>& potentials,
479 double epsilon, // Regularization parameter (epsilon in formula)
480 double tolerance, // Tolerance for relative error (tau in formula)
481 double C_value, // Accumulated C_global (C in formula)
482 double current_functional_val) const // J_epsilon(psi)
483{
490 double max_potential = *std::max_element(potentials.begin(), potentials.end());
491
492 double abs_functional_val = std::abs(current_functional_val);
493 const double safety_value = 1e-10; // Consistent with geometric_radius_bound
494 if (abs_functional_val < safety_value) {
495 abs_functional_val = safety_value;
496 }
497
498 double log_numerator = epsilon * C_value;
499 double log_denominator = tolerance * abs_functional_val;
500
501 // log_numerator must be positive (epsilon > 0, C_value > 0 ensured by checks)
502 // log_denominator must be positive (tolerance > 0, abs_functional_val >= safety_value > 0 ensured by checks)
503 double log_argument = log_numerator / log_denominator;
504
505 double radius_squared = 2.0 * max_potential + 2.0 * epsilon * std::log(log_argument);
506
507 return std::sqrt(std::max(0.0, radius_squared));
508}
509
510template <int dim, int spacedim>
512 const Vector<double>& potentials,
513 const double epsilon,
514 const double tolerance) const
515{
516 if (!validate_measures() || potentials.size() != target_measure.points.size()) {
517 throw std::runtime_error("Invalid configuration for computing geometric radius bound");
518 }
519
520 // Use cached covering radius value instead of computing it again
521 // (if it's not already computed, use the method)
522 double r0 = covering_radius > 0.0 ? covering_radius : compute_covering_radius();
523
524 // Compute potential range Γ(ψ) = M - m
525 double min_potential = *std::min_element(potentials.begin(), potentials.end());
526 double max_potential = *std::max_element(potentials.begin(), potentials.end());
527 double potential_range = max_potential - min_potential;
528
529 // Use cached minimum target density
530 double min_density = min_target_density > 0.0 ?
531 min_target_density :
532 *std::min_element(target_measure.density.begin(), target_measure.density.end());
533
534 // Use current functional value or safety value if close to zero
535 double functional_value = std::abs(global_functional);
536 const double safety_value = 1e-10;
537 if (functional_value < safety_value) {
538 functional_value = safety_value;
539 }
540
541 // Calculate the radius bound R_geom according to the formula:
542 // R_geom^2 ≥ R_0^2 + 2Γ(ψ) + 2ε ln(ε/(ν_min * τ * |J_ε(ψ)|))
543 double log_term = std::log(epsilon / (min_density * tolerance * functional_value));
544 double radius_squared = r0 * r0 + 2.0 * potential_range + 2.0 * epsilon * log_term;
545
546 // Ensure radius is not less than covering radius
547 if (radius_squared < r0 * r0) {
548 radius_squared = r0 * r0;
549 }
550
551 return std::sqrt(radius_squared);
552}
553
554template <int dim, int spacedim>
556 const Point<spacedim>& query_point) const
557{
558 namespace bgi = boost::geometry::index;
559 std::vector<std::size_t> indices;
560
561 // Finds indices of target points in `target_measure.rtree` that are within the current distance threshold from the query point through filter provided by `bgi`.
562 for (const auto& indexed_point : target_measure.rtree |
563 bgi::adaptors::queried(bgi::satisfies([&](const IndexedPoint& p) {
564 return distance_function(p.first, query_point) <= current_distance_threshold;
565 })))
566 {
567 indices.push_back(indexed_point.second);
568 }
569
570 return indices;
571}
572
573template <int dim, int spacedim>
575{
576 return solver_control ? solver_control->last_step() : 0;
577}
578
579template <int dim, int spacedim>
581{
582 return solver_control && solver_control->last_check() == SolverControl::success;
583}
584
585template <int dim, int spacedim>
587 const Vector<double>& potentials,
588 std::vector<Point<spacedim>>& barycenters_out,
590{
591 if (!validate_measures()) {
592 throw std::runtime_error("Invalid measures configuration");
593 }
594
595 // Barycenter evaluation parameters
596 double solver_tolerance = params.tolerance;
597 unsigned int max_iterations = params.max_iterations;
598
599 bool use_componentwise = (params.solver_control_type == "componentwise");
600
601 try
602 {
603 unsigned int n_iter = 0;
604
605 Timer timer;
606 timer.start();
607 current_potential = &potentials;
608
609 if (distance_name == "euclidean") {
610 compute_weighted_barycenters_euclidean(
611 *current_potential, barycenters_out);
612
613 } else if (distance_name == "spherical")
614 {
615 for (n_iter = 0; n_iter < max_iterations; ++n_iter)
616 {
617 // evaluate barycenters_gradients
618 // Initialize barycenters_grads with the correct size
619 barycenters_grads = std::vector<Vector<double>>(target_measure.points.size(), Vector<double>(spacedim));
620
621 // Compute weighted barycenters using non-Euclidean distance
622 compute_weighted_barycenters_non_euclidean(
623 *current_potential, barycenters_grads, barycenters_out);
624
625 // evaluated inside `compute_weighted_barycenters_non_euclidean`
626 double l2_norm = barycenters_gradients.l2_norm();
627
628 if (l2_norm < solver_tolerance) {
629 pcout << "Iteration " << CYAN << n_iter + 1 << RESET
630 << " - L-2 gradient norm: " << Color::green << l2_norm << " < " << solver_tolerance << RESET << std::endl;
631 break;
632 } else {
633
634 pcout << "Iteration " << CYAN << n_iter + 1 << RESET
635 << " - L-2 gradient norm: " << Color::yellow << l2_norm << " > " << solver_tolerance << RESET << std::endl;
636
637 for (unsigned int i=0; i<target_measure.points.size();++i)
638 {
639 barycenters_out[i] = distance_function_exponential_map(barycenters_out[i], barycenters_grads[i]);
640 }
641 }
642 }
643 }
644
645 timer.stop();
646
647 pcout << Color::green << Color::bold << "Optimization completed:" << std::endl
648 << " Time taken: " << timer.wall_time() << " seconds" << std::endl
649 << " Iterations: " << n_iter+1 << std::endl
650 << " Distance type: " << distance_name << std::endl << Color::reset;
651
652 } catch (SolverControl::NoConvergence& exc) {
653 pcout << "Warning: Barycenters evaluation did not converge" << std::endl
654 << " Iterations: " << exc.last_step << std::endl
655 << " Residual: " << exc.last_residual << std::endl;
656 throw;
657 }
658
659 // Reset solver state
660 current_potential = nullptr;
661}
662
663template <int dim, int spacedim>
665 const Vector<double>& potentials,
666 std::vector<Vector<double>>& barycenters_gradients_out,
667 std::vector<Point<spacedim>>& barycenters_out
668)
669{
670 // Store current potentials for use in local assembly
671 current_potential = &potentials;
672 current_epsilon = current_params.epsilon;
673
674 barycenters_gradients.reinit(spacedim*target_measure.points.size());
675 Vector<double> local_process_barycenters(spacedim*target_measure.points.size());
676
677 // Update distance threshold for target point search
678 compute_distance_threshold();
679 if (current_params.verbose_output) {
680 pcout << "Using distance threshold: " << current_distance_threshold
681 << " (Effective: " << effective_distance_threshold << ")" << std::endl;
682 }
683
684 try {
685 // Determine if we're using simplex elements
686 bool use_simplex = (dynamic_cast<const FE_SimplexP<dim>*>(&*source_measure.fe) != nullptr);
687
688 // Create appropriate quadrature rule
689 std::unique_ptr<Quadrature<dim>> quadrature;
690 if (use_simplex) {
691 quadrature = std::make_unique<QGaussSimplex<dim>>(source_measure.quadrature_order);
692 } else {
693 quadrature = std::make_unique<QGauss<dim>>(source_measure.quadrature_order);
694 }
695
696 // Create scratch and copy data
697 ScratchData scratch_data(*source_measure.fe,
698 *source_measure.mapping,
699 *quadrature);
700 CopyData copy_data(target_measure.points.size());
701
702 // Create filtered iterators for locally owned cells
703 FilteredIterator<typename DoFHandler<dim, spacedim>::active_cell_iterator>
704 begin_filtered(IteratorFilters::LocallyOwnedCell(),
705 source_measure.dof_handler->begin_active()),
706 end_filtered(IteratorFilters::LocallyOwnedCell(),
707 source_measure.dof_handler->end());
708
709 // Function call
710 auto function_call = [this, &barycenters_out](
711 CopyData& copy,
712 const Point<spacedim> &x,
713 const std::vector<std::size_t> &cell_target_indices,
714 const std::vector<double> &exp_terms,
715 const std::vector<double> &target_densities,
716 const double &density_value,
717 const double &JxW,
718 const double &total_sum_exp,
719 const double &max_exponent,
720 const double &current_epsilon)
721 {
722 const double scale = density_value * JxW / total_sum_exp;
723
724 #pragma omp simd
725 for (size_t i = 0; i < exp_terms.size(); ++i) {
726 if (exp_terms[i] > 0.0) {
727 auto v = distance_function_gradient(barycenters_out[cell_target_indices[i]], x);
728 for (unsigned int d = 0; d < spacedim; ++d)
729 copy.barycenters_values[spacedim*cell_target_indices[i] + d] += scale * (exp_terms[i]) * v[d];
730 }
731 }
732 };
733
734 // Parallel assembly using WorkStream
735 WorkStream::run(
736 begin_filtered,
737 end_filtered,
738 [this, &function_call](const typename DoFHandler<dim, spacedim>::active_cell_iterator& cell,
739 ScratchData& scratch,
740 CopyData& copy) {
741 this->local_assemble(
742 cell, scratch, copy, function_call);
743 },
744 [this, &local_process_barycenters](const CopyData& copy) {
745 local_process_barycenters += copy.barycenters_values;
746 },
747 scratch_data,
748 copy_data);
749
750 // Synchronize across MPI processes
751 Utilities::MPI::sum(local_process_barycenters, mpi_communicator, barycenters_gradients);
752
753 // Copy result to output barycenters_gradients TODO why do I need this?
754 // Resize output vector and fill with barycenters_gradients data
755 barycenters_gradients_out.resize(target_measure.points.size());
756 for (unsigned int i = 0; i < target_measure.points.size(); ++i) {
757 for (unsigned int d = 0; d < spacedim; ++d) {
758 // - for gradient descent
759 barycenters_gradients_out[i][d] = -barycenters_gradients[spacedim * i + d];
760 }
761 }
762
763 } catch (const std::exception& e) {
764 pcout << "Error in functional evaluation: " << e.what() << std::endl;
765 throw;
766 }
767}
768
769template <int dim, int spacedim>
771 const typename DoFHandler<dim, spacedim>::active_cell_iterator& cell,
772 ScratchData& scratch,
773 CopyData& copy,
774 std::function<void(CopyData&,
775 const Point<spacedim>&,
776 const std::vector<std::size_t>&,
777 const std::vector<double>&,
778 const std::vector<double>&,
779 const double&,
780 const double&,
781 const double&,
782 const double&,
783 const double&)> function_call)
784{
785 if (!cell->is_locally_owned())
786 return;
787
788 scratch.fe_values.reinit(cell);
789 const std::vector<Point<spacedim>>& q_points = scratch.fe_values.get_quadrature_points();
790 scratch.fe_values.get_function_values(*source_measure.density, scratch.density_values);
791
792 copy.barycenters_values = 0;
793 copy.functional_value = 0.0;
794 copy.gradient_values = 0;
795 copy.local_C_sum = 0.0; // Accumulator for C_sum on this cell
796
797 const unsigned int n_q_points = q_points.size();
798 const double epsilon_inv = 1.0 / current_epsilon;
799 const bool use_log_sum_exp = current_params.use_log_sum_exp_trick;
800
801 std::vector<std::size_t> cell_target_indices = find_nearest_target_points(cell->center());
802 if (cell_target_indices.empty()) return;
803
804 const unsigned int n_target_points = cell_target_indices.size();
805
806 std::vector<Point<spacedim>> target_positions(n_target_points);
807 std::vector<double> target_densities(n_target_points);
808 std::vector<double> potential_values(n_target_points);
809
810 for (size_t i = 0; i < n_target_points; ++i) {
811 const size_t idx = cell_target_indices[i];
812 target_positions[i] = target_measure.points[idx];
813 target_densities[i] = target_measure.density[idx];
814 potential_values[i] = (*current_potential)[idx];
815 }
816
817 for (unsigned int q = 0; q < n_q_points; ++q) {
818 const Point<spacedim>& x = q_points[q];
819 const double density_value = scratch.density_values[q];
820 const double JxW = scratch.fe_values.JxW(q);
821
822 double total_sum_exp = 0.0;
823 double max_exponent = -std::numeric_limits<double>::max();
824 std::vector<double> exp_terms(n_target_points);
825
826 if (use_log_sum_exp) {
827 // First pass: find maximum exponent
828 #pragma omp simd reduction(max:max_exponent)
829 for (size_t i = 0; i < n_target_points; ++i) {
830 const double local_dist2 = std::pow(distance_function(x, target_positions[i]), 2);
831 const double exponent = (potential_values[i] - 0.5 * local_dist2) * epsilon_inv;
832 max_exponent = std::max(max_exponent, exponent);
833 }
834
835 // Second pass: compute shifted exponentials
836 #pragma omp simd reduction(+:total_sum_exp)
837 for (size_t i = 0; i < n_target_points; ++i) {
838 const double local_dist2 = std::pow(distance_function(x, target_positions[i]), 2);
839 const double shifted_exp = std::exp((potential_values[i] - 0.5 * local_dist2) * epsilon_inv - max_exponent);
840 exp_terms[i] = target_densities[i] * shifted_exp;
841 total_sum_exp += exp_terms[i];
842 }
843 } else {
844 // Original computation method
845 #pragma omp simd reduction(+:total_sum_exp)
846 for (size_t i = 0; i < n_target_points; ++i) {
847 const double local_dist2 = std::pow(distance_function(x, target_positions[i]), 2);
848 exp_terms[i] = target_densities[i] *
849 std::exp((potential_values[i] - 0.5 * local_dist2) * epsilon_inv);
850 total_sum_exp += exp_terms[i];
851 }
852 }
853
854 if (total_sum_exp <= 0.0) continue;
855
856 function_call(
857 copy,
858 x,
859 cell_target_indices,
860 exp_terms,
861 target_densities,
862 density_value,
863 JxW,
864 total_sum_exp,
865 max_exponent,
866 current_epsilon
867 );
868 }
869}
870
871template <int dim, int spacedim>
873 const Vector<double>& potentials,
874 std::vector<Point<spacedim>>& barycenters_out)
875{
876 // Store current potentials for use in local assembly
877 current_potential = &potentials;
878 current_epsilon = current_params.epsilon;
879
880 barycenters.reinit(spacedim*target_measure.points.size());
881 Vector<double> local_process_barycenters(spacedim*target_measure.points.size());
882
883 // Update distance threshold for target point search
884 compute_distance_threshold();
885 if (current_params.verbose_output) {
886 pcout << "Using distance threshold: " << current_distance_threshold
887 << " (Effective: " << effective_distance_threshold << ")" << std::endl;
888 }
889
890 try {
891 // Determine if we're using simplex elements
892 bool use_simplex = (dynamic_cast<const FE_SimplexP<dim>*>(&*source_measure.fe) != nullptr);
893
894 // Create appropriate quadrature rule
895 std::unique_ptr<Quadrature<dim>> quadrature;
896 if (use_simplex) {
897 quadrature = std::make_unique<QGaussSimplex<dim>>(source_measure.quadrature_order);
898 } else {
899 quadrature = std::make_unique<QGauss<dim>>(source_measure.quadrature_order);
900 }
901
902 // Create scratch and copy data
903 ScratchData scratch_data(*source_measure.fe,
904 *source_measure.mapping,
905 *quadrature);
906 CopyData copy_data(target_measure.points.size());
907
908 // Create filtered iterators for locally owned cells
909 FilteredIterator<typename DoFHandler<dim, spacedim>::active_cell_iterator>
910 begin_filtered(IteratorFilters::LocallyOwnedCell(),
911 source_measure.dof_handler->begin_active()),
912 end_filtered(IteratorFilters::LocallyOwnedCell(),
913 source_measure.dof_handler->end());
914
915 // Function call
916 auto function_call = [this](
917 CopyData& copy,
918 const Point<spacedim> &x,
919 const std::vector<std::size_t> &cell_target_indices,
920 const std::vector<double> &exp_terms,
921 const std::vector<double> &target_densities,
922 const double &density_value,
923 const double &JxW,
924 const double &total_sum_exp,
925 const double &max_exponent,
926 const double &current_epsilon)
927 {
928 const double scale = density_value * JxW / total_sum_exp;
929
930 #pragma omp simd
931 for (size_t i = 0; i < exp_terms.size(); ++i) {
932 if (exp_terms[i] > 0.0) {
933 for (unsigned int d = 0; d < spacedim; ++d) {
934 copy.barycenters_values[spacedim*cell_target_indices[i] + d] += scale * (exp_terms[i]) * x[d];
935 }
936 }
937 }
938 };
939
940 // Parallel assembly using WorkStream
941 WorkStream::run(
942 begin_filtered,
943 end_filtered,
944 [this, &function_call](const typename DoFHandler<dim, spacedim>::active_cell_iterator& cell,
945 ScratchData& scratch,
946 CopyData& copy) {
947 this->local_assemble(cell, scratch, copy, function_call);
948 },
949 [this, &local_process_barycenters](const CopyData& copy) {
950 local_process_barycenters += copy.barycenters_values;
951 },
952 scratch_data,
953 copy_data);
954
955 // Synchronize across MPI processes
956 Utilities::MPI::sum(local_process_barycenters, mpi_communicator, barycenters);
957
958 // Copy result to output barycenters TODO why do I need this?
959 // Resize output vector and fill with barycenters data
960 barycenters_out.resize(target_measure.points.size());
961 for (unsigned int i = 0; i < target_measure.points.size(); ++i) {
962 for (unsigned int d = 0; d < spacedim; ++d) {
963 barycenters_out[i][d] = barycenters[spacedim * i + d]/target_measure.density[i];
964 }
965 }
966
967 } catch (const std::exception& e) {
968 pcout << "Error in functional evaluation: " << e.what() << std::endl;
969 throw;
970 }
971}
972
973template <int dim, int spacedim>
975 const DoFHandler<dim, spacedim> &dof_handler,
976 const Mapping<dim, spacedim> &mapping,
977 const Vector<double> &potential,
978 const std::vector<unsigned int> &potential_indices,
979 std::vector<LinearAlgebra::distributed::Vector<double, MemorySpace::Host>> &conditioned_densities)
980{
981 std::cout << "Current epsilon: " << current_epsilon << std::endl;
982 auto locally_owned_dofs = dof_handler.locally_owned_dofs();
983 conditioned_densities.resize(potential_indices.size());
984
985 std::map<types::global_dof_index, Point<spacedim>> sp;
986 DoFTools::map_dofs_to_support_points(
987 mapping, dof_handler, sp);
988
989 for (unsigned int idensity = 0; idensity < conditioned_densities.size(); ++idensity)
990 conditioned_densities[idensity].reinit(locally_owned_dofs, mpi_communicator);
991
992 double epsilon_inv = 1.0 / current_epsilon;
993
994 for (auto idx: locally_owned_dofs)
995 {
996 std::vector<std::size_t> cell_target_indices = find_nearest_target_points(sp[idx]);
997
998 std::vector<double> exp(potential.size(), 0.0);
999 double total_sum_exp = 0;
1000 double max_exponent = -std::numeric_limits<double>::max();
1001
1002 #pragma omp simd reduction(max:max_exponent)
1003 for (unsigned int i = 0; i < cell_target_indices.size(); ++i)
1004 {
1005 const size_t tidx = cell_target_indices[i];
1006 const double exponent = (potential[tidx]-0.5*
1007 std::pow(
1008 distance_function(
1009 sp[idx],
1010 target_measure.points[tidx]), 2))*epsilon_inv;
1011 max_exponent = std::max(max_exponent, exponent);
1012 }
1013
1014 #pragma omp simd reduction(+:total_sum_exp)
1015 for (unsigned int i = 0; i < cell_target_indices.size(); ++i)
1016 {
1017 const size_t tidx = cell_target_indices[i];
1018 exp[tidx] = std::exp((potential[tidx]-0.5*
1019 std::pow(
1020 distance_function(
1021 sp[idx],
1022 target_measure.points[tidx]), 2))*epsilon_inv-max_exponent);
1023 total_sum_exp += target_measure.density[tidx] * exp[tidx];
1024 }
1025
1026 if (total_sum_exp > 0.0)
1027 {
1028 for (unsigned int idensity = 0; idensity < potential_indices.size(); ++idensity)
1029 {
1030 bool index_found = std::find(cell_target_indices.begin(),
1031 cell_target_indices.end(),
1032 potential_indices[idensity]) != cell_target_indices.end();
1033
1034 if (index_found)
1035 conditioned_densities[idensity][idx] = (*source_measure.density)[idx] * (exp[potential_indices[idensity]]/total_sum_exp);
1036
1037 }
1038 }
1039 }
1040
1041 for (unsigned int idensity = 0; idensity < conditioned_densities.size(); ++idensity)
1042 {
1043 conditioned_densities[idensity].compress(VectorOperation::insert);
1044 }
1045}
1046
1047
1048// Explicit instantiation
1049template class SotSolver<2>;
1050template class SotSolver<3>;
1051template class SotSolver<2, 3>;
#define RESET
#define CYAN
A verbose solver control class that prints the progress of the solver.
Definition SotSolver.h:355
A solver for semi-discrete optimal transport problems.
Definition SotSolver.h:55
unsigned int get_last_iteration_count() const
Returns the number of iterations of the last solve.
Definition SotSolver.cc:574
double compute_geometric_radius_bound(const Vector< double > &potentials, const double epsilon, const double tolerance) const
Computes the geometric radius bound for truncating quadrature rules.
Definition SotSolver.cc:511
bool validate_measures() const
Definition SotSolver.cc:73
void local_assemble(const typename DoFHandler< dim, spacedim >::active_cell_iterator &cell, ScratchData &scratch, CopyData &copy, std::function< void(CopyData &, const Point< spacedim > &, const std::vector< std::size_t > &, const std::vector< double > &, const std::vector< double > &, const double &, const double &, const double &, const double &, const double &)> function_call)
Definition SotSolver.cc:770
std::function< double(const Point< spacedim > &, const Point< spacedim > &)> distance_function
The distance function.
Definition SotSolver.h:324
std::pair< Point< spacedim >, std::size_t > IndexedPoint
Definition SotSolver.h:59
void get_potential_conditioned_density(const DoFHandler< dim, spacedim > &dof_handler, const Mapping< dim, spacedim > &mapping, const Vector< double > &potential, const std::vector< unsigned int > &potential_indices, std::vector< LinearAlgebra::distributed::Vector< double, MemorySpace::Host > > &conditioned_densities)
Computes the conditional density of the source measure given a potential.
Definition SotSolver.cc:974
void set_distance_function(const std::string &distance_name)
Sets the distance function to be used by the solver.
Definition SotSolver.cc:23
void setup_target(const std::vector< Point< spacedim > > &target_points, const Vector< double > &target_density)
Sets up the target measure for the solver.
Definition SotSolver.cc:56
void compute_weighted_barycenters_non_euclidean(const Vector< double > &potentials, std::vector< Vector< double > > &barycenters_gradients_out, std::vector< Point< spacedim > > &barycenters_out)
Definition SotSolver.cc:664
void solve(Vector< double > &potential, const SotParameterManager::SolverParameters &params)
Solves the optimal transport problem.
Definition SotSolver.cc:106
double compute_covering_radius() const
Computes the covering radius of the target measure with respect to the source domain.
Definition SotSolver.cc:433
void compute_weighted_barycenters_euclidean(const Vector< double > &potentials, std::vector< Point< spacedim > > &barycenters_out)
Definition SotSolver.cc:872
std::vector< std::size_t > find_nearest_target_points(const Point< spacedim > &query_point) const
Definition SotSolver.cc:555
SotSolver(const MPI_Comm &comm)
Constructor for the SotSolver.
Definition SotSolver.cc:4
void setup_source(const DoFHandler< dim, spacedim > &dof_handler, const Mapping< dim, spacedim > &mapping, const FiniteElement< dim, spacedim > &fe, const LinearAlgebra::distributed::Vector< double, MemorySpace::Host > &source_density, const unsigned int quadrature_order)
Sets up the source measure for the solver.
Definition SotSolver.cc:45
void evaluate_weighted_barycenters(const Vector< double > &potentials, std::vector< Point< spacedim > > &barycenters_out, const SotParameterManager::SolverParameters &params)
Evaluates the weighted barycenters of the power cells.
Definition SotSolver.cc:586
void compute_distance_threshold() const
Definition SotSolver.cc:381
bool get_convergence_status() const
Returns the convergence status of the last solve.
Definition SotSolver.cc:580
double evaluate_functional(const Vector< double > &potential, Vector< double > &gradient_out)
Evaluates the dual functional and its gradient.
Definition SotSolver.cc:219
void configure(const SotParameterManager::SolverParameters &params)
Configures the solver with the given parameters.
Definition SotSolver.cc:64
double compute_integral_radius_bound(const Vector< double > &potentials, double epsilon, double tolerance, double C_value, double current_functional_val) const
Definition SotSolver.cc:477
const std::string reset
const std::string yellow
const std::string bold
const std::string green
double tolerance
Convergence tolerance.
bool use_log_sum_exp_trick
Enable log-sum-exp trick for numerical stability with small entropy.
unsigned int n_threads
Number of threads (0 = auto)
std::string solver_control_type
Type of solver control to use (l1norm/componentwise)
bool verbose_output
Enable detailed solver output.
unsigned int max_iterations
Maximum number of solver iterations.
double epsilon
Entropy regularization parameter.
A struct to hold copy data for parallel assembly.
Definition SotSolver.h:154
double functional_value
The value of the functional on the current cell.
Definition SotSolver.h:155
Vector< double > barycenters_values
The barycenter values for the current cell.
Definition SotSolver.h:160
Vector< double > gradient_values
The local contribution to the gradient.
Definition SotSolver.h:156
double local_C_sum
The sum of the scale terms for this cell.
Definition SotSolver.h:158
A struct to hold scratch data for parallel assembly.
Definition SotSolver.h:132
FEValues< dim, spacedim > fe_values
FEValues object for the current cell.
Definition SotSolver.h:147
std::vector< double > density_values
The density values at the quadrature points of the current cell.
Definition SotSolver.h:148
A struct to hold all the necessary information about the source measure.
Definition SotSolver.h:66
A struct to hold all the necessary information about the target measure.
Definition SotSolver.h:95