SemiDiscreteOT 1.0
Semi-Discrete Optimal Transport Library
Loading...
Searching...
No Matches
SoftmaxRefinement.cc
Go to the documentation of this file.
2#include <deal.II/base/quadrature_lib.h>
3
4template <int dim, int spacedim>
6 MPI_Comm mpi_comm,
7 const DoFHandler<dim, spacedim>& dof_handler,
8 const Mapping<dim, spacedim>& mapping,
9 const FiniteElement<dim, spacedim>& fe,
10 const LinearAlgebra::distributed::Vector<double>& source_density,
11 unsigned int quadrature_order,
12 double distance_threshold,
13 bool use_log_sum_exp_trick)
14 : mpi_communicator(mpi_comm)
15 , n_mpi_processes(Utilities::MPI::n_mpi_processes(mpi_comm))
16 , this_mpi_process(Utilities::MPI::this_mpi_process(mpi_comm))
17 , pcout(std::cout, this_mpi_process == 0)
18 , dof_handler(dof_handler)
19 , mapping(mapping)
20 , fe(fe)
21 , source_density(source_density)
22 , quadrature_order(quadrature_order)
23 , current_distance_threshold(distance_threshold)
24 , use_log_sum_exp_trick(use_log_sum_exp_trick)
25{
26 distance_function = [](const Point<spacedim> x, const Point<spacedim> y) { return euclidean_distance<spacedim>(x, y); };
27}
28
29template <int dim, int spacedim>
31{
32 namespace bgi = boost::geometry::index;
33 std::vector<IndexedPoint> indexed_points;
34 indexed_points.reserve(current_target_points_coarse->size());
35
36 for (std::size_t i = 0; i < current_target_points_coarse->size(); ++i) {
37 indexed_points.emplace_back((*current_target_points_coarse)[i], i);
38 }
39
40 target_points_rtree = RTree(indexed_points.begin(), indexed_points.end());
41}
42
43template <int dim, int spacedim>
45 const Point<spacedim>& query_point) const
46{
47 namespace bgi = boost::geometry::index;
48 std::vector<std::size_t> indices;
49
50 for (const auto& indexed_point : target_points_rtree |
51 bgi::adaptors::queried(bgi::satisfies([&](const IndexedPoint& p) {
52 return distance_function(p.first, query_point) <= current_distance_threshold;
53 })))
54 {
55 indices.push_back(indexed_point.second);
56 }
57
58 return indices;
59}
60
61template <int dim, int spacedim>
63 const typename DoFHandler<dim, spacedim>::active_cell_iterator &cell,
64 ScratchData &scratch_data,
65 CopyData &copy_data)
66{
67 if (!cell->is_locally_owned())
68 return;
69
70 scratch_data.fe_values.reinit(cell);
71 const std::vector<Point<spacedim>> &q_points = scratch_data.fe_values.get_quadrature_points();
72 scratch_data.fe_values.get_function_values(source_density, scratch_data.density_values);
73
74 copy_data.potential_values = 0;
75
76 const unsigned int n_q_points = q_points.size();
77 const double lambda_inv = 1.0 / current_lambda;
78 const double threshold_sq = current_distance_threshold * current_distance_threshold;
79
80 // Get relevant coarse target points for this cell
81 std::vector<std::size_t> cell_target_indices_coarse = find_nearest_target_points(cell->center());
82
83 if (cell_target_indices_coarse.empty()) return;
84
85 const unsigned int n_target_points_coarse = cell_target_indices_coarse.size();
86 std::vector<Point<spacedim>> target_positions_coarse(n_target_points_coarse);
87 std::vector<double> target_densities_coarse(n_target_points_coarse);
88 std::vector<double> potential_values_coarse(n_target_points_coarse);
89
90 // Load coarse target point data
91 for (size_t i = 0; i < n_target_points_coarse; ++i) {
92 const size_t idx = cell_target_indices_coarse[i];
93 target_positions_coarse[i] = (*current_target_points_coarse)[idx];
94 target_densities_coarse[i] = (*current_target_density_coarse)[idx];
95 potential_values_coarse[i] = (*current_potential_coarse)[idx];
96 }
97
98 // Get fine points that are children of the coarse points
99 std::vector<std::size_t> cell_target_indices_fine;
100 std::vector<Point<spacedim>> target_positions_fine;
101
102 // Add bounds checking for child_indices_ access
103 if (current_level < 0 || current_level >= static_cast<int>(current_child_indices->size())) {
104 std::cerr << "Error: Invalid level " << current_level << " for child_indices_ of size "
105 << current_child_indices->size() << std::endl;
106 return;
107 }
108
109 for (size_t i = 0; i < n_target_points_coarse; ++i) {
110 const size_t coarse_idx = cell_target_indices_coarse[i];
111 if (coarse_idx >= (*current_child_indices)[current_level].size()) {
112 std::cerr << "Error: Invalid coarse index " << coarse_idx << " for child_indices_["
113 << current_level << "] of size " << (*current_child_indices)[current_level].size()
114 << std::endl;
115 continue;
116 }
117 const auto& children = (*current_child_indices)[current_level][coarse_idx];
118
119 for (const auto& child_idx : children) {
120 if (child_idx >= current_target_points_fine->size()) {
121 std::cerr << "Error: Invalid child index " << child_idx << " for target_points_fine of size "
122 << current_target_points_fine->size() << std::endl;
123 continue;
124 }
125 cell_target_indices_fine.push_back(child_idx);
126 target_positions_fine.push_back((*current_target_points_fine)[child_idx]);
127 }
128 }
129
130 const unsigned int n_target_points_fine = cell_target_indices_fine.size();
131 if (n_target_points_fine == 0) {
132 std::cerr << "Warning: No valid fine points found for coarse points at level " << current_level << std::endl;
133 return;
134 }
135
136 // For each quadrature point
137 for (unsigned int q = 0; q < n_q_points; ++q) {
138 const Point<spacedim> &x = q_points[q];
139 const double density_value = scratch_data.density_values[q];
140 const double JxW = scratch_data.fe_values.JxW(q);
141
142 // First compute normalization using coarse points
143 double total_sum_exp = 0.0;
144 double max_exponent = -std::numeric_limits<double>::max();
145 std::vector<double> exp_terms_coarse(n_target_points_coarse);
146
147 if (use_log_sum_exp_trick) {
148 // First pass: find maximum exponent
149 #pragma omp simd reduction(max:max_exponent)
150 for (size_t i = 0; i < n_target_points_coarse; ++i) {
151 const double local_dist2 = std::pow(distance_function(x, target_positions_coarse[i]), 2);
152 if (local_dist2 <= threshold_sq) {
153 const double exponent = (potential_values_coarse[i] - 0.5 * local_dist2) * lambda_inv;
154 max_exponent = std::max(max_exponent, exponent);
155 }
156 }
157
158 // Second pass: compute shifted exponentials
159 #pragma omp simd reduction(+:total_sum_exp)
160 for (size_t i = 0; i < n_target_points_coarse; ++i) {
161 const double local_dist2 = std::pow(distance_function(x, target_positions_coarse[i]), 2);
162 if (local_dist2 <= threshold_sq) {
163 const double shifted_exp = std::exp((potential_values_coarse[i] - 0.5 * local_dist2) * lambda_inv - max_exponent);
164 exp_terms_coarse[i] = target_densities_coarse[i] * shifted_exp;
165 total_sum_exp += exp_terms_coarse[i];
166 }
167 }
168 } else {
169 // Original computation method
170 #pragma omp simd reduction(+:total_sum_exp)
171 for (size_t i = 0; i < n_target_points_coarse; ++i) {
172 const double local_dist2 = std::pow(distance_function(x, target_positions_coarse[i]), 2);
173 if (local_dist2 <= threshold_sq) {
174 exp_terms_coarse[i] = target_densities_coarse[i] *
175 std::exp((potential_values_coarse[i] - 0.5 * local_dist2) * lambda_inv);
176 total_sum_exp += exp_terms_coarse[i];
177 }
178 }
179 }
180
181 if (total_sum_exp <= 0.0) continue;
182
183 // Now update potential for fine points using their parent's exp term for normalization
184 double scale = density_value * JxW / total_sum_exp;
185 if (use_log_sum_exp_trick) {
186 scale *= std::exp(-max_exponent);
187 }
188
189 #pragma omp simd
190 for (size_t i = 0; i < n_target_points_fine; ++i) {
191 const double local_dist2_fine = std::pow(distance_function(x, target_positions_fine[i]), 2);
192 if (local_dist2_fine <= threshold_sq) {
193 const double exp_term_fine = std::exp((- 0.5 * local_dist2_fine) * lambda_inv);
194 copy_data.potential_values[cell_target_indices_fine[i]] += scale * exp_term_fine;
195 }
196 }
197 }
198}
199
200template <int dim, int spacedim>
202 const std::vector<Point<spacedim>>& target_points_fine,
203 const Vector<double>& target_density_fine,
204 const std::vector<Point<spacedim>>& target_points_coarse,
205 const Vector<double>& target_density_coarse,
206 const Vector<double>& potential_coarse,
207 double regularization_param,
208 int level,
209 const std::vector<std::vector<std::vector<size_t>>>& child_indices)
210{
211 // Store computation parameters
212 current_target_points_fine = &target_points_fine;
213 current_target_density_fine = &target_density_fine;
214 current_target_points_coarse = &target_points_coarse;
215 current_target_density_coarse = &target_density_coarse;
216 current_potential_coarse = &potential_coarse;
217 current_child_indices = &child_indices;
218 current_level = level;
219 current_lambda = regularization_param;
220
221 // Initialize RTree for spatial queries
222 setup_rtree();
223
224 // Initialize output potential
225 Vector<double> potential_fine(target_points_fine.size());
226 Vector<double> local_process_potential(target_points_fine.size());
227
228 // Create appropriate quadrature
229 std::unique_ptr<Quadrature<dim>> quadrature;
230 const bool use_simplex = (dynamic_cast<const FE_SimplexP<dim, spacedim>*>(&fe) != nullptr);
231 if (use_simplex) {
232 quadrature = std::make_unique<QGaussSimplex<dim>>(quadrature_order);
233 } else {
234 quadrature = std::make_unique<QGauss<dim>>(quadrature_order);
235 }
236
237 // Create scratch and copy data objects
238 ScratchData scratch_data(fe, mapping, *quadrature);
239 CopyData copy_data(target_points_fine.size());
240
241 // Create filtered iterator for locally owned cells
242 FilteredIterator<typename DoFHandler<dim, spacedim>::active_cell_iterator>
243 begin_filtered(IteratorFilters::LocallyOwnedCell(),
244 dof_handler.begin_active()),
245 end_filtered(IteratorFilters::LocallyOwnedCell(),
246 dof_handler.end());
247
248 // Parallel assembly
249 WorkStream::run(
250 begin_filtered,
251 end_filtered,
252 [this](const typename DoFHandler<dim, spacedim>::active_cell_iterator &cell,
253 ScratchData &scratch_data,
254 CopyData &copy_data) {
255 this->local_assemble(cell, scratch_data, copy_data);
256 },
257 [&local_process_potential](const CopyData &copy_data) {
258 local_process_potential += copy_data.potential_values;
259 },
260 scratch_data,
261 copy_data);
262
263 // Sum up contributions across all MPI processes
264 potential_fine = 0;
265 Utilities::MPI::sum(local_process_potential, mpi_communicator, potential_fine);
266
267 // Apply epsilon scaling to potential
268 if (Utilities::MPI::this_mpi_process(mpi_communicator) == 0) {
269 for (unsigned int i = 0; i < target_points_fine.size(); ++i) {
270 if (potential_fine[i] > 0.0) {
271 potential_fine[i] = -regularization_param * std::log(potential_fine[i]);
272 }
273 }
274 }
275
276 // Broadcast final potential to all processes
277 potential_fine = Utilities::MPI::broadcast(mpi_communicator, potential_fine, 0);
278
279 return potential_fine;
280}
281
282// Explicit instantiation
283template class SoftmaxRefinement<2>;
284template class SoftmaxRefinement<3>;
285template class SoftmaxRefinement<2, 3>;
A class for refining the optimal transport potential using a softmax operation.
std::vector< std::size_t > find_nearest_target_points(const Point< spacedim > &query_point) const
Finds the nearest target points to a query point.
void setup_rtree()
Sets up the R-tree.
Vector< double > compute_refinement(const std::vector< Point< spacedim > > &target_points_fine, const Vector< double > &target_density_fine, const std::vector< Point< spacedim > > &target_points_coarse, const Vector< double > &target_density_coarse, const Vector< double > &potential_coarse, double regularization_param, int current_level, const std::vector< std::vector< std::vector< size_t > > > &child_indices)
Computes the refined potential.
SoftmaxRefinement(MPI_Comm mpi_comm, const DoFHandler< dim, spacedim > &dof_handler, const Mapping< dim, spacedim > &mapping, const FiniteElement< dim, spacedim > &fe, const LinearAlgebra::distributed::Vector< double > &source_density, unsigned int quadrature_order, double distance_threshold, bool use_log_sum_exp_trick=true)
Constructor for the SoftmaxRefinement class.
void local_assemble(const typename DoFHandler< dim, spacedim >::active_cell_iterator &cell, ScratchData &scratch_data, CopyData &copy_data)
Assembles the local contributions to the refined potential.
std::function< double(const Point< spacedim > &, const Point< spacedim > &)> distance_function
The distance function.
boost::geometry::index::rtree< IndexedPoint, RTreeParams > RTree
std::pair< Point< spacedim >, std::size_t > IndexedPoint
A struct to hold copy data for parallel assembly.
Vector< double > potential_values
The potential values at the target points.
A struct to hold scratch data for parallel assembly.
FEValues< dim, spacedim > fe_values
FEValues object for the current cell.
std::vector< double > density_values
The density values at the quadrature points of the current cell.