Tatooine
operator_overloads.h
Go to the documentation of this file.
1#ifndef TATOOINE_TENSOR_OPERATIONS_OPERATOR_OVERLOADS_H
2#define TATOOINE_TENSOR_OPERATIONS_OPERATOR_OVERLOADS_H
3//==============================================================================
6//==============================================================================
7namespace tatooine {
8//==============================================================================
9template <static_tensor Lhs, static_tensor Rhs>
10requires(same_dimensions<Lhs, Rhs>()) auto constexpr operator==(
11 Lhs const& lhs, Rhs const& rhs) {
12 bool equal = true;
14 [&](auto const... is) {
15 if (lhs(is...) != rhs(is...)) {
16 equal = false;
17 return;
18 }
19 },
20 tensor_dimensions<Lhs>);
21
22 return equal;
23}
24//------------------------------------------------------------------------------
25template <static_tensor Lhs, static_tensor Rhs>
26requires(same_dimensions<Lhs, Rhs>()) auto constexpr operator!=(
27 Lhs const& lhs, Rhs const& rhs) {
28 return !(lhs == rhs);
29}
30//------------------------------------------------------------------------------
31auto constexpr operator-(static_tensor auto const& t) {
32 return unary_operation([](auto const& c) { return -c; }, t);
33}
34//------------------------------------------------------------------------------
35auto constexpr operator+(static_tensor auto const& lhs,
36 arithmetic_or_complex auto const scalar) {
37 return unary_operation([scalar](auto const& c) { return c + scalar; }, lhs);
38}
39//------------------------------------------------------------------------------
41template <static_mat Lhs, static_mat Rhs>
42requires(Lhs::dimension(1) == Rhs::dimension(0))
43auto constexpr operator*(Lhs const& lhs, Rhs const& rhs) {
44 auto product =
46 Lhs::dimension(0), Rhs::dimension(1)>{};
47 for (std::size_t r = 0; r < Lhs::dimension(0); ++r) {
48 for (std::size_t c = 0; c < Rhs::dimension(1); ++c) {
49 for (std::size_t i = 0; i < Lhs::dimension(1); ++i) {
50 product(r, c) += lhs(r, i) * rhs(i, c);
51 }
52 }
53 }
54 return product;
55}
56//------------------------------------------------------------------------------
58template <static_mat Lhs, static_vec Rhs>
59requires(Lhs::dimension(1) == Rhs::dimension(0))
60auto constexpr operator*(Lhs const& lhs, Rhs const& rhs) {
61 auto product =
62 vec<common_type<typename Lhs::value_type, typename Rhs::value_type>,
63 Lhs::dimension(0)>{};
64 for (std::size_t j = 0; j < Lhs::dimension(0); ++j) {
65 for (std::size_t i = 0; i < Lhs::dimension(1); ++i) {
66 product(j) += lhs(j, i) * rhs(i);
67 }
68 }
69 return product;
70}
71//------------------------------------------------------------------------------
73template <static_vec Lhs, static_mat Rhs>
74requires(Lhs::dimension(0) == Rhs::dimension(0))
75auto constexpr operator*(Lhs const& lhs, Rhs const& rhs) {
76 auto product =
79 Rhs::dimension(1)>{};
80 for (std::size_t j = 0; j < Rhs::dimension(0); ++j) {
81 for (std::size_t i = 0; i < Rhs::dimension(1); ++i) {
82 product += lhs(i) * rhs(i, j);
83 }
84 }
85 return product;
86}
87//------------------------------------------------------------------------------
89template <static_tensor Lhs, static_tensor Rhs>
90requires(same_dimensions<Lhs, Rhs>() && Lhs::rank() != 2 && Rhs::rank() != 2)
91auto constexpr operator*(Lhs const& lhs, Rhs const& rhs) {
92 return binary_operation(
93 [](auto&& l, auto&& r) {
94 using out_type =
95 common_type<std::decay_t<decltype(l)>, std::decay_t<decltype(r)>>;
96 return static_cast<out_type>(l) * static_cast<out_type>(r);
97 },
98 lhs, rhs);
99}
100//------------------------------------------------------------------------------
102template <static_tensor Lhs, static_tensor Rhs>
103requires(same_dimensions<Lhs, Rhs>()) auto constexpr operator/(Lhs const& lhs,
104 Rhs const& rhs) {
105 return binary_operation(
106 [](auto&& l, auto&& r) {
107 using out_type =
108 common_type<std::decay_t<decltype(l)>, std::decay_t<decltype(r)>>;
109 return static_cast<out_type>(l) / static_cast<out_type>(r);
110 },
111 lhs, rhs);
112}
113//------------------------------------------------------------------------------
115template <static_tensor Lhs, static_tensor Rhs>
116requires(same_dimensions<Lhs, Rhs>()) auto constexpr operator+(Lhs const& lhs,
117 Rhs const& rhs) {
118 return binary_operation(
119 [](auto&& l, auto&& r) {
120 using out_type =
121 common_type<std::decay_t<decltype(l)>, std::decay_t<decltype(r)>>;
122 return static_cast<out_type>(l) + static_cast<out_type>(r);
123 },
124 lhs, rhs);
125}
126//------------------------------------------------------------------------------
128template <static_tensor Lhs, static_tensor Rhs>
129requires(same_dimensions<Lhs, Rhs>()) auto constexpr operator-(Lhs const& lhs,
130 Rhs const& rhs) {
131 return binary_operation(
132 [](auto&& l, auto&& r) {
133 using out_type =
134 common_type<std::decay_t<decltype(l)>, std::decay_t<decltype(r)>>;
135 return static_cast<out_type>(l) - static_cast<out_type>(r);
136 },
137 lhs, rhs);
138}
139//------------------------------------------------------------------------------
140auto constexpr operator*(static_tensor auto const& t,
141 arithmetic_or_complex auto const scalar) {
142 return unary_operation(
143 [scalar](auto const& component) { return component * scalar; }, t);
144}
145// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
146auto constexpr operator*(arithmetic_or_complex auto const scalar,
147 static_tensor auto const& t) {
148 return unary_operation(
149 [scalar](auto const& component) { return component * scalar; }, t);
150}
151//------------------------------------------------------------------------------
152auto constexpr operator/(static_tensor auto const& t,
153 arithmetic_or_complex auto const scalar) {
154 return t * (1 / scalar);
155}
156// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
157auto constexpr operator/(arithmetic_or_complex auto const scalar,
158 static_tensor auto const& t) {
159 return unary_operation(
160 [scalar](auto const& component) { return scalar / component; }, t);
161}
162//------------------------------------------------------------------------------
163template <dynamic_tensor Lhs, dynamic_tensor Rhs>
164auto operator*(Lhs const& lhs, Rhs const& rhs) {
165 using out_t = tensor<
166 std::common_type_t<typename Lhs::value_type, typename Rhs::value_type>>;
167 auto out = out_t{};
168 // matrix-matrix-multiplication
169 if (lhs.rank() == 2 && rhs.rank() == 2 &&
170 lhs.dimension(1) == rhs.dimension(0)) {
171 auto out = out_t::zeros(lhs.dimension(0), rhs.dimension(1));
172 for (std::size_t r = 0; r < lhs.dimension(0); ++r) {
173 for (std::size_t c = 0; c < rhs.dimension(1); ++c) {
174 for (std::size_t i = 0; i < lhs.dimension(1); ++i) {
175 out(r, c) += lhs(r, i) * rhs(i, c);
176 }
177 }
178 }
179 return out;
180 }
181 // matrix-vector-multiplication
182 else if (lhs.rank() == 2 && rhs.rank() == 1 &&
183 lhs.dimension(1) == rhs.dimension(0)) {
184 auto out = out_t::zeros(lhs.dimension(0));
185 for (std::size_t r = 0; r < lhs.dimension(0); ++r) {
186 for (std::size_t i = 0; i < lhs.dimension(1); ++i) {
187 out(r) += lhs(r, i) * rhs(i);
188 }
189 }
190 return out;
191 }
192
193 std::stringstream A;
194 A << "[ " << lhs.dimension(0);
195 for (std::size_t i = 1; i < lhs.rank(); ++i) {
196 A << " x " << lhs.dimension(i);
197 }
198 A << " ]";
199 std::stringstream B;
200 B << "[ " << rhs.dimension(0);
201 for (std::size_t i = 1; i < rhs.rank(); ++i) {
202 B << " x " << rhs.dimension(i);
203 }
204 B << " ]";
205 throw std::runtime_error{"Cannot contract given dynamic tensors. (A:" +
206 A.str() + "; B" + B.str() + ")"};
207}
208//==============================================================================
209} // namespace tatooine
210//==============================================================================
211#endif
Definition: concepts.h:36
Definition: tensor_concepts.h:20
Definition: algorithm.h:6
typename common_type_impl< Ts... >::type common_type
Definition: common_type.h:23
auto constexpr operator/(Lhs const &lhs, Rhs const &rhs)
component-wise division
Definition: operator_overloads.h:103
auto for_loop_unpacked(Iteration &&iteration, execution_policy_tag auto policy, std::array< Int, N > const &sizes)
Definition: for_loop.h:480
constexpr auto operator*(diag_static_tensor< TensorA, M, N > const &A, static_vec auto const &b) -> vec< common_type< tatooine::value_type< TensorA >, tatooine::value_type< decltype(b)> >, M > requires(N==decltype(b)::dimension(0))
Definition: diag_tensor.h:141
auto constexpr operator-(static_tensor auto const &t)
Definition: operator_overloads.h:31
constexpr auto unary_operation(invocable< tatooine::value_type< Tensor > > auto &&op, Tensor const &t)
Definition: unary_operation.h:9
auto constexpr operator+(static_tensor auto const &lhs, arithmetic_or_complex auto const scalar)
Definition: operator_overloads.h:35
constexpr auto binary_operation(F &&f, Lhs const &lhs, Rhs const &rhs, std::index_sequence< Seq... >)
Definition: binary_operation.h:11
Definition: mat.h:14
Definition: tensor.h:17
Definition: vec.h:12