SemiDiscreteOT 1.0
Semi-Discrete Optimal Transport Library
Loading...
Searching...
No Matches
SemiDiscreteOT.h
Go to the documentation of this file.
1#ifndef RSOT_H
2#define RSOT_H
3
4#include <filesystem>
5#include <memory>
6#include <mutex>
7#include <atomic>
8#include <boost/geometry/strategies/disjoint.hpp>
9#include <boost/geometry/index/rtree.hpp>
10
11#include <deal.II/base/mpi.h>
12#include <deal.II/base/utilities.h>
13#include <deal.II/base/conditional_ostream.h>
14#include <deal.II/base/parameter_acceptor.h>
15#include <deal.II/base/timer.h>
16#include <deal.II/base/vectorization.h>
17#include <deal.II/base/function.h>
18#include <deal.II/grid/tria.h>
19#include <deal.II/distributed/tria.h>
20#include <deal.II/distributed/fully_distributed_tria.h>
21#include <deal.II/distributed/tria_base.h>
22#include <deal.II/grid/grid_generator.h>
23#include <deal.II/grid/grid_out.h>
24#include <deal.II/grid/grid_in.h>
25#include <deal.II/grid/grid_tools.h>
26#include <deal.II/dofs/dof_handler.h>
27#include <deal.II/dofs/dof_tools.h>
28#include <deal.II/fe/fe_q.h>
29#include <deal.II/fe/fe_simplex_p.h>
30#include <deal.II/fe/fe_system.h>
31#include <deal.II/fe/fe_values.h>
32#include <deal.II/fe/mapping_q1.h>
33#include <deal.II/fe/mapping_fe.h>
34#include <deal.II/base/quadrature_lib.h>
35#include <deal.II/base/quadrature.h>
36#include <deal.II/lac/vector.h>
37#include <deal.II/lac/la_parallel_vector.h>
38#include <deal.II/lac/vector_operations_internal.h>
39#include <deal.II/numerics/vector_tools.h>
40#include <deal.II/numerics/rtree.h>
41#include <deal.II/numerics/fe_field_function.h>
42#include <deal.II/optimization/solver_bfgs.h>
43#include <deal.II/base/work_stream.h>
44#include <deal.II/base/multithread_info.h>
45#include <deal.II/base/data_out_base.h>
46#include <deal.II/numerics/data_out.h>
47#include <deal.II/base/logstream.h>
48
62
63
64using namespace dealii;
65
76template <int dim, int spacedim = dim>
78public:
83 SemiDiscreteOT(const MPI_Comm &mpi_communicator);
87 void run();
88
94 void configure(std::function<void(SotParameterManager&)> config_func);
95
96
107 Triangulation<dim, spacedim>& tria,
108 const DoFHandler<dim, spacedim>& dh,
109 const Vector<double>& density,
110 const std::string& name = "source");
111
118 const std::vector<Point<spacedim>>& points,
119 const Vector<double>& weights);
120
127
133
139
145
151
156 const Vector<double> &get_coarsest_potential() const { return coarsest_potential; }
157
158
159
166 Vector<double> solve(const Vector<double>& initial_potential = Vector<double>());
167
172
178 const std::function<double(const Point<spacedim>&, const Point<spacedim>&)>& dist)
179 {
180 sot_solver->distance_function = dist;
181 }
182
183 ConditionalOStream pcout;
184protected:
185 // MPI-related members
187 const unsigned int n_mpi_processes;
188 const unsigned int this_mpi_process;
189
190 // Solver member
191 std::unique_ptr<SotSolver<dim, spacedim>> sot_solver;
192
193 // Parameter manager and references
201 std::string& selected_task;
202 std::string& io_coding;
203
204 std::unique_ptr<DoFHandler<dim, spacedim>> initial_fine_dof_handler;
205 std::unique_ptr<Vector<double>> initial_fine_density;
207
208 // Source mesh name for saving and hierarchy generation
209 std::string source_mesh_name = "source";
210
211 // Mesh and DoF handler members
212 parallel::fullydistributed::Triangulation<dim, spacedim> source_mesh;
213 Triangulation<dim, spacedim> target_mesh;
214 DoFHandler<dim, spacedim> dof_handler_source;
215 DoFHandler<dim, spacedim> dof_handler_target;
216
217 std::unique_ptr<VTKHandler<dim,spacedim>> source_vtk_handler;
218 DoFHandler<dim,spacedim> vtk_dof_handler_source;
219 Vector<double> vtk_field_source;
220 Triangulation<dim,spacedim> vtk_tria_source;
221 // Finite element and mapping members
222 std::unique_ptr<FiniteElement<dim, spacedim>> fe_system;
223 std::unique_ptr<Mapping<dim, spacedim>> mapping;
224 std::unique_ptr<FiniteElement<dim, spacedim>> fe_system_target;
225 std::unique_ptr<Mapping<dim, spacedim>> mapping_target;
226 LinearAlgebra::distributed::Vector<double> source_density;
227 Vector<double> target_density;
228 std::vector<Point<spacedim>> target_points;
229 std::vector<Point<spacedim>> source_points;
230
231 // Mesh manager
232 std::unique_ptr<MeshManager<dim, spacedim>> mesh_manager;
233
234 // Epsilon scaling handler
235 std::unique_ptr<EpsilonScalingHandler> epsilon_scaling_handler;
236
243 void save_results(const Vector<double>& potentials, const std::string& filename, bool add_epsilon_prefix = true);
244
249 void normalize_density(LinearAlgebra::distributed::Vector<double>& density);
250private:
251
252 // Core functionality methods
256 void mesh_generation();
260 void load_meshes();
261
267 Vector<double> run_sot(const Vector<double>& initial_potential = Vector<double>());
268
281
282
283
289 Vector<double> run_multilevel(const Vector<double>& initial_potential = Vector<double>());
290
296 Vector<double> run_combined_multilevel(const Vector<double>& initial_potential = Vector<double>());
297
303 Vector<double> run_source_multilevel(const Vector<double>& initial_potential = Vector<double>());
304
310 Vector<double> run_target_multilevel(const Vector<double>& initial_potential = Vector<double>());
311
312 // Setup methods
317 void setup_source_finite_elements(bool is_multilevel = false);
329 void setup_target_points();
334
335 // Exact SOT method (3D only)
339 template <int d = dim, int s = spacedim>
340 typename std::enable_if<d == 3 && s == 3>::type run_exact_sot();
341
342 // Hierarchy-related members
343 std::vector<std::vector<std::vector<size_t>>> child_indices_;
350 void load_hierarchy_data(const std::string& hierarchy_dir, int specific_level = -1);
351
352 // Multilevel computation state
353 std::vector<Point<spacedim>> target_points_coarse;
354 Vector<double> target_density_coarse;
355 mutable double current_distance_threshold{0.0};
356 Vector<double> coarsest_potential;
357
358 // Potential transfer between hierarchy levels
366 void assign_potentials_by_hierarchy(Vector<double>& potentials,
367 int coarse_level,
368 int fine_level,
369 const Vector<double>& prev_potentials);
370
371 // Helper methods
376 std::vector<std::pair<std::string, std::string>> get_target_hierarchy_files() const;
381 std::vector<std::string> get_mesh_hierarchy_files() const;
387 void load_target_points_at_level(const std::string& points_file,
388 const std::string& density_file);
389
396};
397
398#endif
399
Main class for the semi-discrete optimal transport solver.
void save_discrete_measures()
Saves the discrete source and target measures to files.
std::string & io_coding
A reference to the I/O coding.
LinearAlgebra::distributed::Vector< double > source_density
The source density.
void mesh_generation()
Generates the source and target meshes.
bool has_hierarchy_data_
A flag to indicate if the hierarchy data is loaded.
void prepare_target_multilevel()
Pre-computes the multilevel hierarchy for the target.
std::vector< std::pair< std::string, std::string > > get_target_hierarchy_files() const
Gets the target hierarchy files.
parallel::fullydistributed::Triangulation< dim, spacedim > source_mesh
The source mesh.
std::enable_if< d==3 &&s==3 >::type run_exact_sot()
Runs the exact semi-discrete optimal transport solver.
void setup_multilevel_finite_elements()
Sets up the finite elements for a multilevel computation.
SotSolver< dim, spacedim > * get_solver()
Get a pointer to the solver object.
void save_interpolated_fields()
const unsigned int this_mpi_process
The rank of the current MPI process.
double current_distance_threshold
The current distance threshold for computations.
DoFHandler< dim, spacedim > vtk_dof_handler_source
The DoF handler for the source VTK mesh.
void load_meshes()
Loads the source and target meshes from files.
std::unique_ptr< Mapping< dim, spacedim > > mapping_target
The target mapping.
DoFHandler< dim, spacedim > dof_handler_target
The DoF handler for the target mesh.
std::vector< std::string > get_mesh_hierarchy_files() const
Gets the mesh hierarchy files.
bool is_setup_programmatically_
A flag to indicate if the setup is done programmatically.
void normalize_density(LinearAlgebra::distributed::Vector< double > &density)
Normalizes the density vector.
void configure(std::function< void(SotParameterManager &)> config_func)
Configure the solver parameters programmatically.
SotParameterManager::MeshParameters & target_params
A reference to the target mesh parameters.
void setup_source_measure(Triangulation< dim, spacedim > &tria, const DoFHandler< dim, spacedim > &dh, const Vector< double > &density, const std::string &name="source")
Setup source measure from standard deal.II objects (simplified API for tutorials)
Vector< double > target_density
The target density.
std::vector< Point< spacedim > > target_points
The target points.
void assign_potentials_by_hierarchy(Vector< double > &potentials, int coarse_level, int fine_level, const Vector< double > &prev_potentials)
Assigns potentials by hierarchy.
Triangulation< dim, spacedim > vtk_tria_source
The triangulation from the source VTK file.
Vector< double > run_source_multilevel(const Vector< double > &initial_potential=Vector< double >())
Run source-only multilevel SOT computation.
std::unique_ptr< Vector< double > > initial_fine_density
The initial fine density.
Vector< double > run_multilevel(const Vector< double > &initial_potential=Vector< double >())
Run multilevel SOT computation (dispatcher method).
void setup_target_finite_elements()
Sets up the finite elements for the target mesh.
void run()
Runs the solver with the current configuration.
Triangulation< dim, spacedim > target_mesh
The target mesh.
ConditionalOStream pcout
A conditional output stream for parallel printing.
void compute_power_diagram()
Computes the power diagram of the target points.
std::unique_ptr< EpsilonScalingHandler > epsilon_scaling_handler
The epsilon scaling handler.
void setup_finite_elements()
Sets up the finite elements for both the source and target meshes.
void compute_conditional_density()
Computes the conditional density of the source measure.
Vector< double > vtk_field_source
The source field from the VTK file.
void set_distance_function(const std::function< double(const Point< spacedim > &, const Point< spacedim > &)> &dist)
Sets the distance function to be used by the solver.
void compute_transport_map()
Computes the transport map from the source to the target measure.
std::unique_ptr< VTKHandler< dim, spacedim > > source_vtk_handler
The VTK handler for the source mesh.
std::string source_mesh_name
The name of the source mesh.
MPI_Comm mpi_communicator
The MPI communicator.
SotParameterManager::MeshParameters & source_params
A reference to the source mesh parameters.
void load_hierarchy_data(const std::string &hierarchy_dir, int specific_level=-1)
Loads the hierarchy data from a directory.
const unsigned int n_mpi_processes
The number of MPI processes.
const Vector< double > & get_coarsest_potential() const
Get the coarsest potential from the multilevel solve.
Vector< double > coarsest_potential
The coarsest level potential for the multilevel solve.
Vector< double > run_target_multilevel(const Vector< double > &initial_potential=Vector< double >())
Run target-only multilevel SOT computation.
SotParameterManager::MultilevelParameters & multilevel_params
A reference to the multilevel parameters.
std::unique_ptr< Mapping< dim, spacedim > > mapping
The mapping.
const SotParameterManager::SolverParameters & get_solver_params() const
Get a reference to the solver parameters.
void load_target_points_at_level(const std::string &points_file, const std::string &density_file)
Loads the target points at a specific level.
void setup_target_points()
Sets up the target points.
DoFHandler< dim, spacedim > dof_handler_source
The DoF handler for the source mesh.
SotParameterManager::SolverParameters & solver_params
A reference to the solver parameters.
std::string & selected_task
A reference to the selected task.
std::vector< Point< spacedim > > target_points_coarse
The coarse level target points.
void save_results(const Vector< double > &potentials, const std::string &filename, bool add_epsilon_prefix=true)
Saves the results of the computation.
std::unique_ptr< FiniteElement< dim, spacedim > > fe_system
The finite element system.
Vector< double > target_density_coarse
The coarse level target densities.
Vector< double > solve(const Vector< double > &initial_potential=Vector< double >())
Run the optimal transport computation based on the current configuration. This method handles single-...
Vector< double > run_sot(const Vector< double > &initial_potential=Vector< double >())
Run single-level SOT computation.
void setup_target_measure(const std::vector< Point< spacedim > > &points, const Vector< double > &weights)
Set up the target measure from a discrete set of points and weights.
SotParameterManager param_manager
The parameter manager.
SotParameterManager::TransportMapParameters & transport_map_params
A reference to the transport map parameters.
void setup_source_finite_elements(bool is_multilevel=false)
Sets up the finite elements for the source mesh.
std::unique_ptr< FiniteElement< dim, spacedim > > fe_system_target
The target finite element system.
std::vector< Point< spacedim > > source_points
The source points.
std::unique_ptr< DoFHandler< dim, spacedim > > initial_fine_dof_handler
The initial fine DoF handler.
std::unique_ptr< MeshManager< dim, spacedim > > mesh_manager
The mesh manager.
std::vector< std::vector< std::vector< size_t > > > child_indices_
The child indices for the multilevel hierarchy.
std::unique_ptr< SotSolver< dim, spacedim > > sot_solver
The semi-discrete optimal transport solver.
void prepare_multilevel_hierarchies()
Pre-computes the multilevel hierarchies for source and/or target. This must be called after setting u...
Vector< double > run_combined_multilevel(const Vector< double > &initial_potential=Vector< double >())
Run combined source and target multilevel SOT computation.
void prepare_source_multilevel()
Pre-computes the multilevel hierarchy for the source.
SotParameterManager::PowerDiagramParameters & power_diagram_params
A reference to the power diagram parameters.
A solver for semi-discrete optimal transport problems.
Definition SotSolver.h:55