1#ifndef TATOOINE_TENSOR_OPERATIONS_OPERATOR_OVERLOADS_H
2#define TATOOINE_TENSOR_OPERATIONS_OPERATOR_OVERLOADS_H
9template <static_tensor Lhs, static_tensor Rhs>
10requires(same_dimensions<Lhs, Rhs>())
auto constexpr operator==(
11 Lhs
const& lhs, Rhs
const& rhs) {
14 [&](
auto const... is) {
15 if (lhs(is...) != rhs(is...)) {
20 tensor_dimensions<Lhs>);
25template <static_tensor Lhs, static_tensor Rhs>
26requires(same_dimensions<Lhs, Rhs>())
auto constexpr operator!=(
27 Lhs
const& lhs, Rhs
const& rhs) {
37 return unary_operation([scalar](
auto const& c) {
return c + scalar; }, lhs);
41template <static_mat Lhs, static_mat Rhs>
42requires(Lhs::dimension(1) == Rhs::dimension(0))
43auto constexpr operator*(Lhs
const& lhs, Rhs
const& rhs) {
46 Lhs::dimension(0), Rhs::dimension(1)>{};
47 for (std::size_t r = 0; r < Lhs::dimension(0); ++r) {
48 for (std::size_t c = 0; c < Rhs::dimension(1); ++c) {
49 for (std::size_t i = 0; i < Lhs::dimension(1); ++i) {
50 product(r, c) += lhs(r, i) * rhs(i, c);
58template <static_mat Lhs, static_vec Rhs>
59requires(Lhs::dimension(1) == Rhs::dimension(0))
60auto constexpr operator*(Lhs
const& lhs, Rhs
const& rhs) {
62 vec<common_type<typename Lhs::value_type, typename Rhs::value_type>,
64 for (std::size_t j = 0; j < Lhs::dimension(0); ++j) {
65 for (std::size_t i = 0; i < Lhs::dimension(1); ++i) {
66 product(j) += lhs(j, i) * rhs(i);
73template <static_vec Lhs, static_mat Rhs>
74requires(Lhs::dimension(0) == Rhs::dimension(0))
75auto constexpr operator*(Lhs
const& lhs, Rhs
const& rhs) {
80 for (std::size_t j = 0; j < Rhs::dimension(0); ++j) {
81 for (std::size_t i = 0; i < Rhs::dimension(1); ++i) {
82 product += lhs(i) * rhs(i, j);
89template <static_tensor Lhs, static_tensor Rhs>
90requires(same_dimensions<Lhs, Rhs>() && Lhs::rank() != 2 && Rhs::rank() != 2)
91auto constexpr operator*(Lhs
const& lhs, Rhs
const& rhs) {
93 [](
auto&& l,
auto&& r) {
95 common_type<std::decay_t<
decltype(l)>, std::decay_t<
decltype(r)>>;
96 return static_cast<out_type
>(l) *
static_cast<out_type
>(r);
102template <static_tensor Lhs, static_tensor Rhs>
103requires(same_dimensions<Lhs, Rhs>())
auto constexpr operator/(Lhs
const& lhs,
106 [](
auto&& l,
auto&& r) {
108 common_type<std::decay_t<
decltype(l)>, std::decay_t<
decltype(r)>>;
109 return static_cast<out_type
>(l) /
static_cast<out_type
>(r);
115template <static_tensor Lhs, static_tensor Rhs>
116requires(same_dimensions<Lhs, Rhs>())
auto constexpr operator+(Lhs
const& lhs,
119 [](
auto&& l,
auto&& r) {
121 common_type<std::decay_t<
decltype(l)>, std::decay_t<
decltype(r)>>;
122 return static_cast<out_type
>(l) +
static_cast<out_type
>(r);
128template <static_tensor Lhs, static_tensor Rhs>
129requires(same_dimensions<Lhs, Rhs>())
auto constexpr operator-(Lhs
const& lhs,
132 [](
auto&& l,
auto&& r) {
134 common_type<std::decay_t<
decltype(l)>, std::decay_t<
decltype(r)>>;
135 return static_cast<out_type
>(l) -
static_cast<out_type
>(r);
143 [scalar](
auto const& component) {
return component * scalar; }, t);
149 [scalar](
auto const& component) {
return component * scalar; }, t);
154 return t * (1 / scalar);
160 [scalar](
auto const& component) {
return scalar / component; }, t);
163template <dynamic_tensor Lhs, dynamic_tensor Rhs>
166 std::common_type_t<typename Lhs::value_type, typename Rhs::value_type>>;
169 if (lhs.rank() == 2 && rhs.rank() == 2 &&
170 lhs.dimension(1) == rhs.dimension(0)) {
171 auto out = out_t::zeros(lhs.dimension(0), rhs.dimension(1));
172 for (std::size_t r = 0; r < lhs.dimension(0); ++r) {
173 for (std::size_t c = 0; c < rhs.dimension(1); ++c) {
174 for (std::size_t i = 0; i < lhs.dimension(1); ++i) {
175 out(r, c) += lhs(r, i) * rhs(i, c);
182 else if (lhs.rank() == 2 && rhs.rank() == 1 &&
183 lhs.dimension(1) == rhs.dimension(0)) {
184 auto out = out_t::zeros(lhs.dimension(0));
185 for (std::size_t r = 0; r < lhs.dimension(0); ++r) {
186 for (std::size_t i = 0; i < lhs.dimension(1); ++i) {
187 out(r) += lhs(r, i) * rhs(i);
194 A <<
"[ " << lhs.dimension(0);
195 for (std::size_t i = 1; i < lhs.rank(); ++i) {
196 A <<
" x " << lhs.dimension(i);
200 B <<
"[ " << rhs.dimension(0);
201 for (std::size_t i = 1; i < rhs.rank(); ++i) {
202 B <<
" x " << rhs.dimension(i);
205 throw std::runtime_error{
"Cannot contract given dynamic tensors. (A:" +
206 A.str() +
"; B" + B.str() +
")"};
Definition: concepts.h:36
Definition: tensor_concepts.h:20
Definition: algorithm.h:6
typename common_type_impl< Ts... >::type common_type
Definition: common_type.h:23
auto constexpr operator/(Lhs const &lhs, Rhs const &rhs)
component-wise division
Definition: operator_overloads.h:103
auto for_loop_unpacked(Iteration &&iteration, execution_policy_tag auto policy, std::array< Int, N > const &sizes)
Definition: for_loop.h:480
constexpr auto operator*(diag_static_tensor< TensorA, M, N > const &A, static_vec auto const &b) -> vec< common_type< tatooine::value_type< TensorA >, tatooine::value_type< decltype(b)> >, M > requires(N==decltype(b)::dimension(0))
Definition: diag_tensor.h:141
auto constexpr operator-(static_tensor auto const &t)
Definition: operator_overloads.h:31
constexpr auto unary_operation(invocable< tatooine::value_type< Tensor > > auto &&op, Tensor const &t)
Definition: unary_operation.h:9
auto constexpr operator+(static_tensor auto const &lhs, arithmetic_or_complex auto const scalar)
Definition: operator_overloads.h:35
constexpr auto binary_operation(F &&f, Lhs const &lhs, Rhs const &rhs, std::index_sequence< Seq... >)
Definition: binary_operation.h:11