1#ifndef TATOOINE_DIAG_TENSOR_H
2#define TATOOINE_DIAG_TENSOR_H
10template <static_vec Tensor, std::
size_t M, std::
size_t N>
13 static auto constexpr is_diag() {
return true; }
15 static auto constexpr rank() {
return 2; }
16 static auto constexpr dimensions() {
return std::array{M, N}; }
17 static auto constexpr dimension(std::size_t
const i) {
40 if constexpr (
sizeof...(is) == 2) {
41 auto i = std::array{
static_cast<std::size_t
>(is)...};
58 assert(is.size() == 2);
59 return at(is[0], is[1]);
63 assert(is.size() == 2);
64 return at(is[0], is[1]);
73template <static_vec Tensor>
78template <static_vec Tensor>
82template <static_vec Tensor>
87template <arithmetic_or_complex Real, std::
size_t N>
96template <std::
size_t M, std::
size_t N>
98 if constexpr (std::is_rvalue_reference_v<
decltype(t)>) {
100 std::forward<decltype(t)>(t)};
108template <
typename Tensor, std::
size_t N>
112 for (std::size_t i = 0; i < N; ++i) {
113 if (std::abs(A.internal_tensor()(i)) < 1e-10) {
122template <
typename TensorA, static_vec TensorB, std::
size_t N>
123requires(tensor_dimensions<TensorB>[0] == N)
131template <
typename TensorA, static_vec TensorB, std::
size_t N>
132requires(tensor_dimensions<TensorB>[0] == N)
140template <
typename TensorA, std::
size_t M, std::
size_t N>
146requires(N ==
decltype(b)::dimension(0)) {
150 for (std::size_t i = 0; i < N; ++i) {
164template <
typename TensorA, std::
size_t M, std::
size_t N>
168 std::decay_t<
decltype(B)>::dimension(
172 M,
decltype(B)::dimension(1)>;
174 for (std::size_t i = 0; i < M; ++i) {
180template <
typename TensorA, std::
size_t M, std::
size_t N>
184 A)
requires(std::decay_t<
decltype(B)>::dimension(1) == M) {
187 std::decay_t<
decltype(B)>::dimension(0), N>{B};
188 for (std::size_t i = 0; i < N; ++i) {
194template <
typename TensorA, static_mat TensorB, std::
size_t N>
195requires(tensor_dimensions<TensorB>[0] == N)
196constexpr auto solve(diag_static_tensor<TensorA, N, N>&& A, TensorB&& B)
203 for (std::size_t i = 0; i < N; ++i) {
209template <
typename TensorA, static_mat TensorB, std::
size_t N>
210requires(tensor_dimensions<TensorB>[0] == N)
211constexpr auto solve(diag_static_tensor<TensorA, N, N>
const& A, TensorB&& B)
218 for (std::size_t i = 0; i < N; ++i) {
226template <dynamic_tensor Tensor>
230 static auto constexpr is_diag() {
return true; }
238 static auto constexpr rank() {
return 2; }
240 return std::vector{internal_tensor().dimension(0),
241 internal_tensor().dimension(0)};
244 return internal_tensor().dimension(i);
248 if constexpr (
sizeof...(is) == 2) {
249 auto i = std::array{is...};
251 return internal_tensor()(i[0]);
262 assert(is.size() == 2);
263 return at(is[0], is[1]);
267 assert(is.size() == 2);
268 return at(is[0], is[1]);
274template <dynamic_tensor Tensor>
277template <dynamic_tensor Tensor>
280template <dynamic_tensor Tensor>
284 assert(A.
rank() == 1);
288template <dynamic_tensor Lhs, dynamic_tensor Rhs>
289requires diag_tensor<Lhs>
296 if (lhs.rank() == 2 && rhs.rank() == 2 &&
297 lhs.internal_tensor().dimension(0) == rhs.dimension(0)) {
299 out_t::zeros(lhs.internal_tensor().dimension(0), rhs.dimension(1));
300 for (std::size_t r = 0; r < lhs.internal_tensor().dimension(0); ++r) {
301 for (std::size_t c = 0; c < rhs.dimension(1); ++c) {
302 out(r, c) = lhs.internal_tensor()(r) * rhs(r, c);
308 }
else if (lhs.rank() == 2 && rhs.rank() == 1 &&
309 lhs.dimension(1) == rhs.dimension(0)) {
310 auto out = out_t::zeros(lhs.dimension(0));
311 for (std::size_t i = 0; i < rhs.dimension(1); ++i) {
312 out(i) = lhs.internal_tensor()(i) * rhs(i);
319 for (std::size_t i = 1; i < lhs.rank(); ++i) {
320 A <<
" x " << lhs.dimension(i);
324 B <<
"[ " << rhs.dimension(0);
325 for (std::size_t i = 1; i < rhs.rank(); ++i) {
326 B <<
" x " << rhs.dimension(i);
329 throw std::runtime_error{
"Cannot contract given dynamic tensors. (A:" +
330 A.str() +
"; B" + B.str() +
")"};
Definition: tensor_concepts.h:17
Definition: concepts.h:91
Definition: concepts.h:21
Definition: tensor_concepts.h:26
Definition: tensor_concepts.h:20
Definition: tensor_concepts.h:23
Definition: algorithm.h:6
typename common_type_impl< Ts... >::type common_type
Definition: common_type.h:23
typename value_type_impl< T >::type value_type
Definition: type_traits.h:280
tensor< real_number, Dimensions... > Tensor
Definition: tensor.h:184
constexpr auto inv(diag_static_tensor< Tensor, N, N > const &A) -> std::optional< diag_static_tensor< vec< tatooine::value_type< Tensor >, N >, N, N > >
Definition: diag_tensor.h:109
constexpr auto operator*(diag_static_tensor< TensorA, M, N > const &A, static_vec auto const &b) -> vec< common_type< tatooine::value_type< TensorA >, tatooine::value_type< decltype(b)> >, M > requires(N==decltype(b)::dimension(0))
Definition: diag_tensor.h:141
auto solve(polynomial< Real, 1 > const &p) -> std::vector< Real >
solve a + b*x
Definition: polynomial.h:187
constexpr auto diag_rect(static_vec auto &&t)
Definition: diag_tensor.h:97
static constexpr forward_tag forward
Definition: tags.h:9
Definition: diag_tensor.h:227
constexpr auto operator()(integral_range auto const &is) const
Definition: diag_tensor.h:266
auto internal_tensor() -> auto &
Definition: diag_tensor.h:236
static auto constexpr rank()
Definition: diag_tensor.h:238
auto at(integral auto const ... is) const -> value_type
Definition: diag_tensor.h:247
static auto constexpr is_dynamic()
Definition: diag_tensor.h:231
auto operator()(integral auto const ... is) const
Definition: diag_tensor.h:259
static auto constexpr is_tensor()
Definition: diag_tensor.h:229
auto internal_tensor() const -> auto const &
Definition: diag_tensor.h:235
auto dimension(std::size_t const i) const
Definition: diag_tensor.h:243
constexpr auto at(integral_range auto const &is) const -> value_type
Definition: diag_tensor.h:261
static auto constexpr is_diag()
Definition: diag_tensor.h:230
auto dimensions() const
Definition: diag_tensor.h:239
Tensor m_internal_tensor
Definition: diag_tensor.h:233
tatooine::value_type< Tensor > value_type
Definition: diag_tensor.h:228
Definition: diag_tensor.h:11
tatooine::value_type< Tensor > value_type
Definition: diag_tensor.h:29
static auto constexpr is_static()
Definition: diag_tensor.h:14
tensor_type m_internal_tensor
Definition: diag_tensor.h:32
static auto constexpr is_diag()
Definition: diag_tensor.h:13
static auto constexpr rank()
Definition: diag_tensor.h:15
constexpr auto operator()(integral auto const ... is) const
Definition: diag_tensor.h:53
static auto constexpr dimensions()
Definition: diag_tensor.h:16
auto internal_tensor() -> auto &
Definition: diag_tensor.h:68
constexpr auto operator()(integral_range auto const &is) const
Definition: diag_tensor.h:62
constexpr auto at(integral auto const ... is) const -> value_type
Definition: diag_tensor.h:39
auto internal_tensor() const -> const auto &
Definition: diag_tensor.h:67
static auto constexpr is_tensor()
Definition: diag_tensor.h:12
static auto constexpr dimension(std::size_t const i)
Definition: diag_tensor.h:17
constexpr diag_static_tensor(static_vec auto &&v)
Definition: diag_tensor.h:36
constexpr auto at(integral_range auto const &is) const -> value_type
Definition: diag_tensor.h:57
auto constexpr dimension() const
Definition: contracted_dynamic_tensor.h:36
auto constexpr col(std::size_t i)
Definition: mat.h:175