1#ifndef TATOOINE_LAPACK_GEMM_H
2#define TATOOINE_LAPACK_GEMM_H
8auto dgemm_(
char* transA,
char* transB,
int* m,
int* n,
int* k,
double* alpha,
9 double* A,
int* lda,
double* B,
int* ldb,
double* beta,
double* C,
11auto sgemm_(
char* transA,
char* transB,
int* m,
int* n,
int* k,
float* alpha,
12 float* A,
int* lda,
float* B,
int* ldb,
float* beta,
float* C,
40template <std::
floating_po
int Float>
41auto gemm(
op TRANSA,
op TRANSB,
int M,
int N,
int K, Float ALPHA, Float
const* A,
42 int LDA, Float
const* B,
int LDB, Float BETA, Float* C,
int LDC) ->
void {
43 if constexpr (std::same_as<Float, double>) {
44 dgemm_(
reinterpret_cast<char*
>(&TRANSA),
reinterpret_cast<char*
>(&TRANSB),
45 &M, &N, &K, &ALPHA,
const_cast<Float*
>(A), &LDA,
46 const_cast<Float*
>(B), &LDB, &BETA, C, &LDC);
47 }
else if constexpr (std::same_as<Float, float>) {
48 sgemm_(
reinterpret_cast<char*
>(&TRANSA),
reinterpret_cast<char*
>(&TRANSB),
49 &M, &N, &K, &ALPHA,
const_cast<Float*
>(A), &LDA,
50 const_cast<Float*
>(B), &LDB, &BETA, C, &LDC);
55template <std::
floating_po
int Float, std::
size_t M, std::
size_t N, std::
size_t K>
64template <std::
floating_po
int Float>
68 assert(A.
rank() == 2);
69 assert(B.
rank() == 1 || B.
rank() == 2);
78 return gemm<Float>(trans_A, trans_B, M, N, K, alpha, A.
data(), M, B.
data(), K,
83template <std::
floating_po
int Float>
constexpr auto data() -> ValueType *
Definition: static_multidim_array.h:260
auto dgemm_(char *transA, char *transB, int *m, int *n, int *k, double *alpha, double *A, int *lda, double *B, int *ldb, double *beta, double *C, int *ldc) -> void
auto sgemm_(char *transA, char *transB, int *m, int *n, int *k, float *alpha, float *A, int *lda, float *B, int *ldb, float *beta, float *C, int *ldc) -> void
auto gemm(op TRANSA, op TRANSB, int M, int N, int K, Float ALPHA, Float const *A, int LDA, Float const *B, int LDB, Float BETA, Float *C, int LDC) -> void
Definition: gemm.h:41
static auto constexpr rank()
Definition: base_tensor.h:41
static auto constexpr dimension(std::size_t const i)
Definition: base_tensor.h:49