Tatooine
indexed_dynamic_tensor.h
Go to the documentation of this file.
1#ifndef TATOOINE_EINSTEIN_NOTATION_INDEXED_DYNAMIC_TENSOR_H
2#define TATOOINE_EINSTEIN_NOTATION_INDEXED_DYNAMIC_TENSOR_H
3//==============================================================================
4#if TATOOINE_BLAS_AND_LAPACK_AVAILABLE
5#include <tatooine/blas.h>
6#endif
8#include <tatooine/for_loop.h>
12#include <tatooine/type_list.h>
13#include <tatooine/type_set.h>
14
15#include <map>
16#include <tuple>
17//==============================================================================
19//==============================================================================
20template <typename... IndexedTensors>
21struct contracted_dynamic_tensor;
22//==============================================================================
23template <typename... ContractedTensors>
24struct added_contracted_dynamic_tensor;
25//==============================================================================
26template <dynamic_tensor Tensor, index... Indices>
27requires(sizeof...(Indices) > 0)
29 using tensor_type = std::decay_t<Tensor>;
31
32 private:
34
35 public:
36 explicit indexed_dynamic_tensor(Tensor t) : m_tensor{t} {}
37
38 auto tensor() const -> auto const& { return m_tensor; }
39 auto tensor() -> auto& { return m_tensor; }
40
41 using indices = type_list<Indices...>;
42 template <std::size_t I>
43 using index_at = typename indices::template at<I>;
44
45 static auto index_map() {
46 return index_map(std::make_index_sequence<rank()>{});
47 }
48 template <std::size_t... Seq>
49 static auto constexpr index_map(std::index_sequence<Seq...> /*seq*/) {
50 return std::array{Indices::get()...};
51 }
52 static auto constexpr rank() { return sizeof...(Indices); }
53 auto dimension(std::size_t const i) {
54 return m_tensor.dimension(i);
55 }
56
57 private:
58 template <std::size_t I, index E, index HeadIndex,
59 index... TailIndices>
60 auto dimension_() const {
61 if constexpr (is_same<E, HeadIndex>) {
62 return m_tensor.dimension(I);
63 } else {
64 return dimension_<I + 1, E, TailIndices...>();
65 }
66 }
67
68 public:
69 template <index E>
70 auto dimension() const {
71 return dimension_<0, E, Indices...>();
72 }
73
74 template <index E>
75 static auto constexpr contains() -> bool {
76 return type_list<Indices...>::template contains<E>;
77 }
78 //============================================================================
79 template <typename... ContractedTensors, std::size_t... Seq>
81 std::index_sequence<Seq...> /*seq*/)
82 requires(!is_const<std::remove_reference_t<Tensor>>)
83 {
84 ([&] { *this += other.template at<Seq>(); }(), ...);
85 }
86 //----------------------------------------------------------------------------
87 template <typename... IndexedTensors, std::size_t... FreeIndexSequence,
88 std::size_t... ContractedIndexSequence,
89 std::size_t... ContractedTensorsSequence>
91 std::index_sequence<FreeIndexSequence...> /*seq*/,
92 std::index_sequence<ContractedIndexSequence...> /*seq*/,
93 std::index_sequence<ContractedTensorsSequence...> /*seq*/)
94 requires(!is_const<std::remove_reference_t<Tensor>>)
95 {
96 using map_t = std::map<std::size_t, std::size_t>;
100
101 auto const free_indices_map = map_t{
102 map_t::value_type{free_indices::template at<FreeIndexSequence>::get(),
103 FreeIndexSequence}...};
104 auto c = std::array{ContractedTensorsSequence...};
105 auto const contracted_indices_map = map_t{map_t::value_type{
106 contracted_indices::template at<ContractedIndexSequence>::get(),
107 ContractedIndexSequence,
108 }...};
109 auto const tensor_index_maps = std::tuple{IndexedTensors::index_map()...};
110 auto index_arrays =
111 std::tuple{make_array<std::size_t, IndexedTensors::rank()>()...};
112
113 for_loop(
114 [&](auto const... free_indices) {
115 // setup indices of single tensors for free indices
116 {
117 auto const free_index_array = std::array{free_indices...};
118 (
119 [&] {
120 auto& index_array =
121 std::get<ContractedTensorsSequence>(index_arrays);
122 auto const& tensor_index_map =
123 std::get<ContractedTensorsSequence>(tensor_index_maps);
124 auto index_arr_it = begin(index_array);
125 auto tensor_index_map_it = begin(tensor_index_map);
126
127 for (; tensor_index_map_it != end(tensor_index_map);
128 ++tensor_index_map_it, ++index_arr_it) {
129 if (free_indices_map.contains(*tensor_index_map_it)) {
130 *index_arr_it = free_index_array[free_indices_map.at(
131 *tensor_index_map_it)];
132 }
133 }
134 }(),
135 ...);
136 }
137 if constexpr (contracted_indices::empty) {
138 m_tensor(free_indices...) +=
139 (other.template at<ContractedTensorsSequence>().tensor()(
140 std::get<ContractedTensorsSequence>(index_arrays)) *
141 ...);
142 } else {
143 for_loop(
144 [&](auto const... contracted_indices) {
145 // setup indices of single tensors for contracted indices
146 {
147 auto const contracted_index_array =
148 std::array{contracted_indices...};
149 (
150 [&] {
151 auto& index_array =
152 std::get<ContractedTensorsSequence>(index_arrays);
153 auto const& tensor_index_map =
154 std::get<ContractedTensorsSequence>(
155 tensor_index_maps);
156 auto index_arr_it = begin(index_array);
157 auto tensor_index_map_it = begin(tensor_index_map);
158
159 for (; tensor_index_map_it != end(tensor_index_map);
160 ++tensor_index_map_it, ++index_arr_it) {
161 if (contracted_indices_map.contains(
162 *tensor_index_map_it)) {
163 *index_arr_it = contracted_index_array
164 [contracted_indices_map.at(
165 *tensor_index_map_it)];
166 }
167 }
168 }(),
169 ...);
170 }
171
172 auto const f = std::array{free_indices...};
173 auto const g = std::array{
174 std::get<ContractedTensorsSequence>(index_arrays)...};
175 m_tensor(free_indices...) +=
176 (other.template at<ContractedTensorsSequence>().tensor()(
177 std::get<ContractedTensorsSequence>(index_arrays)) *
178 ...);
179 },
180 other.template dimension<typename contracted_indices::template at<ContractedIndexSequence>>()...);
181 }
182 },
183 m_tensor.dimension(FreeIndexSequence)...);
184 }
185 //----------------------------------------------------------------------------
186 template <typename... IndexedTensors>
189 requires(!is_const<std::remove_reference_t<Tensor>>)
190 {
191 return add(other, std::make_index_sequence<rank()>{},
192 std::make_index_sequence<contracted_dynamic_tensor<
193 IndexedTensors...>::contracted_indices::size>{},
194 std::make_index_sequence<rank()>{});
195 }
196 //----------------------------------------------------------------------------
197 template <typename... IndexedTensors>
200 requires(!is_const<std::remove_reference_t<Tensor>>) {
201 return add(other);
202 }
203 //----------------------------------------------------------------------------
204 template <typename... IndexedTensors, index T, index... Ts>
206 type_set_impl<T, Ts...> const /*ts*/,
207 std::vector<std::size_t>& size) {
208 size.push_back(other.template dimension<T>());
209 resize_internal_tensor(other, type_set_impl<Ts...>{}, size);
210 }
211 //----------------------------------------------------------------------------
212 template <typename... IndexedTensors>
215 type_set_impl<> const /*ts*/, std::vector<std::size_t>& size) {
216 m_tensor = tensor_type::zeros(size);
217 }
218 //----------------------------------------------------------------------------
219 template <typename... IndexedTensors>
222 auto size = std::vector<std::size_t>{};
223 resize_internal_tensor(
224 other,
226 size);
227 }
228 //----------------------------------------------------------------------------
229 template <typename... IndexedTensors>
232 requires(!is_const<std::remove_reference_t<Tensor>>) {
233 resize_internal_tensor(
234 other);
235 return add(other, std::make_index_sequence<rank()>{},
236 std::make_index_sequence<
237 contracted_dynamic_tensor<IndexedTensors...>::contracted_indices::size>{},
238 std::make_index_sequence<rank()>{});
239 }
240 // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
241 template <typename... IndexedTensors>
244 requires(!is_const<std::remove_reference_t<Tensor>>) {
245 assign(other);
246 return *this;
247 }
248 // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
249 template <typename... ContractedTensors>
252 requires(!is_const<std::remove_reference_t<Tensor>>) {
253 m_tensor = tensor_type::zeros();
254 assign(other, std::make_index_sequence<sizeof...(ContractedTensors)>{});
255 return *this;
256 }
257 // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
258 template <typename Tensors, typename... Is>
261 requires(!is_const<std::remove_reference_t<Tensor>>) {
262 m_tensor = tensor_type::zeros();
263 *this += contracted_dynamic_tensor{other};
264 return *this;
265 }
266 // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
267#if TATOOINE_BLAS_AND_LAPACK_AVAILABLE
273 template <typename LHS, typename RHS, same_as<index_at<0>> I, index J,
274 same_as<index_at<1>> K>
277 other)
278 -> indexed_dynamic_tensor& requires(
279 !is_const<std::remove_reference_t<Tensor>> &&
280 is_same<value_type, tatooine::value_type<LHS>> &&
281 is_same<value_type, tatooine::value_type<RHS>>) {
282 m_tensor.resize(other.template at<0>().tensor().dimension(0),
283 other.template at<1>().tensor().dimension(1));
284 blas::gemm(value_type(1), other.template at<0>().tensor(),
285 other.template at<1>().tensor(), value_type(1), m_tensor);
286 return *this;
287 }
288 // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
294 template <typename LHS, typename RHS, same_as<index_at<0>> I, index J,
295 same_as<index_at<1>> K>
298 other)
299 ->indexed_dynamic_tensor& requires(
300 !is_const<std::remove_reference_t<Tensor>> &&
301 is_same<value_type, tatooine::value_type<LHS>> &&
302 is_same<value_type, tatooine::value_type<RHS>>) {
303 m_tensor.resize(other.template at<0>().tensor().dimension(0),
304 other.template at<1>().tensor().dimension(1));
305 blas::gemm(value_type(1), other.template at<0>().tensor(),
306 other.template at<1>().tensor(), value_type(0), m_tensor);
307 return *this;
308 }
309 //----------------------------------------------------------------------------
311 template <typename LHS, typename RHS, typename I, typename J>
314 other)
316 requires(!is_const<std::remove_reference_t<Tensor>> &&
317 is_same<value_type, tatooine::value_type<LHS>> &&
318 is_same<value_type, tatooine::value_type<RHS>> &&
319 is_same<I, index_at<0>>) {
320 assert(m_tensor.dimension(0) ==
321 other.template at<0>().tensor().dimension(0));
322 blas::gemm(value_type(1), other.template at<0>().tensor(),
323 other.template at<1>().tensor(), value_type(1), m_tensor);
324 return *this;
325 }
326 //----------------------------------------------------------------------------
328 template <typename LHS, typename RHS, typename I, typename J>
331 other)
332 -> indexed_dynamic_tensor& requires(
333 !is_const<std::remove_reference_t<Tensor>> &&
334 is_same<value_type, tatooine::value_type<LHS>> &&
335 is_same<value_type, tatooine::value_type<RHS>> &&
336 is_same<I, index_at<0>>) {
337 m_tensor.resize(other.template at<0>().tensor().dimension(0));
338 blas::gemv(value_type(1), other.template at<0>().tensor(),
339 other.template at<1>().tensor(), value_type(0), m_tensor);
340 return *this;
341 }
342#endif
343};
344//==============================================================================
345} // namespace tatooine::einstein_notation
346//==============================================================================
347#endif
auto gemm(op TRANSA, op TRANSB, int M, int N, int K, Float ALPHA, Float const *A, int LDA, Float const *B, int LDB, Float BETA, Float *C, int LDC) -> void
Definition: gemm.h:41
auto gemv(op TRANS, int M, int N, Float ALPHA, Float const *A, int LDA, Float const *X, int INCX, Float BETA, Float *Y, int INCY) -> void
Definition: gemv.h:37
Definition: added_contracted_dynamic_tensor.h:4
static auto constexpr i
Definition: index.h:13
auto add(added_contracted_static_tensor< ContractedTensorsLHS... > lhs, contracted_static_tensor< TensorsRHS... > rhs, std::index_sequence< Seq... >)
Definition: operator_overloads.h:128
typename contracted_indices_aux< indexed_tensors_to_index_list< Indices... > >::type contracted_indices
Definition: type_traits.h:90
typename free_indices_aux< indexed_tensors_to_index_list< Indices... > >::type free_indices
Definition: type_traits.h:59
static auto constexpr t
Definition: index.h:24
auto begin(Range &&range)
Definition: iterator_facade.h:318
typename value_type_impl< T >::type value_type
Definition: type_traits.h:280
auto end(Range &&range)
Definition: iterator_facade.h:322
tensor< real_number, Dimensions... > Tensor
Definition: tensor.h:184
auto constexpr index(handle< Child, Int > const h)
Definition: handle.h:119
auto size(vec< ValueType, N > const &v)
Definition: vec.h:148
constexpr auto for_loop(Iteration &&iteration, execution_policy::sequential_t, Ranges(&&... ranges)[2]) -> void
Use this function for creating a sequential nested loop.
Definition: for_loop.h:336
constexpr auto rank()
Definition: rank.h:10
Definition: added_contracted_dynamic_tensor.h:7
Definition: contracted_dynamic_tensor.h:9
tatooine::einstein_notation::contracted_indices< IndexedTensors... > contracted_indices
Definition: contracted_dynamic_tensor.h:18
tatooine::einstein_notation::free_indices< IndexedTensors... > free_indices
Definition: contracted_dynamic_tensor.h:16
Definition: indexed_dynamic_tensor.h:28
auto operator=(contracted_dynamic_tensor< IndexedTensors... > other) -> indexed_dynamic_tensor &requires(!is_const< std::remove_reference_t< Tensor > >)
Definition: indexed_dynamic_tensor.h:242
auto assign(added_contracted_dynamic_tensor< ContractedTensors... > other, std::index_sequence< Seq... >)
Definition: indexed_dynamic_tensor.h:80
auto operator+=(contracted_dynamic_tensor< indexed_dynamic_tensor< LHS, I, J >, indexed_dynamic_tensor< RHS, J, K > > other) -> indexed_dynamic_tensor &requires(!is_const< std::remove_reference_t< Tensor > > &&is_same< value_type, tatooine::value_type< LHS > > &&is_same< value_type, tatooine::value_type< RHS > >)
Definition: indexed_dynamic_tensor.h:275
auto assign(contracted_dynamic_tensor< IndexedTensors... > other) -> indexed_dynamic_tensor &requires(!is_const< std::remove_reference_t< Tensor > >)
Definition: indexed_dynamic_tensor.h:230
static auto index_map()
Definition: indexed_dynamic_tensor.h:45
auto tensor() -> auto &
Definition: indexed_dynamic_tensor.h:39
auto tensor() const -> auto const &
Definition: indexed_dynamic_tensor.h:38
tatooine::value_type< tensor_type > value_type
Definition: indexed_dynamic_tensor.h:30
auto operator=(indexed_dynamic_tensor< Tensors, Is... > other) -> indexed_dynamic_tensor &requires(!is_const< std::remove_reference_t< Tensor > >)
Definition: indexed_dynamic_tensor.h:259
auto dimension() const
Definition: indexed_dynamic_tensor.h:70
static auto constexpr index_map(std::index_sequence< Seq... >)
Definition: indexed_dynamic_tensor.h:49
auto operator=(contracted_dynamic_tensor< indexed_dynamic_tensor< LHS, I, J >, indexed_dynamic_tensor< RHS, J, K > > other) -> indexed_dynamic_tensor &requires(!is_const< std::remove_reference_t< Tensor > > &&is_same< value_type, tatooine::value_type< LHS > > &&is_same< value_type, tatooine::value_type< RHS > >)
Definition: indexed_dynamic_tensor.h:296
auto resize_internal_tensor(contracted_dynamic_tensor< IndexedTensors... > other)
Definition: indexed_dynamic_tensor.h:220
auto operator=(added_contracted_dynamic_tensor< ContractedTensors... > other) -> indexed_dynamic_tensor &requires(!is_const< std::remove_reference_t< Tensor > >)
Definition: indexed_dynamic_tensor.h:250
auto add(contracted_dynamic_tensor< IndexedTensors... > other) -> indexed_dynamic_tensor &requires(!is_const< std::remove_reference_t< Tensor > >)
Definition: indexed_dynamic_tensor.h:187
auto dimension_() const
Definition: indexed_dynamic_tensor.h:60
Tensor m_tensor
Definition: indexed_dynamic_tensor.h:33
auto operator+=(contracted_dynamic_tensor< indexed_dynamic_tensor< LHS, I, J >, indexed_dynamic_tensor< RHS, J > > other) -> indexed_dynamic_tensor &requires(!is_const< std::remove_reference_t< Tensor > > &&is_same< value_type, tatooine::value_type< LHS > > &&is_same< value_type, tatooine::value_type< RHS > > &&is_same< I, index_at< 0 > >)
Definition: indexed_dynamic_tensor.h:312
auto add(contracted_dynamic_tensor< IndexedTensors... > other, std::index_sequence< FreeIndexSequence... >, std::index_sequence< ContractedIndexSequence... >, std::index_sequence< ContractedTensorsSequence... >)
Definition: indexed_dynamic_tensor.h:90
auto resize_internal_tensor(contracted_dynamic_tensor< IndexedTensors... > other, type_set_impl< T, Ts... > const, std::vector< std::size_t > &size)
Definition: indexed_dynamic_tensor.h:205
auto operator+=(contracted_dynamic_tensor< IndexedTensors... > other) -> indexed_dynamic_tensor &requires(!is_const< std::remove_reference_t< Tensor > >)
Definition: indexed_dynamic_tensor.h:198
indexed_dynamic_tensor(Tensor t)
Definition: indexed_dynamic_tensor.h:36
typename indices::template at< I > index_at
Definition: indexed_dynamic_tensor.h:43
std::decay_t< Tensor > tensor_type
Definition: indexed_dynamic_tensor.h:29
auto resize_internal_tensor(contracted_dynamic_tensor< IndexedTensors... >, type_set_impl<> const, std::vector< std::size_t > &size)
Definition: indexed_dynamic_tensor.h:213
static auto constexpr rank()
Definition: indexed_dynamic_tensor.h:52
auto operator=(contracted_dynamic_tensor< indexed_dynamic_tensor< LHS, I, J >, indexed_dynamic_tensor< RHS, J > > other) -> indexed_dynamic_tensor &requires(!is_const< std::remove_reference_t< Tensor > > &&is_same< value_type, tatooine::value_type< LHS > > &&is_same< value_type, tatooine::value_type< RHS > > &&is_same< I, index_at< 0 > >)
Definition: indexed_dynamic_tensor.h:329
auto dimension(std::size_t const i)
Definition: indexed_dynamic_tensor.h:53
static auto constexpr contains() -> bool
Definition: indexed_dynamic_tensor.h:75
static auto constexpr dimension(std::size_t const i)
Definition: base_tensor.h:49
An empty struct that holds types.
Definition: type_list.h:248
Inherits from a type_list with only unique types.
Definition: type_set.h:138