SemiDiscreteOT 1.0
Semi-Discrete Optimal Transport Library
Loading...
Searching...
No Matches
Classes | Public Types | Public Member Functions | Public Attributes | Private Member Functions | Private Attributes | List of all members
SotSolver< dim, spacedim > Class Template Reference

A solver for semi-discrete optimal transport problems. More...

#include <SotSolver.h>

Collaboration diagram for SotSolver< dim, spacedim >:
Collaboration graph
[legend]

Classes

struct  CopyData
 A struct to hold copy data for parallel assembly. More...
 
struct  ScratchData
 A struct to hold scratch data for parallel assembly. More...
 
struct  SourceMeasure
 A struct to hold all the necessary information about the source measure. More...
 
struct  TargetMeasure
 A struct to hold all the necessary information about the target measure. More...
 
class  VerboseSolverControl
 A verbose solver control class that prints the progress of the solver. More...
 

Public Types

using IndexedPoint = std::pair< Point< spacedim >, std::size_t >
 
using RTreeParams = boost::geometry::index::rstar< 8 >
 
using RTree = boost::geometry::index::rtree< IndexedPoint, RTreeParams >
 

Public Member Functions

 SotSolver (const MPI_Comm &comm)
 Constructor for the SotSolver.
 
void setup_source (const DoFHandler< dim, spacedim > &dof_handler, const Mapping< dim, spacedim > &mapping, const FiniteElement< dim, spacedim > &fe, const LinearAlgebra::distributed::Vector< double, MemorySpace::Host > &source_density, const unsigned int quadrature_order)
 Sets up the source measure for the solver.
 
void setup_target (const std::vector< Point< spacedim > > &target_points, const Vector< double > &target_density)
 Sets up the target measure for the solver.
 
void configure (const SotParameterManager::SolverParameters &params)
 Configures the solver with the given parameters.
 
void solve (Vector< double > &potential, const SotParameterManager::SolverParameters &params)
 Solves the optimal transport problem.
 
void solve (Vector< double > &potential, const SourceMeasure &source, const TargetMeasure &target, const SotParameterManager::SolverParameters &params)
 Alternative solve interface if measures are not set up beforehand.
 
void evaluate_weighted_barycenters (const Vector< double > &potentials, std::vector< Point< spacedim > > &barycenters_out, const SotParameterManager::SolverParameters &params)
 Evaluates the weighted barycenters of the power cells.
 
double get_last_functional_value () const
 Returns the value of the functional at the last iteration.
 
unsigned int get_last_iteration_count () const
 Returns the number of iterations of the last solve.
 
bool get_convergence_status () const
 Returns the convergence status of the last solve.
 
double get_last_distance_threshold () const
 Returns the distance threshold used in the last solve.
 
double get_C_global () const
 Returns the global C value.
 
void set_distance_threshold (double threshold)
 Sets the distance threshold for the solver.
 
double compute_covering_radius () const
 Computes the covering radius of the target measure with respect to the source domain.
 
double compute_geometric_radius_bound (const Vector< double > &potentials, const double epsilon, const double tolerance) const
 Computes the geometric radius bound for truncating quadrature rules.
 
void set_distance_function (const std::string &distance_name)
 Sets the distance function to be used by the solver.
 
void get_potential_conditioned_density (const DoFHandler< dim, spacedim > &dof_handler, const Mapping< dim, spacedim > &mapping, const Vector< double > &potential, const std::vector< unsigned int > &potential_indices, std::vector< LinearAlgebra::distributed::Vector< double, MemorySpace::Host > > &conditioned_densities)
 Computes the conditional density of the source measure given a potential.
 
double evaluate_functional (const Vector< double > &potential, Vector< double > &gradient_out)
 Evaluates the dual functional and its gradient.
 
void compute_hessian (const Vector< double > &potential, LAPACKFullMatrix< double > &hessian_out)
 Computes the Hessian matrix of the dual functional.
 

Public Attributes

std::string distance_name
 The name of the distance function.
 
std::function< double(const Point< spacedim > &, const Point< spacedim > &)> distance_function
 The distance function.
 
std::function< Vector< double >(const Point< spacedim > &, const Point< spacedim > &)> distance_function_gradient
 The gradient of the distance function.
 
std::function< Point< spacedim >(const Point< spacedim > &, const Vector< double > &)> distance_function_exponential_map
 The exponential map of the distance function.
 
SourceMeasure source_measure
 The source measure.
 
TargetMeasure target_measure
 The target measure.
 

Private Member Functions

void local_assemble (const typename DoFHandler< dim, spacedim >::active_cell_iterator &cell, ScratchData &scratch, CopyData &copy, std::function< void(CopyData &, const Point< spacedim > &, const std::vector< std::size_t > &, const std::vector< double > &, const std::vector< double > &, const double &, const double &, const double &, const double &, const double &)> function_call)
 
void compute_distance_threshold () const
 
std::vector< std::size_t > find_nearest_target_points (const Point< spacedim > &query_point) const
 
double compute_integral_radius_bound (const Vector< double > &potentials, double epsilon, double tolerance, double C_value, double current_functional_val) const
 
bool validate_measures () const
 
void compute_weighted_barycenters_non_euclidean (const Vector< double > &potentials, std::vector< Vector< double > > &barycenters_gradients_out, std::vector< Point< spacedim > > &barycenters_out)
 
void compute_weighted_barycenters_euclidean (const Vector< double > &potentials, std::vector< Point< spacedim > > &barycenters_out)
 

Private Attributes

MPI_Comm mpi_communicator
 
const unsigned int n_mpi_processes
 
const unsigned int this_mpi_process
 
ConditionalOStream pcout
 
std::unique_ptr< SolverControl > solver_control
 
double current_distance_threshold
 
double effective_distance_threshold
 
const Vector< double > * current_potential
 
double current_epsilon
 
double global_functional
 
Vector< double > gradient
 
double covering_radius
 
double min_target_density
 
double C_global = 0.0
 
SotParameterManager::SolverParameters current_params
 
Vector< double > barycenters
 
Vector< double > barycenters_gradients
 
std::vector< Point< spacedim > > barycenters_points
 
std::vector< Vector< double > > barycenters_grads
 

Detailed Description

template<int dim, int spacedim = dim>
class SotSolver< dim, spacedim >

A solver for semi-discrete optimal transport problems.

This class implements a solver for the dual formulation of the regularized semi-discrete optimal transport problem.

Template Parameters
dimThe dimension of the source mesh.
spacedimThe dimension of the space the mesh is embedded in.

Definition at line 55 of file SotSolver.h.

Member Typedef Documentation

◆ IndexedPoint

template<int dim, int spacedim = dim>
using SotSolver< dim, spacedim >::IndexedPoint = std::pair<Point<spacedim>, std::size_t>

Definition at line 59 of file SotSolver.h.

◆ RTreeParams

template<int dim, int spacedim = dim>
using SotSolver< dim, spacedim >::RTreeParams = boost::geometry::index::rstar<8>

Definition at line 60 of file SotSolver.h.

◆ RTree

template<int dim, int spacedim = dim>
using SotSolver< dim, spacedim >::RTree = boost::geometry::index::rtree<IndexedPoint, RTreeParams>

Definition at line 61 of file SotSolver.h.

Constructor & Destructor Documentation

◆ SotSolver()

template<int dim, int spacedim>
SotSolver< dim, spacedim >::SotSolver ( const MPI_Comm &  comm)

Constructor for the SotSolver.

Parameters
commThe MPI communicator.

Definition at line 4 of file SotSolver.cc.

Member Function Documentation

◆ setup_source()

template<int dim, int spacedim>
void SotSolver< dim, spacedim >::setup_source ( const DoFHandler< dim, spacedim > &  dof_handler,
const Mapping< dim, spacedim > &  mapping,
const FiniteElement< dim, spacedim > &  fe,
const LinearAlgebra::distributed::Vector< double, MemorySpace::Host > &  source_density,
const unsigned int  quadrature_order 
)

Sets up the source measure for the solver.

Parameters
dof_handlerThe DoF handler for the source mesh.
mappingThe mapping for the source mesh.
feThe finite element for the source mesh.
source_densityThe density of the source measure.
quadrature_orderThe order of the quadrature rule to use for integration.

Definition at line 45 of file SotSolver.cc.

◆ setup_target()

template<int dim, int spacedim>
void SotSolver< dim, spacedim >::setup_target ( const std::vector< Point< spacedim > > &  target_points,
const Vector< double > &  target_density 
)

Sets up the target measure for the solver.

Parameters
target_pointsThe points of the discrete target measure.
target_densityThe weights of the discrete target measure.

Definition at line 56 of file SotSolver.cc.

◆ configure()

template<int dim, int spacedim>
void SotSolver< dim, spacedim >::configure ( const SotParameterManager::SolverParameters params)

Configures the solver with the given parameters.

Parameters
paramsThe solver parameters.

Definition at line 64 of file SotSolver.cc.

◆ solve() [1/2]

template<int dim, int spacedim>
void SotSolver< dim, spacedim >::solve ( Vector< double > &  potential,
const SotParameterManager::SolverParameters params 
)

Solves the optimal transport problem.

Parameters
potentialThe vector to store the computed optimal transport potential.
paramsThe solver parameters.

Definition at line 106 of file SotSolver.cc.

◆ solve() [2/2]

template<int dim, int spacedim>
void SotSolver< dim, spacedim >::solve ( Vector< double > &  potential,
const SourceMeasure source,
const TargetMeasure target,
const SotParameterManager::SolverParameters params 
)

Alternative solve interface if measures are not set up beforehand.

Parameters
potentialThe vector to store the computed optimal transport potential.
sourceThe source measure.
targetThe target measure.
paramsThe solver parameters.

Definition at line 91 of file SotSolver.cc.

◆ evaluate_weighted_barycenters()

template<int dim, int spacedim>
void SotSolver< dim, spacedim >::evaluate_weighted_barycenters ( const Vector< double > &  potentials,
std::vector< Point< spacedim > > &  barycenters_out,
const SotParameterManager::SolverParameters params 
)

Evaluates the weighted barycenters of the power cells.

Parameters
potentialsThe optimal transport potentials.
barycenters_outThe vector to store the computed barycenters.
paramsThe solver parameters.

Definition at line 586 of file SotSolver.cc.

◆ get_last_functional_value()

template<int dim, int spacedim = dim>
double SotSolver< dim, spacedim >::get_last_functional_value ( ) const
inline

Returns the value of the functional at the last iteration.

Definition at line 240 of file SotSolver.h.

◆ get_last_iteration_count()

template<int dim, int spacedim>
unsigned int SotSolver< dim, spacedim >::get_last_iteration_count ( ) const

Returns the number of iterations of the last solve.

Definition at line 574 of file SotSolver.cc.

◆ get_convergence_status()

template<int dim, int spacedim>
bool SotSolver< dim, spacedim >::get_convergence_status ( ) const

Returns the convergence status of the last solve.

Definition at line 580 of file SotSolver.cc.

◆ get_last_distance_threshold()

template<int dim, int spacedim = dim>
double SotSolver< dim, spacedim >::get_last_distance_threshold ( ) const
inline

Returns the distance threshold used in the last solve.

Definition at line 252 of file SotSolver.h.

◆ get_C_global()

template<int dim, int spacedim = dim>
double SotSolver< dim, spacedim >::get_C_global ( ) const
inline

Returns the global C value.

Definition at line 256 of file SotSolver.h.

◆ set_distance_threshold()

template<int dim, int spacedim = dim>
void SotSolver< dim, spacedim >::set_distance_threshold ( double  threshold)
inline

Sets the distance threshold for the solver.

Parameters
thresholdThe distance threshold.

Definition at line 262 of file SotSolver.h.

◆ compute_covering_radius()

template<int dim, int spacedim>
double SotSolver< dim, spacedim >::compute_covering_radius ( ) const

Computes the covering radius of the target measure with respect to the source domain.

The covering radius R0 is defined as: R0 = max_{x∈Ω} min_{1≤j≤N} ||x - y_j||

which represents the maximum distance any point in the source domain needs to travel to reach the nearest target point.

Returns
The covering radius value

Definition at line 433 of file SotSolver.cc.

◆ compute_geometric_radius_bound()

template<int dim, int spacedim>
double SotSolver< dim, spacedim >::compute_geometric_radius_bound ( const Vector< double > &  potentials,
const double  epsilon,
const double  tolerance 
) const

Computes the geometric radius bound for truncating quadrature rules.

The geometric radius bound R_geom is defined as: R_geom^2 ≥ R_0^2 + 2Γ(ψ) + 2ε ln(ε/(ν_min * τ * |J_ε(ψ)|))

where:

  • R_0 is the covering radius
  • Γ(ψ) = M-m is the potential range (max - min)
  • ε is the regularization parameter
  • τ is the tolerance parameter
  • ν_min is the minimum target density
  • J_ε(ψ) is the functional value
Parameters
potentialsCurrent potential values
epsilonRegularization parameter
toleranceDesired tolerance for truncation error
Returns
The geometric radius bound

Definition at line 511 of file SotSolver.cc.

◆ set_distance_function()

template<int dim, int spacedim>
void SotSolver< dim, spacedim >::set_distance_function ( const std::string &  distance_name)

Sets the distance function to be used by the solver.

Parameters
distance_nameThe name of the distance function.

Definition at line 23 of file SotSolver.cc.

◆ get_potential_conditioned_density()

template<int dim, int spacedim>
void SotSolver< dim, spacedim >::get_potential_conditioned_density ( const DoFHandler< dim, spacedim > &  dof_handler,
const Mapping< dim, spacedim > &  mapping,
const Vector< double > &  potential,
const std::vector< unsigned int > &  potential_indices,
std::vector< LinearAlgebra::distributed::Vector< double, MemorySpace::Host > > &  conditioned_densities 
)

Computes the conditional density of the source measure given a potential.

Parameters
dof_handlerThe DoF handler for the source mesh.
mappingThe mapping for the source mesh.
potentialThe optimal transport potential.
potential_indicesThe indices of the potential to use.
conditioned_densitiesThe vector to store the computed conditional densities.

Definition at line 974 of file SotSolver.cc.

◆ evaluate_functional()

template<int dim, int spacedim>
double SotSolver< dim, spacedim >::evaluate_functional ( const Vector< double > &  potential,
Vector< double > &  gradient_out 
)

Evaluates the dual functional and its gradient.

Parameters
potentialThe potential at which to evaluate the functional.
gradient_outThe vector to store the computed gradient.
Returns
The value of the functional.

Definition at line 219 of file SotSolver.cc.

◆ compute_hessian()

template<int dim, int spacedim = dim>
void SotSolver< dim, spacedim >::compute_hessian ( const Vector< double > &  potential,
LAPACKFullMatrix< double > &  hessian_out 
)

Computes the Hessian matrix of the dual functional.

Parameters
potentialThe potential at which to compute the Hessian.
hessian_outThe matrix to store the computed Hessian.

◆ local_assemble()

template<int dim, int spacedim>
void SotSolver< dim, spacedim >::local_assemble ( const typename DoFHandler< dim, spacedim >::active_cell_iterator &  cell,
ScratchData scratch,
CopyData copy,
std::function< void(CopyData &, const Point< spacedim > &, const std::vector< std::size_t > &, const std::vector< double > &, const std::vector< double > &, const double &, const double &, const double &, const double &, const double &)>  function_call 
)
private

Definition at line 770 of file SotSolver.cc.

◆ compute_distance_threshold()

template<int dim, int spacedim>
void SotSolver< dim, spacedim >::compute_distance_threshold ( ) const
private

Definition at line 381 of file SotSolver.cc.

◆ find_nearest_target_points()

template<int dim, int spacedim>
std::vector< std::size_t > SotSolver< dim, spacedim >::find_nearest_target_points ( const Point< spacedim > &  query_point) const
private

Definition at line 555 of file SotSolver.cc.

◆ compute_integral_radius_bound()

template<int dim, int spacedim>
double SotSolver< dim, spacedim >::compute_integral_radius_bound ( const Vector< double > &  potentials,
double  epsilon,
double  tolerance,
double  C_value,
double  current_functional_val 
) const
private

Computes the integral radius bound R_int based on the formula: R_int^2 >= 2*M + 2*epsilon*log( (epsilon*C_value) / (tolerance*|J_epsilon(psi)|) ) where M is the max potential value.

Definition at line 477 of file SotSolver.cc.

◆ validate_measures()

template<int dim, int spacedim>
bool SotSolver< dim, spacedim >::validate_measures ( ) const
private

Definition at line 73 of file SotSolver.cc.

◆ compute_weighted_barycenters_non_euclidean()

template<int dim, int spacedim>
void SotSolver< dim, spacedim >::compute_weighted_barycenters_non_euclidean ( const Vector< double > &  potentials,
std::vector< Vector< double > > &  barycenters_gradients_out,
std::vector< Point< spacedim > > &  barycenters_out 
)
private

Definition at line 664 of file SotSolver.cc.

◆ compute_weighted_barycenters_euclidean()

template<int dim, int spacedim>
void SotSolver< dim, spacedim >::compute_weighted_barycenters_euclidean ( const Vector< double > &  potentials,
std::vector< Point< spacedim > > &  barycenters_out 
)
private

Definition at line 872 of file SotSolver.cc.

Member Data Documentation

◆ distance_name

template<int dim, int spacedim = dim>
std::string SotSolver< dim, spacedim >::distance_name

The name of the distance function.

Definition at line 323 of file SotSolver.h.

◆ distance_function

template<int dim, int spacedim = dim>
std::function<double(const Point<spacedim>&, const Point<spacedim>&)> SotSolver< dim, spacedim >::distance_function

The distance function.

Definition at line 324 of file SotSolver.h.

◆ distance_function_gradient

template<int dim, int spacedim = dim>
std::function<Vector<double>(const Point<spacedim>&, const Point<spacedim>&)> SotSolver< dim, spacedim >::distance_function_gradient

The gradient of the distance function.

Definition at line 325 of file SotSolver.h.

◆ distance_function_exponential_map

template<int dim, int spacedim = dim>
std::function<Point<spacedim>(const Point<spacedim>&, const Vector<double>&)> SotSolver< dim, spacedim >::distance_function_exponential_map

The exponential map of the distance function.

Definition at line 326 of file SotSolver.h.

◆ source_measure

template<int dim, int spacedim = dim>
SourceMeasure SotSolver< dim, spacedim >::source_measure

The source measure.

Definition at line 346 of file SotSolver.h.

◆ target_measure

template<int dim, int spacedim = dim>
TargetMeasure SotSolver< dim, spacedim >::target_measure

The target measure.

Definition at line 347 of file SotSolver.h.

◆ mpi_communicator

template<int dim, int spacedim = dim>
MPI_Comm SotSolver< dim, spacedim >::mpi_communicator
private

Definition at line 495 of file SotSolver.h.

◆ n_mpi_processes

template<int dim, int spacedim = dim>
const unsigned int SotSolver< dim, spacedim >::n_mpi_processes
private

Definition at line 496 of file SotSolver.h.

◆ this_mpi_process

template<int dim, int spacedim = dim>
const unsigned int SotSolver< dim, spacedim >::this_mpi_process
private

Definition at line 497 of file SotSolver.h.

◆ pcout

template<int dim, int spacedim = dim>
ConditionalOStream SotSolver< dim, spacedim >::pcout
private

Definition at line 498 of file SotSolver.h.

◆ solver_control

template<int dim, int spacedim = dim>
std::unique_ptr<SolverControl> SotSolver< dim, spacedim >::solver_control
private

Definition at line 501 of file SotSolver.h.

◆ current_distance_threshold

template<int dim, int spacedim = dim>
double SotSolver< dim, spacedim >::current_distance_threshold
mutableprivate

Definition at line 502 of file SotSolver.h.

◆ effective_distance_threshold

template<int dim, int spacedim = dim>
double SotSolver< dim, spacedim >::effective_distance_threshold
mutableprivate

Definition at line 503 of file SotSolver.h.

◆ current_potential

template<int dim, int spacedim = dim>
const Vector<double>* SotSolver< dim, spacedim >::current_potential
private

Definition at line 504 of file SotSolver.h.

◆ current_epsilon

template<int dim, int spacedim = dim>
double SotSolver< dim, spacedim >::current_epsilon
private

Definition at line 505 of file SotSolver.h.

◆ global_functional

template<int dim, int spacedim = dim>
double SotSolver< dim, spacedim >::global_functional
mutableprivate

Definition at line 506 of file SotSolver.h.

◆ gradient

template<int dim, int spacedim = dim>
Vector<double> SotSolver< dim, spacedim >::gradient
private

Definition at line 507 of file SotSolver.h.

◆ covering_radius

template<int dim, int spacedim = dim>
double SotSolver< dim, spacedim >::covering_radius
private

Definition at line 508 of file SotSolver.h.

◆ min_target_density

template<int dim, int spacedim = dim>
double SotSolver< dim, spacedim >::min_target_density
private

Definition at line 509 of file SotSolver.h.

◆ C_global

template<int dim, int spacedim = dim>
double SotSolver< dim, spacedim >::C_global = 0.0
private

Definition at line 510 of file SotSolver.h.

◆ current_params

template<int dim, int spacedim = dim>
SotParameterManager::SolverParameters SotSolver< dim, spacedim >::current_params
private

Definition at line 513 of file SotSolver.h.

◆ barycenters

template<int dim, int spacedim = dim>
Vector<double> SotSolver< dim, spacedim >::barycenters
private

Definition at line 518 of file SotSolver.h.

◆ barycenters_gradients

template<int dim, int spacedim = dim>
Vector<double> SotSolver< dim, spacedim >::barycenters_gradients
private

Definition at line 519 of file SotSolver.h.

◆ barycenters_points

template<int dim, int spacedim = dim>
std::vector<Point<spacedim> > SotSolver< dim, spacedim >::barycenters_points
private

Definition at line 522 of file SotSolver.h.

◆ barycenters_grads

template<int dim, int spacedim = dim>
std::vector<Vector<double> > SotSolver< dim, spacedim >::barycenters_grads
private

Definition at line 523 of file SotSolver.h.


The documentation for this class was generated from the following files: