Tatooine
binary_operation.h
Go to the documentation of this file.
1#ifndef TATOOINE_TENSOR_OPERATIONS_BINARY_OPERATION_H
2#define TATOOINE_TENSOR_OPERATIONS_BINARY_OPERATION_H
3//==============================================================================
6//==============================================================================
7namespace tatooine {
8//==============================================================================
9template <typename F, static_tensor Lhs, static_tensor Rhs, std::size_t... Seq>
10requires(same_dimensions<Lhs, Rhs>())
11constexpr auto binary_operation(
12 F&& f, Lhs const& lhs, Rhs const& rhs,
13 std::index_sequence<Seq...> /*seq*/) {
14 using TOut =
15 std::invoke_result_t<F, tatooine::value_type<Lhs>, tatooine::value_type<Rhs>>;
16 auto constexpr rank = tensor_rank<Lhs>;
17 auto t_out = [&] {
18 if constexpr (rank == 1) {
19 return vec<TOut, Lhs::dimension(Seq)...>{};
20 } else if constexpr (rank == 2) {
21 return mat<TOut, Lhs::dimension(Seq)...>{};
22 } else {
23 return tensor<TOut, Lhs::dimension(Seq)...>{};
24 };
25 }();
26 t_out.for_indices(
27 [&](auto const... is) { t_out(is...) = f(lhs(is...), rhs(is...)); });
28 return t_out;
29}
30//------------------------------------------------------------------------------
31template <typename F, static_tensor Lhs, static_tensor Rhs>
32requires(same_dimensions<Lhs, Rhs>()) constexpr auto binary_operation(
33 F&& f, Lhs const& lhs, Rhs const& rhs) {
34 return binary_operation(std::forward<F>(f), lhs, rhs,
35 std::make_index_sequence<Lhs::rank()>{});
36}
37//==============================================================================
38} // namespace tatooine
39//==============================================================================
40#endif
Definition: algorithm.h:6
typename value_type_impl< T >::type value_type
Definition: type_traits.h:280
constexpr auto binary_operation(F &&f, Lhs const &lhs, Rhs const &rhs, std::index_sequence< Seq... >)
Definition: binary_operation.h:11
constexpr auto rank()
Definition: rank.h:10
static auto constexpr for_indices(invocable< decltype(Dims)... > auto &&f)
Definition: base_tensor.h:58
Definition: mat.h:14
Definition: tensor.h:17
Definition: vec.h:12