Tatooine
bidiagonal_system_solver.h
Go to the documentation of this file.
1#ifndef TATOOINE_BIDIAGONAL_SYSTEM_QR_SOLVER_H
2#define TATOOINE_BIDIAGONAL_SYSTEM_QR_SOLVER_H
3//==============================================================================
4#include <cassert>
5#include <cmath>
6#include <cstdlib>
7//==============================================================================
8namespace tatooine {
9//==============================================================================
10// General solvers
11//-------------------------------------------------------------------------------
19template <typename T>
20bool bdsvu(int _n, int _nrhs, const T* _d, const T* _du, T* _b, int _ldb) {
21 // assert(_n > 1);
22 assert(_nrhs > 0);
23 assert(_ldb >= _n || _nrhs == 1);
24
25 if (_d[_n - 1] == T(0)) return false; // shall we handle inf?
26
27 for (int j = 0; j < _nrhs; ++j)
28 _b[(_n - 1) + j * _ldb] /= _d[_n - 1];
29
30 for (int i = _n - 2; i >= 0; --i) {
31 if (_d[i] == T(0)) return false;
32 for (int j = 0; j < _nrhs; ++j)
33 _b[i + j * _ldb] =
34 (_b[i + j * _ldb] - _du[i] * _b[(i + 1) + j * _ldb]) / _d[i];
35 }
36 return true;
37}
38//------------------------------------------------------------------------------
48template <typename T>
49bool bdsvl(int _n, int _nrhs, const T* _dl, const T* _d, T* _b, int _ldb) {
50 // assert(_n > 1);
51 assert(_nrhs > 0);
52 assert(_ldb >= _n || _nrhs == 1);
53
54 if (_d[0] == T(0)) return false; // shall we handle inf?
55
56 for (int j = 0; j < _nrhs; ++j)
57 _b[0 + j * _ldb] /= _d[0];
58
59 for (int i = 1; i < _n; ++i) {
60 if (_d[i] == T(0)) return false;
61 for (int j = 0; j < _nrhs; ++j)
62 _b[i + j * _ldb] =
63 (_b[i + j * _ldb] - _dl[i - 1] * _b[(i - 1) + j * _ldb]) / _d[i];
64 }
65 return true;
66}
67//-------------------------------------------------------------------------------
68// Special solvers tailored to problem
69//-------------------------------------------------------------------------------
81template <typename T>
82bool solve_blockwise(int _n, const T* _d, const T* _du, T* _b) {
83 // assert(_n > 1);
84
85 int m = _n - 1;
86
87 // get s
88 T s(1);
89 {
90 if (_d[m - 1] == T(0)) return false;
91
92 T x = _du[m - 1] / _d[m - 1];
93 s -= x;
94 for (int i = m - 2; i >= 0; --i) {
95 x *= -_du[i] / _d[i];
96 s -= x;
97 }
98 }
99
100 // get y
101 assert(_n < 2048 && "limit temporary buffer size using alloca");
102 T* b = (T*)alloca(_n * sizeof(T));
103 memcpy(b, _b, _n * sizeof(T));
104
105 bool regular = bdsvl(m, 1, _d, _du, b, 1);
106 assert(regular);
107
108 T y(b[m]);
109 for (int i = 0; i < m; ++i)
110 y -= b[i];
111 y /= s;
112
113 _b[m - 1] -= _du[m - 1] * y;
114
115 regular = bdsvu(m, 1, _d, _du, _b, 1);
116 assert(regular);
117
118 _b[m] = y;
119
120 return true;
121}
122//-------------------------------------------------------------------------------
128template <typename T>
129void _planerot(T _c, T _s, T* _x) {
130 T x = _x[0];
131 T y = _x[1];
132 _x[0] = _c * x - _s * y;
133 _x[1] = _s * x + _c * y;
134}
135//------------------------------------------------------------------------------
145
146#ifdef DOXYGEN_SKIP
147#ifndef _Q_IN_NULL
148#define _Q_IN_NULL
149#endif
150#endif
171template <typename T>
172bool solve_qr(int _n, T* _d, T* _du, T* _b, T* _null) {
173 assert(_n < 2048 && "limit temporary buffer size using alloca");
174#ifndef _Q_IN_NULL
175 T* q = (T*)alloca(2 * _n * sizeof(T)); // store givens rotations
176#else
177 T* qc = (T*)alloca(_n * sizeof(T)); // store cosines of givens rotations
178#endif
179
180 //
181 // 1. Compute givens rotations and construct R.
182 //
183
184 for (int j = 0; j < _n; ++j) {
185 T c, s;
186
187 if (_du[j] == T(0)) {
188 c = 1;
189 s = 0;
190 } else {
191 if (fabs(_du[j]) > fabs(_d[j])) { // Do it the super-stable way!
192 T tau = _d[j] / _du[j]; // (avoid overflows)
193 s = 1 / std::sqrt(T(1) + tau * tau);
194 c = s * tau;
195 } else {
196 T tau = _du[j] / _d[j];
197 c = 1 / std::sqrt(T(1) + tau * tau);
198 s = c * tau;
199 }
200 }
201
202 _d[j] = c * _d[j] + s * _du[j];
203
204 if (j + 1 < _n) {
205 T a = _d[j + 1];
206 _du[j] = s * a; // first component...
207 _d[j + 1] = c * a; // ... is zero
208 }
209#ifndef _Q_IN_NULL
210 q[2 * j] = c;
211 q[2 * j + 1] = s;
212#else
213 qc[j] = c;
214 _null[j] = s; // store sine in _null
215#endif
216 }
217 _du[_n - 1] = T(0);
218
219 // Now we have Q' as givens rotations in q and
220 // R in _d (diagonal), _du (superdiagonal).
221
222 //
223 // 2. Solve y=R'\_b
224 //
225
226 if (!bdsvl(_n, 1, _du, _d, _b, 1)) return false;
227 _b[_n] = T(0);
228
229 //
230 // 3. Compute xln=Q*y and _null=Q(:,end).
231 //
232
233 _null[_n] = T(1);
234
235 for (int j = _n - 1; j >= 0; --j) {
236#ifndef _Q_IN_NULL
237 T c = q[2 * j];
238 T s = q[2 * j + 1];
239#else
240 T c = qc[j];
241 T s = _null[j];
242#endif
243
244 _null[j] = T(0);
245 _planerot(c, s, _null + j);
246 _planerot(c, s, _b + j);
247 }
248
249 return true;
250}
251//==============================================================================
252} // namespace tatooine
253//==============================================================================
254#endif
Definition: algorithm.h:6
void _planerot(T _c, T _s, T *_x)
Definition: bidiagonal_system_solver.h:129
bool bdsvl(int _n, int _nrhs, const T *_dl, const T *_d, T *_b, int _ldb)
Definition: bidiagonal_system_solver.h:49
bool solve_blockwise(int _n, const T *_d, const T *_du, T *_b)
Definition: bidiagonal_system_solver.h:82
bool solve_qr(int _n, T *_d, T *_du, T *_b, T *_null)
Definition: bidiagonal_system_solver.h:172
bool bdsvu(int _n, int _nrhs, const T *_d, const T *_du, T *_b, int _ldb)
Definition: bidiagonal_system_solver.h:20