SemiDiscreteOT 1.0
Semi-Discrete Optimal Transport Library
Loading...
Searching...
No Matches
OptimalTransportPlan.h
Go to the documentation of this file.
1#ifndef OPTIMAL_TRANSPORT_PLAN_H
2#define OPTIMAL_TRANSPORT_PLAN_H
3
4#include <memory>
5#include <string>
6#include <vector>
7#include <map>
8#include <filesystem>
9#include <fstream>
10#include <algorithm>
11#include <cmath>
12#include <stdexcept>
13
14#include <deal.II/base/point.h>
15#include <deal.II/base/quadrature_lib.h>
16#include <deal.II/base/utilities.h>
17#include <deal.II/base/function.h>
18#include <deal.II/base/parameter_acceptor.h>
19#include <deal.II/lac/vector.h>
20#include <deal.II/grid/tria.h>
21#include <deal.II/fe/fe_values.h>
22#include <deal.II/numerics/rtree.h>
23#include <deal.II/numerics/vector_tools.h>
24#include <deal.II/numerics/data_out.h>
25
28
29
30using namespace dealii;
31
33
34// Forward declarations
35template <int spacedim> class MapApproximationStrategy;
36
45template <int spacedim>
46class OptimalTransportPlan : public ParameterAcceptor {
47public:
52 OptimalTransportPlan(const std::string& strategy_name = "modal");
53
61 void set_distance_function(const std::function<double(const Point<spacedim>&, const Point<spacedim>&)>& dist)
62 {
63 distance_function = dist;
64 }
65
71 void set_source_measure(const std::vector<Point<spacedim>>& points,
72 const std::vector<double>& density);
73
79 void set_target_measure(const std::vector<Point<spacedim>>& points,
80 const std::vector<double>& density);
81
87 void set_potential(const Vector<double>& potential,
88 const double regularization_param = 0.0);
89
96 void set_truncation_radius(double radius);
97
101 void compute_map();
102
107 void save_map(const std::string& output_dir) const;
108
113 void set_strategy(const std::string& strategy_name);
114
119 static std::vector<std::string> get_available_strategies();
120
121private:
122 // Data members
123 std::vector<Point<spacedim>> source_points;
124 std::vector<double> source_density;
125 std::vector<Point<spacedim>> target_points;
126 std::vector<double> target_density;
127 Vector<double> transport_potential;
128 double epsilon;
129 double truncation_radius = -1.0; // Negative means no truncation
130
131 // Distance function
132 std::function<double(const Point<spacedim>&, const Point<spacedim>&)> distance_function;
133
134 // Strategy pattern implementation
135 std::unique_ptr<MapApproximationStrategy<spacedim>> strategy;
136
137 // Factory method to create strategies
138 static std::unique_ptr<MapApproximationStrategy<spacedim>>
139 create_strategy(const std::string& name);
140};
141
147template <int spacedim>
149public:
150 virtual ~MapApproximationStrategy() = default;
151
163 virtual void compute_map(
164 const std::function<double(const Point<spacedim>&, const Point<spacedim>&)> distance_function,
165 const std::vector<Point<spacedim>>& source_points,
166 const std::vector<double>& source_density,
167 const std::vector<Point<spacedim>>& target_points,
168 const std::vector<double>& target_density,
169 const Vector<double>& potential,
170 const double regularization_param,
171 const double truncation_radius) = 0;
172
177 virtual void save_results(const std::string& output_dir) const = 0;
178
179protected:
180 std::vector<Point<spacedim>> source_points;
181 std::vector<Point<spacedim>> mapped_points;
182 std::vector<double> transport_density;
183};
184
193 template <int spacedim>
194 class ModalStrategy : public MapApproximationStrategy<spacedim> {
195 public:
207 void compute_map(
208 const std::function<double(const Point<spacedim>&, const Point<spacedim>&)> distance_function,
209 const std::vector<Point<spacedim>>& source_points,
210 const std::vector<double>& source_density,
211 const std::vector<Point<spacedim>>& target_points,
212 const std::vector<double>& target_density,
213 const Vector<double>& potential,
214 const double regularization_param,
215 const double truncation_radius) override;
216
221 void save_results(const std::string& output_dir) const override;
222 };
223
232template <int spacedim>
234public:
246 void compute_map(
247 const std::function<double(const Point<spacedim>&, const Point<spacedim>&)> distance_function,
248 const std::vector<Point<spacedim>>& source_points,
249 const std::vector<double>& source_density,
250 const std::vector<Point<spacedim>>& target_points,
251 const std::vector<double>& target_density,
252 const Vector<double>& potential,
253 const double regularization_param,
254 const double truncation_radius) override;
255
260 void save_results(const std::string& output_dir) const override;
261};
262
263} // namespace OptimalTransportPlanSpace
264
265#endif // OPTIMAL_TRANSPORT_PLAN_H
A collection of distance functions, their gradients, and exponential maps.
Barycentric interpolation strategy for map approximation.
void compute_map(const std::function< double(const Point< spacedim > &, const Point< spacedim > &)> distance_function, const std::vector< Point< spacedim > > &source_points, const std::vector< double > &source_density, const std::vector< Point< spacedim > > &target_points, const std::vector< double > &target_density, const Vector< double > &potential, const double regularization_param, const double truncation_radius) override
Computes the transport map.
void save_results(const std::string &output_dir) const override
Saves the results to a file.
Abstract base class for map approximation strategies.
virtual void save_results(const std::string &output_dir) const =0
Saves the results to a file.
std::vector< Point< spacedim > > source_points
The source points.
std::vector< Point< spacedim > > mapped_points
The mapped points.
virtual void compute_map(const std::function< double(const Point< spacedim > &, const Point< spacedim > &)> distance_function, const std::vector< Point< spacedim > > &source_points, const std::vector< double > &source_density, const std::vector< Point< spacedim > > &target_points, const std::vector< double > &target_density, const Vector< double > &potential, const double regularization_param, const double truncation_radius)=0
Computes the transport map.
std::vector< double > transport_density
The transported density.
Modal strategy for map approximation.
void compute_map(const std::function< double(const Point< spacedim > &, const Point< spacedim > &)> distance_function, const std::vector< Point< spacedim > > &source_points, const std::vector< double > &source_density, const std::vector< Point< spacedim > > &target_points, const std::vector< double > &target_density, const Vector< double > &potential, const double regularization_param, const double truncation_radius) override
Computes the transport map.
void save_results(const std::string &output_dir) const override
Saves the results to a file.
A class for computing and managing optimal transport map approximations.
void compute_map()
Compute the optimal transport map approximation using the current strategy.
static std::unique_ptr< MapApproximationStrategy< spacedim > > create_strategy(const std::string &name)
static std::vector< std::string > get_available_strategies()
Get available strategy names.
void set_strategy(const std::string &strategy_name)
Change the approximation strategy.
void set_distance_function(const std::function< double(const Point< spacedim > &, const Point< spacedim > &)> &dist)
Set the distance function used to compute distances between points. This function accepts a callable ...
std::function< double(const Point< spacedim > &, const Point< spacedim > &)> distance_function
void set_target_measure(const std::vector< Point< spacedim > > &points, const std::vector< double > &density)
Set the target measure data.
void set_source_measure(const std::vector< Point< spacedim > > &points, const std::vector< double > &density)
Set the source measure data.
void set_truncation_radius(double radius)
Set the truncation radius for map computation. Points outside this radius will not be considered in t...
std::unique_ptr< MapApproximationStrategy< spacedim > > strategy
void set_potential(const Vector< double > &potential, const double regularization_param=0.0)
Set the optimal transport potential.
void save_map(const std::string &output_dir) const
Save the computed transport map to files.