nmatrix 0.0.2 → 0.0.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- data/Gemfile +1 -1
- data/History.txt +31 -3
- data/Manifest.txt +5 -0
- data/README.rdoc +29 -27
- data/ext/nmatrix/binary_format.txt +53 -0
- data/ext/nmatrix/data/data.cpp +18 -18
- data/ext/nmatrix/data/data.h +38 -7
- data/ext/nmatrix/data/rational.h +13 -0
- data/ext/nmatrix/data/ruby_object.h +10 -0
- data/ext/nmatrix/extconf.rb +2 -0
- data/ext/nmatrix/nmatrix.cpp +655 -103
- data/ext/nmatrix/nmatrix.h +26 -14
- data/ext/nmatrix/ruby_constants.cpp +4 -0
- data/ext/nmatrix/ruby_constants.h +2 -0
- data/ext/nmatrix/storage/dense.cpp +99 -41
- data/ext/nmatrix/storage/dense.h +3 -3
- data/ext/nmatrix/storage/list.cpp +36 -14
- data/ext/nmatrix/storage/list.h +4 -4
- data/ext/nmatrix/storage/storage.cpp +19 -19
- data/ext/nmatrix/storage/storage.h +11 -11
- data/ext/nmatrix/storage/yale.cpp +17 -20
- data/ext/nmatrix/storage/yale.h +13 -11
- data/ext/nmatrix/util/io.cpp +25 -23
- data/ext/nmatrix/util/io.h +5 -5
- data/ext/nmatrix/util/math.cpp +634 -17
- data/ext/nmatrix/util/math.h +958 -9
- data/ext/nmatrix/util/sl_list.cpp +7 -7
- data/ext/nmatrix/util/sl_list.h +2 -2
- data/lib/nmatrix.rb +9 -0
- data/lib/nmatrix/blas.rb +4 -4
- data/lib/nmatrix/io/market.rb +227 -0
- data/lib/nmatrix/io/mat_reader.rb +7 -7
- data/lib/nmatrix/lapack.rb +80 -0
- data/lib/nmatrix/nmatrix.rb +78 -52
- data/lib/nmatrix/shortcuts.rb +486 -0
- data/lib/nmatrix/version.rb +1 -1
- data/spec/2x2_dense_double.mat +0 -0
- data/spec/blas_spec.rb +59 -9
- data/spec/elementwise_spec.rb +25 -12
- data/spec/io_spec.rb +69 -1
- data/spec/lapack_spec.rb +53 -4
- data/spec/math_spec.rb +9 -0
- data/spec/nmatrix_list_spec.rb +95 -0
- data/spec/nmatrix_spec.rb +10 -53
- data/spec/nmatrix_yale_spec.rb +17 -15
- data/spec/shortcuts_spec.rb +154 -0
- metadata +22 -15
data/ext/nmatrix/util/math.h
CHANGED
@@ -70,7 +70,10 @@
|
|
70
70
|
|
71
71
|
extern "C" { // These need to be in an extern "C" block or you'll get all kinds of undefined symbol errors.
|
72
72
|
#include <cblas.h>
|
73
|
-
|
73
|
+
|
74
|
+
#ifdef HAVE_CLAPACK_H
|
75
|
+
#include <clapack.h>
|
76
|
+
#endif
|
74
77
|
}
|
75
78
|
|
76
79
|
#include <algorithm> // std::min, std::max
|
@@ -85,6 +88,7 @@ extern "C" { // These need to be in an extern "C" block or you'll get all kinds
|
|
85
88
|
/*
|
86
89
|
* Macros
|
87
90
|
*/
|
91
|
+
#define REAL_RECURSE_LIMIT 4
|
88
92
|
|
89
93
|
/*
|
90
94
|
* Data
|
@@ -95,7 +99,7 @@ extern "C" {
|
|
95
99
|
/*
|
96
100
|
* C accessors.
|
97
101
|
*/
|
98
|
-
void nm_math_det_exact(const int M, const void* elements, const int lda, dtype_t dtype, void* result);
|
102
|
+
void nm_math_det_exact(const int M, const void* elements, const int lda, nm::dtype_t dtype, void* result);
|
99
103
|
void nm_math_transpose_generic(const size_t M, const size_t N, const void* A, const int lda, void* B, const int ldb, size_t element_size);
|
100
104
|
void nm_math_init_blas(void);
|
101
105
|
}
|
@@ -143,6 +147,9 @@ template <> inline double numeric_inverse<double>(const double& n) { return 1 /
|
|
143
147
|
*
|
144
148
|
* For row major, call trsm<DType> instead. That will handle necessary changes-of-variables
|
145
149
|
* and parameter checks.
|
150
|
+
*
|
151
|
+
* Note that some of the boundary conditions here may be incorrect. Very little has been tested!
|
152
|
+
* This was converted directly from dtrsm.f using f2c, and then rewritten more cleanly.
|
146
153
|
*/
|
147
154
|
template <typename DType>
|
148
155
|
inline void trsm_nothrow(const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
@@ -150,6 +157,9 @@ inline void trsm_nothrow(const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
|
150
157
|
const int m, const int n, const DType alpha, const DType* a,
|
151
158
|
const int lda, DType* b, const int ldb)
|
152
159
|
{
|
160
|
+
|
161
|
+
// (row-major) trsm: left upper trans nonunit m=3 n=1 1/1 a 3 b 3
|
162
|
+
|
153
163
|
if (m == 0 || n == 0) return; /* Quick return if possible. */
|
154
164
|
|
155
165
|
if (alpha == 0) { // Handle alpha == 0
|
@@ -210,7 +220,7 @@ inline void trsm_nothrow(const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
|
210
220
|
for (int j = 0; j < n; ++j) {
|
211
221
|
for (int i = 0; i < m; ++i) {
|
212
222
|
DType temp = alpha * b[i + j * ldb];
|
213
|
-
for (int k = 0; k < i
|
223
|
+
for (int k = 0; k < i; ++k) { // limit was i-1. Lots of similar bugs in this code, probably.
|
214
224
|
temp -= a[k + i * lda] * b[k + j * ldb];
|
215
225
|
}
|
216
226
|
if (diag == CblasNonUnit) {
|
@@ -339,6 +349,92 @@ inline void trsm_nothrow(const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
|
339
349
|
}
|
340
350
|
|
341
351
|
|
352
|
+
template <typename DType>
|
353
|
+
inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
|
354
|
+
const int K, const DType* alpha, const DType* A, const int lda, const DType* beta, DType* C, const int ldc) {
|
355
|
+
rb_raise(rb_eNotImpError, "syrk not yet implemented for non-BLAS dtypes");
|
356
|
+
}
|
357
|
+
|
358
|
+
template <typename DType>
|
359
|
+
inline void herk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
|
360
|
+
const int K, const DType* alpha, const DType* A, const int lda, const DType* beta, DType* C, const int ldc) {
|
361
|
+
rb_raise(rb_eNotImpError, "herk not yet implemented for non-BLAS dtypes");
|
362
|
+
}
|
363
|
+
|
364
|
+
template <>
|
365
|
+
inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
|
366
|
+
const int K, const float* alpha, const float* A, const int lda, const float* beta, float* C, const int ldc) {
|
367
|
+
cblas_ssyrk(Order, Uplo, Trans, N, K, *alpha, A, lda, *beta, C, ldc);
|
368
|
+
}
|
369
|
+
|
370
|
+
template <>
|
371
|
+
inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
|
372
|
+
const int K, const double* alpha, const double* A, const int lda, const double* beta, double* C, const int ldc) {
|
373
|
+
cblas_dsyrk(Order, Uplo, Trans, N, K, *alpha, A, lda, *beta, C, ldc);
|
374
|
+
}
|
375
|
+
|
376
|
+
template <>
|
377
|
+
inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
|
378
|
+
const int K, const Complex64* alpha, const Complex64* A, const int lda, const Complex64* beta, Complex64* C, const int ldc) {
|
379
|
+
cblas_csyrk(Order, Uplo, Trans, N, K, alpha, A, lda, beta, C, ldc);
|
380
|
+
}
|
381
|
+
|
382
|
+
template <>
|
383
|
+
inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
|
384
|
+
const int K, const Complex128* alpha, const Complex128* A, const int lda, const Complex128* beta, Complex128* C, const int ldc) {
|
385
|
+
cblas_zsyrk(Order, Uplo, Trans, N, K, alpha, A, lda, beta, C, ldc);
|
386
|
+
}
|
387
|
+
|
388
|
+
|
389
|
+
template <>
|
390
|
+
inline void herk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
|
391
|
+
const int K, const Complex64* alpha, const Complex64* A, const int lda, const Complex64* beta, Complex64* C, const int ldc) {
|
392
|
+
cblas_cherk(Order, Uplo, Trans, N, K, alpha->r, A, lda, beta->r, C, ldc);
|
393
|
+
}
|
394
|
+
|
395
|
+
template <>
|
396
|
+
inline void herk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
|
397
|
+
const int K, const Complex128* alpha, const Complex128* A, const int lda, const Complex128* beta, Complex128* C, const int ldc) {
|
398
|
+
cblas_zherk(Order, Uplo, Trans, N, K, alpha->r, A, lda, beta->r, C, ldc);
|
399
|
+
}
|
400
|
+
|
401
|
+
|
402
|
+
template <typename DType>
|
403
|
+
inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
404
|
+
const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const DType* alpha,
|
405
|
+
const DType* A, const int lda, DType* B, const int ldb) {
|
406
|
+
rb_raise(rb_eNotImpError, "trmm not yet implemented for non-BLAS dtypes");
|
407
|
+
}
|
408
|
+
|
409
|
+
template <>
|
410
|
+
inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
411
|
+
const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const float* alpha,
|
412
|
+
const float* A, const int lda, float* B, const int ldb) {
|
413
|
+
cblas_strmm(order, side, uplo, ta, diag, m, n, *alpha, A, lda, B, ldb);
|
414
|
+
}
|
415
|
+
|
416
|
+
template <>
|
417
|
+
inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
418
|
+
const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const double* alpha,
|
419
|
+
const double* A, const int lda, double* B, const int ldb) {
|
420
|
+
cblas_dtrmm(order, side, uplo, ta, diag, m, n, *alpha, A, lda, B, ldb);
|
421
|
+
}
|
422
|
+
|
423
|
+
template <>
|
424
|
+
inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
425
|
+
const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const Complex64* alpha,
|
426
|
+
const Complex64* A, const int lda, Complex64* B, const int ldb) {
|
427
|
+
cblas_ctrmm(order, side, uplo, ta, diag, m, n, alpha, A, lda, B, ldb);
|
428
|
+
}
|
429
|
+
|
430
|
+
template <>
|
431
|
+
inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
|
432
|
+
const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const Complex128* alpha,
|
433
|
+
const Complex128* A, const int lda, Complex128* B, const int ldb) {
|
434
|
+
cblas_ztrmm(order, side, uplo, ta, diag, m, n, alpha, A, lda, B, ldb);
|
435
|
+
}
|
436
|
+
|
437
|
+
|
342
438
|
/*
|
343
439
|
* BLAS' DTRSM function, generalized.
|
344
440
|
*/
|
@@ -349,6 +445,9 @@ inline void trsm(const enum CBLAS_ORDER order,
|
|
349
445
|
const int m, const int n, const DType alpha, const DType* a,
|
350
446
|
const int lda, DType* b, const int ldb)
|
351
447
|
{
|
448
|
+
/*using std::cerr;
|
449
|
+
using std::endl;*/
|
450
|
+
|
352
451
|
int num_rows_a = n;
|
353
452
|
if (side == CblasLeft) num_rows_a = m;
|
354
453
|
|
@@ -368,6 +467,13 @@ inline void trsm(const enum CBLAS_ORDER order,
|
|
368
467
|
enum CBLAS_SIDE side_ = side == CblasLeft ? CblasRight : CblasLeft;
|
369
468
|
enum CBLAS_UPLO uplo_ = uplo == CblasUpper ? CblasLower : CblasUpper;
|
370
469
|
|
470
|
+
/*
|
471
|
+
cerr << "(row-major) trsm: " << (side_ == CblasLeft ? "left " : "right ")
|
472
|
+
<< (uplo_ == CblasUpper ? "upper " : "lower ")
|
473
|
+
<< (trans_a == CblasTrans ? "trans " : "notrans ")
|
474
|
+
<< (diag == CblasNonUnit ? "nonunit " : "unit ")
|
475
|
+
<< n << " " << m << " " << alpha << " a " << lda << " b " << ldb << endl;
|
476
|
+
*/
|
371
477
|
trsm_nothrow<DType>(side_, uplo_, trans_a, diag, n, m, alpha, a, lda, b, ldb);
|
372
478
|
|
373
479
|
} else { // CblasColMajor
|
@@ -376,7 +482,13 @@ inline void trsm(const enum CBLAS_ORDER order,
|
|
376
482
|
fprintf(stderr, "TRSM: M=%d; got ldb=%d\n", m, ldb);
|
377
483
|
rb_raise(rb_eArgError, "TRSM: Expected ldb >= max(1,M)");
|
378
484
|
}
|
379
|
-
|
485
|
+
/*
|
486
|
+
cerr << "(col-major) trsm: " << (side == CblasLeft ? "left " : "right ")
|
487
|
+
<< (uplo == CblasUpper ? "upper " : "lower ")
|
488
|
+
<< (trans_a == CblasTrans ? "trans " : "notrans ")
|
489
|
+
<< (diag == CblasNonUnit ? "nonunit " : "unit ")
|
490
|
+
<< m << " " << n << " " << alpha << " a " << lda << " b " << ldb << endl;
|
491
|
+
*/
|
380
492
|
trsm_nothrow<DType>(side, uplo, trans_a, diag, m, n, alpha, a, lda, b, ldb);
|
381
493
|
|
382
494
|
}
|
@@ -390,7 +502,7 @@ inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const
|
|
390
502
|
const int m, const int n, const float alpha, const float* a,
|
391
503
|
const int lda, float* b, const int ldb)
|
392
504
|
{
|
393
|
-
cblas_strsm(
|
505
|
+
cblas_strsm(order, side, uplo, trans_a, diag, m, n, alpha, a, lda, b, ldb);
|
394
506
|
}
|
395
507
|
|
396
508
|
template <>
|
@@ -399,7 +511,15 @@ inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const
|
|
399
511
|
const int m, const int n, const double alpha, const double* a,
|
400
512
|
const int lda, double* b, const int ldb)
|
401
513
|
{
|
402
|
-
|
514
|
+
/* using std::cerr;
|
515
|
+
using std::endl;
|
516
|
+
cerr << "(row-major) dtrsm: " << (side == CblasLeft ? "left " : "right ")
|
517
|
+
<< (uplo == CblasUpper ? "upper " : "lower ")
|
518
|
+
<< (trans_a == CblasTrans ? "trans " : "notrans ")
|
519
|
+
<< (diag == CblasNonUnit ? "nonunit " : "unit ")
|
520
|
+
<< m << " " << n << " " << alpha << " a " << lda << " b " << ldb << endl;
|
521
|
+
*/
|
522
|
+
cblas_dtrsm(order, side, uplo, trans_a, diag, m, n, alpha, a, lda, b, ldb);
|
403
523
|
}
|
404
524
|
|
405
525
|
|
@@ -409,7 +529,7 @@ inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const
|
|
409
529
|
const int m, const int n, const Complex64 alpha, const Complex64* a,
|
410
530
|
const int lda, Complex64* b, const int ldb)
|
411
531
|
{
|
412
|
-
cblas_ctrsm(
|
532
|
+
cblas_ctrsm(order, side, uplo, trans_a, diag, m, n, (const void*)(&alpha), (const void*)(a), lda, (void*)(b), ldb);
|
413
533
|
}
|
414
534
|
|
415
535
|
template <>
|
@@ -418,7 +538,7 @@ inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const
|
|
418
538
|
const int m, const int n, const Complex128 alpha, const Complex128* a,
|
419
539
|
const int lda, Complex128* b, const int ldb)
|
420
540
|
{
|
421
|
-
cblas_ztrsm(
|
541
|
+
cblas_ztrsm(order, side, uplo, trans_a, diag, m, n, (const void*)(&alpha), (const void*)(a), lda, (void*)(b), ldb);
|
422
542
|
}
|
423
543
|
|
424
544
|
|
@@ -429,7 +549,7 @@ inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const
|
|
429
549
|
*/
|
430
550
|
template <typename DType>
|
431
551
|
inline void laswp(const int N, DType* A, const int lda, const int K1, const int K2, const int *piv, const int inci) {
|
432
|
-
const int n = K2 - K1;
|
552
|
+
//const int n = K2 - K1; // not sure why this is declared. commented it out because it's unused.
|
433
553
|
|
434
554
|
int nb = N >> 5;
|
435
555
|
|
@@ -1261,6 +1381,87 @@ inline int getrf_nothrow(const int M, const int N, DType* A, const int lda, int*
|
|
1261
1381
|
return(ierr);
|
1262
1382
|
}
|
1263
1383
|
|
1384
|
+
/*
|
1385
|
+
* Solves a system of linear equations A*X = B with a general NxN matrix A using the LU factorization computed by GETRF.
|
1386
|
+
*
|
1387
|
+
* From ATLAS 3.8.0.
|
1388
|
+
*/
|
1389
|
+
template <typename DType>
|
1390
|
+
int getrs(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE Trans, const int N, const int NRHS, const DType* A,
|
1391
|
+
const int lda, const int* ipiv, DType* B, const int ldb)
|
1392
|
+
{
|
1393
|
+
// enum CBLAS_DIAG Lunit, Uunit; // These aren't used. Not sure why they're declared in ATLAS' src.
|
1394
|
+
|
1395
|
+
if (!N || !NRHS) return 0;
|
1396
|
+
|
1397
|
+
const DType ONE = 1;
|
1398
|
+
|
1399
|
+
if (Order == CblasColMajor) {
|
1400
|
+
if (Trans == CblasNoTrans) {
|
1401
|
+
nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, 1);
|
1402
|
+
nm::math::trsm<DType>(Order, CblasLeft, CblasLower, CblasNoTrans, CblasUnit, N, NRHS, ONE, A, lda, B, ldb);
|
1403
|
+
nm::math::trsm<DType>(Order, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
|
1404
|
+
} else {
|
1405
|
+
nm::math::trsm<DType>(Order, CblasLeft, CblasUpper, Trans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
|
1406
|
+
nm::math::trsm<DType>(Order, CblasLeft, CblasLower, Trans, CblasUnit, N, NRHS, ONE, A, lda, B, ldb);
|
1407
|
+
nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, -1);
|
1408
|
+
}
|
1409
|
+
} else {
|
1410
|
+
if (Trans == CblasNoTrans) {
|
1411
|
+
nm::math::trsm<DType>(Order, CblasRight, CblasLower, CblasTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
|
1412
|
+
nm::math::trsm<DType>(Order, CblasRight, CblasUpper, CblasTrans, CblasUnit, NRHS, N, ONE, A, lda, B, ldb);
|
1413
|
+
nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, -1);
|
1414
|
+
} else {
|
1415
|
+
nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, 1);
|
1416
|
+
nm::math::trsm<DType>(Order, CblasRight, CblasUpper, CblasNoTrans, CblasUnit, NRHS, N, ONE, A, lda, B, ldb);
|
1417
|
+
nm::math::trsm<DType>(Order, CblasRight, CblasLower, CblasNoTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
|
1418
|
+
}
|
1419
|
+
}
|
1420
|
+
return 0;
|
1421
|
+
}
|
1422
|
+
|
1423
|
+
|
1424
|
+
/*
|
1425
|
+
* Solves a system of linear equations A*X = B with a symmetric positive definite matrix A using the Cholesky factorization computed by POTRF.
|
1426
|
+
*
|
1427
|
+
* From ATLAS 3.8.0.
|
1428
|
+
*/
|
1429
|
+
template <typename DType, bool is_complex>
|
1430
|
+
int potrs(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, const int NRHS, const DType* A,
|
1431
|
+
const int lda, DType* B, const int ldb)
|
1432
|
+
{
|
1433
|
+
// enum CBLAS_DIAG Lunit, Uunit; // These aren't used. Not sure why they're declared in ATLAS' src.
|
1434
|
+
|
1435
|
+
CBLAS_TRANSPOSE MyTrans = is_complex ? CblasConjTrans : CblasTrans;
|
1436
|
+
|
1437
|
+
if (!N || !NRHS) return 0;
|
1438
|
+
|
1439
|
+
const DType ONE = 1;
|
1440
|
+
|
1441
|
+
if (Order == CblasColMajor) {
|
1442
|
+
if (Uplo == CblasUpper) {
|
1443
|
+
nm::math::trsm<DType>(Order, CblasLeft, CblasUpper, MyTrans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
|
1444
|
+
nm::math::trsm<DType>(Order, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
|
1445
|
+
} else {
|
1446
|
+
nm::math::trsm<DType>(Order, CblasLeft, CblasLower, CblasNoTrans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
|
1447
|
+
nm::math::trsm<DType>(Order, CblasLeft, CblasLower, MyTrans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
|
1448
|
+
}
|
1449
|
+
} else {
|
1450
|
+
// There's some kind of scaling operation that normally happens here in ATLAS. Not sure what it does, so we'll only
|
1451
|
+
// worry if something breaks. It probably has to do with their non-templated code and doesn't apply to us.
|
1452
|
+
|
1453
|
+
if (Uplo == CblasUpper) {
|
1454
|
+
nm::math::trsm<DType>(Order, CblasRight, CblasUpper, CblasNoTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
|
1455
|
+
nm::math::trsm<DType>(Order, CblasRight, CblasUpper, MyTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
|
1456
|
+
} else {
|
1457
|
+
nm::math::trsm<DType>(Order, CblasRight, CblasLower, MyTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
|
1458
|
+
nm::math::trsm<DType>(Order, CblasRight, CblasLower, CblasNoTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
|
1459
|
+
}
|
1460
|
+
}
|
1461
|
+
return 0;
|
1462
|
+
}
|
1463
|
+
|
1464
|
+
|
1264
1465
|
|
1265
1466
|
/*
|
1266
1467
|
* From ATLAS 3.8.0:
|
@@ -1317,6 +1518,617 @@ inline int getrf(const enum CBLAS_ORDER Order, const int M, const int N, DType*
|
|
1317
1518
|
}
|
1318
1519
|
|
1319
1520
|
|
1521
|
+
/*
|
1522
|
+
* From ATLAS 3.8.0:
|
1523
|
+
*
|
1524
|
+
* Computes one of two LU factorizations based on the setting of the Order
|
1525
|
+
* parameter, as follows:
|
1526
|
+
* ----------------------------------------------------------------------------
|
1527
|
+
* Order == CblasColMajor
|
1528
|
+
* Column-major factorization of form
|
1529
|
+
* A = P * L * U
|
1530
|
+
* where P is a row-permutation matrix, L is lower triangular with unit
|
1531
|
+
* diagonal elements (lower trapazoidal if M > N), and U is upper triangular
|
1532
|
+
* (upper trapazoidal if M < N).
|
1533
|
+
*
|
1534
|
+
* ----------------------------------------------------------------------------
|
1535
|
+
* Order == CblasRowMajor
|
1536
|
+
* Row-major factorization of form
|
1537
|
+
* A = P * L * U
|
1538
|
+
* where P is a column-permutation matrix, L is lower triangular (lower
|
1539
|
+
* trapazoidal if M > N), and U is upper triangular with unit diagonals (upper
|
1540
|
+
* trapazoidal if M < N).
|
1541
|
+
*
|
1542
|
+
* ============================================================================
|
1543
|
+
* Let IERR be the return value of the function:
|
1544
|
+
* If IERR == 0, successful exit.
|
1545
|
+
* If (IERR < 0) the -IERR argument had an illegal value
|
1546
|
+
* If (IERR > 0 && Order == CblasColMajor)
|
1547
|
+
* U(i-1,i-1) is exactly zero. The factorization has been completed,
|
1548
|
+
* but the factor U is exactly singular, and division by zero will
|
1549
|
+
* occur if it is used to solve a system of equations.
|
1550
|
+
* If (IERR > 0 && Order == CblasRowMajor)
|
1551
|
+
* L(i-1,i-1) is exactly zero. The factorization has been completed,
|
1552
|
+
* but the factor L is exactly singular, and division by zero will
|
1553
|
+
* occur if it is used to solve a system of equations.
|
1554
|
+
*/
|
1555
|
+
template <typename DType>
|
1556
|
+
inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, DType* A, const int lda) {
|
1557
|
+
#ifdef HAVE_CLAPACK_H
|
1558
|
+
rb_raise(rb_eNotImpError, "not yet implemented for non-BLAS dtypes");
|
1559
|
+
#else
|
1560
|
+
rb_raise(rb_eNotImpError, "only LAPACK version implemented thus far");
|
1561
|
+
#endif
|
1562
|
+
return 0;
|
1563
|
+
}
|
1564
|
+
|
1565
|
+
#ifdef HAVE_CLAPACK_H
|
1566
|
+
template <>
|
1567
|
+
inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, float* A, const int lda) {
|
1568
|
+
return clapack_spotrf(order, uplo, N, A, lda);
|
1569
|
+
}
|
1570
|
+
|
1571
|
+
template <>
|
1572
|
+
inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, double* A, const int lda) {
|
1573
|
+
return clapack_dpotrf(order, uplo, N, A, lda);
|
1574
|
+
}
|
1575
|
+
|
1576
|
+
template <>
|
1577
|
+
inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, Complex64* A, const int lda) {
|
1578
|
+
return clapack_cpotrf(order, uplo, N, reinterpret_cast<void*>(A), lda);
|
1579
|
+
}
|
1580
|
+
|
1581
|
+
template <>
|
1582
|
+
inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, Complex128* A, const int lda) {
|
1583
|
+
return clapack_zpotrf(order, uplo, N, reinterpret_cast<void*>(A), lda);
|
1584
|
+
}
|
1585
|
+
#endif
|
1586
|
+
|
1587
|
+
|
1588
|
+
// This is the old BLAS version of this function. ATLAS has an optimized version, but
|
1589
|
+
// it's going to be tough to translate.
|
1590
|
+
template <typename DType>
|
1591
|
+
static void swap(const int N, DType* X, const int incX, DType* Y, const int incY) {
|
1592
|
+
if (N > 0) {
|
1593
|
+
int ix = 0, iy = 0;
|
1594
|
+
for (int i = 0; i < N; ++i) {
|
1595
|
+
DType temp = X[i];
|
1596
|
+
X[i] = Y[i];
|
1597
|
+
Y[i] = temp;
|
1598
|
+
|
1599
|
+
ix += incX;
|
1600
|
+
iy += incY;
|
1601
|
+
}
|
1602
|
+
}
|
1603
|
+
}
|
1604
|
+
|
1605
|
+
|
1606
|
+
// Copies an upper row-major array from U, zeroing U; U is unit, so diagonal is not copied.
|
1607
|
+
//
|
1608
|
+
// From ATLAS 3.8.0.
|
1609
|
+
template <typename DType>
|
1610
|
+
static inline void trcpzeroU(const int M, const int N, DType* U, const int ldu, DType* C, const int ldc) {
|
1611
|
+
|
1612
|
+
for (int i = 0; i != M; ++i) {
|
1613
|
+
for (int j = i+1; j < N; ++j) {
|
1614
|
+
C[j] = U[j];
|
1615
|
+
U[j] = 0;
|
1616
|
+
}
|
1617
|
+
|
1618
|
+
C += ldc;
|
1619
|
+
U += ldu;
|
1620
|
+
}
|
1621
|
+
}
|
1622
|
+
|
1623
|
+
|
1624
|
+
/*
|
1625
|
+
* Un-comment the following lines when we figure out how to calculate NB for each of the ATLAS-derived
|
1626
|
+
* functions. This is probably really complicated.
|
1627
|
+
*
|
1628
|
+
* Also needed: ATL_MulByNB, ATL_DivByNB (both defined in the build process for ATLAS), and ATL_mmMU.
|
1629
|
+
*
|
1630
|
+
*/
|
1631
|
+
|
1632
|
+
/*
|
1633
|
+
|
1634
|
+
template <bool RowMajor, bool Upper, typename DType>
|
1635
|
+
static int trtri_4(const enum CBLAS_DIAG Diag, DType* A, const int lda) {
|
1636
|
+
|
1637
|
+
if (RowMajor) {
|
1638
|
+
DType *pA0 = A, *pA1 = A+lda, *pA2 = A+2*lda, *pA3 = A+3*lda;
|
1639
|
+
DType tmp;
|
1640
|
+
if (Upper) {
|
1641
|
+
DType A01 = pA0[1], A02 = pA0[2], A03 = pA0[3],
|
1642
|
+
A12 = pA1[2], A13 = pA1[3],
|
1643
|
+
A23 = pA2[3];
|
1644
|
+
|
1645
|
+
if (Diag == CblasNonUnit) {
|
1646
|
+
pA0->inverse();
|
1647
|
+
(pA1+1)->inverse();
|
1648
|
+
(pA2+2)->inverse();
|
1649
|
+
(pA3+3)->inverse();
|
1650
|
+
|
1651
|
+
pA0[1] = -A01 * pA1[1] * pA0[0];
|
1652
|
+
pA1[2] = -A12 * pA2[2] * pA1[1];
|
1653
|
+
pA2[3] = -A23 * pA3[3] * pA2[2];
|
1654
|
+
|
1655
|
+
pA0[2] = -(A01 * pA1[2] + A02 * pA2[2]) * pA0[0];
|
1656
|
+
pA1[3] = -(A12 * pA2[3] + A13 * pA3[3]) * pA1[1];
|
1657
|
+
|
1658
|
+
pA0[3] = -(A01 * pA1[3] + A02 * pA2[3] + A03 * pA3[3]) * pA0[0];
|
1659
|
+
|
1660
|
+
} else {
|
1661
|
+
|
1662
|
+
pA0[1] = -A01;
|
1663
|
+
pA1[2] = -A12;
|
1664
|
+
pA2[3] = -A23;
|
1665
|
+
|
1666
|
+
pA0[2] = -(A01 * pA1[2] + A02);
|
1667
|
+
pA1[3] = -(A12 * pA2[3] + A13);
|
1668
|
+
|
1669
|
+
pA0[3] = -(A01 * pA1[3] + A02 * pA2[3] + A03);
|
1670
|
+
}
|
1671
|
+
|
1672
|
+
} else { // Lower
|
1673
|
+
DType A10 = pA1[0],
|
1674
|
+
A20 = pA2[0], A21 = pA2[1],
|
1675
|
+
A30 = PA3[0], A31 = pA3[1], A32 = pA3[2];
|
1676
|
+
DType *B10 = pA1,
|
1677
|
+
*B20 = pA2,
|
1678
|
+
*B30 = pA3,
|
1679
|
+
*B21 = pA2+1,
|
1680
|
+
*B31 = pA3+1,
|
1681
|
+
*B32 = pA3+2;
|
1682
|
+
|
1683
|
+
|
1684
|
+
if (Diag == CblasNonUnit) {
|
1685
|
+
pA0->inverse();
|
1686
|
+
(pA1+1)->inverse();
|
1687
|
+
(pA2+2)->inverse();
|
1688
|
+
(pA3+3)->inverse();
|
1689
|
+
|
1690
|
+
*B10 = -A10 * pA0[0] * pA1[1];
|
1691
|
+
*B21 = -A21 * pA1[1] * pA2[2];
|
1692
|
+
*B32 = -A32 * pA2[2] * pA3[3];
|
1693
|
+
*B20 = -(A20 * pA0[0] + A21 * (*B10)) * pA2[2];
|
1694
|
+
*B31 = -(A31 * pA1[1] + A32 * (*B21)) * pA3[3];
|
1695
|
+
*B30 = -(A30 * pA0[0] + A31 * (*B10) + A32 * (*B20)) * pA3;
|
1696
|
+
} else {
|
1697
|
+
*B10 = -A10;
|
1698
|
+
*B21 = -A21;
|
1699
|
+
*B32 = -A32;
|
1700
|
+
*B20 = -(A20 + A21 * (*B10));
|
1701
|
+
*B31 = -(A31 + A32 * (*B21));
|
1702
|
+
*B30 = -(A30 + A31 * (*B10) + A32 * (*B20));
|
1703
|
+
}
|
1704
|
+
}
|
1705
|
+
|
1706
|
+
} else {
|
1707
|
+
rb_raise(rb_eNotImpError, "only row-major implemented at this time");
|
1708
|
+
}
|
1709
|
+
|
1710
|
+
return 0;
|
1711
|
+
|
1712
|
+
}
|
1713
|
+
|
1714
|
+
|
1715
|
+
template <bool RowMajor, bool Upper, typename DType>
|
1716
|
+
static int trtri_3(const enum CBLAS_DIAG Diag, DType* A, const int lda) {
|
1717
|
+
|
1718
|
+
if (RowMajor) {
|
1719
|
+
|
1720
|
+
DType tmp;
|
1721
|
+
|
1722
|
+
if (Upper) {
|
1723
|
+
DType A01 = pA0[1], A02 = pA0[2], A03 = pA0[3],
|
1724
|
+
A12 = pA1[2], A13 = pA1[3];
|
1725
|
+
|
1726
|
+
DType *B01 = pA0 + 1,
|
1727
|
+
*B02 = pA0 + 2,
|
1728
|
+
*B12 = pA1 + 2;
|
1729
|
+
|
1730
|
+
if (Diag == CblasNonUnit) {
|
1731
|
+
pA0->inverse();
|
1732
|
+
(pA1+1)->inverse();
|
1733
|
+
(pA2+2)->inverse();
|
1734
|
+
|
1735
|
+
*B01 = -A01 * pA1[1] * pA0[0];
|
1736
|
+
*B12 = -A12 * pA2[2] * pA1[1];
|
1737
|
+
*B02 = -(A01 * (*B12) + A02 * pA2[2]) * pA0[0];
|
1738
|
+
} else {
|
1739
|
+
*B01 = -A01;
|
1740
|
+
*B12 = -A12;
|
1741
|
+
*B02 = -(A01 * (*B12) + A02);
|
1742
|
+
}
|
1743
|
+
|
1744
|
+
} else { // Lower
|
1745
|
+
DType *pA0=A, *pA1=A+lda, *pA2=A+2*lda;
|
1746
|
+
DType A10=pA1[0],
|
1747
|
+
A20=pA2[0], A21=pA2[1];
|
1748
|
+
|
1749
|
+
DType *B10 = pA1,
|
1750
|
+
*B20 = pA2;
|
1751
|
+
*B21 = pA2+1;
|
1752
|
+
|
1753
|
+
if (Diag == CblasNonUnit) {
|
1754
|
+
pA0->inverse();
|
1755
|
+
(pA1+1)->inverse();
|
1756
|
+
(pA2+2)->inverse();
|
1757
|
+
*B10 = -A10 * pA0[0] * pA1[1];
|
1758
|
+
*B21 = -A21 * pA1[1] * pA2[2];
|
1759
|
+
*B20 = -(A20 * pA0[0] + A21 * (*B10)) * pA2[2];
|
1760
|
+
} else {
|
1761
|
+
*B10 = -A10;
|
1762
|
+
*B21 = -A21;
|
1763
|
+
*B20 = -(A20 + A21 * (*B10));
|
1764
|
+
}
|
1765
|
+
}
|
1766
|
+
|
1767
|
+
|
1768
|
+
} else {
|
1769
|
+
rb_raise(rb_eNotImpError, "only row-major implemented at this time");
|
1770
|
+
}
|
1771
|
+
|
1772
|
+
return 0;
|
1773
|
+
|
1774
|
+
}
|
1775
|
+
|
1776
|
+
template <bool RowMajor, bool Upper, bool Real, typename DType>
|
1777
|
+
static void trtri(const enum CBLAS_DIAG Diag, const int N, DType* A, const int lda) {
|
1778
|
+
DType *Age, *Atr;
|
1779
|
+
DType tmp;
|
1780
|
+
int Nleft, Nright;
|
1781
|
+
|
1782
|
+
int ierr = 0;
|
1783
|
+
|
1784
|
+
static const DType ONE = 1;
|
1785
|
+
static const DType MONE -1;
|
1786
|
+
static const DType NONE = -1;
|
1787
|
+
|
1788
|
+
if (RowMajor) {
|
1789
|
+
|
1790
|
+
// FIXME: Use REAL_RECURSE_LIMIT here for float32 and float64 (instead of 1)
|
1791
|
+
if ((Real && N > REAL_RECURSE_LIMIT) || (N > 1)) {
|
1792
|
+
Nleft = N >> 1;
|
1793
|
+
#ifdef NB
|
1794
|
+
if (Nleft > NB) NLeft = ATL_MulByNB(ATL_DivByNB(Nleft));
|
1795
|
+
#endif
|
1796
|
+
|
1797
|
+
Nright = N - Nleft;
|
1798
|
+
|
1799
|
+
if (Upper) {
|
1800
|
+
Age = A + Nleft;
|
1801
|
+
Atr = A + (Nleft * (lda+1));
|
1802
|
+
|
1803
|
+
nm::math::trsm<DType>(CblasRowMajor, CblasRight, CblasUpper, CblasNoTrans, Diag,
|
1804
|
+
Nleft, Nright, ONE, Atr, lda, Age, lda);
|
1805
|
+
|
1806
|
+
nm::math::trsm<DType>(CblasRowMajor, CblasLeft, CblasUpper, CblasNoTrans, Diag,
|
1807
|
+
Nleft, Nright, MONE, A, lda, Age, lda);
|
1808
|
+
|
1809
|
+
} else { // Lower
|
1810
|
+
Age = A + ((Nleft*lda));
|
1811
|
+
Atr = A + (Nleft * (lda+1));
|
1812
|
+
|
1813
|
+
nm::math::trsm<DType>(CblasRowMajor, CblasRight, CblasLower, CblasNoTrans, Diag,
|
1814
|
+
Nright, Nleft, ONE, A, lda, Age, lda);
|
1815
|
+
nm::math::trsm<DType>(CblasRowMajor, CblasLeft, CblasLower, CblasNoTrans, Diag,
|
1816
|
+
Nright, Nleft, MONE, Atr, lda, Age, lda);
|
1817
|
+
}
|
1818
|
+
|
1819
|
+
ierr = trtri<RowMajor,Upper,Real,DType>(Diag, Nleft, A, lda);
|
1820
|
+
if (ierr) return ierr;
|
1821
|
+
|
1822
|
+
ierr = trtri<RowMajor,Upper,Real,DType>(Diag, Nright, Atr, lda);
|
1823
|
+
if (ierr) return ierr + Nleft;
|
1824
|
+
|
1825
|
+
} else {
|
1826
|
+
if (Real) {
|
1827
|
+
if (N == 4) {
|
1828
|
+
return trtri_4<RowMajor,Upper,Real,DType>(Diag, A, lda);
|
1829
|
+
} else if (N == 3) {
|
1830
|
+
return trtri_3<RowMajor,Upper,Real,DType>(Diag, A, lda);
|
1831
|
+
} else if (N == 2) {
|
1832
|
+
if (Diag == CblasNonUnit) {
|
1833
|
+
A->inverse();
|
1834
|
+
(A+(lda+1))->inverse();
|
1835
|
+
|
1836
|
+
if (Upper) {
|
1837
|
+
*(A+1) *= *A; // TRI_MUL
|
1838
|
+
*(A+1) *= *(A+lda+1); // TRI_MUL
|
1839
|
+
} else {
|
1840
|
+
*(A+lda) *= *A; // TRI_MUL
|
1841
|
+
*(A+lda) *= *(A+lda+1); // TRI_MUL
|
1842
|
+
}
|
1843
|
+
}
|
1844
|
+
|
1845
|
+
if (Upper) *(A+1) = -*(A+1); // TRI_NEG
|
1846
|
+
else *(A+lda) = -*(A+lda); // TRI_NEG
|
1847
|
+
} else if (Diag == CblasNonUnit) A->inverse();
|
1848
|
+
} else { // not real
|
1849
|
+
if (Diag == CblasNonUnit) A->inverse();
|
1850
|
+
}
|
1851
|
+
}
|
1852
|
+
|
1853
|
+
} else {
|
1854
|
+
rb_raise(rb_eNotImpError, "only row-major implemented at this time");
|
1855
|
+
}
|
1856
|
+
|
1857
|
+
return ierr;
|
1858
|
+
}
|
1859
|
+
|
1860
|
+
|
1861
|
+
template <bool RowMajor, bool Real, typename DType>
|
1862
|
+
int getri(const int N, DType* A, const int lda, const int* ipiv, DType* wrk, const int lwrk) {
|
1863
|
+
|
1864
|
+
if (!RowMajor) rb_raise(rb_eNotImpError, "only row-major implemented at this time");
|
1865
|
+
|
1866
|
+
int jb, nb, I, ndown, iret;
|
1867
|
+
|
1868
|
+
const DType ONE = 1, NONE = -1;
|
1869
|
+
|
1870
|
+
int iret = trtri<RowMajor,false,Real,DType>(CblasNonUnit, N, A, lda);
|
1871
|
+
if (!iret && N > 1) {
|
1872
|
+
jb = lwrk / N;
|
1873
|
+
if (jb >= NB) nb = ATL_MulByNB(ATL_DivByNB(jb));
|
1874
|
+
else if (jb >= ATL_mmMU) nb = (jb/ATL_mmMU)*ATL_mmMU;
|
1875
|
+
else nb = jb;
|
1876
|
+
if (!nb) return -6; // need at least 1 row of workspace
|
1877
|
+
|
1878
|
+
// only first iteration will have partial block, unroll it
|
1879
|
+
|
1880
|
+
jb = N - (N/nb) * nb;
|
1881
|
+
if (!jb) jb = nb;
|
1882
|
+
I = N - jb;
|
1883
|
+
A += lda * I;
|
1884
|
+
trcpzeroU<DType>(jb, jb, A+I, lda, wrk, jb);
|
1885
|
+
nm::math::trsm<DType>(CblasRowMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasUnit,
|
1886
|
+
jb, N, ONE, wrk, jb, A, lda);
|
1887
|
+
|
1888
|
+
if (I) {
|
1889
|
+
do {
|
1890
|
+
I -= nb;
|
1891
|
+
A -= nb * lda;
|
1892
|
+
ndown = N-I;
|
1893
|
+
trcpzeroU<DType>(nb, ndown, A+I, lda, wrk, ndown);
|
1894
|
+
nm::math::gemm<DType>(CblasRowMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasUnit,
|
1895
|
+
nb, N, ONE, wrk, ndown, A, lda);
|
1896
|
+
} while (I);
|
1897
|
+
}
|
1898
|
+
|
1899
|
+
// Apply row interchanges
|
1900
|
+
|
1901
|
+
for (I = N - 2; I >= 0; --I) {
|
1902
|
+
jb = ipiv[I];
|
1903
|
+
if (jb != I) nm::math::swap<DType>(N, A+I*lda, 1, A+jb*lda, 1);
|
1904
|
+
}
|
1905
|
+
}
|
1906
|
+
|
1907
|
+
return iret;
|
1908
|
+
}
|
1909
|
+
*/
|
1910
|
+
|
1911
|
+
|
1912
|
+
// TODO: Test this to see if it works properly on complex. ATLAS has a separate algorithm for complex, which looks like
|
1913
|
+
// TODO: it may actually be the same one.
|
1914
|
+
//
|
1915
|
+
// This function is called ATL_rot in ATLAS 3.8.4.
|
1916
|
+
template <typename DType>
|
1917
|
+
inline void rot_helper(const int N, DType* X, const int incX, DType* Y, const int incY, const DType c, const DType s) {
|
1918
|
+
if (c != 1 || s != 0) {
|
1919
|
+
if (incX == 1 && incY == 1) {
|
1920
|
+
for (int i = 0; i != N; ++i) {
|
1921
|
+
DType tmp = X[i] * c + Y[i] * s;
|
1922
|
+
Y[i] = Y[i] * c - X[i] * s;
|
1923
|
+
X[i] = tmp;
|
1924
|
+
}
|
1925
|
+
} else {
|
1926
|
+
for (int i = N; i > 0; --i, Y += incY, X += incX) {
|
1927
|
+
DType tmp = *X * c + *Y * s;
|
1928
|
+
*Y = *Y * c - *X * s;
|
1929
|
+
*X = tmp;
|
1930
|
+
}
|
1931
|
+
}
|
1932
|
+
}
|
1933
|
+
}
|
1934
|
+
|
1935
|
+
|
1936
|
+
/* Givens plane rotation. From ATLAS 3.8.4. */
|
1937
|
+
// FIXME: Need a specialized algorithm for Rationals. BLAS' algorithm simply will not work for most values due to the
|
1938
|
+
// FIXME: sqrt.
|
1939
|
+
template <typename DType>
|
1940
|
+
inline void rotg(DType* a, DType* b, DType* c, DType* s) {
|
1941
|
+
DType aa = std::abs(*a), ab = std::abs(*b);
|
1942
|
+
DType roe = aa > ab ? *a : *b;
|
1943
|
+
DType scal = aa + ab;
|
1944
|
+
|
1945
|
+
if (scal == 0) {
|
1946
|
+
*c = 1;
|
1947
|
+
*s = *a = *b = 0;
|
1948
|
+
} else {
|
1949
|
+
DType t0 = aa / scal, t1 = ab / scal;
|
1950
|
+
DType r = scal * std::sqrt(t0 * t0 + t1 * t1);
|
1951
|
+
if (roe < 0) r = -r;
|
1952
|
+
*c = *a / r;
|
1953
|
+
*s = *b / r;
|
1954
|
+
DType z = (*c != 0) ? (1 / *c) : DType(1);
|
1955
|
+
*a = r;
|
1956
|
+
*b = z;
|
1957
|
+
}
|
1958
|
+
}
|
1959
|
+
|
1960
|
+
template <>
|
1961
|
+
inline void rotg(float* a, float* b, float* c, float* s) {
|
1962
|
+
cblas_srotg(a, b, c, s);
|
1963
|
+
}
|
1964
|
+
|
1965
|
+
template <>
|
1966
|
+
inline void rotg(double* a, double* b, double* c, double* s) {
|
1967
|
+
cblas_drotg(a, b, c, s);
|
1968
|
+
}
|
1969
|
+
|
1970
|
+
template <>
|
1971
|
+
inline void rotg(Complex64* a, Complex64* b, Complex64* c, Complex64* s) {
|
1972
|
+
cblas_crotg(reinterpret_cast<void*>(a), reinterpret_cast<void*>(b), reinterpret_cast<void*>(c), reinterpret_cast<void*>(s));
|
1973
|
+
}
|
1974
|
+
|
1975
|
+
template <>
|
1976
|
+
inline void rotg(Complex128* a, Complex128* b, Complex128* c, Complex128* s) {
|
1977
|
+
cblas_zrotg(reinterpret_cast<void*>(a), reinterpret_cast<void*>(b), reinterpret_cast<void*>(c), reinterpret_cast<void*>(s));
|
1978
|
+
}
|
1979
|
+
|
1980
|
+
template <typename DType>
|
1981
|
+
inline void cblas_rotg(void* a, void* b, void* c, void* s) {
|
1982
|
+
rotg<DType>(reinterpret_cast<DType*>(a), reinterpret_cast<DType*>(b), reinterpret_cast<DType*>(c), reinterpret_cast<DType*>(s));
|
1983
|
+
}
|
1984
|
+
|
1985
|
+
|
1986
|
+
/* Applies a plane rotation. From ATLAS 3.8.4. */
|
1987
|
+
template <typename DType, typename CSDType>
|
1988
|
+
inline void rot(const int N, DType* X, const int incX, DType* Y, const int incY, const CSDType c, const CSDType s) {
|
1989
|
+
DType *x = X, *y = Y;
|
1990
|
+
int incx = incX, incy = incY;
|
1991
|
+
|
1992
|
+
if (N > 0) {
|
1993
|
+
if (incX < 0) {
|
1994
|
+
if (incY < 0) { incx = -incx; incy = -incy; }
|
1995
|
+
else x += -incX * (N-1);
|
1996
|
+
} else if (incY < 0) {
|
1997
|
+
incy = -incy;
|
1998
|
+
incx = -incx;
|
1999
|
+
x += (N-1) * incX;
|
2000
|
+
}
|
2001
|
+
rot_helper<DType>(N, x, incx, y, incy, c, s);
|
2002
|
+
}
|
2003
|
+
}
|
2004
|
+
|
2005
|
+
template <>
|
2006
|
+
inline void rot(const int N, float* X, const int incX, float* Y, const int incY, const float c, const float s) {
|
2007
|
+
cblas_srot(N, X, incX, Y, incY, (float)c, (float)s);
|
2008
|
+
}
|
2009
|
+
|
2010
|
+
template <>
|
2011
|
+
inline void rot(const int N, double* X, const int incX, double* Y, const int incY, const double c, const double s) {
|
2012
|
+
cblas_drot(N, X, incX, Y, incY, c, s);
|
2013
|
+
}
|
2014
|
+
|
2015
|
+
template <>
|
2016
|
+
inline void rot(const int N, Complex64* X, const int incX, Complex64* Y, const int incY, const float c, const float s) {
|
2017
|
+
cblas_csrot(N, X, incX, Y, incY, c, s);
|
2018
|
+
}
|
2019
|
+
|
2020
|
+
template <>
|
2021
|
+
inline void rot(const int N, Complex128* X, const int incX, Complex128* Y, const int incY, const double c, const double s) {
|
2022
|
+
cblas_zdrot(N, X, incX, Y, incY, c, s);
|
2023
|
+
}
|
2024
|
+
|
2025
|
+
|
2026
|
+
template <typename DType, typename CSDType>
|
2027
|
+
inline void cblas_rot(const int N, void* X, const int incX, void* Y, const int incY, const void* c, const void* s) {
|
2028
|
+
rot<DType,CSDType>(N, reinterpret_cast<DType*>(X), incX, reinterpret_cast<DType*>(Y), incY, *reinterpret_cast<const CSDType*>(c), *reinterpret_cast<const CSDType*>(s));
|
2029
|
+
}
|
2030
|
+
|
2031
|
+
|
2032
|
+
template <bool is_complex, typename DType>
|
2033
|
+
inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, DType* A, const int lda) {
|
2034
|
+
|
2035
|
+
int Nleft, Nright;
|
2036
|
+
const DType ONE = 1;
|
2037
|
+
DType *G, *U0 = A, *U1;
|
2038
|
+
|
2039
|
+
if (N > 1) {
|
2040
|
+
Nleft = N >> 1;
|
2041
|
+
#ifdef NB
|
2042
|
+
if (Nleft > NB) Nleft = ATL_MulByNB(ATL_DivByNB(Nleft));
|
2043
|
+
#endif
|
2044
|
+
|
2045
|
+
Nright = N - Nleft;
|
2046
|
+
|
2047
|
+
// FIXME: There's a simpler way to write this next block, but I'm way too tired to work it out right now.
|
2048
|
+
if (uplo == CblasUpper) {
|
2049
|
+
if (order == CblasRowMajor) {
|
2050
|
+
G = A + Nleft;
|
2051
|
+
U1 = G + Nleft * lda;
|
2052
|
+
} else {
|
2053
|
+
G = A + Nleft * lda;
|
2054
|
+
U1 = G + Nleft;
|
2055
|
+
}
|
2056
|
+
} else {
|
2057
|
+
if (order == CblasRowMajor) {
|
2058
|
+
G = A + Nleft * lda;
|
2059
|
+
U1 = G + Nleft;
|
2060
|
+
} else {
|
2061
|
+
G = A + Nleft;
|
2062
|
+
U1 = G + Nleft * lda;
|
2063
|
+
}
|
2064
|
+
}
|
2065
|
+
|
2066
|
+
lauum<is_complex, DType>(order, uplo, Nleft, U0, lda);
|
2067
|
+
|
2068
|
+
if (is_complex) {
|
2069
|
+
|
2070
|
+
nm::math::herk<DType>(order, uplo,
|
2071
|
+
uplo == CblasLower ? CblasConjTrans : CblasNoTrans,
|
2072
|
+
Nleft, Nright, &ONE, G, lda, &ONE, U0, lda);
|
2073
|
+
|
2074
|
+
nm::math::trmm<DType>(order, CblasLeft, uplo, CblasConjTrans, CblasNonUnit, Nright, Nleft, &ONE, U1, lda, G, lda);
|
2075
|
+
} else {
|
2076
|
+
nm::math::syrk<DType>(order, uplo,
|
2077
|
+
uplo == CblasLower ? CblasTrans : CblasNoTrans,
|
2078
|
+
Nleft, Nright, &ONE, G, lda, &ONE, U0, lda);
|
2079
|
+
|
2080
|
+
nm::math::trmm<DType>(order, CblasLeft, uplo, CblasTrans, CblasNonUnit, Nright, Nleft, &ONE, U1, lda, G, lda);
|
2081
|
+
}
|
2082
|
+
lauum<is_complex, DType>(order, uplo, Nright, U1, lda);
|
2083
|
+
|
2084
|
+
} else {
|
2085
|
+
*A = *A * *A;
|
2086
|
+
}
|
2087
|
+
}
|
2088
|
+
|
2089
|
+
|
2090
|
+
#ifdef HAVE_CLAPACK_H
|
2091
|
+
template <bool is_complex>
|
2092
|
+
inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, float* A, const int lda) {
|
2093
|
+
clapack_slauum(order, uplo, N, A, lda);
|
2094
|
+
}
|
2095
|
+
|
2096
|
+
template <bool is_complex>
|
2097
|
+
inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, double* A, const int lda) {
|
2098
|
+
clapack_dlauum(order, uplo, N, A, lda);
|
2099
|
+
}
|
2100
|
+
|
2101
|
+
template <bool is_complex>
|
2102
|
+
inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, Complex64* A, const int lda) {
|
2103
|
+
clapack_clauum(order, uplo, N, A, lda);
|
2104
|
+
}
|
2105
|
+
|
2106
|
+
template <bool is_complex>
|
2107
|
+
inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, Complex128* A, const int lda) {
|
2108
|
+
clapack_zlauum(order, uplo, N, A, lda);
|
2109
|
+
}
|
2110
|
+
#endif
|
2111
|
+
|
2112
|
+
|
2113
|
+
/*
|
2114
|
+
* Function signature conversion for calling LAPACK's lauum functions as directly as possible.
|
2115
|
+
*
|
2116
|
+
* For documentation: http://www.netlib.org/lapack/double/dlauum.f
|
2117
|
+
*
|
2118
|
+
* This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
|
2119
|
+
*/
|
2120
|
+
template <bool is_complex, typename DType>
|
2121
|
+
inline int clapack_lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, void* a, const int lda) {
|
2122
|
+
if (n < 0) rb_raise(rb_eArgError, "n cannot be less than zero, is set to %d", n);
|
2123
|
+
if (lda < n || lda < 1) rb_raise(rb_eArgError, "lda must be >= max(n,1); lda=%d, n=%d\n", lda, n);
|
2124
|
+
|
2125
|
+
if (uplo == CblasUpper) lauum<is_complex, DType>(order, uplo, n, reinterpret_cast<DType*>(a), lda);
|
2126
|
+
else lauum<is_complex, DType>(order, uplo, n, reinterpret_cast<DType*>(a), lda);
|
2127
|
+
|
2128
|
+
return 0;
|
2129
|
+
}
|
2130
|
+
|
2131
|
+
|
1320
2132
|
|
1321
2133
|
|
1322
2134
|
/*
|
@@ -1357,6 +2169,143 @@ inline int clapack_getrf(const enum CBLAS_ORDER order, const int m, const int n,
|
|
1357
2169
|
}
|
1358
2170
|
|
1359
2171
|
|
2172
|
+
/*
|
2173
|
+
* Function signature conversion for calling LAPACK's potrf functions as directly as possible.
|
2174
|
+
*
|
2175
|
+
* For documentation: http://www.netlib.org/lapack/double/dpotrf.f
|
2176
|
+
*
|
2177
|
+
* This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
|
2178
|
+
*/
|
2179
|
+
template <typename DType>
|
2180
|
+
inline int clapack_potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, void* a, const int lda) {
|
2181
|
+
return potrf<DType>(order, uplo, n, reinterpret_cast<DType*>(a), lda);
|
2182
|
+
}
|
2183
|
+
|
2184
|
+
|
2185
|
+
/*
|
2186
|
+
* Function signature conversion for calling LAPACK's getrs functions as directly as possible.
|
2187
|
+
*
|
2188
|
+
* For documentation: http://www.netlib.org/lapack/double/dgetrs.f
|
2189
|
+
*
|
2190
|
+
* This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
|
2191
|
+
*/
|
2192
|
+
template <typename DType>
|
2193
|
+
inline int clapack_getrs(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans, const int n, const int nrhs,
|
2194
|
+
const void* a, const int lda, const int* ipiv, void* b, const int ldb) {
|
2195
|
+
return getrs<DType>(order, trans, n, nrhs, reinterpret_cast<const DType*>(a), lda, ipiv, reinterpret_cast<DType*>(b), ldb);
|
2196
|
+
}
|
2197
|
+
|
2198
|
+
/*
|
2199
|
+
* Function signature conversion for calling LAPACK's potrs functions as directly as possible.
|
2200
|
+
*
|
2201
|
+
* For documentation: http://www.netlib.org/lapack/double/dpotrs.f
|
2202
|
+
*
|
2203
|
+
* This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
|
2204
|
+
*/
|
2205
|
+
template <typename DType, bool is_complex>
|
2206
|
+
inline int clapack_potrs(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, const int nrhs,
|
2207
|
+
const void* a, const int lda, void* b, const int ldb) {
|
2208
|
+
return potrs<DType,is_complex>(order, uplo, n, nrhs, reinterpret_cast<const DType*>(a), lda, reinterpret_cast<DType*>(b), ldb);
|
2209
|
+
}
|
2210
|
+
|
2211
|
+
template <typename DType>
|
2212
|
+
inline int getri(const enum CBLAS_ORDER order, const int n, DType* a, const int lda, const int* ipiv) {
|
2213
|
+
rb_raise(rb_eNotImpError, "getri not yet implemented for non-BLAS dtypes");
|
2214
|
+
return 0;
|
2215
|
+
}
|
2216
|
+
|
2217
|
+
#ifdef HAVE_CLAPACK_H
|
2218
|
+
template <>
|
2219
|
+
inline int getri(const enum CBLAS_ORDER order, const int n, float* a, const int lda, const int* ipiv) {
|
2220
|
+
return clapack_sgetri(order, n, a, lda, ipiv);
|
2221
|
+
}
|
2222
|
+
|
2223
|
+
template <>
|
2224
|
+
inline int getri(const enum CBLAS_ORDER order, const int n, double* a, const int lda, const int* ipiv) {
|
2225
|
+
return clapack_dgetri(order, n, a, lda, ipiv);
|
2226
|
+
}
|
2227
|
+
|
2228
|
+
template <>
|
2229
|
+
inline int getri(const enum CBLAS_ORDER order, const int n, Complex64* a, const int lda, const int* ipiv) {
|
2230
|
+
return clapack_cgetri(order, n, reinterpret_cast<void*>(a), lda, ipiv);
|
2231
|
+
}
|
2232
|
+
|
2233
|
+
template <>
|
2234
|
+
inline int getri(const enum CBLAS_ORDER order, const int n, Complex128* a, const int lda, const int* ipiv) {
|
2235
|
+
return clapack_zgetri(order, n, reinterpret_cast<void*>(a), lda, ipiv);
|
2236
|
+
}
|
2237
|
+
#endif
|
2238
|
+
|
2239
|
+
|
2240
|
+
template <typename DType>
|
2241
|
+
inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, DType* a, const int lda) {
|
2242
|
+
rb_raise(rb_eNotImpError, "potri not yet implemented for non-BLAS dtypes");
|
2243
|
+
return 0;
|
2244
|
+
}
|
2245
|
+
|
2246
|
+
|
2247
|
+
#ifdef HAVE_CLAPACK_H
|
2248
|
+
template <>
|
2249
|
+
inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, float* a, const int lda) {
|
2250
|
+
return clapack_spotri(order, uplo, n, a, lda);
|
2251
|
+
}
|
2252
|
+
|
2253
|
+
template <>
|
2254
|
+
inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, double* a, const int lda) {
|
2255
|
+
return clapack_dpotri(order, uplo, n, a, lda);
|
2256
|
+
}
|
2257
|
+
|
2258
|
+
template <>
|
2259
|
+
inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, Complex64* a, const int lda) {
|
2260
|
+
return clapack_cpotri(order, uplo, n, reinterpret_cast<void*>(a), lda);
|
2261
|
+
}
|
2262
|
+
|
2263
|
+
template <>
|
2264
|
+
inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, Complex128* a, const int lda) {
|
2265
|
+
return clapack_zpotri(order, uplo, n, reinterpret_cast<void*>(a), lda);
|
2266
|
+
}
|
2267
|
+
#endif
|
2268
|
+
|
2269
|
+
/*
|
2270
|
+
* Function signature conversion for calling LAPACK's getri functions as directly as possible.
|
2271
|
+
*
|
2272
|
+
* For documentation: http://www.netlib.org/lapack/double/dgetri.f
|
2273
|
+
*
|
2274
|
+
* This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
|
2275
|
+
*/
|
2276
|
+
template <typename DType>
|
2277
|
+
inline int clapack_getri(const enum CBLAS_ORDER order, const int n, void* a, const int lda, const int* ipiv) {
|
2278
|
+
return getri<DType>(order, n, reinterpret_cast<DType*>(a), lda, ipiv);
|
2279
|
+
}
|
2280
|
+
|
2281
|
+
|
2282
|
+
/*
|
2283
|
+
* Function signature conversion for calling LAPACK's potri functions as directly as possible.
|
2284
|
+
*
|
2285
|
+
* For documentation: http://www.netlib.org/lapack/double/dpotri.f
|
2286
|
+
*
|
2287
|
+
* This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
|
2288
|
+
*/
|
2289
|
+
template <typename DType>
|
2290
|
+
inline int clapack_potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, void* a, const int lda) {
|
2291
|
+
return potri<DType>(order, uplo, n, reinterpret_cast<DType*>(a), lda);
|
2292
|
+
}
|
2293
|
+
|
2294
|
+
|
2295
|
+
/*
|
2296
|
+
* Function signature conversion for calling LAPACK's laswp functions as directly as possible.
|
2297
|
+
*
|
2298
|
+
* For documentation: http://www.netlib.org/lapack/double/dlaswp.f
|
2299
|
+
*
|
2300
|
+
* This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
|
2301
|
+
*/
|
2302
|
+
template <typename DType>
|
2303
|
+
inline void clapack_laswp(const int n, void* a, const int lda, const int k1, const int k2, const int* ipiv, const int incx) {
|
2304
|
+
laswp<DType>(n, reinterpret_cast<DType*>(a), lda, k1, k2, ipiv, incx);
|
2305
|
+
}
|
2306
|
+
|
2307
|
+
|
2308
|
+
|
1360
2309
|
}} // end namespace nm::math
|
1361
2310
|
|
1362
2311
|
|