Tatooine
trtrs.h
Go to the documentation of this file.
1#ifndef TATOOINE_LAPACK_TRTRS_H
2#define TATOOINE_LAPACK_TRTRS_H
3//==============================================================================
4extern "C" {
5auto dtrtrs_(char* UPLO, char* TRANS, char* DIAG, int* N, int* NRHS, double* A,
6 int* LDA, double* B, int* LDB, int* INFO) -> void;
7//------------------------------------------------------------------------------
8auto strtrs_(char* UPLO, char* TRANS, char* DIAG, int* N, int* NRHS, float* A,
9 int* LDA, float* B, int* LDB, int* INFO) -> void;
10}
11//==============================================================================
12#include <concepts>
14//==============================================================================
15namespace tatooine::lapack {
16//==============================================================================
17template <std::floating_point Float>
18auto trtrs(uplo u, op t, diag d, int N, int NRHS, Float* A, int LDA,
19 Float* B, int LDB) -> int {
20 auto INFO = int{};
21 if constexpr (std::same_as<Float, double>) {
22 dtrtrs_(reinterpret_cast<char*>(&u), reinterpret_cast<char*>(&t),
23 reinterpret_cast<char*>(&d), &N, &NRHS, A, &LDA, B, &LDB, &INFO);
24 } else if constexpr (std::same_as<Float, float>) {
25 strtrs_(reinterpret_cast<char*>(&u), reinterpret_cast<char*>(&t),
26 reinterpret_cast<char*>(&d), &N, &NRHS, A, &LDA, B, &LDB, &INFO);
27 }
28 return INFO;
29}
30//==============================================================================
49//==============================================================================
59template <typename T, size_t M, size_t N, size_t NRHS>
60auto trtrs(tensor<T, M, N>& A, tensor<T, M, NRHS>& B, uplo const u,
61 op const t, diag const d) {
62 return trtrs(u, t, d, static_cast<int>(N), static_cast<int>(NRHS), A.data(),
63 static_cast<int>(M), B.data(), static_cast<int>(M));
64}
65//------------------------------------------------------------------------------
75template <typename T, size_t M, size_t N>
76auto trtrs(tensor<T, M, N>& A, tensor<T, M>& b, uplo const u,
77 op const t, diag const d) {
78 return trtrs(u, t, d, static_cast<int>(N), 1, A.data(), static_cast<int>(M),
79 b.data(), static_cast<int>(M));
80}
81//------------------------------------------------------------------------------
91template <typename T>
92auto trtrs(tensor<T>& A, tensor<T>& B, uplo const u,
93 op const t, diag const d) {
94 assert(A.rank() == 2);
95 assert(B.rank() == 1 || B.rank() == 2);
96 assert(A.dimension(0) == B.dimension(0));
97 auto const M = A.dimension(0);
98 auto const N = A.dimension(1);
99 auto const NRHS = B.rank() == 2 ? B.dimension(1) : 1;
100 return trtrs(u, t, d, static_cast<int>(N), static_cast<int>(NRHS), A.data(),
101 static_cast<int>(M), B.data(), static_cast<int>(M));
102}
103//==============================================================================
105//==============================================================================
106} // namespace tatooine::lapack
107//==============================================================================
108#endif
constexpr auto data() -> ValueType *
Definition: static_multidim_array.h:260
Definition: base.h:6
auto trtrs(uplo u, op t, diag d, int N, int NRHS, Float *A, int LDA, Float *B, int LDB) -> int
Definition: trtrs.h:18
Definition: tensor.h:17
static auto constexpr rank()
Definition: base_tensor.h:41
static auto constexpr dimension(std::size_t const i)
Definition: base_tensor.h:49
auto dtrtrs_(char *UPLO, char *TRANS, char *DIAG, int *N, int *NRHS, double *A, int *LDA, double *B, int *LDB, int *INFO) -> void
auto strtrs_(char *UPLO, char *TRANS, char *DIAG, int *N, int *NRHS, float *A, int *LDA, float *B, int *LDB, int *INFO) -> void