1#ifndef TATOOINE_BIDIAGONAL_SYSTEM_QR_SOLVER_H
2#define TATOOINE_BIDIAGONAL_SYSTEM_QR_SOLVER_H
20bool bdsvu(
int _n,
int _nrhs,
const T* _d,
const T* _du, T* _b,
int _ldb) {
23 assert(_ldb >= _n || _nrhs == 1);
25 if (_d[_n - 1] == T(0))
return false;
27 for (
int j = 0; j < _nrhs; ++j)
28 _b[(_n - 1) + j * _ldb] /= _d[_n - 1];
30 for (
int i = _n - 2; i >= 0; --i) {
31 if (_d[i] == T(0))
return false;
32 for (
int j = 0; j < _nrhs; ++j)
34 (_b[i + j * _ldb] - _du[i] * _b[(i + 1) + j * _ldb]) / _d[i];
49bool bdsvl(
int _n,
int _nrhs,
const T* _dl,
const T* _d, T* _b,
int _ldb) {
52 assert(_ldb >= _n || _nrhs == 1);
54 if (_d[0] == T(0))
return false;
56 for (
int j = 0; j < _nrhs; ++j)
57 _b[0 + j * _ldb] /= _d[0];
59 for (
int i = 1; i < _n; ++i) {
60 if (_d[i] == T(0))
return false;
61 for (
int j = 0; j < _nrhs; ++j)
63 (_b[i + j * _ldb] - _dl[i - 1] * _b[(i - 1) + j * _ldb]) / _d[i];
90 if (_d[m - 1] == T(0))
return false;
92 T x = _du[m - 1] / _d[m - 1];
94 for (
int i = m - 2; i >= 0; --i) {
101 assert(_n < 2048 &&
"limit temporary buffer size using alloca");
102 T* b = (T*)alloca(_n *
sizeof(T));
103 memcpy(b, _b, _n *
sizeof(T));
105 bool regular =
bdsvl(m, 1, _d, _du, b, 1);
109 for (
int i = 0; i < m; ++i)
113 _b[m - 1] -= _du[m - 1] * y;
115 regular =
bdsvu(m, 1, _d, _du, _b, 1);
132 _x[0] = _c * x - _s * y;
133 _x[1] = _s * x + _c * y;
172bool solve_qr(
int _n, T* _d, T* _du, T* _b, T* _null) {
173 assert(_n < 2048 &&
"limit temporary buffer size using alloca");
175 T* q = (T*)alloca(2 * _n *
sizeof(T));
177 T* qc = (T*)alloca(_n *
sizeof(T));
184 for (
int j = 0; j < _n; ++j) {
187 if (_du[j] == T(0)) {
191 if (fabs(_du[j]) > fabs(_d[j])) {
192 T tau = _d[j] / _du[j];
193 s = 1 / std::sqrt(T(1) + tau * tau);
196 T tau = _du[j] / _d[j];
197 c = 1 / std::sqrt(T(1) + tau * tau);
202 _d[j] = c * _d[j] + s * _du[j];
226 if (!
bdsvl(_n, 1, _du, _d, _b, 1))
return false;
235 for (
int j = _n - 1; j >= 0; --j) {
Definition: algorithm.h:6
void _planerot(T _c, T _s, T *_x)
Definition: bidiagonal_system_solver.h:129
bool bdsvl(int _n, int _nrhs, const T *_dl, const T *_d, T *_b, int _ldb)
Definition: bidiagonal_system_solver.h:49
bool solve_blockwise(int _n, const T *_d, const T *_du, T *_b)
Definition: bidiagonal_system_solver.h:82
bool solve_qr(int _n, T *_d, T *_du, T *_b, T *_null)
Definition: bidiagonal_system_solver.h:172
bool bdsvu(int _n, int _nrhs, const T *_d, const T *_du, T *_b, int _ldb)
Definition: bidiagonal_system_solver.h:20