nmatrix 0.0.2 → 0.0.3
Sign up to get free protection for your applications and to get access to all the features.
- data/Gemfile +1 -1
- data/History.txt +31 -3
- data/Manifest.txt +5 -0
- data/README.rdoc +29 -27
- data/ext/nmatrix/binary_format.txt +53 -0
- data/ext/nmatrix/data/data.cpp +18 -18
- data/ext/nmatrix/data/data.h +38 -7
- data/ext/nmatrix/data/rational.h +13 -0
- data/ext/nmatrix/data/ruby_object.h +10 -0
- data/ext/nmatrix/extconf.rb +2 -0
- data/ext/nmatrix/nmatrix.cpp +655 -103
- data/ext/nmatrix/nmatrix.h +26 -14
- data/ext/nmatrix/ruby_constants.cpp +4 -0
- data/ext/nmatrix/ruby_constants.h +2 -0
- data/ext/nmatrix/storage/dense.cpp +99 -41
- data/ext/nmatrix/storage/dense.h +3 -3
- data/ext/nmatrix/storage/list.cpp +36 -14
- data/ext/nmatrix/storage/list.h +4 -4
- data/ext/nmatrix/storage/storage.cpp +19 -19
- data/ext/nmatrix/storage/storage.h +11 -11
- data/ext/nmatrix/storage/yale.cpp +17 -20
- data/ext/nmatrix/storage/yale.h +13 -11
- data/ext/nmatrix/util/io.cpp +25 -23
- data/ext/nmatrix/util/io.h +5 -5
- data/ext/nmatrix/util/math.cpp +634 -17
- data/ext/nmatrix/util/math.h +958 -9
- data/ext/nmatrix/util/sl_list.cpp +7 -7
- data/ext/nmatrix/util/sl_list.h +2 -2
- data/lib/nmatrix.rb +9 -0
- data/lib/nmatrix/blas.rb +4 -4
- data/lib/nmatrix/io/market.rb +227 -0
- data/lib/nmatrix/io/mat_reader.rb +7 -7
- data/lib/nmatrix/lapack.rb +80 -0
- data/lib/nmatrix/nmatrix.rb +78 -52
- data/lib/nmatrix/shortcuts.rb +486 -0
- data/lib/nmatrix/version.rb +1 -1
- data/spec/2x2_dense_double.mat +0 -0
- data/spec/blas_spec.rb +59 -9
- data/spec/elementwise_spec.rb +25 -12
- data/spec/io_spec.rb +69 -1
- data/spec/lapack_spec.rb +53 -4
- data/spec/math_spec.rb +9 -0
- data/spec/nmatrix_list_spec.rb +95 -0
- data/spec/nmatrix_spec.rb +10 -53
- data/spec/nmatrix_yale_spec.rb +17 -15
- data/spec/shortcuts_spec.rb +154 -0
- metadata +22 -15
data/ext/nmatrix/util/math.h
CHANGED
@@ -70,7 +70,10 @@
|
|
70
70
|
|
71
71
|
extern "C" { // These need to be in an extern "C" block or you'll get all kinds of undefined symbol errors.
|
72
72
|
#include <cblas.h>
|
73
|
-
|
73
|
+
|
74
|
+
#ifdef HAVE_CLAPACK_H
|
75
|
+
#include <clapack.h>
|
76
|
+
#endif
|
74
77
|
}
|
75
78
|
|
76
79
|
#include <algorithm> // std::min, std::max
|
@@ -85,6 +88,7 @@ extern "C" { // These need to be in an extern "C" block or you'll get all kinds
|
|
85
88
|
/*
|
86
89
|
* Macros
|
87
90
|
*/
|
91
|
+
#define REAL_RECURSE_LIMIT 4
|
88
92
|
|
89
93
|
/*
|
90
94
|
* Data
|
@@ -95,7 +99,7 @@ extern "C" {
|
|
95
99
|
/*
|
96
100
|
* C accessors.
|
97
101
|
*/
|
98
|
-
void nm_math_det_exact(const int M, const void* elements, const int lda, dtype_t dtype, void* result);
|
102
|
+
void nm_math_det_exact(const int M, const void* elements, const int lda, nm::dtype_t dtype, void* result);
|
99
103
|
void nm_math_transpose_generic(const size_t M, const size_t N, const void* A, const int lda, void* B, const int ldb, size_t element_size);
|
100
104
|
void nm_math_init_blas(void);
|
101
105
|
}
|
@@ -143,6 +147,9 @@ template <> inline double numeric_inverse<double>(const double& n) { return 1 /
|
|
143
147
|
*
|
144
148
|
* For row major, call trsm<DType> instead. That will handle necessary changes-of-variables
|
145
149
|
* and parameter checks.
|
150
|
+
*
|
151
|
+
* Note that some of the boundary conditions here may be incorrect. Very little has been tested!
|
152
|
+
* This was converted directly from dtrsm.f using f2c, and then rewritten more cleanly.
|
146
153
|
*/
|
147
154
|
template <typename DType>
|
148
155
|
inline void trsm_nothrow(const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
@@ -150,6 +157,9 @@ inline void trsm_nothrow(const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
|
150
157
|
const int m, const int n, const DType alpha, const DType* a,
|
151
158
|
const int lda, DType* b, const int ldb)
|
152
159
|
{
|
160
|
+
|
161
|
+
// (row-major) trsm: left upper trans nonunit m=3 n=1 1/1 a 3 b 3
|
162
|
+
|
153
163
|
if (m == 0 || n == 0) return; /* Quick return if possible. */
|
154
164
|
|
155
165
|
if (alpha == 0) { // Handle alpha == 0
|
@@ -210,7 +220,7 @@ inline void trsm_nothrow(const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
|
210
220
|
for (int j = 0; j < n; ++j) {
|
211
221
|
for (int i = 0; i < m; ++i) {
|
212
222
|
DType temp = alpha * b[i + j * ldb];
|
213
|
-
for (int k = 0; k < i
|
223
|
+
for (int k = 0; k < i; ++k) { // limit was i-1. Lots of similar bugs in this code, probably.
|
214
224
|
temp -= a[k + i * lda] * b[k + j * ldb];
|
215
225
|
}
|
216
226
|
if (diag == CblasNonUnit) {
|
@@ -339,6 +349,92 @@ inline void trsm_nothrow(const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
|
339
349
|
}
|
340
350
|
|
341
351
|
|
352
|
+
template <typename DType>
|
353
|
+
inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
|
354
|
+
const int K, const DType* alpha, const DType* A, const int lda, const DType* beta, DType* C, const int ldc) {
|
355
|
+
rb_raise(rb_eNotImpError, "syrk not yet implemented for non-BLAS dtypes");
|
356
|
+
}
|
357
|
+
|
358
|
+
template <typename DType>
|
359
|
+
inline void herk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
|
360
|
+
const int K, const DType* alpha, const DType* A, const int lda, const DType* beta, DType* C, const int ldc) {
|
361
|
+
rb_raise(rb_eNotImpError, "herk not yet implemented for non-BLAS dtypes");
|
362
|
+
}
|
363
|
+
|
364
|
+
template <>
|
365
|
+
inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
|
366
|
+
const int K, const float* alpha, const float* A, const int lda, const float* beta, float* C, const int ldc) {
|
367
|
+
cblas_ssyrk(Order, Uplo, Trans, N, K, *alpha, A, lda, *beta, C, ldc);
|
368
|
+
}
|
369
|
+
|
370
|
+
template <>
|
371
|
+
inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
|
372
|
+
const int K, const double* alpha, const double* A, const int lda, const double* beta, double* C, const int ldc) {
|
373
|
+
cblas_dsyrk(Order, Uplo, Trans, N, K, *alpha, A, lda, *beta, C, ldc);
|
374
|
+
}
|
375
|
+
|
376
|
+
template <>
|
377
|
+
inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
|
378
|
+
const int K, const Complex64* alpha, const Complex64* A, const int lda, const Complex64* beta, Complex64* C, const int ldc) {
|
379
|
+
cblas_csyrk(Order, Uplo, Trans, N, K, alpha, A, lda, beta, C, ldc);
|
380
|
+
}
|
381
|
+
|
382
|
+
template <>
|
383
|
+
inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
|
384
|
+
const int K, const Complex128* alpha, const Complex128* A, const int lda, const Complex128* beta, Complex128* C, const int ldc) {
|
385
|
+
cblas_zsyrk(Order, Uplo, Trans, N, K, alpha, A, lda, beta, C, ldc);
|
386
|
+
}
|
387
|
+
|
388
|
+
|
389
|
+
template <>
|
390
|
+
inline void herk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
|
391
|
+
const int K, const Complex64* alpha, const Complex64* A, const int lda, const Complex64* beta, Complex64* C, const int ldc) {
|
392
|
+
cblas_cherk(Order, Uplo, Trans, N, K, alpha->r, A, lda, beta->r, C, ldc);
|
393
|
+
}
|
394
|
+
|
395
|
+
template <>
|
396
|
+
inline void herk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
|
397
|
+
const int K, const Complex128* alpha, const Complex128* A, const int lda, const Complex128* beta, Complex128* C, const int ldc) {
|
398
|
+
cblas_zherk(Order, Uplo, Trans, N, K, alpha->r, A, lda, beta->r, C, ldc);
|
399
|
+
}
|
400
|
+
|
401
|
+
|
402
|
+
template <typename DType>
|
403
|
+
inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
404
|
+
const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const DType* alpha,
|
405
|
+
const DType* A, const int lda, DType* B, const int ldb) {
|
406
|
+
rb_raise(rb_eNotImpError, "trmm not yet implemented for non-BLAS dtypes");
|
407
|
+
}
|
408
|
+
|
409
|
+
template <>
|
410
|
+
inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
411
|
+
const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const float* alpha,
|
412
|
+
const float* A, const int lda, float* B, const int ldb) {
|
413
|
+
cblas_strmm(order, side, uplo, ta, diag, m, n, *alpha, A, lda, B, ldb);
|
414
|
+
}
|
415
|
+
|
416
|
+
template <>
|
417
|
+
inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
418
|
+
const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const double* alpha,
|
419
|
+
const double* A, const int lda, double* B, const int ldb) {
|
420
|
+
cblas_dtrmm(order, side, uplo, ta, diag, m, n, *alpha, A, lda, B, ldb);
|
421
|
+
}
|
422
|
+
|
423
|
+
template <>
|
424
|
+
inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
425
|
+
const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const Complex64* alpha,
|
426
|
+
const Complex64* A, const int lda, Complex64* B, const int ldb) {
|
427
|
+
cblas_ctrmm(order, side, uplo, ta, diag, m, n, alpha, A, lda, B, ldb);
|
428
|
+
}
|
429
|
+
|
430
|
+
template <>
|
431
|
+
inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
432
|
+
const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const Complex128* alpha,
|
433
|
+
const Complex128* A, const int lda, Complex128* B, const int ldb) {
|
434
|
+
cblas_ztrmm(order, side, uplo, ta, diag, m, n, alpha, A, lda, B, ldb);
|
435
|
+
}
|
436
|
+
|
437
|
+
|
342
438
|
/*
|
343
439
|
* BLAS' DTRSM function, generalized.
|
344
440
|
*/
|
@@ -349,6 +445,9 @@ inline void trsm(const enum CBLAS_ORDER order,
|
|
349
445
|
const int m, const int n, const DType alpha, const DType* a,
|
350
446
|
const int lda, DType* b, const int ldb)
|
351
447
|
{
|
448
|
+
/*using std::cerr;
|
449
|
+
using std::endl;*/
|
450
|
+
|
352
451
|
int num_rows_a = n;
|
353
452
|
if (side == CblasLeft) num_rows_a = m;
|
354
453
|
|
@@ -368,6 +467,13 @@ inline void trsm(const enum CBLAS_ORDER order,
|
|
368
467
|
enum CBLAS_SIDE side_ = side == CblasLeft ? CblasRight : CblasLeft;
|
369
468
|
enum CBLAS_UPLO uplo_ = uplo == CblasUpper ? CblasLower : CblasUpper;
|
370
469
|
|
470
|
+
/*
|
471
|
+
cerr << "(row-major) trsm: " << (side_ == CblasLeft ? "left " : "right ")
|
472
|
+
<< (uplo_ == CblasUpper ? "upper " : "lower ")
|
473
|
+
<< (trans_a == CblasTrans ? "trans " : "notrans ")
|
474
|
+
<< (diag == CblasNonUnit ? "nonunit " : "unit ")
|
475
|
+
<< n << " " << m << " " << alpha << " a " << lda << " b " << ldb << endl;
|
476
|
+
*/
|
371
477
|
trsm_nothrow<DType>(side_, uplo_, trans_a, diag, n, m, alpha, a, lda, b, ldb);
|
372
478
|
|
373
479
|
} else { // CblasColMajor
|
@@ -376,7 +482,13 @@ inline void trsm(const enum CBLAS_ORDER order,
|
|
376
482
|
fprintf(stderr, "TRSM: M=%d; got ldb=%d\n", m, ldb);
|
377
483
|
rb_raise(rb_eArgError, "TRSM: Expected ldb >= max(1,M)");
|
378
484
|
}
|
379
|
-
|
485
|
+
/*
|
486
|
+
cerr << "(col-major) trsm: " << (side == CblasLeft ? "left " : "right ")
|
487
|
+
<< (uplo == CblasUpper ? "upper " : "lower ")
|
488
|
+
<< (trans_a == CblasTrans ? "trans " : "notrans ")
|
489
|
+
<< (diag == CblasNonUnit ? "nonunit " : "unit ")
|
490
|
+
<< m << " " << n << " " << alpha << " a " << lda << " b " << ldb << endl;
|
491
|
+
*/
|
380
492
|
trsm_nothrow<DType>(side, uplo, trans_a, diag, m, n, alpha, a, lda, b, ldb);
|
381
493
|
|
382
494
|
}
|
@@ -390,7 +502,7 @@ inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const
|
|
390
502
|
const int m, const int n, const float alpha, const float* a,
|
391
503
|
const int lda, float* b, const int ldb)
|
392
504
|
{
|
393
|
-
cblas_strsm(
|
505
|
+
cblas_strsm(order, side, uplo, trans_a, diag, m, n, alpha, a, lda, b, ldb);
|
394
506
|
}
|
395
507
|
|
396
508
|
template <>
|
@@ -399,7 +511,15 @@ inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const
|
|
399
511
|
const int m, const int n, const double alpha, const double* a,
|
400
512
|
const int lda, double* b, const int ldb)
|
401
513
|
{
|
402
|
-
|
514
|
+
/* using std::cerr;
|
515
|
+
using std::endl;
|
516
|
+
cerr << "(row-major) dtrsm: " << (side == CblasLeft ? "left " : "right ")
|
517
|
+
<< (uplo == CblasUpper ? "upper " : "lower ")
|
518
|
+
<< (trans_a == CblasTrans ? "trans " : "notrans ")
|
519
|
+
<< (diag == CblasNonUnit ? "nonunit " : "unit ")
|
520
|
+
<< m << " " << n << " " << alpha << " a " << lda << " b " << ldb << endl;
|
521
|
+
*/
|
522
|
+
cblas_dtrsm(order, side, uplo, trans_a, diag, m, n, alpha, a, lda, b, ldb);
|
403
523
|
}
|
404
524
|
|
405
525
|
|
@@ -409,7 +529,7 @@ inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const
|
|
409
529
|
const int m, const int n, const Complex64 alpha, const Complex64* a,
|
410
530
|
const int lda, Complex64* b, const int ldb)
|
411
531
|
{
|
412
|
-
cblas_ctrsm(
|
532
|
+
cblas_ctrsm(order, side, uplo, trans_a, diag, m, n, (const void*)(&alpha), (const void*)(a), lda, (void*)(b), ldb);
|
413
533
|
}
|
414
534
|
|
415
535
|
template <>
|
@@ -418,7 +538,7 @@ inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const
|
|
418
538
|
const int m, const int n, const Complex128 alpha, const Complex128* a,
|
419
539
|
const int lda, Complex128* b, const int ldb)
|
420
540
|
{
|
421
|
-
cblas_ztrsm(
|
541
|
+
cblas_ztrsm(order, side, uplo, trans_a, diag, m, n, (const void*)(&alpha), (const void*)(a), lda, (void*)(b), ldb);
|
422
542
|
}
|
423
543
|
|
424
544
|
|
@@ -429,7 +549,7 @@ inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const
|
|
429
549
|
*/
|
430
550
|
template <typename DType>
|
431
551
|
inline void laswp(const int N, DType* A, const int lda, const int K1, const int K2, const int *piv, const int inci) {
|
432
|
-
const int n = K2 - K1;
|
552
|
+
//const int n = K2 - K1; // not sure why this is declared. commented it out because it's unused.
|
433
553
|
|
434
554
|
int nb = N >> 5;
|
435
555
|
|
@@ -1261,6 +1381,87 @@ inline int getrf_nothrow(const int M, const int N, DType* A, const int lda, int*
|
|
1261
1381
|
return(ierr);
|
1262
1382
|
}
|
1263
1383
|
|
1384
|
+
/*
|
1385
|
+
* Solves a system of linear equations A*X = B with a general NxN matrix A using the LU factorization computed by GETRF.
|
1386
|
+
*
|
1387
|
+
* From ATLAS 3.8.0.
|
1388
|
+
*/
|
1389
|
+
template <typename DType>
|
1390
|
+
int getrs(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE Trans, const int N, const int NRHS, const DType* A,
|
1391
|
+
const int lda, const int* ipiv, DType* B, const int ldb)
|
1392
|
+
{
|
1393
|
+
// enum CBLAS_DIAG Lunit, Uunit; // These aren't used. Not sure why they're declared in ATLAS' src.
|
1394
|
+
|
1395
|
+
if (!N || !NRHS) return 0;
|
1396
|
+
|
1397
|
+
const DType ONE = 1;
|
1398
|
+
|
1399
|
+
if (Order == CblasColMajor) {
|
1400
|
+
if (Trans == CblasNoTrans) {
|
1401
|
+
nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, 1);
|
1402
|
+
nm::math::trsm<DType>(Order, CblasLeft, CblasLower, CblasNoTrans, CblasUnit, N, NRHS, ONE, A, lda, B, ldb);
|
1403
|
+
nm::math::trsm<DType>(Order, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
|
1404
|
+
} else {
|
1405
|
+
nm::math::trsm<DType>(Order, CblasLeft, CblasUpper, Trans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
|
1406
|
+
nm::math::trsm<DType>(Order, CblasLeft, CblasLower, Trans, CblasUnit, N, NRHS, ONE, A, lda, B, ldb);
|
1407
|
+
nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, -1);
|
1408
|
+
}
|
1409
|
+
} else {
|
1410
|
+
if (Trans == CblasNoTrans) {
|
1411
|
+
nm::math::trsm<DType>(Order, CblasRight, CblasLower, CblasTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
|
1412
|
+
nm::math::trsm<DType>(Order, CblasRight, CblasUpper, CblasTrans, CblasUnit, NRHS, N, ONE, A, lda, B, ldb);
|
1413
|
+
nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, -1);
|
1414
|
+
} else {
|
1415
|
+
nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, 1);
|
1416
|
+
nm::math::trsm<DType>(Order, CblasRight, CblasUpper, CblasNoTrans, CblasUnit, NRHS, N, ONE, A, lda, B, ldb);
|
1417
|
+
nm::math::trsm<DType>(Order, CblasRight, CblasLower, CblasNoTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
|
1418
|
+
}
|
1419
|
+
}
|
1420
|
+
return 0;
|
1421
|
+
}
|
1422
|
+
|
1423
|
+
|
1424
|
+
/*
|
1425
|
+
* Solves a system of linear equations A*X = B with a symmetric positive definite matrix A using the Cholesky factorization computed by POTRF.
|
1426
|
+
*
|
1427
|
+
* From ATLAS 3.8.0.
|
1428
|
+
*/
|
1429
|
+
template <typename DType, bool is_complex>
|
1430
|
+
int potrs(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, const int NRHS, const DType* A,
|
1431
|
+
const int lda, DType* B, const int ldb)
|
1432
|
+
{
|
1433
|
+
// enum CBLAS_DIAG Lunit, Uunit; // These aren't used. Not sure why they're declared in ATLAS' src.
|
1434
|
+
|
1435
|
+
CBLAS_TRANSPOSE MyTrans = is_complex ? CblasConjTrans : CblasTrans;
|
1436
|
+
|
1437
|
+
if (!N || !NRHS) return 0;
|
1438
|
+
|
1439
|
+
const DType ONE = 1;
|
1440
|
+
|
1441
|
+
if (Order == CblasColMajor) {
|
1442
|
+
if (Uplo == CblasUpper) {
|
1443
|
+
nm::math::trsm<DType>(Order, CblasLeft, CblasUpper, MyTrans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
|
1444
|
+
nm::math::trsm<DType>(Order, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
|
1445
|
+
} else {
|
1446
|
+
nm::math::trsm<DType>(Order, CblasLeft, CblasLower, CblasNoTrans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
|
1447
|
+
nm::math::trsm<DType>(Order, CblasLeft, CblasLower, MyTrans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
|
1448
|
+
}
|
1449
|
+
} else {
|
1450
|
+
// There's some kind of scaling operation that normally happens here in ATLAS. Not sure what it does, so we'll only
|
1451
|
+
// worry if something breaks. It probably has to do with their non-templated code and doesn't apply to us.
|
1452
|
+
|
1453
|
+
if (Uplo == CblasUpper) {
|
1454
|
+
nm::math::trsm<DType>(Order, CblasRight, CblasUpper, CblasNoTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
|
1455
|
+
nm::math::trsm<DType>(Order, CblasRight, CblasUpper, MyTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
|
1456
|
+
} else {
|
1457
|
+
nm::math::trsm<DType>(Order, CblasRight, CblasLower, MyTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
|
1458
|
+
nm::math::trsm<DType>(Order, CblasRight, CblasLower, CblasNoTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
|
1459
|
+
}
|
1460
|
+
}
|
1461
|
+
return 0;
|
1462
|
+
}
|
1463
|
+
|
1464
|
+
|
1264
1465
|
|
1265
1466
|
/*
|
1266
1467
|
* From ATLAS 3.8.0:
|
@@ -1317,6 +1518,617 @@ inline int getrf(const enum CBLAS_ORDER Order, const int M, const int N, DType*
|
|
1317
1518
|
}
|
1318
1519
|
|
1319
1520
|
|
1521
|
+
/*
|
1522
|
+
* From ATLAS 3.8.0:
|
1523
|
+
*
|
1524
|
+
* Computes one of two LU factorizations based on the setting of the Order
|
1525
|
+
* parameter, as follows:
|
1526
|
+
* ----------------------------------------------------------------------------
|
1527
|
+
* Order == CblasColMajor
|
1528
|
+
* Column-major factorization of form
|
1529
|
+
* A = P * L * U
|
1530
|
+
* where P is a row-permutation matrix, L is lower triangular with unit
|
1531
|
+
* diagonal elements (lower trapazoidal if M > N), and U is upper triangular
|
1532
|
+
* (upper trapazoidal if M < N).
|
1533
|
+
*
|
1534
|
+
* ----------------------------------------------------------------------------
|
1535
|
+
* Order == CblasRowMajor
|
1536
|
+
* Row-major factorization of form
|
1537
|
+
* A = P * L * U
|
1538
|
+
* where P is a column-permutation matrix, L is lower triangular (lower
|
1539
|
+
* trapazoidal if M > N), and U is upper triangular with unit diagonals (upper
|
1540
|
+
* trapazoidal if M < N).
|
1541
|
+
*
|
1542
|
+
* ============================================================================
|
1543
|
+
* Let IERR be the return value of the function:
|
1544
|
+
* If IERR == 0, successful exit.
|
1545
|
+
* If (IERR < 0) the -IERR argument had an illegal value
|
1546
|
+
* If (IERR > 0 && Order == CblasColMajor)
|
1547
|
+
* U(i-1,i-1) is exactly zero. The factorization has been completed,
|
1548
|
+
* but the factor U is exactly singular, and division by zero will
|
1549
|
+
* occur if it is used to solve a system of equations.
|
1550
|
+
* If (IERR > 0 && Order == CblasRowMajor)
|
1551
|
+
* L(i-1,i-1) is exactly zero. The factorization has been completed,
|
1552
|
+
* but the factor L is exactly singular, and division by zero will
|
1553
|
+
* occur if it is used to solve a system of equations.
|
1554
|
+
*/
|
1555
|
+
template <typename DType>
|
1556
|
+
inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, DType* A, const int lda) {
|
1557
|
+
#ifdef HAVE_CLAPACK_H
|
1558
|
+
rb_raise(rb_eNotImpError, "not yet implemented for non-BLAS dtypes");
|
1559
|
+
#else
|
1560
|
+
rb_raise(rb_eNotImpError, "only LAPACK version implemented thus far");
|
1561
|
+
#endif
|
1562
|
+
return 0;
|
1563
|
+
}
|
1564
|
+
|
1565
|
+
#ifdef HAVE_CLAPACK_H
|
1566
|
+
template <>
|
1567
|
+
inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, float* A, const int lda) {
|
1568
|
+
return clapack_spotrf(order, uplo, N, A, lda);
|
1569
|
+
}
|
1570
|
+
|
1571
|
+
template <>
|
1572
|
+
inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, double* A, const int lda) {
|
1573
|
+
return clapack_dpotrf(order, uplo, N, A, lda);
|
1574
|
+
}
|
1575
|
+
|
1576
|
+
template <>
|
1577
|
+
inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, Complex64* A, const int lda) {
|
1578
|
+
return clapack_cpotrf(order, uplo, N, reinterpret_cast<void*>(A), lda);
|
1579
|
+
}
|
1580
|
+
|
1581
|
+
template <>
|
1582
|
+
inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, Complex128* A, const int lda) {
|
1583
|
+
return clapack_zpotrf(order, uplo, N, reinterpret_cast<void*>(A), lda);
|
1584
|
+
}
|
1585
|
+
#endif
|
1586
|
+
|
1587
|
+
|
1588
|
+
// This is the old BLAS version of this function. ATLAS has an optimized version, but
|
1589
|
+
// it's going to be tough to translate.
|
1590
|
+
template <typename DType>
|
1591
|
+
static void swap(const int N, DType* X, const int incX, DType* Y, const int incY) {
|
1592
|
+
if (N > 0) {
|
1593
|
+
int ix = 0, iy = 0;
|
1594
|
+
for (int i = 0; i < N; ++i) {
|
1595
|
+
DType temp = X[i];
|
1596
|
+
X[i] = Y[i];
|
1597
|
+
Y[i] = temp;
|
1598
|
+
|
1599
|
+
ix += incX;
|
1600
|
+
iy += incY;
|
1601
|
+
}
|
1602
|
+
}
|
1603
|
+
}
|
1604
|
+
|
1605
|
+
|
1606
|
+
// Copies an upper row-major array from U, zeroing U; U is unit, so diagonal is not copied.
|
1607
|
+
//
|
1608
|
+
// From ATLAS 3.8.0.
|
1609
|
+
template <typename DType>
|
1610
|
+
static inline void trcpzeroU(const int M, const int N, DType* U, const int ldu, DType* C, const int ldc) {
|
1611
|
+
|
1612
|
+
for (int i = 0; i != M; ++i) {
|
1613
|
+
for (int j = i+1; j < N; ++j) {
|
1614
|
+
C[j] = U[j];
|
1615
|
+
U[j] = 0;
|
1616
|
+
}
|
1617
|
+
|
1618
|
+
C += ldc;
|
1619
|
+
U += ldu;
|
1620
|
+
}
|
1621
|
+
}
|
1622
|
+
|
1623
|
+
|
1624
|
+
/*
|
1625
|
+
* Un-comment the following lines when we figure out how to calculate NB for each of the ATLAS-derived
|
1626
|
+
* functions. This is probably really complicated.
|
1627
|
+
*
|
1628
|
+
* Also needed: ATL_MulByNB, ATL_DivByNB (both defined in the build process for ATLAS), and ATL_mmMU.
|
1629
|
+
*
|
1630
|
+
*/
|
1631
|
+
|
1632
|
+
/*
|
1633
|
+
|
1634
|
+
template <bool RowMajor, bool Upper, typename DType>
|
1635
|
+
static int trtri_4(const enum CBLAS_DIAG Diag, DType* A, const int lda) {
|
1636
|
+
|
1637
|
+
if (RowMajor) {
|
1638
|
+
DType *pA0 = A, *pA1 = A+lda, *pA2 = A+2*lda, *pA3 = A+3*lda;
|
1639
|
+
DType tmp;
|
1640
|
+
if (Upper) {
|
1641
|
+
DType A01 = pA0[1], A02 = pA0[2], A03 = pA0[3],
|
1642
|
+
A12 = pA1[2], A13 = pA1[3],
|
1643
|
+
A23 = pA2[3];
|
1644
|
+
|
1645
|
+
if (Diag == CblasNonUnit) {
|
1646
|
+
pA0->inverse();
|
1647
|
+
(pA1+1)->inverse();
|
1648
|
+
(pA2+2)->inverse();
|
1649
|
+
(pA3+3)->inverse();
|
1650
|
+
|
1651
|
+
pA0[1] = -A01 * pA1[1] * pA0[0];
|
1652
|
+
pA1[2] = -A12 * pA2[2] * pA1[1];
|
1653
|
+
pA2[3] = -A23 * pA3[3] * pA2[2];
|
1654
|
+
|
1655
|
+
pA0[2] = -(A01 * pA1[2] + A02 * pA2[2]) * pA0[0];
|
1656
|
+
pA1[3] = -(A12 * pA2[3] + A13 * pA3[3]) * pA1[1];
|
1657
|
+
|
1658
|
+
pA0[3] = -(A01 * pA1[3] + A02 * pA2[3] + A03 * pA3[3]) * pA0[0];
|
1659
|
+
|
1660
|
+
} else {
|
1661
|
+
|
1662
|
+
pA0[1] = -A01;
|
1663
|
+
pA1[2] = -A12;
|
1664
|
+
pA2[3] = -A23;
|
1665
|
+
|
1666
|
+
pA0[2] = -(A01 * pA1[2] + A02);
|
1667
|
+
pA1[3] = -(A12 * pA2[3] + A13);
|
1668
|
+
|
1669
|
+
pA0[3] = -(A01 * pA1[3] + A02 * pA2[3] + A03);
|
1670
|
+
}
|
1671
|
+
|
1672
|
+
} else { // Lower
|
1673
|
+
DType A10 = pA1[0],
|
1674
|
+
A20 = pA2[0], A21 = pA2[1],
|
1675
|
+
A30 = PA3[0], A31 = pA3[1], A32 = pA3[2];
|
1676
|
+
DType *B10 = pA1,
|
1677
|
+
*B20 = pA2,
|
1678
|
+
*B30 = pA3,
|
1679
|
+
*B21 = pA2+1,
|
1680
|
+
*B31 = pA3+1,
|
1681
|
+
*B32 = pA3+2;
|
1682
|
+
|
1683
|
+
|
1684
|
+
if (Diag == CblasNonUnit) {
|
1685
|
+
pA0->inverse();
|
1686
|
+
(pA1+1)->inverse();
|
1687
|
+
(pA2+2)->inverse();
|
1688
|
+
(pA3+3)->inverse();
|
1689
|
+
|
1690
|
+
*B10 = -A10 * pA0[0] * pA1[1];
|
1691
|
+
*B21 = -A21 * pA1[1] * pA2[2];
|
1692
|
+
*B32 = -A32 * pA2[2] * pA3[3];
|
1693
|
+
*B20 = -(A20 * pA0[0] + A21 * (*B10)) * pA2[2];
|
1694
|
+
*B31 = -(A31 * pA1[1] + A32 * (*B21)) * pA3[3];
|
1695
|
+
*B30 = -(A30 * pA0[0] + A31 * (*B10) + A32 * (*B20)) * pA3;
|
1696
|
+
} else {
|
1697
|
+
*B10 = -A10;
|
1698
|
+
*B21 = -A21;
|
1699
|
+
*B32 = -A32;
|
1700
|
+
*B20 = -(A20 + A21 * (*B10));
|
1701
|
+
*B31 = -(A31 + A32 * (*B21));
|
1702
|
+
*B30 = -(A30 + A31 * (*B10) + A32 * (*B20));
|
1703
|
+
}
|
1704
|
+
}
|
1705
|
+
|
1706
|
+
} else {
|
1707
|
+
rb_raise(rb_eNotImpError, "only row-major implemented at this time");
|
1708
|
+
}
|
1709
|
+
|
1710
|
+
return 0;
|
1711
|
+
|
1712
|
+
}
|
1713
|
+
|
1714
|
+
|
1715
|
+
template <bool RowMajor, bool Upper, typename DType>
|
1716
|
+
static int trtri_3(const enum CBLAS_DIAG Diag, DType* A, const int lda) {
|
1717
|
+
|
1718
|
+
if (RowMajor) {
|
1719
|
+
|
1720
|
+
DType tmp;
|
1721
|
+
|
1722
|
+
if (Upper) {
|
1723
|
+
DType A01 = pA0[1], A02 = pA0[2], A03 = pA0[3],
|
1724
|
+
A12 = pA1[2], A13 = pA1[3];
|
1725
|
+
|
1726
|
+
DType *B01 = pA0 + 1,
|
1727
|
+
*B02 = pA0 + 2,
|
1728
|
+
*B12 = pA1 + 2;
|
1729
|
+
|
1730
|
+
if (Diag == CblasNonUnit) {
|
1731
|
+
pA0->inverse();
|
1732
|
+
(pA1+1)->inverse();
|
1733
|
+
(pA2+2)->inverse();
|
1734
|
+
|
1735
|
+
*B01 = -A01 * pA1[1] * pA0[0];
|
1736
|
+
*B12 = -A12 * pA2[2] * pA1[1];
|
1737
|
+
*B02 = -(A01 * (*B12) + A02 * pA2[2]) * pA0[0];
|
1738
|
+
} else {
|
1739
|
+
*B01 = -A01;
|
1740
|
+
*B12 = -A12;
|
1741
|
+
*B02 = -(A01 * (*B12) + A02);
|
1742
|
+
}
|
1743
|
+
|
1744
|
+
} else { // Lower
|
1745
|
+
DType *pA0=A, *pA1=A+lda, *pA2=A+2*lda;
|
1746
|
+
DType A10=pA1[0],
|
1747
|
+
A20=pA2[0], A21=pA2[1];
|
1748
|
+
|
1749
|
+
DType *B10 = pA1,
|
1750
|
+
*B20 = pA2;
|
1751
|
+
*B21 = pA2+1;
|
1752
|
+
|
1753
|
+
if (Diag == CblasNonUnit) {
|
1754
|
+
pA0->inverse();
|
1755
|
+
(pA1+1)->inverse();
|
1756
|
+
(pA2+2)->inverse();
|
1757
|
+
*B10 = -A10 * pA0[0] * pA1[1];
|
1758
|
+
*B21 = -A21 * pA1[1] * pA2[2];
|
1759
|
+
*B20 = -(A20 * pA0[0] + A21 * (*B10)) * pA2[2];
|
1760
|
+
} else {
|
1761
|
+
*B10 = -A10;
|
1762
|
+
*B21 = -A21;
|
1763
|
+
*B20 = -(A20 + A21 * (*B10));
|
1764
|
+
}
|
1765
|
+
}
|
1766
|
+
|
1767
|
+
|
1768
|
+
} else {
|
1769
|
+
rb_raise(rb_eNotImpError, "only row-major implemented at this time");
|
1770
|
+
}
|
1771
|
+
|
1772
|
+
return 0;
|
1773
|
+
|
1774
|
+
}
|
1775
|
+
|
1776
|
+
template <bool RowMajor, bool Upper, bool Real, typename DType>
|
1777
|
+
static void trtri(const enum CBLAS_DIAG Diag, const int N, DType* A, const int lda) {
|
1778
|
+
DType *Age, *Atr;
|
1779
|
+
DType tmp;
|
1780
|
+
int Nleft, Nright;
|
1781
|
+
|
1782
|
+
int ierr = 0;
|
1783
|
+
|
1784
|
+
static const DType ONE = 1;
|
1785
|
+
static const DType MONE -1;
|
1786
|
+
static const DType NONE = -1;
|
1787
|
+
|
1788
|
+
if (RowMajor) {
|
1789
|
+
|
1790
|
+
// FIXME: Use REAL_RECURSE_LIMIT here for float32 and float64 (instead of 1)
|
1791
|
+
if ((Real && N > REAL_RECURSE_LIMIT) || (N > 1)) {
|
1792
|
+
Nleft = N >> 1;
|
1793
|
+
#ifdef NB
|
1794
|
+
if (Nleft > NB) NLeft = ATL_MulByNB(ATL_DivByNB(Nleft));
|
1795
|
+
#endif
|
1796
|
+
|
1797
|
+
Nright = N - Nleft;
|
1798
|
+
|
1799
|
+
if (Upper) {
|
1800
|
+
Age = A + Nleft;
|
1801
|
+
Atr = A + (Nleft * (lda+1));
|
1802
|
+
|
1803
|
+
nm::math::trsm<DType>(CblasRowMajor, CblasRight, CblasUpper, CblasNoTrans, Diag,
|
1804
|
+
Nleft, Nright, ONE, Atr, lda, Age, lda);
|
1805
|
+
|
1806
|
+
nm::math::trsm<DType>(CblasRowMajor, CblasLeft, CblasUpper, CblasNoTrans, Diag,
|
1807
|
+
Nleft, Nright, MONE, A, lda, Age, lda);
|
1808
|
+
|
1809
|
+
} else { // Lower
|
1810
|
+
Age = A + ((Nleft*lda));
|
1811
|
+
Atr = A + (Nleft * (lda+1));
|
1812
|
+
|
1813
|
+
nm::math::trsm<DType>(CblasRowMajor, CblasRight, CblasLower, CblasNoTrans, Diag,
|
1814
|
+
Nright, Nleft, ONE, A, lda, Age, lda);
|
1815
|
+
nm::math::trsm<DType>(CblasRowMajor, CblasLeft, CblasLower, CblasNoTrans, Diag,
|
1816
|
+
Nright, Nleft, MONE, Atr, lda, Age, lda);
|
1817
|
+
}
|
1818
|
+
|
1819
|
+
ierr = trtri<RowMajor,Upper,Real,DType>(Diag, Nleft, A, lda);
|
1820
|
+
if (ierr) return ierr;
|
1821
|
+
|
1822
|
+
ierr = trtri<RowMajor,Upper,Real,DType>(Diag, Nright, Atr, lda);
|
1823
|
+
if (ierr) return ierr + Nleft;
|
1824
|
+
|
1825
|
+
} else {
|
1826
|
+
if (Real) {
|
1827
|
+
if (N == 4) {
|
1828
|
+
return trtri_4<RowMajor,Upper,Real,DType>(Diag, A, lda);
|
1829
|
+
} else if (N == 3) {
|
1830
|
+
return trtri_3<RowMajor,Upper,Real,DType>(Diag, A, lda);
|
1831
|
+
} else if (N == 2) {
|
1832
|
+
if (Diag == CblasNonUnit) {
|
1833
|
+
A->inverse();
|
1834
|
+
(A+(lda+1))->inverse();
|
1835
|
+
|
1836
|
+
if (Upper) {
|
1837
|
+
*(A+1) *= *A; // TRI_MUL
|
1838
|
+
*(A+1) *= *(A+lda+1); // TRI_MUL
|
1839
|
+
} else {
|
1840
|
+
*(A+lda) *= *A; // TRI_MUL
|
1841
|
+
*(A+lda) *= *(A+lda+1); // TRI_MUL
|
1842
|
+
}
|
1843
|
+
}
|
1844
|
+
|
1845
|
+
if (Upper) *(A+1) = -*(A+1); // TRI_NEG
|
1846
|
+
else *(A+lda) = -*(A+lda); // TRI_NEG
|
1847
|
+
} else if (Diag == CblasNonUnit) A->inverse();
|
1848
|
+
} else { // not real
|
1849
|
+
if (Diag == CblasNonUnit) A->inverse();
|
1850
|
+
}
|
1851
|
+
}
|
1852
|
+
|
1853
|
+
} else {
|
1854
|
+
rb_raise(rb_eNotImpError, "only row-major implemented at this time");
|
1855
|
+
}
|
1856
|
+
|
1857
|
+
return ierr;
|
1858
|
+
}
|
1859
|
+
|
1860
|
+
|
1861
|
+
template <bool RowMajor, bool Real, typename DType>
|
1862
|
+
int getri(const int N, DType* A, const int lda, const int* ipiv, DType* wrk, const int lwrk) {
|
1863
|
+
|
1864
|
+
if (!RowMajor) rb_raise(rb_eNotImpError, "only row-major implemented at this time");
|
1865
|
+
|
1866
|
+
int jb, nb, I, ndown, iret;
|
1867
|
+
|
1868
|
+
const DType ONE = 1, NONE = -1;
|
1869
|
+
|
1870
|
+
int iret = trtri<RowMajor,false,Real,DType>(CblasNonUnit, N, A, lda);
|
1871
|
+
if (!iret && N > 1) {
|
1872
|
+
jb = lwrk / N;
|
1873
|
+
if (jb >= NB) nb = ATL_MulByNB(ATL_DivByNB(jb));
|
1874
|
+
else if (jb >= ATL_mmMU) nb = (jb/ATL_mmMU)*ATL_mmMU;
|
1875
|
+
else nb = jb;
|
1876
|
+
if (!nb) return -6; // need at least 1 row of workspace
|
1877
|
+
|
1878
|
+
// only first iteration will have partial block, unroll it
|
1879
|
+
|
1880
|
+
jb = N - (N/nb) * nb;
|
1881
|
+
if (!jb) jb = nb;
|
1882
|
+
I = N - jb;
|
1883
|
+
A += lda * I;
|
1884
|
+
trcpzeroU<DType>(jb, jb, A+I, lda, wrk, jb);
|
1885
|
+
nm::math::trsm<DType>(CblasRowMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasUnit,
|
1886
|
+
jb, N, ONE, wrk, jb, A, lda);
|
1887
|
+
|
1888
|
+
if (I) {
|
1889
|
+
do {
|
1890
|
+
I -= nb;
|
1891
|
+
A -= nb * lda;
|
1892
|
+
ndown = N-I;
|
1893
|
+
trcpzeroU<DType>(nb, ndown, A+I, lda, wrk, ndown);
|
1894
|
+
nm::math::gemm<DType>(CblasRowMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasUnit,
|
1895
|
+
nb, N, ONE, wrk, ndown, A, lda);
|
1896
|
+
} while (I);
|
1897
|
+
}
|
1898
|
+
|
1899
|
+
// Apply row interchanges
|
1900
|
+
|
1901
|
+
for (I = N - 2; I >= 0; --I) {
|
1902
|
+
jb = ipiv[I];
|
1903
|
+
if (jb != I) nm::math::swap<DType>(N, A+I*lda, 1, A+jb*lda, 1);
|
1904
|
+
}
|
1905
|
+
}
|
1906
|
+
|
1907
|
+
return iret;
|
1908
|
+
}
|
1909
|
+
*/
|
1910
|
+
|
1911
|
+
|
1912
|
+
// TODO: Test this to see if it works properly on complex. ATLAS has a separate algorithm for complex, which looks like
|
1913
|
+
// TODO: it may actually be the same one.
|
1914
|
+
//
|
1915
|
+
// This function is called ATL_rot in ATLAS 3.8.4.
|
1916
|
+
template <typename DType>
|
1917
|
+
inline void rot_helper(const int N, DType* X, const int incX, DType* Y, const int incY, const DType c, const DType s) {
|
1918
|
+
if (c != 1 || s != 0) {
|
1919
|
+
if (incX == 1 && incY == 1) {
|
1920
|
+
for (int i = 0; i != N; ++i) {
|
1921
|
+
DType tmp = X[i] * c + Y[i] * s;
|
1922
|
+
Y[i] = Y[i] * c - X[i] * s;
|
1923
|
+
X[i] = tmp;
|
1924
|
+
}
|
1925
|
+
} else {
|
1926
|
+
for (int i = N; i > 0; --i, Y += incY, X += incX) {
|
1927
|
+
DType tmp = *X * c + *Y * s;
|
1928
|
+
*Y = *Y * c - *X * s;
|
1929
|
+
*X = tmp;
|
1930
|
+
}
|
1931
|
+
}
|
1932
|
+
}
|
1933
|
+
}
|
1934
|
+
|
1935
|
+
|
1936
|
+
/* Givens plane rotation. From ATLAS 3.8.4. */
|
1937
|
+
// FIXME: Need a specialized algorithm for Rationals. BLAS' algorithm simply will not work for most values due to the
|
1938
|
+
// FIXME: sqrt.
|
1939
|
+
template <typename DType>
|
1940
|
+
inline void rotg(DType* a, DType* b, DType* c, DType* s) {
|
1941
|
+
DType aa = std::abs(*a), ab = std::abs(*b);
|
1942
|
+
DType roe = aa > ab ? *a : *b;
|
1943
|
+
DType scal = aa + ab;
|
1944
|
+
|
1945
|
+
if (scal == 0) {
|
1946
|
+
*c = 1;
|
1947
|
+
*s = *a = *b = 0;
|
1948
|
+
} else {
|
1949
|
+
DType t0 = aa / scal, t1 = ab / scal;
|
1950
|
+
DType r = scal * std::sqrt(t0 * t0 + t1 * t1);
|
1951
|
+
if (roe < 0) r = -r;
|
1952
|
+
*c = *a / r;
|
1953
|
+
*s = *b / r;
|
1954
|
+
DType z = (*c != 0) ? (1 / *c) : DType(1);
|
1955
|
+
*a = r;
|
1956
|
+
*b = z;
|
1957
|
+
}
|
1958
|
+
}
|
1959
|
+
|
1960
|
+
template <>
|
1961
|
+
inline void rotg(float* a, float* b, float* c, float* s) {
|
1962
|
+
cblas_srotg(a, b, c, s);
|
1963
|
+
}
|
1964
|
+
|
1965
|
+
template <>
|
1966
|
+
inline void rotg(double* a, double* b, double* c, double* s) {
|
1967
|
+
cblas_drotg(a, b, c, s);
|
1968
|
+
}
|
1969
|
+
|
1970
|
+
template <>
|
1971
|
+
inline void rotg(Complex64* a, Complex64* b, Complex64* c, Complex64* s) {
|
1972
|
+
cblas_crotg(reinterpret_cast<void*>(a), reinterpret_cast<void*>(b), reinterpret_cast<void*>(c), reinterpret_cast<void*>(s));
|
1973
|
+
}
|
1974
|
+
|
1975
|
+
template <>
|
1976
|
+
inline void rotg(Complex128* a, Complex128* b, Complex128* c, Complex128* s) {
|
1977
|
+
cblas_zrotg(reinterpret_cast<void*>(a), reinterpret_cast<void*>(b), reinterpret_cast<void*>(c), reinterpret_cast<void*>(s));
|
1978
|
+
}
|
1979
|
+
|
1980
|
+
template <typename DType>
|
1981
|
+
inline void cblas_rotg(void* a, void* b, void* c, void* s) {
|
1982
|
+
rotg<DType>(reinterpret_cast<DType*>(a), reinterpret_cast<DType*>(b), reinterpret_cast<DType*>(c), reinterpret_cast<DType*>(s));
|
1983
|
+
}
|
1984
|
+
|
1985
|
+
|
1986
|
+
/* Applies a plane rotation. From ATLAS 3.8.4. */
|
1987
|
+
template <typename DType, typename CSDType>
|
1988
|
+
inline void rot(const int N, DType* X, const int incX, DType* Y, const int incY, const CSDType c, const CSDType s) {
|
1989
|
+
DType *x = X, *y = Y;
|
1990
|
+
int incx = incX, incy = incY;
|
1991
|
+
|
1992
|
+
if (N > 0) {
|
1993
|
+
if (incX < 0) {
|
1994
|
+
if (incY < 0) { incx = -incx; incy = -incy; }
|
1995
|
+
else x += -incX * (N-1);
|
1996
|
+
} else if (incY < 0) {
|
1997
|
+
incy = -incy;
|
1998
|
+
incx = -incx;
|
1999
|
+
x += (N-1) * incX;
|
2000
|
+
}
|
2001
|
+
rot_helper<DType>(N, x, incx, y, incy, c, s);
|
2002
|
+
}
|
2003
|
+
}
|
2004
|
+
|
2005
|
+
template <>
|
2006
|
+
inline void rot(const int N, float* X, const int incX, float* Y, const int incY, const float c, const float s) {
|
2007
|
+
cblas_srot(N, X, incX, Y, incY, (float)c, (float)s);
|
2008
|
+
}
|
2009
|
+
|
2010
|
+
template <>
|
2011
|
+
inline void rot(const int N, double* X, const int incX, double* Y, const int incY, const double c, const double s) {
|
2012
|
+
cblas_drot(N, X, incX, Y, incY, c, s);
|
2013
|
+
}
|
2014
|
+
|
2015
|
+
template <>
|
2016
|
+
inline void rot(const int N, Complex64* X, const int incX, Complex64* Y, const int incY, const float c, const float s) {
|
2017
|
+
cblas_csrot(N, X, incX, Y, incY, c, s);
|
2018
|
+
}
|
2019
|
+
|
2020
|
+
template <>
|
2021
|
+
inline void rot(const int N, Complex128* X, const int incX, Complex128* Y, const int incY, const double c, const double s) {
|
2022
|
+
cblas_zdrot(N, X, incX, Y, incY, c, s);
|
2023
|
+
}
|
2024
|
+
|
2025
|
+
|
2026
|
+
template <typename DType, typename CSDType>
|
2027
|
+
inline void cblas_rot(const int N, void* X, const int incX, void* Y, const int incY, const void* c, const void* s) {
|
2028
|
+
rot<DType,CSDType>(N, reinterpret_cast<DType*>(X), incX, reinterpret_cast<DType*>(Y), incY, *reinterpret_cast<const CSDType*>(c), *reinterpret_cast<const CSDType*>(s));
|
2029
|
+
}
|
2030
|
+
|
2031
|
+
|
2032
|
+
template <bool is_complex, typename DType>
|
2033
|
+
inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, DType* A, const int lda) {
|
2034
|
+
|
2035
|
+
int Nleft, Nright;
|
2036
|
+
const DType ONE = 1;
|
2037
|
+
DType *G, *U0 = A, *U1;
|
2038
|
+
|
2039
|
+
if (N > 1) {
|
2040
|
+
Nleft = N >> 1;
|
2041
|
+
#ifdef NB
|
2042
|
+
if (Nleft > NB) Nleft = ATL_MulByNB(ATL_DivByNB(Nleft));
|
2043
|
+
#endif
|
2044
|
+
|
2045
|
+
Nright = N - Nleft;
|
2046
|
+
|
2047
|
+
// FIXME: There's a simpler way to write this next block, but I'm way too tired to work it out right now.
|
2048
|
+
if (uplo == CblasUpper) {
|
2049
|
+
if (order == CblasRowMajor) {
|
2050
|
+
G = A + Nleft;
|
2051
|
+
U1 = G + Nleft * lda;
|
2052
|
+
} else {
|
2053
|
+
G = A + Nleft * lda;
|
2054
|
+
U1 = G + Nleft;
|
2055
|
+
}
|
2056
|
+
} else {
|
2057
|
+
if (order == CblasRowMajor) {
|
2058
|
+
G = A + Nleft * lda;
|
2059
|
+
U1 = G + Nleft;
|
2060
|
+
} else {
|
2061
|
+
G = A + Nleft;
|
2062
|
+
U1 = G + Nleft * lda;
|
2063
|
+
}
|
2064
|
+
}
|
2065
|
+
|
2066
|
+
lauum<is_complex, DType>(order, uplo, Nleft, U0, lda);
|
2067
|
+
|
2068
|
+
if (is_complex) {
|
2069
|
+
|
2070
|
+
nm::math::herk<DType>(order, uplo,
|
2071
|
+
uplo == CblasLower ? CblasConjTrans : CblasNoTrans,
|
2072
|
+
Nleft, Nright, &ONE, G, lda, &ONE, U0, lda);
|
2073
|
+
|
2074
|
+
nm::math::trmm<DType>(order, CblasLeft, uplo, CblasConjTrans, CblasNonUnit, Nright, Nleft, &ONE, U1, lda, G, lda);
|
2075
|
+
} else {
|
2076
|
+
nm::math::syrk<DType>(order, uplo,
|
2077
|
+
uplo == CblasLower ? CblasTrans : CblasNoTrans,
|
2078
|
+
Nleft, Nright, &ONE, G, lda, &ONE, U0, lda);
|
2079
|
+
|
2080
|
+
nm::math::trmm<DType>(order, CblasLeft, uplo, CblasTrans, CblasNonUnit, Nright, Nleft, &ONE, U1, lda, G, lda);
|
2081
|
+
}
|
2082
|
+
lauum<is_complex, DType>(order, uplo, Nright, U1, lda);
|
2083
|
+
|
2084
|
+
} else {
|
2085
|
+
*A = *A * *A;
|
2086
|
+
}
|
2087
|
+
}
|
2088
|
+
|
2089
|
+
|
2090
|
+
#ifdef HAVE_CLAPACK_H
|
2091
|
+
template <bool is_complex>
|
2092
|
+
inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, float* A, const int lda) {
|
2093
|
+
clapack_slauum(order, uplo, N, A, lda);
|
2094
|
+
}
|
2095
|
+
|
2096
|
+
template <bool is_complex>
|
2097
|
+
inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, double* A, const int lda) {
|
2098
|
+
clapack_dlauum(order, uplo, N, A, lda);
|
2099
|
+
}
|
2100
|
+
|
2101
|
+
template <bool is_complex>
|
2102
|
+
inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, Complex64* A, const int lda) {
|
2103
|
+
clapack_clauum(order, uplo, N, A, lda);
|
2104
|
+
}
|
2105
|
+
|
2106
|
+
template <bool is_complex>
|
2107
|
+
inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, Complex128* A, const int lda) {
|
2108
|
+
clapack_zlauum(order, uplo, N, A, lda);
|
2109
|
+
}
|
2110
|
+
#endif
|
2111
|
+
|
2112
|
+
|
2113
|
+
/*
|
2114
|
+
* Function signature conversion for calling LAPACK's lauum functions as directly as possible.
|
2115
|
+
*
|
2116
|
+
* For documentation: http://www.netlib.org/lapack/double/dlauum.f
|
2117
|
+
*
|
2118
|
+
* This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
|
2119
|
+
*/
|
2120
|
+
template <bool is_complex, typename DType>
|
2121
|
+
inline int clapack_lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, void* a, const int lda) {
|
2122
|
+
if (n < 0) rb_raise(rb_eArgError, "n cannot be less than zero, is set to %d", n);
|
2123
|
+
if (lda < n || lda < 1) rb_raise(rb_eArgError, "lda must be >= max(n,1); lda=%d, n=%d\n", lda, n);
|
2124
|
+
|
2125
|
+
if (uplo == CblasUpper) lauum<is_complex, DType>(order, uplo, n, reinterpret_cast<DType*>(a), lda);
|
2126
|
+
else lauum<is_complex, DType>(order, uplo, n, reinterpret_cast<DType*>(a), lda);
|
2127
|
+
|
2128
|
+
return 0;
|
2129
|
+
}
|
2130
|
+
|
2131
|
+
|
1320
2132
|
|
1321
2133
|
|
1322
2134
|
/*
|
@@ -1357,6 +2169,143 @@ inline int clapack_getrf(const enum CBLAS_ORDER order, const int m, const int n,
|
|
1357
2169
|
}
|
1358
2170
|
|
1359
2171
|
|
2172
|
+
/*
|
2173
|
+
* Function signature conversion for calling LAPACK's potrf functions as directly as possible.
|
2174
|
+
*
|
2175
|
+
* For documentation: http://www.netlib.org/lapack/double/dpotrf.f
|
2176
|
+
*
|
2177
|
+
* This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
|
2178
|
+
*/
|
2179
|
+
template <typename DType>
|
2180
|
+
inline int clapack_potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, void* a, const int lda) {
|
2181
|
+
return potrf<DType>(order, uplo, n, reinterpret_cast<DType*>(a), lda);
|
2182
|
+
}
|
2183
|
+
|
2184
|
+
|
2185
|
+
/*
|
2186
|
+
* Function signature conversion for calling LAPACK's getrs functions as directly as possible.
|
2187
|
+
*
|
2188
|
+
* For documentation: http://www.netlib.org/lapack/double/dgetrs.f
|
2189
|
+
*
|
2190
|
+
* This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
|
2191
|
+
*/
|
2192
|
+
template <typename DType>
|
2193
|
+
inline int clapack_getrs(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans, const int n, const int nrhs,
|
2194
|
+
const void* a, const int lda, const int* ipiv, void* b, const int ldb) {
|
2195
|
+
return getrs<DType>(order, trans, n, nrhs, reinterpret_cast<const DType*>(a), lda, ipiv, reinterpret_cast<DType*>(b), ldb);
|
2196
|
+
}
|
2197
|
+
|
2198
|
+
/*
|
2199
|
+
* Function signature conversion for calling LAPACK's potrs functions as directly as possible.
|
2200
|
+
*
|
2201
|
+
* For documentation: http://www.netlib.org/lapack/double/dpotrs.f
|
2202
|
+
*
|
2203
|
+
* This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
|
2204
|
+
*/
|
2205
|
+
template <typename DType, bool is_complex>
|
2206
|
+
inline int clapack_potrs(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, const int nrhs,
|
2207
|
+
const void* a, const int lda, void* b, const int ldb) {
|
2208
|
+
return potrs<DType,is_complex>(order, uplo, n, nrhs, reinterpret_cast<const DType*>(a), lda, reinterpret_cast<DType*>(b), ldb);
|
2209
|
+
}
|
2210
|
+
|
2211
|
+
template <typename DType>
|
2212
|
+
inline int getri(const enum CBLAS_ORDER order, const int n, DType* a, const int lda, const int* ipiv) {
|
2213
|
+
rb_raise(rb_eNotImpError, "getri not yet implemented for non-BLAS dtypes");
|
2214
|
+
return 0;
|
2215
|
+
}
|
2216
|
+
|
2217
|
+
#ifdef HAVE_CLAPACK_H
|
2218
|
+
template <>
|
2219
|
+
inline int getri(const enum CBLAS_ORDER order, const int n, float* a, const int lda, const int* ipiv) {
|
2220
|
+
return clapack_sgetri(order, n, a, lda, ipiv);
|
2221
|
+
}
|
2222
|
+
|
2223
|
+
template <>
|
2224
|
+
inline int getri(const enum CBLAS_ORDER order, const int n, double* a, const int lda, const int* ipiv) {
|
2225
|
+
return clapack_dgetri(order, n, a, lda, ipiv);
|
2226
|
+
}
|
2227
|
+
|
2228
|
+
template <>
|
2229
|
+
inline int getri(const enum CBLAS_ORDER order, const int n, Complex64* a, const int lda, const int* ipiv) {
|
2230
|
+
return clapack_cgetri(order, n, reinterpret_cast<void*>(a), lda, ipiv);
|
2231
|
+
}
|
2232
|
+
|
2233
|
+
template <>
|
2234
|
+
inline int getri(const enum CBLAS_ORDER order, const int n, Complex128* a, const int lda, const int* ipiv) {
|
2235
|
+
return clapack_zgetri(order, n, reinterpret_cast<void*>(a), lda, ipiv);
|
2236
|
+
}
|
2237
|
+
#endif
|
2238
|
+
|
2239
|
+
|
2240
|
+
template <typename DType>
|
2241
|
+
inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, DType* a, const int lda) {
|
2242
|
+
rb_raise(rb_eNotImpError, "potri not yet implemented for non-BLAS dtypes");
|
2243
|
+
return 0;
|
2244
|
+
}
|
2245
|
+
|
2246
|
+
|
2247
|
+
#ifdef HAVE_CLAPACK_H
|
2248
|
+
template <>
|
2249
|
+
inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, float* a, const int lda) {
|
2250
|
+
return clapack_spotri(order, uplo, n, a, lda);
|
2251
|
+
}
|
2252
|
+
|
2253
|
+
template <>
|
2254
|
+
inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, double* a, const int lda) {
|
2255
|
+
return clapack_dpotri(order, uplo, n, a, lda);
|
2256
|
+
}
|
2257
|
+
|
2258
|
+
template <>
|
2259
|
+
inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, Complex64* a, const int lda) {
|
2260
|
+
return clapack_cpotri(order, uplo, n, reinterpret_cast<void*>(a), lda);
|
2261
|
+
}
|
2262
|
+
|
2263
|
+
template <>
|
2264
|
+
inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, Complex128* a, const int lda) {
|
2265
|
+
return clapack_zpotri(order, uplo, n, reinterpret_cast<void*>(a), lda);
|
2266
|
+
}
|
2267
|
+
#endif
|
2268
|
+
|
2269
|
+
/*
|
2270
|
+
* Function signature conversion for calling LAPACK's getri functions as directly as possible.
|
2271
|
+
*
|
2272
|
+
* For documentation: http://www.netlib.org/lapack/double/dgetri.f
|
2273
|
+
*
|
2274
|
+
* This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
|
2275
|
+
*/
|
2276
|
+
template <typename DType>
|
2277
|
+
inline int clapack_getri(const enum CBLAS_ORDER order, const int n, void* a, const int lda, const int* ipiv) {
|
2278
|
+
return getri<DType>(order, n, reinterpret_cast<DType*>(a), lda, ipiv);
|
2279
|
+
}
|
2280
|
+
|
2281
|
+
|
2282
|
+
/*
|
2283
|
+
* Function signature conversion for calling LAPACK's potri functions as directly as possible.
|
2284
|
+
*
|
2285
|
+
* For documentation: http://www.netlib.org/lapack/double/dpotri.f
|
2286
|
+
*
|
2287
|
+
* This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
|
2288
|
+
*/
|
2289
|
+
template <typename DType>
|
2290
|
+
inline int clapack_potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, void* a, const int lda) {
|
2291
|
+
return potri<DType>(order, uplo, n, reinterpret_cast<DType*>(a), lda);
|
2292
|
+
}
|
2293
|
+
|
2294
|
+
|
2295
|
+
/*
|
2296
|
+
* Function signature conversion for calling LAPACK's laswp functions as directly as possible.
|
2297
|
+
*
|
2298
|
+
* For documentation: http://www.netlib.org/lapack/double/dlaswp.f
|
2299
|
+
*
|
2300
|
+
* This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
|
2301
|
+
*/
|
2302
|
+
template <typename DType>
|
2303
|
+
inline void clapack_laswp(const int n, void* a, const int lda, const int k1, const int k2, const int* ipiv, const int incx) {
|
2304
|
+
laswp<DType>(n, reinterpret_cast<DType*>(a), lda, k1, k2, ipiv, incx);
|
2305
|
+
}
|
2306
|
+
|
2307
|
+
|
2308
|
+
|
1360
2309
|
}} // end namespace nm::math
|
1361
2310
|
|
1362
2311
|
|