SemiDiscreteOT 1.0
Semi-Discrete Optimal Transport Library
Loading...
Searching...
No Matches
SotSolver.h
Go to the documentation of this file.
1#ifndef SOT_SOLVER_H
2#define SOT_SOLVER_H
3
4#include <boost/geometry.hpp>
5#include <boost/geometry/index/rtree.hpp>
6#include <memory>
7#include <map>
8#include <mutex>
9#include <atomic>
10
11#include <deal.II/base/conditional_ostream.h>
12#include <deal.II/lac/vector.h>
13#include <deal.II/lac/la_parallel_vector.h>
14#include <deal.II/lac/solver_control.h>
15#include <deal.II/lac/lapack_full_matrix.h>
16#include <deal.II/base/point.h>
17#include <deal.II/fe/fe_values.h>
18#include <deal.II/dofs/dof_handler.h>
19#include <deal.II/base/work_stream.h>
20#include <deal.II/base/exceptions.h>
21#include <deal.II/dofs/dof_tools.h>
22#include <deal.II/grid/filtered_iterator.h>
23#include <deal.II/numerics/rtree.h>
24#include <deal.II/base/mpi.h>
25#include <deal.II/base/utilities.h>
26#include <deal.II/base/multithread_info.h>
27#include <deal.II/optimization/solver_bfgs.h>
28#include <deal.II/base/utilities.h>
29#include <deal.II/base/multithread_info.h>
30#include <deal.II/base/timer.h>
31#include <deal.II/fe/fe_values.h>
32#include <deal.II/base/work_stream.h>
33#include <deal.II/dofs/dof_tools.h>
34#include <deal.II/fe/fe_simplex_p.h>
35#include <deal.II/fe/fe_q.h>
36#include <deal.II/fe/mapping_fe.h>
37#include <deal.II/fe/mapping_q1.h>
38#include <deal.II/lac/generic_linear_algebra.h>
39
43
44using namespace dealii;
45
54template <int dim, int spacedim=dim>
55class SotSolver {
56public:
57
58 // Type definitions for RTree
59 using IndexedPoint = std::pair<Point<spacedim>, std::size_t>;
60 using RTreeParams = boost::geometry::index::rstar<8>;
61 using RTree = boost::geometry::index::rtree<IndexedPoint, RTreeParams>;
62
67 bool initialized = false;
68 SmartPointer<const DoFHandler<dim, spacedim>> dof_handler;
69 SmartPointer<const Mapping<dim, spacedim>> mapping;
70 SmartPointer<const FiniteElement<dim, spacedim>> fe;
71 SmartPointer<const LinearAlgebra::distributed::Vector<double, MemorySpace::Host>> density;
72 unsigned int quadrature_order;
73
74 SourceMeasure() = default;
78 SourceMeasure(const DoFHandler<dim, spacedim>& dof_handler_,
79 const Mapping<dim, spacedim>& mapping_,
80 const FiniteElement<dim, spacedim>& fe_,
81 const LinearAlgebra::distributed::Vector<double, MemorySpace::Host>& density_,
82 const unsigned int quadrature_order_)
83 : initialized(true)
84 , dof_handler(&dof_handler_)
85 , mapping(&mapping_)
86 , fe(&fe_)
87 , density(&density_)
88 , quadrature_order(quadrature_order_)
89 {}
90 };
91
96 bool initialized = false;
97 std::vector<Point<spacedim>> points;
100
101 TargetMeasure() = default;
108 , points(points_)
109 , density(density_)
110 {
111 AssertThrow(points.size() == density.size(),
112 ExcDimensionMismatch(points.size(), density.size()));
114 }
115
120 std::vector<IndexedPoint> indexed_points;
121 indexed_points.reserve(points.size());
122 for (std::size_t i = 0; i < points.size(); ++i) {
123 indexed_points.emplace_back(points[i], i);
124 }
125 rtree = RTree(indexed_points.begin(), indexed_points.end());
126 }
127 };
128
132 struct ScratchData {
133 ScratchData(const FiniteElement<dim, spacedim>& fe,
134 const Mapping<dim, spacedim>& mapping,
135 const Quadrature<dim>& quadrature)
136 : fe_values(mapping, fe, quadrature,
137 update_values | update_quadrature_points | update_JxW_values)
138 , density_values(quadrature.size()) {}
139
141 : fe_values(other.fe_values.get_mapping(),
142 other.fe_values.get_fe(),
143 other.fe_values.get_quadrature(),
144 update_values | update_quadrature_points | update_JxW_values)
146
147 FEValues<dim, spacedim> fe_values;
148 std::vector<double> density_values;
149 };
150
154 struct CopyData {
155 double functional_value{0.0};
156 Vector<double> gradient_values;
157 Vector<double> potential_values;
158 double local_C_sum = 0.0;
159
160 Vector<double> barycenters_values;
161
162 CopyData(const unsigned int n_target_points)
163 : gradient_values(n_target_points),
164 potential_values(n_target_points),
165 barycenters_values(spacedim*n_target_points)
166 {
167 gradient_values = 0; // Initialize local gradient to zero
168 barycenters_values = 0; // Initialize local barycenters to zero
169 }
170 };
171
176 SotSolver(const MPI_Comm& comm);
177
186 void setup_source(const DoFHandler<dim, spacedim>& dof_handler,
187 const Mapping<dim, spacedim>& mapping,
188 const FiniteElement<dim, spacedim>& fe,
189 const LinearAlgebra::distributed::Vector<double, MemorySpace::Host>& source_density,
190 const unsigned int quadrature_order);
191
197 void setup_target(const std::vector<Point<spacedim>>& target_points,
198 const Vector<double>& target_density);
199
205
211 void solve(Vector<double>& potential,
213
221 void solve(Vector<double>& potential,
222 const SourceMeasure& source,
223 const TargetMeasure& target,
225
233 const Vector<double>& potentials,
234 std::vector<Point<spacedim>>& barycenters_out,
236
244 unsigned int get_last_iteration_count() const;
248 bool get_convergence_status() const;
256 double get_C_global() const { return C_global; }
257
262 void set_distance_threshold(double threshold) { current_distance_threshold = threshold; }
263
275 double compute_covering_radius() const;
276
297 const Vector<double>& potentials,
298 const double epsilon,
299 const double tolerance) const;
300
305 void set_distance_function(const std::string &distance_name);
306
316 const DoFHandler<dim, spacedim> &dof_handler,
317 const Mapping<dim, spacedim> &mapping,
318 const Vector<double> &potential,
319 const std::vector<unsigned int> &potential_indices,
320 std::vector<LinearAlgebra::distributed::Vector<double, MemorySpace::Host>> &conditioned_densities);
321
322 // Distance function
323 std::string distance_name;
324 std::function<double(const Point<spacedim>&, const Point<spacedim>&)> distance_function;
325 std::function<Vector<double>(const Point<spacedim>&, const Point<spacedim>&)> distance_function_gradient;
326 std::function<Point<spacedim>(const Point<spacedim>&, const Vector<double>&)> distance_function_exponential_map;
327
334 double evaluate_functional(const Vector<double>& potential,
335 Vector<double>& gradient_out);
336
342 void compute_hessian(const Vector<double>& potential,
343 LAPACKFullMatrix<double>& hessian_out);
344
345 // Source and target measures
348
349private:
350
354 class VerboseSolverControl : public SolverControl
355 {
356 public:
357 VerboseSolverControl(unsigned int n, double tol, bool use_componentwise, ConditionalOStream& pcout_)
358 : SolverControl(n, tol)
359 , pcout(pcout_)
360 , use_componentwise_check(use_componentwise)
361 , gradient(nullptr)
362 , target_measure(nullptr)
364 {}
365
366 void set_gradient(const Vector<double>& grad) {
367 gradient = &grad;
368 }
369
370 void set_target_measure(const Vector<double>& target_density, double user_tolerance) {
371 AssertThrow(use_componentwise_check, ExcMessage("Target measure only needed for component-wise check"));
372 target_measure = &target_density;
373 user_tolerance_for_componentwise = user_tolerance;
374 }
375
376
377 virtual State check(unsigned int step, double value) override
378 {
379 AssertThrow(gradient != nullptr,
380 ExcMessage("Gradient vector not set in VerboseSolverControl"));
381
382 double check_value = 0.0;
383 std::string check_description;
384 std::string color;
385
387 {
388 AssertThrow(target_measure != nullptr,
389 ExcMessage("Target measure not set for component-wise check"));
390 AssertThrow(gradient->size() == target_measure->size(),
391 ExcDimensionMismatch(gradient->size(), target_measure->size()));
392
393 double max_scaled_residual = -std::numeric_limits<double>::infinity();
394 for (unsigned int j = 0; j < gradient->size(); ++j) {
395 double scaled_residual = std::abs((*gradient)[j]) - (*target_measure)[j] * user_tolerance_for_componentwise;
396 max_scaled_residual = std::max(max_scaled_residual, scaled_residual);
397 }
398 check_value = max_scaled_residual;
399
400 if (check_value < tolerance()) {
401 color = Color::green;
402 } else if (check_value < 0) {
403 color = Color::yellow;
404 } else {
405 color = Color::red;
406 }
407
408 check_description = "Max Scaled Residual (max |g_i| - T_i*tol): ";
409 pcout << "Iteration " << CYAN << step << RESET
410 << " - L-2 gradient norm: " << color << value << RESET // value is L2 norm from BFGS
411 << " - " << check_description << color << check_value << RESET << std::endl;
412
413 }
414 else // Use L1-norm check
415 {
416 check_value = gradient->l1_norm();
417
418 double rel_residual = (step == 0 || initial_l1_norm == 0.0) ?
419 check_value : check_value / initial_l1_norm;
420
421 if (step == 0)
422 initial_l1_norm = check_value;
423
424 if (check_value < tolerance()) {
425 color = Color::green;
426 } else if (rel_residual < 0.5) {
427 color = Color::yellow;
428 } else {
429 color = Color::red;
430 }
431
432 check_description = "L-1 gradient norm: ";
433 pcout << "Iteration " << CYAN << step << RESET
434 << " - L-2 gradient norm: " << color << value << RESET
435 << " - " << check_description << color << check_value << RESET
436 << " - Relative L-1 residual: " << color << rel_residual << RESET << std::endl;
437 }
438
439 last_check_value = check_value;
440 return SolverControl::check(step, check_value);
441 }
442
443 double get_last_check_value() const { return last_check_value; }
444
445 private:
446 ConditionalOStream& pcout;
448 double initial_l1_norm = 1.0;
449 const Vector<double>* gradient;
450 const Vector<double>* target_measure;
452 double last_check_value = 0.0;
453 };
454
455 // Local assembly methods
456 void local_assemble(
457 const typename DoFHandler<dim, spacedim>::active_cell_iterator& cell,
458 ScratchData& scratch,
459 CopyData& copy,
460 std::function<void(CopyData&,
461 const Point<spacedim>&,
462 const std::vector<std::size_t>&,
463 const std::vector<double>&,
464 const std::vector<double>&,
465 const double&,
466 const double&,
467 const double&,
468 const double&,
469 const double&)> function_call);
470
471 // Distance threshold and caching methods
472 void compute_distance_threshold() const;
473 std::vector<std::size_t> find_nearest_target_points(const Point<spacedim>& query_point) const;
475 const Vector<double>& potentials,
476 double epsilon,
477 double tolerance,
478 double C_value,
479 double current_functional_val) const;
480
481 // Validation methods
482 bool validate_measures() const;
483
484 // Barycenters computation methods
486 const Vector<double>& potentials,
487 std::vector<Vector<double>>& barycenters_gradients_out,
488 std::vector<Point<spacedim>>& barycenters_out
489 );
491 const Vector<double>& potentials,
492 std::vector<Point<spacedim>>& barycenters_out);
493
494 // MPI and parallel related members
496 const unsigned int n_mpi_processes;
497 const unsigned int this_mpi_process;
498 ConditionalOStream pcout;
499
500 // Solver state
501 std::unique_ptr<SolverControl> solver_control;
504 const Vector<double>* current_potential;
506 mutable double global_functional;
507 Vector<double> gradient;
510 double C_global = 0.0; // Sum of all scale terms
511
512 // Current solver parameters
514
515 // weighted truncated barycenters evaluation
516
517 // Barycenters evaluation data
518 Vector<double> barycenters;
519 Vector<double> barycenters_gradients;
520
521 // Barycenters points and gradients
522 std::vector<Point<spacedim>> barycenters_points;
523 std::vector<Vector<double>> barycenters_grads;
524};
525
526#endif // SOT_SOLVER_H
A collection of distance functions, their gradients, and exponential maps.
#define RESET
#define CYAN
A verbose solver control class that prints the progress of the solver.
Definition SotSolver.h:355
void set_target_measure(const Vector< double > &target_density, double user_tolerance)
Definition SotSolver.h:370
const Vector< double > * target_measure
Definition SotSolver.h:450
ConditionalOStream & pcout
Definition SotSolver.h:446
VerboseSolverControl(unsigned int n, double tol, bool use_componentwise, ConditionalOStream &pcout_)
Definition SotSolver.h:357
void set_gradient(const Vector< double > &grad)
Definition SotSolver.h:366
const Vector< double > * gradient
Definition SotSolver.h:449
double get_last_check_value() const
Definition SotSolver.h:443
virtual State check(unsigned int step, double value) override
Definition SotSolver.h:377
A solver for semi-discrete optimal transport problems.
Definition SotSolver.h:55
Vector< double > gradient
Definition SotSolver.h:507
Vector< double > barycenters
Definition SotSolver.h:518
unsigned int get_last_iteration_count() const
Returns the number of iterations of the last solve.
Definition SotSolver.cc:574
const unsigned int this_mpi_process
Definition SotSolver.h:497
double get_last_functional_value() const
Returns the value of the functional at the last iteration.
Definition SotSolver.h:240
std::function< Vector< double >(const Point< spacedim > &, const Point< spacedim > &)> distance_function_gradient
The gradient of the distance function.
Definition SotSolver.h:325
Vector< double > barycenters_gradients
Definition SotSolver.h:519
boost::geometry::index::rstar< 8 > RTreeParams
Definition SotSolver.h:60
ConditionalOStream pcout
Definition SotSolver.h:498
std::vector< Point< spacedim > > barycenters_points
Definition SotSolver.h:522
SourceMeasure source_measure
The source measure.
Definition SotSolver.h:346
double compute_geometric_radius_bound(const Vector< double > &potentials, const double epsilon, const double tolerance) const
Computes the geometric radius bound for truncating quadrature rules.
Definition SotSolver.cc:511
void set_distance_threshold(double threshold)
Sets the distance threshold for the solver.
Definition SotSolver.h:262
double covering_radius
Definition SotSolver.h:508
bool validate_measures() const
Definition SotSolver.cc:73
const Vector< double > * current_potential
Definition SotSolver.h:504
void local_assemble(const typename DoFHandler< dim, spacedim >::active_cell_iterator &cell, ScratchData &scratch, CopyData &copy, std::function< void(CopyData &, const Point< spacedim > &, const std::vector< std::size_t > &, const std::vector< double > &, const std::vector< double > &, const double &, const double &, const double &, const double &, const double &)> function_call)
Definition SotSolver.cc:770
std::function< double(const Point< spacedim > &, const Point< spacedim > &)> distance_function
The distance function.
Definition SotSolver.h:324
void compute_hessian(const Vector< double > &potential, LAPACKFullMatrix< double > &hessian_out)
Computes the Hessian matrix of the dual functional.
std::pair< Point< spacedim >, std::size_t > IndexedPoint
Definition SotSolver.h:59
std::unique_ptr< SolverControl > solver_control
Definition SotSolver.h:501
double get_last_distance_threshold() const
Returns the distance threshold used in the last solve.
Definition SotSolver.h:252
void get_potential_conditioned_density(const DoFHandler< dim, spacedim > &dof_handler, const Mapping< dim, spacedim > &mapping, const Vector< double > &potential, const std::vector< unsigned int > &potential_indices, std::vector< LinearAlgebra::distributed::Vector< double, MemorySpace::Host > > &conditioned_densities)
Computes the conditional density of the source measure given a potential.
Definition SotSolver.cc:974
void set_distance_function(const std::string &distance_name)
Sets the distance function to be used by the solver.
Definition SotSolver.cc:23
double current_distance_threshold
Definition SotSolver.h:502
double current_epsilon
Definition SotSolver.h:505
double global_functional
Definition SotSolver.h:506
void setup_target(const std::vector< Point< spacedim > > &target_points, const Vector< double > &target_density)
Sets up the target measure for the solver.
Definition SotSolver.cc:56
void compute_weighted_barycenters_non_euclidean(const Vector< double > &potentials, std::vector< Vector< double > > &barycenters_gradients_out, std::vector< Point< spacedim > > &barycenters_out)
Definition SotSolver.cc:664
void solve(Vector< double > &potential, const SotParameterManager::SolverParameters &params)
Solves the optimal transport problem.
Definition SotSolver.cc:106
double compute_covering_radius() const
Computes the covering radius of the target measure with respect to the source domain.
Definition SotSolver.cc:433
boost::geometry::index::rtree< IndexedPoint, RTreeParams > RTree
Definition SotSolver.h:61
void compute_weighted_barycenters_euclidean(const Vector< double > &potentials, std::vector< Point< spacedim > > &barycenters_out)
Definition SotSolver.cc:872
std::function< Point< spacedim >(const Point< spacedim > &, const Vector< double > &)> distance_function_exponential_map
The exponential map of the distance function.
Definition SotSolver.h:326
const unsigned int n_mpi_processes
Definition SotSolver.h:496
double effective_distance_threshold
Definition SotSolver.h:503
std::vector< std::size_t > find_nearest_target_points(const Point< spacedim > &query_point) const
Definition SotSolver.cc:555
double min_target_density
Definition SotSolver.h:509
std::string distance_name
The name of the distance function.
Definition SotSolver.h:323
void setup_source(const DoFHandler< dim, spacedim > &dof_handler, const Mapping< dim, spacedim > &mapping, const FiniteElement< dim, spacedim > &fe, const LinearAlgebra::distributed::Vector< double, MemorySpace::Host > &source_density, const unsigned int quadrature_order)
Sets up the source measure for the solver.
Definition SotSolver.cc:45
void evaluate_weighted_barycenters(const Vector< double > &potentials, std::vector< Point< spacedim > > &barycenters_out, const SotParameterManager::SolverParameters &params)
Evaluates the weighted barycenters of the power cells.
Definition SotSolver.cc:586
SotParameterManager::SolverParameters current_params
Definition SotSolver.h:513
double get_C_global() const
Returns the global C value.
Definition SotSolver.h:256
void compute_distance_threshold() const
Definition SotSolver.cc:381
std::vector< Vector< double > > barycenters_grads
Definition SotSolver.h:523
TargetMeasure target_measure
The target measure.
Definition SotSolver.h:347
bool get_convergence_status() const
Returns the convergence status of the last solve.
Definition SotSolver.cc:580
MPI_Comm mpi_communicator
Definition SotSolver.h:495
double C_global
Definition SotSolver.h:510
double evaluate_functional(const Vector< double > &potential, Vector< double > &gradient_out)
Evaluates the dual functional and its gradient.
Definition SotSolver.cc:219
void configure(const SotParameterManager::SolverParameters &params)
Configures the solver with the given parameters.
Definition SotSolver.cc:64
double compute_integral_radius_bound(const Vector< double > &potentials, double epsilon, double tolerance, double C_value, double current_functional_val) const
Definition SotSolver.cc:477
const std::string yellow
const std::string red
const std::string green
A struct to hold copy data for parallel assembly.
Definition SotSolver.h:154
double functional_value
The value of the functional on the current cell.
Definition SotSolver.h:155
Vector< double > barycenters_values
The barycenter values for the current cell.
Definition SotSolver.h:160
CopyData(const unsigned int n_target_points)
Definition SotSolver.h:162
Vector< double > gradient_values
The local contribution to the gradient.
Definition SotSolver.h:156
double local_C_sum
The sum of the scale terms for this cell.
Definition SotSolver.h:158
Vector< double > potential_values
The potential values at the target points.
Definition SotSolver.h:157
A struct to hold scratch data for parallel assembly.
Definition SotSolver.h:132
ScratchData(const ScratchData &other)
Definition SotSolver.h:140
FEValues< dim, spacedim > fe_values
FEValues object for the current cell.
Definition SotSolver.h:147
ScratchData(const FiniteElement< dim, spacedim > &fe, const Mapping< dim, spacedim > &mapping, const Quadrature< dim > &quadrature)
Definition SotSolver.h:133
std::vector< double > density_values
The density values at the quadrature points of the current cell.
Definition SotSolver.h:148
A struct to hold all the necessary information about the source measure.
Definition SotSolver.h:66
SmartPointer< const Mapping< dim, spacedim > > mapping
Pointer to the mapping for the source mesh.
Definition SotSolver.h:69
SmartPointer< const DoFHandler< dim, spacedim > > dof_handler
Pointer to the DoF handler for the source mesh.
Definition SotSolver.h:68
unsigned int quadrature_order
The order of the quadrature rule to use for integration.
Definition SotSolver.h:72
SourceMeasure(const DoFHandler< dim, spacedim > &dof_handler_, const Mapping< dim, spacedim > &mapping_, const FiniteElement< dim, spacedim > &fe_, const LinearAlgebra::distributed::Vector< double, MemorySpace::Host > &density_, const unsigned int quadrature_order_)
Constructor for the SourceMeasure struct.
Definition SotSolver.h:78
SmartPointer< const LinearAlgebra::distributed::Vector< double, MemorySpace::Host > > density
Pointer to the density vector of the source measure.
Definition SotSolver.h:71
SmartPointer< const FiniteElement< dim, spacedim > > fe
Pointer to the finite element for the source mesh.
Definition SotSolver.h:70
bool initialized
Flag to check if source measure is set up.
Definition SotSolver.h:67
A struct to hold all the necessary information about the target measure.
Definition SotSolver.h:95
bool initialized
Flag to check if target measure is set up.
Definition SotSolver.h:96
std::vector< Point< spacedim > > points
The points of the discrete target measure.
Definition SotSolver.h:97
void initialize_rtree()
Initializes the R-tree with the target points.
Definition SotSolver.h:119
Vector< double > density
The weights of the discrete target measure.
Definition SotSolver.h:98
TargetMeasure(const std::vector< Point< spacedim > > &points_, const Vector< double > &density_)
Constructor for the TargetMeasure struct.
Definition SotSolver.h:105
RTree rtree
An R-tree for fast spatial queries on the target points.
Definition SotSolver.h:99