Tatooine
diag_tensor.h
Go to the documentation of this file.
1#ifndef TATOOINE_DIAG_TENSOR_H
2#define TATOOINE_DIAG_TENSOR_H
3//==============================================================================
5//==============================================================================
6#include <tatooine/concepts.h>
7//==============================================================================
8namespace tatooine {
9//==============================================================================
10template <static_vec Tensor, std::size_t M, std::size_t N>
12 static auto constexpr is_tensor() { return true; }
13 static auto constexpr is_diag() { return true; }
14 static auto constexpr is_static() { return true; }
15 static auto constexpr rank() { return 2; }
16 static auto constexpr dimensions() { return std::array{M, N}; }
17 static auto constexpr dimension(std::size_t const i) {
18 switch (i) {
19 default:
20 case 0:
21 return M;
22 case 1:
23 return N;
24 }
25 }
26 //============================================================================
30 //============================================================================
31 private:
33
34 //============================================================================
35 public:
36 constexpr explicit diag_static_tensor(static_vec auto&& v)
37 : m_internal_tensor{std::forward<decltype(v)>(v)} {}
38 //----------------------------------------------------------------------------
39 constexpr auto at(integral auto const... is) const -> value_type {
40 if constexpr (sizeof...(is) == 2) {
41 auto i = std::array{static_cast<std::size_t>(is)...};
42 assert(i[0] < M);
43 assert(i[1] < N);
44 if (i[0] == i[1]) {
45 return internal_tensor()(i[0]);
46 }
47 return 0;
48 } else {
49 return value_type(0) / value_type(0);
50 }
51 }
52 //----------------------------------------------------------------------------
53 constexpr auto operator()(integral auto const... is) const {
54 return at(is...);
55 }
56 //----------------------------------------------------------------------------
57 constexpr auto at(integral_range auto const& is) const -> value_type {
58 assert(is.size() == 2);
59 return at(is[0], is[1]);
60 }
61 //----------------------------------------------------------------------------
62 constexpr auto operator()(integral_range auto const& is) const {
63 assert(is.size() == 2);
64 return at(is[0], is[1]);
65 }
66 //----------------------------------------------------------------------------
67 auto internal_tensor() const -> const auto& { return m_internal_tensor; }
68 auto internal_tensor() -> auto& { return m_internal_tensor; }
69};
70//==============================================================================
71// deduction guides
72//==============================================================================
73template <static_vec Tensor>
77// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
78template <static_vec Tensor>
81// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
82template <static_vec Tensor>
86//==============================================================================
87template <arithmetic_or_complex Real, std::size_t N>
88struct vec;
89//==============================================================================
90// factory functions
91//==============================================================================
92constexpr auto diag(static_vec auto&& t) {
93 return diag_static_tensor{std::forward<decltype(t)>(t)};
94}
95//------------------------------------------------------------------------------
96template <std::size_t M, std::size_t N>
97constexpr auto diag_rect(static_vec auto&& t) {
98 if constexpr (std::is_rvalue_reference_v<decltype(t)>) {
99 return diag_static_tensor<std::decay_t<decltype(t)>, M, N>{
100 std::forward<decltype(t)>(t)};
101 } else {
102 return diag_static_tensor<decltype(t), M, N>{std::forward<decltype(t)>(t)};
103 }
104}
105//==============================================================================
106// free functions
107//==============================================================================
108template <typename Tensor, std::size_t N>
109constexpr auto inv(diag_static_tensor<Tensor, N, N> const& A) -> std::optional<
112 for (std::size_t i = 0; i < N; ++i) {
113 if (std::abs(A.internal_tensor()(i)) < 1e-10) {
114 return {};
115 }
116 }
117 return diag_static_tensor{value_type(1) / A.internal_tensor()};
118}
119//------------------------------------------------------------------------------
120#include <tatooine/vec.h>
121//------------------------------------------------------------------------------
122template <typename TensorA, static_vec TensorB, std::size_t N>
123requires(tensor_dimensions<TensorB>[0] == N)
124constexpr auto solve(diag_static_tensor<TensorA, N, N> const& A, TensorB&& b)
125 -> std::optional<
127 N>> {
128 return A.internal_tensor() / b;
129}
130//------------------------------------------------------------------------------
131template <typename TensorA, static_vec TensorB, std::size_t N>
132requires(tensor_dimensions<TensorB>[0] == N)
133constexpr auto solve(diag_static_tensor<TensorA, N, N>&& A, TensorB&& b)
134 -> std::optional<
136 N>> {
137 return A.internal_tensor() / b;
138}
139//------------------------------------------------------------------------------
140template <typename TensorA, std::size_t M, std::size_t N>
142 static_vec auto const& b)
143 -> vec<
145 M>
146requires(N == decltype(b)::dimension(0)) {
148 M>
149 ret = b;
150 for (std::size_t i = 0; i < N; ++i) {
151 ret(i) *= A.internal_tensor()(i);
152 }
153 return ret;
154}
156// template <typename TensorA, typename TensorB, typename BReal, std::size_t N>
157// constexpr auto operator*(base_tensor<TensorB, BReal, N> const& b,
158// diag_static_tensor<TensorA, N, N> const& A) {
159// return A * b;
160// }
161//------------------------------------------------------------------------------
162#include <tatooine/mat.h>
163//------------------------------------------------------------------------------
164template <typename TensorA, std::size_t M, std::size_t N>
165constexpr auto operator*(
167 static_mat auto const& B) requires(N ==
168 std::decay_t<decltype(B)>::dimension(
169 0)) {
170 using mat_t = mat<
172 M, decltype(B)::dimension(1)>;
173 auto ret = mat_t{B};
174 for (std::size_t i = 0; i < M; ++i) {
175 ret.row(i) *= A.internal_tensor()(i);
176 }
177 return ret;
178}
179//------------------------------------------------------------------------------
180template <typename TensorA, std::size_t M, std::size_t N>
181constexpr auto operator*(
182 static_tensor auto const& B,
184 A) requires(std::decay_t<decltype(B)>::dimension(1) == M) {
185 auto ret = mat<
187 std::decay_t<decltype(B)>::dimension(0), N>{B};
188 for (std::size_t i = 0; i < N; ++i) {
189 ret.col(i) *= A.internal_tensor()(i);
190 }
191 return ret;
192}
193//------------------------------------------------------------------------------
194template <typename TensorA, static_mat TensorB, std::size_t N>
195requires(tensor_dimensions<TensorB>[0] == N)
196constexpr auto solve(diag_static_tensor<TensorA, N, N>&& A, TensorB&& B)
197 -> std::optional<
198 mat<common_type<tatooine::value_type<TensorA>, tatooine::value_type<TensorB>>,
199 N, N>> {
200 auto ret =
201 mat<common_type<tatooine::value_type<TensorA>, tatooine::value_type<TensorB>>,
202 N, N>{B};
203 for (std::size_t i = 0; i < N; ++i) {
204 ret.row(i) /= A.internal_tensor()(i);
205 }
206 return ret;
207}
208//------------------------------------------------------------------------------
209template <typename TensorA, static_mat TensorB, std::size_t N>
210requires(tensor_dimensions<TensorB>[0] == N)
211constexpr auto solve(diag_static_tensor<TensorA, N, N> const& A, TensorB&& B)
212 -> std::optional<
213 mat<common_type<tatooine::value_type<TensorA>, tatooine::value_type<TensorB>>,
214 N, N>> {
215 auto ret =
216 mat<common_type<tatooine::value_type<TensorA>, tatooine::value_type<TensorB>>,
217 N, N>{B};
218 for (std::size_t i = 0; i < N; ++i) {
219 ret.row(i) /= A.internal_tensor()(i);
220 }
221 return ret;
222}
223//==============================================================================
224// dynamic
225//==============================================================================
226template <dynamic_tensor Tensor>
229 static auto constexpr is_tensor() { return true; }
230 static auto constexpr is_diag() { return true; }
231 static auto constexpr is_dynamic() { return true; }
232 //============================================================================
234 //----------------------------------------------------------------------------
235 auto internal_tensor() const -> auto const& { return m_internal_tensor; }
236 auto internal_tensor() -> auto& { return m_internal_tensor; }
237 //----------------------------------------------------------------------------
238 static auto constexpr rank() { return 2; }
239 auto dimensions() const {
240 return std::vector{internal_tensor().dimension(0),
241 internal_tensor().dimension(0)};
242 }
243 auto dimension(std::size_t const i) const {
244 return internal_tensor().dimension(i);
245 }
246 //============================================================================
247 auto at(integral auto const... is) const -> value_type {
248 if constexpr (sizeof...(is) == 2) {
249 auto i = std::array{is...};
250 if (i[0] == i[1]) {
251 return internal_tensor()(i[0]);
252 }
253 return 0;
254 } else {
255 return value_type(0) / value_type(0);
256 }
257 }
258 // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
259 auto operator()(integral auto const... is) const { return at(is...); }
260 // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
261 constexpr auto at(integral_range auto const& is) const -> value_type {
262 assert(is.size() == 2);
263 return at(is[0], is[1]);
264 }
265 // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
266 constexpr auto operator()(integral_range auto const& is) const {
267 assert(is.size() == 2);
268 return at(is[0], is[1]);
269 }
270};
271//==============================================================================
272// deduction guides
273//==============================================================================
274template <dynamic_tensor Tensor>
276// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
277template <dynamic_tensor Tensor>
279// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
280template <dynamic_tensor Tensor>
282//==============================================================================
283auto diag(dynamic_tensor auto&& A) {
284 assert(A.rank() == 1);
285 return diag_dynamic_tensor{std::forward<decltype(A)>(A)};
286}
287//------------------------------------------------------------------------------
288template <dynamic_tensor Lhs, dynamic_tensor Rhs>
289requires diag_tensor<Lhs>
290auto operator*(Lhs const& lhs, Rhs const& rhs)
292 using out_t =
294 auto out = out_t{};
295 // matrix-matrix-multiplication
296 if (lhs.rank() == 2 && rhs.rank() == 2 &&
297 lhs.internal_tensor().dimension(0) == rhs.dimension(0)) {
298 auto out =
299 out_t::zeros(lhs.internal_tensor().dimension(0), rhs.dimension(1));
300 for (std::size_t r = 0; r < lhs.internal_tensor().dimension(0); ++r) {
301 for (std::size_t c = 0; c < rhs.dimension(1); ++c) {
302 out(r, c) = lhs.internal_tensor()(r) * rhs(r, c);
303 }
304 }
305 return out;
306
307 // matrix-vector-multiplication
308 } else if (lhs.rank() == 2 && rhs.rank() == 1 &&
309 lhs.dimension(1) == rhs.dimension(0)) {
310 auto out = out_t::zeros(lhs.dimension(0));
311 for (std::size_t i = 0; i < rhs.dimension(1); ++i) {
312 out(i) = lhs.internal_tensor()(i) * rhs(i);
313 }
314 return out;
315 }
316
317 std::stringstream A;
318 A << "[ " << lhs.dimension(0);
319 for (std::size_t i = 1; i < lhs.rank(); ++i) {
320 A << " x " << lhs.dimension(i);
321 }
322 A << " ]";
323 std::stringstream B;
324 B << "[ " << rhs.dimension(0);
325 for (std::size_t i = 1; i < rhs.rank(); ++i) {
326 B << " x " << rhs.dimension(i);
327 }
328 B << " ]";
329 throw std::runtime_error{"Cannot contract given dynamic tensors. (A:" +
330 A.str() + "; B" + B.str() + ")"};
331}
332//==============================================================================
333} // namespace tatooine
334//==============================================================================
335#endif
Definition: tensor_concepts.h:17
Definition: concepts.h:91
Definition: concepts.h:21
Definition: tensor_concepts.h:26
Definition: tensor_concepts.h:20
Definition: tensor_concepts.h:23
Definition: algorithm.h:6
typename common_type_impl< Ts... >::type common_type
Definition: common_type.h:23
typename value_type_impl< T >::type value_type
Definition: type_traits.h:280
tensor< real_number, Dimensions... > Tensor
Definition: tensor.h:184
constexpr auto inv(diag_static_tensor< Tensor, N, N > const &A) -> std::optional< diag_static_tensor< vec< tatooine::value_type< Tensor >, N >, N, N > >
Definition: diag_tensor.h:109
constexpr auto operator*(diag_static_tensor< TensorA, M, N > const &A, static_vec auto const &b) -> vec< common_type< tatooine::value_type< TensorA >, tatooine::value_type< decltype(b)> >, M > requires(N==decltype(b)::dimension(0))
Definition: diag_tensor.h:141
auto solve(polynomial< Real, 1 > const &p) -> std::vector< Real >
solve a + b*x
Definition: polynomial.h:187
constexpr auto diag_rect(static_vec auto &&t)
Definition: diag_tensor.h:97
static constexpr forward_tag forward
Definition: tags.h:9
Definition: diag_tensor.h:227
constexpr auto operator()(integral_range auto const &is) const
Definition: diag_tensor.h:266
auto internal_tensor() -> auto &
Definition: diag_tensor.h:236
static auto constexpr rank()
Definition: diag_tensor.h:238
auto at(integral auto const ... is) const -> value_type
Definition: diag_tensor.h:247
static auto constexpr is_dynamic()
Definition: diag_tensor.h:231
auto operator()(integral auto const ... is) const
Definition: diag_tensor.h:259
static auto constexpr is_tensor()
Definition: diag_tensor.h:229
auto internal_tensor() const -> auto const &
Definition: diag_tensor.h:235
auto dimension(std::size_t const i) const
Definition: diag_tensor.h:243
constexpr auto at(integral_range auto const &is) const -> value_type
Definition: diag_tensor.h:261
static auto constexpr is_diag()
Definition: diag_tensor.h:230
auto dimensions() const
Definition: diag_tensor.h:239
Tensor m_internal_tensor
Definition: diag_tensor.h:233
tatooine::value_type< Tensor > value_type
Definition: diag_tensor.h:228
Definition: diag_tensor.h:11
tatooine::value_type< Tensor > value_type
Definition: diag_tensor.h:29
static auto constexpr is_static()
Definition: diag_tensor.h:14
tensor_type m_internal_tensor
Definition: diag_tensor.h:32
static auto constexpr is_diag()
Definition: diag_tensor.h:13
static auto constexpr rank()
Definition: diag_tensor.h:15
constexpr auto operator()(integral auto const ... is) const
Definition: diag_tensor.h:53
static auto constexpr dimensions()
Definition: diag_tensor.h:16
auto internal_tensor() -> auto &
Definition: diag_tensor.h:68
constexpr auto operator()(integral_range auto const &is) const
Definition: diag_tensor.h:62
constexpr auto at(integral auto const ... is) const -> value_type
Definition: diag_tensor.h:39
auto internal_tensor() const -> const auto &
Definition: diag_tensor.h:67
static auto constexpr is_tensor()
Definition: diag_tensor.h:12
static auto constexpr dimension(std::size_t const i)
Definition: diag_tensor.h:17
constexpr diag_static_tensor(static_vec auto &&v)
Definition: diag_tensor.h:36
constexpr auto at(integral_range auto const &is) const -> value_type
Definition: diag_tensor.h:57
auto constexpr dimension() const
Definition: contracted_dynamic_tensor.h:36
Definition: mat.h:14
auto constexpr col(std::size_t i)
Definition: mat.h:175
Definition: vec.h:12