Tatooine
gemv.h
Go to the documentation of this file.
1#ifndef TATOOINE_BLAS_GEMV_H
2#define TATOOINE_BLAS_GEMV_H
3//==============================================================================
4extern "C" {
5//==============================================================================
6auto dgemv_(char* TRANS, int* M, int* N, double* ALPHA, double* A, int* LDA,
7 double* X, int* INCX, double* BETA, double* Y, int* INCY) -> void;
8auto sgemv_(char* TRANS, int* M, int* N, float* ALPHA, float* A, int* LDA,
9 float* X, int* INCX, float* BETA, float* Y, int* INCY) -> void;
10//==============================================================================
11}
12//==============================================================================
13#include <tatooine/blas/base.h>
14
15#include <cassert>
16//==============================================================================
17namespace tatooine::blas {
18//==============================================================================
35//==============================================================================
36template <std::floating_point Float>
37auto gemv(op TRANS, int M, int N, Float ALPHA, Float const* A, int LDA,
38 Float const* X, int INCX, Float BETA, Float* Y, int INCY) -> void {
39 if constexpr (std::same_as<Float, double>) {
40 dgemv_(reinterpret_cast<char*>(&TRANS), &M, &N, &ALPHA,
41 const_cast<Float*>(A), &LDA, const_cast<Float*>(X), &INCX, &BETA, Y,
42 &INCY);
43 } else if constexpr (std::same_as<Float, float>) {
44 sgemv_(reinterpret_cast<char*>(&TRANS), &M, &N, &ALPHA,
45 const_cast<Float*>(A), &LDA, const_cast<Float*>(X), &INCX, &BETA, Y,
46 &INCY);
47 }
48}
49//==============================================================================
51template <typename Real>
52auto gemv(op trans, Real const alpha, tensor<Real> const& A,
53 tensor<Real> const& x, Real const beta, tensor<Real>& y) {
54 assert(A.rank() == 2);
55 assert(x.rank() == 1);
56 assert(y.rank() == 1);
57 auto const M = A.dimension(0);
58 auto const N = A.dimension(1);
59 assert(N == x.dimension(0));
60 assert(y.dimension(0) == M);
61
62 return gemv(trans, static_cast<int>(M), static_cast<int>(N), alpha, A.data(),
63 static_cast<int>(M), x.data(), 1, beta, y.data(), 1);
64}
65//------------------------------------------------------------------------------
67template <typename Real>
68auto gemv(Real const alpha, tensor<Real> const& A, tensor<Real> const& x,
69 Real const beta, tensor<Real>& y) {
70 return gemv(op::no_transpose, alpha, A, x, beta, y);
71}
73//==============================================================================
74} // namespace tatooine::blas
75//==============================================================================
76#endif
constexpr auto data() -> ValueType *
Definition: static_multidim_array.h:260
auto dgemv_(char *TRANS, int *M, int *N, double *ALPHA, double *A, int *LDA, double *X, int *INCX, double *BETA, double *Y, int *INCY) -> void
auto sgemv_(char *TRANS, int *M, int *N, float *ALPHA, float *A, int *LDA, float *X, int *INCX, float *BETA, float *Y, int *INCY) -> void
auto gemv(op TRANS, int M, int N, Float ALPHA, Float const *A, int LDA, Float const *X, int INCX, Float BETA, Float *Y, int INCY) -> void
Definition: gemv.h:37
Definition: base.h:13
op
Definition: base.h:46
Definition: tensor.h:17
static auto constexpr rank()
Definition: base_tensor.h:41
static auto constexpr dimension(std::size_t const i)
Definition: base_tensor.h:49