Tatooine
indexed_static_tensor.h
Go to the documentation of this file.
1#ifndef TATOOINE_EINSTEIN_NOTATION_INDEXED_STATIC_TENSOR_H
2#define TATOOINE_EINSTEIN_NOTATION_INDEXED_STATIC_TENSOR_H
3//==============================================================================
6
7#if TATOOINE_BLAS_AND_LAPACK_AVAILABLE
8#include <tatooine/blas.h>
9#endif
10#include <map>
11#include <tuple>
12//==============================================================================
14//==============================================================================
15template <typename... IndexedTensors>
16struct contracted_static_tensor;
17//==============================================================================
18template <typename... ContractedTensors>
19struct added_contracted_static_tensor;
20//==============================================================================
21template <static_tensor Tensor, index... Indices>
23 static_assert(std::decay_t<Tensor>::rank() == sizeof...(Indices));
24 using tensor_type = std::decay_t<Tensor>;
25 using value_type = typename tensor_type::value_type;
26
27 private:
29
30 public:
32
33 auto tensor() const -> auto const& { return m_tensor; }
34 auto tensor() -> auto& { return m_tensor; }
35
36 using indices = type_list<Indices...>;
37 template <std::size_t I>
38 using index_at = typename indices::template at<I>;
39
40 static auto index_map() {
41 return index_map(std::make_index_sequence<rank()>{});
42 }
43 template <std::size_t... Seq>
44 static auto constexpr index_map(std::index_sequence<Seq...> /*seq*/) {
45 return std::array{Indices::get()...};
46 }
47 static auto constexpr rank() { return tensor_type::rank(); }
48 template <std::size_t I>
49 static auto constexpr size() {
50 return tensor_type::template size<I>();
51 }
52
53 private:
54 template <std::size_t I, typename E, typename HeadIndex,
55 typename... TailIndices>
56 static auto constexpr size_() {
57 if constexpr (is_same<E, HeadIndex>) {
58 return tensor_type::dimension(I);
59 } else {
60 return size_<I + 1, E, TailIndices...>();
61 }
62 }
63
64 public:
65 template <typename E>
66 static auto constexpr size() {
67 return size_<0, E, Indices...>();
68 }
69
70 template <typename E>
71 static auto constexpr contains() -> bool {
72 return type_list<Indices...>::template contains<E>;
73 }
74 //============================================================================
75 template <typename... ContractedTensors, std::size_t... Seq>
76 requires(!is_const<std::remove_reference_t<Tensor>>)
78 std::index_sequence<Seq...> /*seq*/) {
79 ([&] { *this += other.template at<Seq>(); }(), ...);
80 }
81 //----------------------------------------------------------------------------
82 template <typename... IndexedTensors, std::size_t... FreeIndexSequence,
83 std::size_t... ContractedIndexSequence,
84 std::size_t... ContractedTensorsSequence>
85 requires(!is_const<std::remove_reference_t<Tensor>>)
87 std::index_sequence<FreeIndexSequence...> /*seq*/,
88 std::index_sequence<ContractedIndexSequence...> /*seq*/,
89 std::index_sequence<ContractedTensorsSequence...> /*seq*/) {
90 using map_t = std::map<std::size_t, std::size_t>;
92 contracted_static_tensor<IndexedTensors...>;
94 using contracted_indices =
96
97 auto const free_indices_map = map_t{
98 map_t::value_type{free_indices::template at<FreeIndexSequence>::get(),
99 FreeIndexSequence}...};
100 auto const contracted_indices_map = map_t{map_t::value_type{
101 contracted_indices::template at<ContractedIndexSequence>::get(),
102 ContractedIndexSequence,
103 }...};
104 auto const tensor_index_maps = std::tuple{IndexedTensors::index_map()...};
105 auto index_arrays =
106 std::tuple{make_array<std::size_t, IndexedTensors::rank()>()...};
107
108 for_loop(
109 [&](auto const... free_indices) {
110 // setup indices of single tensors for free indices
111 {
112 auto const free_index_array = std::array{free_indices...};
113 (
114 [&] {
115 auto& index_array =
116 std::get<ContractedTensorsSequence>(index_arrays);
117 auto const& tensor_index_map =
118 std::get<ContractedTensorsSequence>(tensor_index_maps);
119 auto index_arr_it = begin(index_array);
120 auto tensor_index_map_it = begin(tensor_index_map);
121
122 for (; tensor_index_map_it != end(tensor_index_map);
123 ++tensor_index_map_it, ++index_arr_it) {
124 if (free_indices_map.contains(*tensor_index_map_it)) {
125 *index_arr_it = free_index_array[free_indices_map.at(
126 *tensor_index_map_it)];
127 }
128 }
129 }(),
130 ...);
131 }
132 if constexpr (contracted_indices::empty) {
134 (other.template at<ContractedTensorsSequence>().tensor()(
135 std::get<ContractedTensorsSequence>(index_arrays)) *
136 ...);
137 } else {
138 for_loop(
139 [&](auto const... contracted_indices) {
140 // setup indices of single tensors for contracted indices
141 {
142 auto const contracted_index_array =
143 std::array{contracted_indices...};
144 (
145 [&] {
146 auto& index_array =
147 std::get<ContractedTensorsSequence>(index_arrays);
148 auto const& tensor_index_map =
149 std::get<ContractedTensorsSequence>(
150 tensor_index_maps);
151 auto index_arr_it = begin(index_array);
152 auto tensor_index_map_it = begin(tensor_index_map);
153
154 for (; tensor_index_map_it != end(tensor_index_map);
155 ++tensor_index_map_it, ++index_arr_it) {
156 if (contracted_indices_map.contains(
157 *tensor_index_map_it)) {
158 *index_arr_it = contracted_index_array
159 [contracted_indices_map.at(
160 *tensor_index_map_it)];
161 }
162 }
163 }(),
164 ...);
165 }
166
168 (other.template at<ContractedTensorsSequence>().tensor()(
169 std::get<ContractedTensorsSequence>(index_arrays)) *
170 ...);
171 },
172 contracted_static_tensor::template size<
173 typename contracted_indices::template at<
174 ContractedIndexSequence>>()...);
175 }
176 },
177 tensor_type::dimension(FreeIndexSequence)...);
178 }
179 //----------------------------------------------------------------------------
180 template <typename... IndexedTensors>
181 requires(!is_const<std::remove_reference_t<Tensor>>)
184 add(other, std::make_index_sequence<rank()>{},
185 std::make_index_sequence<contracted_static_tensor<
186 IndexedTensors...>::contracted_indices::size>{},
187 std::make_index_sequence<sizeof...(IndexedTensors)>{});
188 return *this;
189 }
190 //----------------------------------------------------------------------------
191 template <typename... IndexedTensors>
192 requires(!is_const<std::remove_reference_t<Tensor>>)
194 m_tensor = tensor_type::zeros();
195 add(other, std::make_index_sequence<rank()>{},
196 std::make_index_sequence<contracted_static_tensor<
197 IndexedTensors...>::contracted_indices::size>{},
198 std::make_index_sequence<sizeof...(IndexedTensors)>{});
199 }
200 // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
201 template <typename... IndexedTensors>
202 requires(!is_const<std::remove_reference_t<Tensor>>)
205 assign(other);
206 return *this;
207 }
208 // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
209 template <typename... ContractedTensors>
210 requires(!is_const<std::remove_reference_t<Tensor>>)
213 m_tensor = tensor_type::zeros();
214 assign(other, std::make_index_sequence<sizeof...(ContractedTensors)>{});
215 return *this;
216 }
217 // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
218 template <typename Tensors, typename... Is>
219 requires(!is_const<std::remove_reference_t<Tensor>>)
222 m_tensor = tensor_type::zeros();
223 *this += contracted_static_tensor{other};
224 return *this;
225 }
226#if TATOOINE_BLAS_AND_LAPACK_AVAILABLE
227 // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
229 template <typename LHS, typename RHS, typename I, typename J, typename K>
232 other) -> indexed_static_tensor& {
233 if constexpr (is_same<I, index_at<0>> && is_same<K, index_at<1>>) {
234 using comp_type = typename tensor_type::value_type;
235 static_assert(is_same<comp_type, typename std::decay_t<LHS>::value_type>);
236 static_assert(is_same<comp_type, typename std::decay_t<RHS>::value_type>);
237 blas::gemm(comp_type(1), other.template at<0>().tensor(),
238 other.template at<1>().tensor(), comp_type(0), m_tensor);
239 } else {
240 assign(other);
241 }
242 return *this;
243 }
244#endif
245};
246//==============================================================================
247} // namespace tatooine::einstein_notation
248//==============================================================================
249#endif
auto gemm(op TRANSA, op TRANSB, int M, int N, int K, Float ALPHA, Float const *A, int LDA, Float const *B, int LDB, Float BETA, Float *C, int LDC) -> void
Definition: gemm.h:41
Definition: added_contracted_dynamic_tensor.h:4
typename contracted_indices_aux< indexed_tensors_to_index_list< Indices... > >::type contracted_indices
Definition: type_traits.h:90
typename free_indices_aux< indexed_tensors_to_index_list< Indices... > >::type free_indices
Definition: type_traits.h:59
static auto constexpr t
Definition: index.h:24
auto begin(Range &&range)
Definition: iterator_facade.h:318
auto end(Range &&range)
Definition: iterator_facade.h:322
tensor< real_number, Dimensions... > Tensor
Definition: tensor.h:184
auto constexpr index(handle< Child, Int > const h)
Definition: handle.h:119
constexpr auto for_loop(Iteration &&iteration, execution_policy::sequential_t, Ranges(&&... ranges)[2]) -> void
Use this function for creating a sequential nested loop.
Definition: for_loop.h:336
Definition: added_contracted_static_tensor.h:7
Definition: contracted_static_tensor.h:9
tatooine::einstein_notation::free_indices< IndexedTensors... > free_indices
Definition: contracted_static_tensor.h:16
tatooine::einstein_notation::contracted_indices< IndexedTensors... > contracted_indices
Definition: contracted_static_tensor.h:18
Definition: indexed_static_tensor.h:22
typename tensor_type::value_type value_type
Definition: indexed_static_tensor.h:25
static auto constexpr rank()
Definition: indexed_static_tensor.h:47
auto tensor() const -> auto const &
Definition: indexed_static_tensor.h:33
static auto index_map()
Definition: indexed_static_tensor.h:40
typename indices::template at< I > index_at
Definition: indexed_static_tensor.h:38
auto assign(contracted_static_tensor< IndexedTensors... > other)
Definition: indexed_static_tensor.h:193
static auto constexpr size()
Definition: indexed_static_tensor.h:49
auto tensor() -> auto &
Definition: indexed_static_tensor.h:34
auto assign(added_contracted_static_tensor< ContractedTensors... > other, std::index_sequence< Seq... >)
Definition: indexed_static_tensor.h:77
static auto constexpr index_map(std::index_sequence< Seq... >)
Definition: indexed_static_tensor.h:44
auto add(contracted_static_tensor< IndexedTensors... > other, std::index_sequence< FreeIndexSequence... >, std::index_sequence< ContractedIndexSequence... >, std::index_sequence< ContractedTensorsSequence... >)
Definition: indexed_static_tensor.h:86
std::decay_t< Tensor > tensor_type
Definition: indexed_static_tensor.h:24
auto operator=(contracted_static_tensor< indexed_static_tensor< LHS, I, J >, indexed_static_tensor< RHS, J, K > > other) -> indexed_static_tensor &
A(i,k) = B(i,j) * C(j, k)
Definition: indexed_static_tensor.h:230
static auto constexpr contains() -> bool
Definition: indexed_static_tensor.h:71
static auto constexpr size_()
Definition: indexed_static_tensor.h:56
Tensor m_tensor
Definition: indexed_static_tensor.h:28
indexed_static_tensor(Tensor t)
Definition: indexed_static_tensor.h:31
An empty struct that holds types.
Definition: type_list.h:248