Tatooine
gemm.h
Go to the documentation of this file.
1#ifndef TATOOINE_LAPACK_GEMM_H
2#define TATOOINE_LAPACK_GEMM_H
3//==============================================================================
5#include <concepts>
6//==============================================================================
7extern "C" {
8auto dgemm_(char* transA, char* transB, int* m, int* n, int* k, double* alpha,
9 double* A, int* lda, double* B, int* ldb, double* beta, double* C,
10 int* ldc) -> void;
11auto sgemm_(char* transA, char* transB, int* m, int* n, int* k, float* alpha,
12 float* A, int* lda, float* B, int* ldb, float* beta, float* C,
13 int* ldc) -> void;
14}
15//==============================================================================
16namespace tatooine::blas {
17//==============================================================================
39//==============================================================================
40template <std::floating_point Float>
41auto gemm(op TRANSA, op TRANSB, int M, int N, int K, Float ALPHA, Float const* A,
42 int LDA, Float const* B, int LDB, Float BETA, Float* C, int LDC) -> void {
43 if constexpr (std::same_as<Float, double>) {
44 dgemm_(reinterpret_cast<char*>(&TRANSA), reinterpret_cast<char*>(&TRANSB),
45 &M, &N, &K, &ALPHA, const_cast<Float*>(A), &LDA,
46 const_cast<Float*>(B), &LDB, &BETA, C, &LDC);
47 } else if constexpr (std::same_as<Float, float>) {
48 sgemm_(reinterpret_cast<char*>(&TRANSA), reinterpret_cast<char*>(&TRANSB),
49 &M, &N, &K, &ALPHA, const_cast<Float*>(A), &LDA,
50 const_cast<Float*>(B), &LDB, &BETA, C, &LDC);
51 }
52}
53//==============================================================================
55template <std::floating_point Float, std::size_t M, std::size_t N, std::size_t K>
56auto gemm(Float const alpha, tensor<Float, M, K> const& A,
57 tensor<Float, K, N> const& B, Float const beta, tensor<Float, M, N>& C) {
58 return gemm<Float>(op::no_transpose, op::no_transpose, M, N, K, alpha,
59 A.data(), M,
60 B.data(), N, beta, C.data(), M);
61}
62//------------------------------------------------------------------------------
64template <std::floating_point Float>
65auto gemm(blas::op trans_A, blas::op trans_B, Float const alpha,
66 tensor<Float> const& A, tensor<Float> const& B, Float const beta,
67 tensor<Float>& C) {
68 assert(A.rank() == 2);
69 assert(B.rank() == 1 || B.rank() == 2);
70 assert(C.rank() == B.rank());
71 auto const M = A.dimension(0);
72 auto const N = B.rank() == 2 ? B.dimension(1) : 1;
73 assert(A.dimension(1) == B.dimension(0));
74 auto const K = A.dimension(1);
75 assert(C.dimension(0) == M);
76 assert(C.rank() == 1 || C.dimension(1) == N);
77
78 return gemm<Float>(trans_A, trans_B, M, N, K, alpha, A.data(), M, B.data(), K,
79 beta, C.data(), M);
80}
81//------------------------------------------------------------------------------
83template <std::floating_point Float>
84auto gemm(Float const alpha, tensor<Float> const& A, tensor<Float> const& B,
85 Float const beta, tensor<Float>& C) {
86 return gemm<Float>(op::no_transpose, op::no_transpose, alpha, A, B, beta, C);
87}
89//==============================================================================
90} // namespace tatooine::blas
91//==============================================================================
92#endif
constexpr auto data() -> ValueType *
Definition: static_multidim_array.h:260
auto dgemm_(char *transA, char *transB, int *m, int *n, int *k, double *alpha, double *A, int *lda, double *B, int *ldb, double *beta, double *C, int *ldc) -> void
auto sgemm_(char *transA, char *transB, int *m, int *n, int *k, float *alpha, float *A, int *lda, float *B, int *ldb, float *beta, float *C, int *ldc) -> void
auto gemm(op TRANSA, op TRANSB, int M, int N, int K, Float ALPHA, Float const *A, int LDA, Float const *B, int LDB, Float BETA, Float *C, int LDC) -> void
Definition: gemm.h:41
Definition: base.h:13
op
Definition: base.h:46
Definition: tensor.h:17
static auto constexpr rank()
Definition: base_tensor.h:41
static auto constexpr dimension(std::size_t const i)
Definition: base_tensor.h:49