1#ifndef TATOOINE_EINSTEIN_NOTATION_INDEXED_STATIC_TENSOR_H
2#define TATOOINE_EINSTEIN_NOTATION_INDEXED_STATIC_TENSOR_H
7#if TATOOINE_BLAS_AND_LAPACK_AVAILABLE
15template <
typename... IndexedTensors>
16struct contracted_static_tensor;
18template <
typename... ContractedTensors>
19struct added_contracted_static_tensor;
23 static_assert(std::decay_t<Tensor>::rank() ==
sizeof...(Indices));
37 template <std::
size_t I>
38 using index_at =
typename indices::template at<I>;
43 template <std::size_t... Seq>
44 static auto constexpr index_map(std::index_sequence<Seq...> ) {
45 return std::array{Indices::get()...};
47 static auto constexpr rank() {
return tensor_type::rank(); }
48 template <std::
size_t I>
49 static auto constexpr size() {
50 return tensor_type::template size<I>();
54 template <std::size_t I,
typename E,
typename HeadIndex,
55 typename... TailIndices>
57 if constexpr (is_same<E, HeadIndex>) {
58 return tensor_type::dimension(I);
60 return size_<I + 1, E, TailIndices...>();
66 static auto constexpr size() {
67 return size_<0, E, Indices...>();
72 return type_list<Indices...>::template contains<E>;
75 template <
typename... ContractedTensors, std::size_t... Seq>
76 requires(!is_const<std::remove_reference_t<Tensor>>)
78 std::index_sequence<Seq...> ) {
79 ([&] { *
this += other.template at<Seq>(); }(), ...);
82 template <
typename... IndexedTensors, std::size_t... FreeIndexSequence,
83 std::size_t... ContractedIndexSequence,
84 std::size_t... ContractedTensorsSequence>
85 requires(!is_const<std::remove_reference_t<Tensor>>)
87 std::index_sequence<FreeIndexSequence...> ,
88 std::index_sequence<ContractedIndexSequence...> ,
89 std::index_sequence<ContractedTensorsSequence...> ) {
90 using map_t = std::map<std::size_t, std::size_t>;
97 auto const free_indices_map = map_t{
98 map_t::value_type{free_indices::template at<FreeIndexSequence>::get(),
99 FreeIndexSequence}...};
100 auto const contracted_indices_map = map_t{map_t::value_type{
101 contracted_indices::template at<ContractedIndexSequence>::get(),
102 ContractedIndexSequence,
104 auto const tensor_index_maps = std::tuple{IndexedTensors::index_map()...};
106 std::tuple{make_array<std::size_t, IndexedTensors::rank()>()...};
112 auto const free_index_array = std::array{
free_indices...};
116 std::get<ContractedTensorsSequence>(index_arrays);
117 auto const& tensor_index_map =
118 std::get<ContractedTensorsSequence>(tensor_index_maps);
119 auto index_arr_it =
begin(index_array);
120 auto tensor_index_map_it =
begin(tensor_index_map);
122 for (; tensor_index_map_it !=
end(tensor_index_map);
123 ++tensor_index_map_it, ++index_arr_it) {
124 if (free_indices_map.contains(*tensor_index_map_it)) {
125 *index_arr_it = free_index_array[free_indices_map.at(
126 *tensor_index_map_it)];
132 if constexpr (contracted_indices::empty) {
134 (other.template at<ContractedTensorsSequence>().tensor()(
135 std::get<ContractedTensorsSequence>(index_arrays)) *
142 auto const contracted_index_array =
147 std::get<ContractedTensorsSequence>(index_arrays);
148 auto const& tensor_index_map =
149 std::get<ContractedTensorsSequence>(
151 auto index_arr_it =
begin(index_array);
152 auto tensor_index_map_it =
begin(tensor_index_map);
154 for (; tensor_index_map_it !=
end(tensor_index_map);
155 ++tensor_index_map_it, ++index_arr_it) {
156 if (contracted_indices_map.contains(
157 *tensor_index_map_it)) {
158 *index_arr_it = contracted_index_array
159 [contracted_indices_map.at(
160 *tensor_index_map_it)];
168 (other.template at<ContractedTensorsSequence>().tensor()(
169 std::get<ContractedTensorsSequence>(index_arrays)) *
172 contracted_static_tensor::template
size<
173 typename contracted_indices::template at<
174 ContractedIndexSequence>>()...);
177 tensor_type::dimension(FreeIndexSequence)...);
180 template <
typename... IndexedTensors>
181 requires(!is_const<std::remove_reference_t<Tensor>>)
184 add(other, std::make_index_sequence<
rank()>{},
186 IndexedTensors...>::contracted_indices::size>{},
187 std::make_index_sequence<
sizeof...(IndexedTensors)>{});
191 template <
typename... IndexedTensors>
192 requires(!is_const<std::remove_reference_t<Tensor>>)
195 add(other, std::make_index_sequence<
rank()>{},
197 IndexedTensors...>::contracted_indices::size>{},
198 std::make_index_sequence<
sizeof...(IndexedTensors)>{});
201 template <
typename... IndexedTensors>
202 requires(!is_const<std::remove_reference_t<Tensor>>)
209 template <
typename... ContractedTensors>
210 requires(!is_const<std::remove_reference_t<Tensor>>)
214 assign(other, std::make_index_sequence<
sizeof...(ContractedTensors)>{});
218 template <
typename Tensors,
typename... Is>
219 requires(!is_const<std::remove_reference_t<Tensor>>)
226#if TATOOINE_BLAS_AND_LAPACK_AVAILABLE
229 template <
typename LHS,
typename RHS,
typename I,
typename J,
typename K>
233 if constexpr (is_same<I, index_at<0>> && is_same<K, index_at<1>>) {
234 using comp_type =
typename tensor_type::value_type;
235 static_assert(is_same<comp_type, typename std::decay_t<LHS>::value_type>);
236 static_assert(is_same<comp_type, typename std::decay_t<RHS>::value_type>);
auto gemm(op TRANSA, op TRANSB, int M, int N, int K, Float ALPHA, Float const *A, int LDA, Float const *B, int LDB, Float BETA, Float *C, int LDC) -> void
Definition: gemm.h:41
Definition: added_contracted_dynamic_tensor.h:4
typename contracted_indices_aux< indexed_tensors_to_index_list< Indices... > >::type contracted_indices
Definition: type_traits.h:90
typename free_indices_aux< indexed_tensors_to_index_list< Indices... > >::type free_indices
Definition: type_traits.h:59
static auto constexpr t
Definition: index.h:24
auto begin(Range &&range)
Definition: iterator_facade.h:318
auto end(Range &&range)
Definition: iterator_facade.h:322
tensor< real_number, Dimensions... > Tensor
Definition: tensor.h:184
auto constexpr index(handle< Child, Int > const h)
Definition: handle.h:119
constexpr auto for_loop(Iteration &&iteration, execution_policy::sequential_t, Ranges(&&... ranges)[2]) -> void
Use this function for creating a sequential nested loop.
Definition: for_loop.h:336
Definition: added_contracted_static_tensor.h:7
Definition: contracted_static_tensor.h:9
tatooine::einstein_notation::free_indices< IndexedTensors... > free_indices
Definition: contracted_static_tensor.h:16
tatooine::einstein_notation::contracted_indices< IndexedTensors... > contracted_indices
Definition: contracted_static_tensor.h:18
Definition: indexed_static_tensor.h:22
typename tensor_type::value_type value_type
Definition: indexed_static_tensor.h:25
static auto constexpr rank()
Definition: indexed_static_tensor.h:47
auto tensor() const -> auto const &
Definition: indexed_static_tensor.h:33
static auto index_map()
Definition: indexed_static_tensor.h:40
typename indices::template at< I > index_at
Definition: indexed_static_tensor.h:38
auto assign(contracted_static_tensor< IndexedTensors... > other)
Definition: indexed_static_tensor.h:193
static auto constexpr size()
Definition: indexed_static_tensor.h:49
auto tensor() -> auto &
Definition: indexed_static_tensor.h:34
auto assign(added_contracted_static_tensor< ContractedTensors... > other, std::index_sequence< Seq... >)
Definition: indexed_static_tensor.h:77
static auto constexpr index_map(std::index_sequence< Seq... >)
Definition: indexed_static_tensor.h:44
auto add(contracted_static_tensor< IndexedTensors... > other, std::index_sequence< FreeIndexSequence... >, std::index_sequence< ContractedIndexSequence... >, std::index_sequence< ContractedTensorsSequence... >)
Definition: indexed_static_tensor.h:86
std::decay_t< Tensor > tensor_type
Definition: indexed_static_tensor.h:24
auto operator=(contracted_static_tensor< indexed_static_tensor< LHS, I, J >, indexed_static_tensor< RHS, J, K > > other) -> indexed_static_tensor &
A(i,k) = B(i,j) * C(j, k)
Definition: indexed_static_tensor.h:230
static auto constexpr contains() -> bool
Definition: indexed_static_tensor.h:71
static auto constexpr size_()
Definition: indexed_static_tensor.h:56
Tensor m_tensor
Definition: indexed_static_tensor.h:28
indexed_static_tensor(Tensor t)
Definition: indexed_static_tensor.h:31
An empty struct that holds types.
Definition: type_list.h:248