1#ifndef TATOOINE_LAPACK_TRTRS_H
2#define TATOOINE_LAPACK_TRTRS_H
5auto dtrtrs_(
char* UPLO,
char* TRANS,
char* DIAG,
int* N,
int* NRHS,
double* A,
6 int* LDA,
double* B,
int* LDB,
int* INFO) -> void;
8auto strtrs_(
char* UPLO,
char* TRANS,
char* DIAG,
int* N,
int* NRHS,
float* A,
9 int* LDA,
float* B,
int* LDB,
int* INFO) -> void;
17template <std::
floating_po
int Float>
18auto trtrs(uplo u, op t, diag d,
int N,
int NRHS, Float* A,
int LDA,
19 Float* B,
int LDB) ->
int {
21 if constexpr (std::same_as<Float, double>) {
22 dtrtrs_(
reinterpret_cast<char*
>(&u),
reinterpret_cast<char*
>(&t),
23 reinterpret_cast<char*
>(&d), &N, &NRHS, A, &LDA, B, &LDB, &INFO);
24 }
else if constexpr (std::same_as<Float, float>) {
25 strtrs_(
reinterpret_cast<char*
>(&u),
reinterpret_cast<char*
>(&t),
26 reinterpret_cast<char*
>(&d), &N, &NRHS, A, &LDA, B, &LDB, &INFO);
59template <
typename T,
size_t M,
size_t N,
size_t NRHS>
61 op
const t, diag
const d) {
62 return trtrs(u, t, d,
static_cast<int>(N),
static_cast<int>(NRHS), A.
data(),
63 static_cast<int>(M), B.
data(),
static_cast<int>(M));
75template <
typename T,
size_t M,
size_t N>
77 op
const t, diag
const d) {
78 return trtrs(u, t, d,
static_cast<int>(N), 1, A.
data(),
static_cast<int>(M),
79 b.
data(),
static_cast<int>(M));
93 op
const t, diag
const d) {
94 assert(A.
rank() == 2);
95 assert(B.
rank() == 1 || B.
rank() == 2);
100 return trtrs(u, t, d,
static_cast<int>(N),
static_cast<int>(NRHS), A.
data(),
101 static_cast<int>(M), B.
data(),
static_cast<int>(M));
constexpr auto data() -> ValueType *
Definition: static_multidim_array.h:260
auto trtrs(uplo u, op t, diag d, int N, int NRHS, Float *A, int LDA, Float *B, int LDB) -> int
Definition: trtrs.h:18
static auto constexpr rank()
Definition: base_tensor.h:41
static auto constexpr dimension(std::size_t const i)
Definition: base_tensor.h:49
auto dtrtrs_(char *UPLO, char *TRANS, char *DIAG, int *N, int *NRHS, double *A, int *LDA, double *B, int *LDB, int *INFO) -> void
auto strtrs_(char *UPLO, char *TRANS, char *DIAG, int *N, int *NRHS, float *A, int *LDA, float *B, int *LDB, int *INFO) -> void