SemiDiscreteOT 1.0
Semi-Discrete Optimal Transport Library
Loading...
Searching...
No Matches
SoftmaxRefinement.h
Go to the documentation of this file.
1#ifndef SOFTMAX_REFINEMENT_H
2#define SOFTMAX_REFINEMENT_H
3
4#include <boost/geometry.hpp>
5#include <boost/geometry/index/rtree.hpp>
6#include <memory>
7#include <vector>
8#include <type_traits>
9
10#include <deal.II/base/conditional_ostream.h>
11#include <deal.II/lac/vector.h>
12#include <deal.II/lac/la_parallel_vector.h>
13#include <deal.II/base/point.h>
14#include <deal.II/fe/fe_values.h>
15#include <deal.II/fe/fe_simplex_p.h>
16#include <deal.II/fe/fe_q.h>
17#include <deal.II/dofs/dof_handler.h>
18#include <deal.II/base/work_stream.h>
19#include <deal.II/base/exceptions.h>
20#include <deal.II/dofs/dof_tools.h>
21#include <deal.II/grid/filtered_iterator.h>
22#include <deal.II/numerics/rtree.h>
23#include <deal.II/base/mpi.h>
24#include <deal.II/base/utilities.h>
25#include <deal.II/base/multithread_info.h>
26
28
29using namespace dealii;
30
42template <int dim, int spacedim = dim>
44public:
50 const std::function<double(const Point<spacedim>&, const Point<spacedim>&)>& dist)
51 {
52 distance_function = dist;
53 }
54
58 struct ScratchData {
59 ScratchData(const FiniteElement<dim, spacedim> &fe,
60 const Mapping<dim, spacedim> &mapping,
61 const Quadrature<dim> &quadrature)
62 : fe_values(mapping, fe, quadrature,
63 update_values | update_quadrature_points | update_JxW_values),
64 density_values(quadrature.size()) {}
65
66 ScratchData(const ScratchData &scratch_data)
67 : fe_values(scratch_data.fe_values.get_mapping(),
68 scratch_data.fe_values.get_fe(),
69 scratch_data.fe_values.get_quadrature(),
70 update_values | update_quadrature_points | update_JxW_values),
71 density_values(scratch_data.density_values) {}
72
73 FEValues<dim, spacedim> fe_values;
74 std::vector<double> density_values;
75 };
76
80 struct CopyData {
81 Vector<double> potential_values;
82
83 CopyData(const unsigned int n_target_points)
84 : potential_values(n_target_points) {}
85 };
86
98 SoftmaxRefinement(MPI_Comm mpi_comm,
99 const DoFHandler<dim, spacedim>& dof_handler,
100 const Mapping<dim, spacedim>& mapping,
101 const FiniteElement<dim, spacedim>& fe,
102 const LinearAlgebra::distributed::Vector<double>& source_density,
103 unsigned int quadrature_order,
104 double distance_threshold,
105 bool use_log_sum_exp_trick = true);
106
119 Vector<double> compute_refinement(
120 const std::vector<Point<spacedim>>& target_points_fine,
121 const Vector<double>& target_density_fine,
122 const std::vector<Point<spacedim>>& target_points_coarse,
123 const Vector<double>& target_density_coarse,
124 const Vector<double>& potential_coarse,
125 double regularization_param,
126 int current_level,
127 const std::vector<std::vector<std::vector<size_t>>>& child_indices);
128
129private:
130 // MPI members
132 const unsigned int n_mpi_processes;
133 const unsigned int this_mpi_process;
134 ConditionalOStream pcout;
135
136 // Distance function
137 std::function<double(const Point<spacedim>&, const Point<spacedim>&)> distance_function;
138
139 // Source mesh and FE data
140 const DoFHandler<dim, spacedim>& dof_handler;
141 const Mapping<dim, spacedim>& mapping;
142 const FiniteElement<dim, spacedim>& fe;
143 const LinearAlgebra::distributed::Vector<double>& source_density;
144 const unsigned int quadrature_order;
145
146 // Spatial search structure
147 using IndexedPoint = std::pair<Point<spacedim>, std::size_t>;
148 using RTreeParams = boost::geometry::index::rstar<8>;
149 using RTree = boost::geometry::index::rtree<IndexedPoint, RTreeParams>;
151
152 // Current computation state
153 double current_lambda{0.0};
156 const std::vector<Point<spacedim>>* current_target_points_fine{nullptr};
157 const Vector<double>* current_target_density_fine{nullptr};
158 const std::vector<Point<spacedim>>* current_target_points_coarse{nullptr};
159 const Vector<double>* current_target_density_coarse{nullptr};
160 const Vector<double>* current_potential_coarse{nullptr};
161 const std::vector<std::vector<std::vector<size_t>>>* current_child_indices{nullptr};
163
164 // Helper methods
168 void setup_rtree();
174 std::vector<std::size_t> find_nearest_target_points(const Point<spacedim>& query_point) const;
175
182 void local_assemble(const typename DoFHandler<dim, spacedim>::active_cell_iterator &cell,
183 ScratchData &scratch_data,
184 CopyData &copy_data);
185
190 {
191 try {
192 const auto* simplex_fe = dynamic_cast<const FE_SimplexP<dim, spacedim>*>(&fe);
193 return (simplex_fe != nullptr);
194 }
195 catch (...) {
196 return false;
197 }
198 }
199};
200
201#endif // SOFTMAX_REFINEMENT_H
A class for refining the optimal transport potential using a softmax operation.
boost::geometry::index::rstar< 8 > RTreeParams
const unsigned int quadrature_order
The order of the quadrature rule to use for integration.
std::vector< std::size_t > find_nearest_target_points(const Point< spacedim > &query_point) const
Finds the nearest target points to a query point.
void setup_rtree()
Sets up the R-tree.
void set_distance_function(const std::function< double(const Point< spacedim > &, const Point< spacedim > &)> &dist)
Sets the distance function to be used.
const Vector< double > * current_target_density_coarse
The current coarse target density.
Vector< double > compute_refinement(const std::vector< Point< spacedim > > &target_points_fine, const Vector< double > &target_density_fine, const std::vector< Point< spacedim > > &target_points_coarse, const Vector< double > &target_density_coarse, const Vector< double > &potential_coarse, double regularization_param, int current_level, const std::vector< std::vector< std::vector< size_t > > > &child_indices)
Computes the refined potential.
const unsigned int n_mpi_processes
The number of MPI processes.
MPI_Comm mpi_communicator
The MPI communicator.
const DoFHandler< dim, spacedim > & dof_handler
The DoF handler for the source mesh.
const std::vector< std::vector< std::vector< size_t > > > * current_child_indices
The current child indices.
const Mapping< dim, spacedim > & mapping
The mapping for the source mesh.
ConditionalOStream pcout
A conditional output stream for parallel printing.
const FiniteElement< dim, spacedim > & fe
The finite element for the source mesh.
const std::vector< Point< spacedim > > * current_target_points_fine
The current fine target points.
void local_assemble(const typename DoFHandler< dim, spacedim >::active_cell_iterator &cell, ScratchData &scratch_data, CopyData &copy_data)
Assembles the local contributions to the refined potential.
RTree target_points_rtree
An R-tree for fast spatial queries on the target points.
std::function< double(const Point< spacedim > &, const Point< spacedim > &)> distance_function
The distance function.
const LinearAlgebra::distributed::Vector< double > & source_density
The density of the source measure.
boost::geometry::index::rtree< IndexedPoint, RTreeParams > RTree
const unsigned int this_mpi_process
The rank of the current MPI process.
const std::vector< Point< spacedim > > * current_target_points_coarse
The current coarse target points.
std::pair< Point< spacedim >, std::size_t > IndexedPoint
const double current_distance_threshold
The current distance threshold.
const Vector< double > * current_target_density_fine
The current fine target density.
const bool use_log_sum_exp_trick
Whether to use the log-sum-exp trick.
double current_lambda
The current regularization parameter.
const Vector< double > * current_potential_coarse
The current coarse potential.
bool is_simplex_element() const
Checks if the finite element is a simplex element.
int current_level
The current level.
A struct to hold copy data for parallel assembly.
CopyData(const unsigned int n_target_points)
Vector< double > potential_values
The potential values at the target points.
A struct to hold scratch data for parallel assembly.
ScratchData(const ScratchData &scratch_data)
FEValues< dim, spacedim > fe_values
FEValues object for the current cell.
ScratchData(const FiniteElement< dim, spacedim > &fe, const Mapping< dim, spacedim > &mapping, const Quadrature< dim > &quadrature)
std::vector< double > density_values
The density values at the quadrature points of the current cell.