Tatooine
ormqr.h
Go to the documentation of this file.
1#ifndef TATOOINE_LAPACK_ORMQR_H
2#define TATOOINE_LAPACK_ORMQR_H
3//==============================================================================
4extern "C" {
5auto dormqr_(char* SIDE, char* TRANS, int* M, int* N, int* K, double* A,
6 int* LDA, double* TAU, double* C, int* LDC, double* WORK,
7 int* LWORK, int* INFO) -> void;
8auto sormqr_(char* SIDE, char* TRANS, int* M, int* N, int* K, float* A,
9 int* LDA, float* TAU, float* C, int* LDC, float* WORK, int* LWORK,
10 int* INFO) -> void;
11}
12//==============================================================================
14
15#include <concepts>
16#include <memory>
17//==============================================================================
18namespace tatooine::lapack {
19//==============================================================================
20template <std::floating_point Float>
21auto ormqr(side SIDE, op TRANS, int M, int N, int K, Float* A, int LDA,
22 Float* TAU, Float* C, int LDC, Float* WORK, int LWORK) -> int {
23 auto INFO = int{};
24 if constexpr (std::same_as<Float, double>) {
25 dormqr_(reinterpret_cast<char*>(&SIDE), reinterpret_cast<char*>(&TRANS), &M,
26 &N, &K, A, &LDA, TAU, C, &LDC, WORK, &LWORK, &INFO);
27 } else if constexpr (std::same_as<Float, float>) {
28 sormqr_(reinterpret_cast<char*>(&SIDE), reinterpret_cast<char*>(&TRANS), &M,
29 &N, &K, A, &LDA, TAU, C, &LDC, WORK, &LWORK, &INFO);
30 }
31 return INFO;
32}
33//==============================================================================
34template <std::floating_point Float>
35auto ormqr(side SIDE, op TRANS, int M, int N, int K, Float* A, int LDA,
36 Float* TAU, Float* C, int LDC) -> int {
37 auto LWORK = int{-1};
38 auto WORK = std::unique_ptr<Float[]>{new Float[1]};
39
40 ormqr<Float>(SIDE, TRANS, M, N, K, A, LDA, TAU, C, LDC, WORK.get(), LWORK);
41 LWORK = static_cast<int>(WORK[0]);
42 WORK = std::unique_ptr<Float[]>{new Float[LWORK]};
43 return ormqr<Float>(SIDE, TRANS, M, N, K, A, LDA, TAU, C, LDC, WORK.get(),
44 LWORK);
45}
46//==============================================================================
71//==============================================================================
72template <typename T, size_t K, size_t M>
73auto ormqr(tensor<T, M, K>& A, tensor<T, M>& c, tensor<T, K>& tau, side const s,
74 op trans) {
75 return ormqr(s, trans, static_cast<int>(M), 1, static_cast<int>(K), A.data(),
76 static_cast<int>(M), tau.data(), c.data(), static_cast<int>(M));
77}
78//==============================================================================
79template <typename T, size_t K, size_t M, size_t N>
81 side const s, op trans) {
82 return ormqr(s, trans, static_cast<int>(M), static_cast<int>(N),
83 static_cast<int>(K), A.data(), static_cast<int>(M), tau.data(),
84 C.data(), static_cast<int>(M));
85}
86//==============================================================================
87template <typename T>
88auto ormqr(tensor<T>& A, tensor<T>& C, tensor<T>& tau, side const s, op trans) {
89 assert(A.rank() == 2);
90 assert(C.rank() == 1 || C.rank() == 2);
91 assert(tau.rank() == 1);
92 assert(A.dimension(0) == C.dimension(0));
93 assert(A.dimension(1) == tau.dimension(0));
94 auto const M = A.dimension(0);
95 auto const K = A.dimension(1);
96 auto const N = C.rank() == 2 ? C.dimension(1) : 1;
97 return ormqr(s, trans, static_cast<int>(M), static_cast<int>(N),
98 static_cast<int>(K), A.data(), static_cast<int>(M), tau.data(),
99 C.data(), static_cast<int>(M));
100}
101//==============================================================================
103//==============================================================================
104} // namespace tatooine::lapack
105//==============================================================================
106#endif
constexpr auto data() -> ValueType *
Definition: static_multidim_array.h:260
Definition: base.h:6
auto ormqr(side SIDE, op TRANS, int M, int N, int K, Float *A, int LDA, Float *TAU, Float *C, int LDC, Float *WORK, int LWORK) -> int
Definition: ormqr.h:21
auto sormqr_(char *SIDE, char *TRANS, int *M, int *N, int *K, float *A, int *LDA, float *TAU, float *C, int *LDC, float *WORK, int *LWORK, int *INFO) -> void
auto dormqr_(char *SIDE, char *TRANS, int *M, int *N, int *K, double *A, int *LDA, double *TAU, double *C, int *LDC, double *WORK, int *LWORK, int *INFO) -> void
Definition: tensor.h:17
static auto constexpr rank()
Definition: base_tensor.h:41
static auto constexpr dimension(std::size_t const i)
Definition: base_tensor.h:49