Loading [MathJax]/extensions/tex2jax.js
Tatooine
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Modules Pages Concepts
solver.h
Go to the documentation of this file.
1#ifndef TATOOINE_ODE_BOOST_SOLVER_H
2#define TATOOINE_ODE_BOOST_SOLVER_H
3//==============================================================================
5#include <tatooine/tensor.h>
6
7#include <boost/numeric/odeint.hpp>
8//==============================================================================
10//==============================================================================
11template <typename Real, size_t N>
12struct is_resizeable<tatooine::vec<Real, N>> {
13 using type = boost::false_type;
14 static const bool value = type::value;
15};
16//==============================================================================
17} // namespace boost::numeric::odeint
18//==============================================================================
19namespace tatooine::ode::boost {
20//==============================================================================
21template <typename Real, size_t N, typename Stepper>
22struct solver : ode::solver<solver<Real, N, Stepper>, Real, N> {
23public:
26 using typename parent_type::pos_type;
27 using typename parent_type::vec_t;
28
29protected:
30 //============================================================================
31 Stepper m_stepper;
33
34private:
35 //============================================================================
37
38public:
39 //============================================================================
40 solver(const Stepper &stepper, const Real stepsize)
41 : m_stepper{stepper}, m_stepsize{stepsize} {}
42 solver(Stepper &&stepper, const Real stepsize)
43 : m_stepper{std::move(stepper)}, m_stepsize{stepsize} {}
44 //============================================================================
45 template <arithmetic Y0Real, typename Evaluator,
47 constexpr void solve(Evaluator &&evaluator, vec<Y0Real, N> const &y0,
48 arithmetic auto const t0, arithmetic auto tau,
49 StepperCallback &&callback) const {
50 using ::boost::numeric::odeint::step_adjustment_error;
51 constexpr auto callback_takes_derivative =
52 std::is_invocable_v<StepperCallback, pos_type, Real, vec_t>;
53
54 if (tau == 0) {
55 return;
56 }
57 auto x_copy = pos_type{y0};
58 try {
59 ::boost::numeric::odeint::integrate_adaptive(
61 [&evaluator, tau, t0, num_same_in_a_row = std::size_t{},
62 prev_y = vec_t::fill(nan<Real>())](pos_type const &y, pos_type &sample,
63 Real t) mutable {
64 auto const delta_pos = euclidean_distance(prev_y, y);
65 auto const rel_error = delta_pos / euclidean_length(sample);
66 if (rel_error < 1e-12) {
67 ++num_same_in_a_row;
68 } else {
69 num_same_in_a_row = 0;
70 }
71 if (num_same_in_a_row == 10) {
72 throw step_adjustment_error{""};
73 }
74 prev_y = y;
75 sample = evaluator(y, t);
76 },
77 x_copy, Real(t0), Real(t0 + tau),
78 Real(tau > 0 ? m_stepsize : -m_stepsize),
79 [tau, t0, &callback, &evaluator](const pos_type &y, Real t) {
80 if constexpr (!callback_takes_derivative) {
81 callback(y, t);
82 } else {
83 callback(y, t, evaluator(y, t));
84 }
85 });
86 } catch (step_adjustment_error const &) {
87 if constexpr (!callback_takes_derivative) {
88 callback(pos_type::fill(nan()), nan());
89 } else {
90 using derivative_type = decltype(evaluator(y0, t0));
91 callback(pos_type::fill(nan()), nan(), derivative_type::fill(nan()));
92 }
93 }
94 }
95 //----------------------------------------------------------------------------
96 auto stepsize() -> auto & { return m_stepsize; }
97 auto stepsize() const { return m_stepsize; }
98};
99
100//==============================================================================
101} // namespace tatooine::ode::boost
102//==============================================================================
103
104#endif
Definition: concepts.h:33
Definition: solver.h:9
Definition: controller_runge_kutta_with_domain_check.h:19
Definition: algorithm.h:6
constexpr auto euclidean_length(base_tensor< Tensor, T, N > const &t_in) -> T
Definition: length.h:12
auto nan(const char *arg="")
Definition: nan.h:26
constexpr auto euclidean_distance(base_tensor< Tensor0, T0, N > const &lhs, base_tensor< Tensor1, T1, N > const &rhs)
Definition: distance.h:19
Definition: solver.h:22
auto stepsize() -> auto &
Definition: solver.h:96
solver(const Stepper &stepper, const Real stepsize)
Definition: solver.h:40
auto stepsize() const
Definition: solver.h:97
Real m_stepsize
Definition: solver.h:32
solver(Stepper &&stepper, const Real stepsize)
Definition: solver.h:42
constexpr void solve(Evaluator &&evaluator, vec< Y0Real, N > const &y0, arithmetic auto const t0, arithmetic auto tau, StepperCallback &&callback) const
Definition: solver.h:47
Stepper m_stepper
Definition: solver.h:31
friend parent_type
Definition: solver.h:36
Definition: solver.h:50
vec< Real, N > vec_t
Definition: solver.h:56
vec_t pos_type
Definition: solver.h:57
Definition: tags.h:96
Definition: vec.h:12