nmatrix 0.0.2 → 0.0.3

Sign up to get free protection for your applications and to get access to all the features.
Files changed (47) hide show
  1. data/Gemfile +1 -1
  2. data/History.txt +31 -3
  3. data/Manifest.txt +5 -0
  4. data/README.rdoc +29 -27
  5. data/ext/nmatrix/binary_format.txt +53 -0
  6. data/ext/nmatrix/data/data.cpp +18 -18
  7. data/ext/nmatrix/data/data.h +38 -7
  8. data/ext/nmatrix/data/rational.h +13 -0
  9. data/ext/nmatrix/data/ruby_object.h +10 -0
  10. data/ext/nmatrix/extconf.rb +2 -0
  11. data/ext/nmatrix/nmatrix.cpp +655 -103
  12. data/ext/nmatrix/nmatrix.h +26 -14
  13. data/ext/nmatrix/ruby_constants.cpp +4 -0
  14. data/ext/nmatrix/ruby_constants.h +2 -0
  15. data/ext/nmatrix/storage/dense.cpp +99 -41
  16. data/ext/nmatrix/storage/dense.h +3 -3
  17. data/ext/nmatrix/storage/list.cpp +36 -14
  18. data/ext/nmatrix/storage/list.h +4 -4
  19. data/ext/nmatrix/storage/storage.cpp +19 -19
  20. data/ext/nmatrix/storage/storage.h +11 -11
  21. data/ext/nmatrix/storage/yale.cpp +17 -20
  22. data/ext/nmatrix/storage/yale.h +13 -11
  23. data/ext/nmatrix/util/io.cpp +25 -23
  24. data/ext/nmatrix/util/io.h +5 -5
  25. data/ext/nmatrix/util/math.cpp +634 -17
  26. data/ext/nmatrix/util/math.h +958 -9
  27. data/ext/nmatrix/util/sl_list.cpp +7 -7
  28. data/ext/nmatrix/util/sl_list.h +2 -2
  29. data/lib/nmatrix.rb +9 -0
  30. data/lib/nmatrix/blas.rb +4 -4
  31. data/lib/nmatrix/io/market.rb +227 -0
  32. data/lib/nmatrix/io/mat_reader.rb +7 -7
  33. data/lib/nmatrix/lapack.rb +80 -0
  34. data/lib/nmatrix/nmatrix.rb +78 -52
  35. data/lib/nmatrix/shortcuts.rb +486 -0
  36. data/lib/nmatrix/version.rb +1 -1
  37. data/spec/2x2_dense_double.mat +0 -0
  38. data/spec/blas_spec.rb +59 -9
  39. data/spec/elementwise_spec.rb +25 -12
  40. data/spec/io_spec.rb +69 -1
  41. data/spec/lapack_spec.rb +53 -4
  42. data/spec/math_spec.rb +9 -0
  43. data/spec/nmatrix_list_spec.rb +95 -0
  44. data/spec/nmatrix_spec.rb +10 -53
  45. data/spec/nmatrix_yale_spec.rb +17 -15
  46. data/spec/shortcuts_spec.rb +154 -0
  47. metadata +22 -15
@@ -70,7 +70,10 @@
70
70
 
71
71
  extern "C" { // These need to be in an extern "C" block or you'll get all kinds of undefined symbol errors.
72
72
  #include <cblas.h>
73
- //#include <clapack.h>
73
+
74
+ #ifdef HAVE_CLAPACK_H
75
+ #include <clapack.h>
76
+ #endif
74
77
  }
75
78
 
76
79
  #include <algorithm> // std::min, std::max
@@ -85,6 +88,7 @@ extern "C" { // These need to be in an extern "C" block or you'll get all kinds
85
88
  /*
86
89
  * Macros
87
90
  */
91
+ #define REAL_RECURSE_LIMIT 4
88
92
 
89
93
  /*
90
94
  * Data
@@ -95,7 +99,7 @@ extern "C" {
95
99
  /*
96
100
  * C accessors.
97
101
  */
98
- void nm_math_det_exact(const int M, const void* elements, const int lda, dtype_t dtype, void* result);
102
+ void nm_math_det_exact(const int M, const void* elements, const int lda, nm::dtype_t dtype, void* result);
99
103
  void nm_math_transpose_generic(const size_t M, const size_t N, const void* A, const int lda, void* B, const int ldb, size_t element_size);
100
104
  void nm_math_init_blas(void);
101
105
  }
@@ -143,6 +147,9 @@ template <> inline double numeric_inverse<double>(const double& n) { return 1 /
143
147
  *
144
148
  * For row major, call trsm<DType> instead. That will handle necessary changes-of-variables
145
149
  * and parameter checks.
150
+ *
151
+ * Note that some of the boundary conditions here may be incorrect. Very little has been tested!
152
+ * This was converted directly from dtrsm.f using f2c, and then rewritten more cleanly.
146
153
  */
147
154
  template <typename DType>
148
155
  inline void trsm_nothrow(const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
@@ -150,6 +157,9 @@ inline void trsm_nothrow(const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
150
157
  const int m, const int n, const DType alpha, const DType* a,
151
158
  const int lda, DType* b, const int ldb)
152
159
  {
160
+
161
+ // (row-major) trsm: left upper trans nonunit m=3 n=1 1/1 a 3 b 3
162
+
153
163
  if (m == 0 || n == 0) return; /* Quick return if possible. */
154
164
 
155
165
  if (alpha == 0) { // Handle alpha == 0
@@ -210,7 +220,7 @@ inline void trsm_nothrow(const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
210
220
  for (int j = 0; j < n; ++j) {
211
221
  for (int i = 0; i < m; ++i) {
212
222
  DType temp = alpha * b[i + j * ldb];
213
- for (int k = 0; k < i-1; ++k) {
223
+ for (int k = 0; k < i; ++k) { // limit was i-1. Lots of similar bugs in this code, probably.
214
224
  temp -= a[k + i * lda] * b[k + j * ldb];
215
225
  }
216
226
  if (diag == CblasNonUnit) {
@@ -339,6 +349,92 @@ inline void trsm_nothrow(const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
339
349
  }
340
350
 
341
351
 
352
+ template <typename DType>
353
+ inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
354
+ const int K, const DType* alpha, const DType* A, const int lda, const DType* beta, DType* C, const int ldc) {
355
+ rb_raise(rb_eNotImpError, "syrk not yet implemented for non-BLAS dtypes");
356
+ }
357
+
358
+ template <typename DType>
359
+ inline void herk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
360
+ const int K, const DType* alpha, const DType* A, const int lda, const DType* beta, DType* C, const int ldc) {
361
+ rb_raise(rb_eNotImpError, "herk not yet implemented for non-BLAS dtypes");
362
+ }
363
+
364
+ template <>
365
+ inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
366
+ const int K, const float* alpha, const float* A, const int lda, const float* beta, float* C, const int ldc) {
367
+ cblas_ssyrk(Order, Uplo, Trans, N, K, *alpha, A, lda, *beta, C, ldc);
368
+ }
369
+
370
+ template <>
371
+ inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
372
+ const int K, const double* alpha, const double* A, const int lda, const double* beta, double* C, const int ldc) {
373
+ cblas_dsyrk(Order, Uplo, Trans, N, K, *alpha, A, lda, *beta, C, ldc);
374
+ }
375
+
376
+ template <>
377
+ inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
378
+ const int K, const Complex64* alpha, const Complex64* A, const int lda, const Complex64* beta, Complex64* C, const int ldc) {
379
+ cblas_csyrk(Order, Uplo, Trans, N, K, alpha, A, lda, beta, C, ldc);
380
+ }
381
+
382
+ template <>
383
+ inline void syrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
384
+ const int K, const Complex128* alpha, const Complex128* A, const int lda, const Complex128* beta, Complex128* C, const int ldc) {
385
+ cblas_zsyrk(Order, Uplo, Trans, N, K, alpha, A, lda, beta, C, ldc);
386
+ }
387
+
388
+
389
+ template <>
390
+ inline void herk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
391
+ const int K, const Complex64* alpha, const Complex64* A, const int lda, const Complex64* beta, Complex64* C, const int ldc) {
392
+ cblas_cherk(Order, Uplo, Trans, N, K, alpha->r, A, lda, beta->r, C, ldc);
393
+ }
394
+
395
+ template <>
396
+ inline void herk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE Trans, const int N,
397
+ const int K, const Complex128* alpha, const Complex128* A, const int lda, const Complex128* beta, Complex128* C, const int ldc) {
398
+ cblas_zherk(Order, Uplo, Trans, N, K, alpha->r, A, lda, beta->r, C, ldc);
399
+ }
400
+
401
+
402
+ template <typename DType>
403
+ inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
404
+ const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const DType* alpha,
405
+ const DType* A, const int lda, DType* B, const int ldb) {
406
+ rb_raise(rb_eNotImpError, "trmm not yet implemented for non-BLAS dtypes");
407
+ }
408
+
409
+ template <>
410
+ inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
411
+ const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const float* alpha,
412
+ const float* A, const int lda, float* B, const int ldb) {
413
+ cblas_strmm(order, side, uplo, ta, diag, m, n, *alpha, A, lda, B, ldb);
414
+ }
415
+
416
+ template <>
417
+ inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
418
+ const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const double* alpha,
419
+ const double* A, const int lda, double* B, const int ldb) {
420
+ cblas_dtrmm(order, side, uplo, ta, diag, m, n, *alpha, A, lda, B, ldb);
421
+ }
422
+
423
+ template <>
424
+ inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
425
+ const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const Complex64* alpha,
426
+ const Complex64* A, const int lda, Complex64* B, const int ldb) {
427
+ cblas_ctrmm(order, side, uplo, ta, diag, m, n, alpha, A, lda, B, ldb);
428
+ }
429
+
430
+ template <>
431
+ inline void trmm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
432
+ const enum CBLAS_TRANSPOSE ta, const enum CBLAS_DIAG diag, const int m, const int n, const Complex128* alpha,
433
+ const Complex128* A, const int lda, Complex128* B, const int ldb) {
434
+ cblas_ztrmm(order, side, uplo, ta, diag, m, n, alpha, A, lda, B, ldb);
435
+ }
436
+
437
+
342
438
  /*
343
439
  * BLAS' DTRSM function, generalized.
344
440
  */
@@ -349,6 +445,9 @@ inline void trsm(const enum CBLAS_ORDER order,
349
445
  const int m, const int n, const DType alpha, const DType* a,
350
446
  const int lda, DType* b, const int ldb)
351
447
  {
448
+ /*using std::cerr;
449
+ using std::endl;*/
450
+
352
451
  int num_rows_a = n;
353
452
  if (side == CblasLeft) num_rows_a = m;
354
453
 
@@ -368,6 +467,13 @@ inline void trsm(const enum CBLAS_ORDER order,
368
467
  enum CBLAS_SIDE side_ = side == CblasLeft ? CblasRight : CblasLeft;
369
468
  enum CBLAS_UPLO uplo_ = uplo == CblasUpper ? CblasLower : CblasUpper;
370
469
 
470
+ /*
471
+ cerr << "(row-major) trsm: " << (side_ == CblasLeft ? "left " : "right ")
472
+ << (uplo_ == CblasUpper ? "upper " : "lower ")
473
+ << (trans_a == CblasTrans ? "trans " : "notrans ")
474
+ << (diag == CblasNonUnit ? "nonunit " : "unit ")
475
+ << n << " " << m << " " << alpha << " a " << lda << " b " << ldb << endl;
476
+ */
371
477
  trsm_nothrow<DType>(side_, uplo_, trans_a, diag, n, m, alpha, a, lda, b, ldb);
372
478
 
373
479
  } else { // CblasColMajor
@@ -376,7 +482,13 @@ inline void trsm(const enum CBLAS_ORDER order,
376
482
  fprintf(stderr, "TRSM: M=%d; got ldb=%d\n", m, ldb);
377
483
  rb_raise(rb_eArgError, "TRSM: Expected ldb >= max(1,M)");
378
484
  }
379
-
485
+ /*
486
+ cerr << "(col-major) trsm: " << (side == CblasLeft ? "left " : "right ")
487
+ << (uplo == CblasUpper ? "upper " : "lower ")
488
+ << (trans_a == CblasTrans ? "trans " : "notrans ")
489
+ << (diag == CblasNonUnit ? "nonunit " : "unit ")
490
+ << m << " " << n << " " << alpha << " a " << lda << " b " << ldb << endl;
491
+ */
380
492
  trsm_nothrow<DType>(side, uplo, trans_a, diag, m, n, alpha, a, lda, b, ldb);
381
493
 
382
494
  }
@@ -390,7 +502,7 @@ inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const
390
502
  const int m, const int n, const float alpha, const float* a,
391
503
  const int lda, float* b, const int ldb)
392
504
  {
393
- cblas_strsm(CblasRowMajor, side, uplo, trans_a, diag, m, n, alpha, a, lda, b, ldb);
505
+ cblas_strsm(order, side, uplo, trans_a, diag, m, n, alpha, a, lda, b, ldb);
394
506
  }
395
507
 
396
508
  template <>
@@ -399,7 +511,15 @@ inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const
399
511
  const int m, const int n, const double alpha, const double* a,
400
512
  const int lda, double* b, const int ldb)
401
513
  {
402
- cblas_dtrsm(CblasRowMajor, side, uplo, trans_a, diag, m, n, alpha, a, lda, b, ldb);
514
+ /* using std::cerr;
515
+ using std::endl;
516
+ cerr << "(row-major) dtrsm: " << (side == CblasLeft ? "left " : "right ")
517
+ << (uplo == CblasUpper ? "upper " : "lower ")
518
+ << (trans_a == CblasTrans ? "trans " : "notrans ")
519
+ << (diag == CblasNonUnit ? "nonunit " : "unit ")
520
+ << m << " " << n << " " << alpha << " a " << lda << " b " << ldb << endl;
521
+ */
522
+ cblas_dtrsm(order, side, uplo, trans_a, diag, m, n, alpha, a, lda, b, ldb);
403
523
  }
404
524
 
405
525
 
@@ -409,7 +529,7 @@ inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const
409
529
  const int m, const int n, const Complex64 alpha, const Complex64* a,
410
530
  const int lda, Complex64* b, const int ldb)
411
531
  {
412
- cblas_ctrsm(CblasRowMajor, side, uplo, trans_a, diag, m, n, (const void*)(&alpha), (const void*)(a), lda, (void*)(b), ldb);
532
+ cblas_ctrsm(order, side, uplo, trans_a, diag, m, n, (const void*)(&alpha), (const void*)(a), lda, (void*)(b), ldb);
413
533
  }
414
534
 
415
535
  template <>
@@ -418,7 +538,7 @@ inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const
418
538
  const int m, const int n, const Complex128 alpha, const Complex128* a,
419
539
  const int lda, Complex128* b, const int ldb)
420
540
  {
421
- cblas_ztrsm(CblasRowMajor, side, uplo, trans_a, diag, m, n, (const void*)(&alpha), (const void*)(a), lda, (void*)(b), ldb);
541
+ cblas_ztrsm(order, side, uplo, trans_a, diag, m, n, (const void*)(&alpha), (const void*)(a), lda, (void*)(b), ldb);
422
542
  }
423
543
 
424
544
 
@@ -429,7 +549,7 @@ inline void trsm(const enum CBLAS_ORDER order, const enum CBLAS_SIDE side, const
429
549
  */
430
550
  template <typename DType>
431
551
  inline void laswp(const int N, DType* A, const int lda, const int K1, const int K2, const int *piv, const int inci) {
432
- const int n = K2 - K1;
552
+ //const int n = K2 - K1; // not sure why this is declared. commented it out because it's unused.
433
553
 
434
554
  int nb = N >> 5;
435
555
 
@@ -1261,6 +1381,87 @@ inline int getrf_nothrow(const int M, const int N, DType* A, const int lda, int*
1261
1381
  return(ierr);
1262
1382
  }
1263
1383
 
1384
+ /*
1385
+ * Solves a system of linear equations A*X = B with a general NxN matrix A using the LU factorization computed by GETRF.
1386
+ *
1387
+ * From ATLAS 3.8.0.
1388
+ */
1389
+ template <typename DType>
1390
+ int getrs(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE Trans, const int N, const int NRHS, const DType* A,
1391
+ const int lda, const int* ipiv, DType* B, const int ldb)
1392
+ {
1393
+ // enum CBLAS_DIAG Lunit, Uunit; // These aren't used. Not sure why they're declared in ATLAS' src.
1394
+
1395
+ if (!N || !NRHS) return 0;
1396
+
1397
+ const DType ONE = 1;
1398
+
1399
+ if (Order == CblasColMajor) {
1400
+ if (Trans == CblasNoTrans) {
1401
+ nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, 1);
1402
+ nm::math::trsm<DType>(Order, CblasLeft, CblasLower, CblasNoTrans, CblasUnit, N, NRHS, ONE, A, lda, B, ldb);
1403
+ nm::math::trsm<DType>(Order, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
1404
+ } else {
1405
+ nm::math::trsm<DType>(Order, CblasLeft, CblasUpper, Trans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
1406
+ nm::math::trsm<DType>(Order, CblasLeft, CblasLower, Trans, CblasUnit, N, NRHS, ONE, A, lda, B, ldb);
1407
+ nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, -1);
1408
+ }
1409
+ } else {
1410
+ if (Trans == CblasNoTrans) {
1411
+ nm::math::trsm<DType>(Order, CblasRight, CblasLower, CblasTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
1412
+ nm::math::trsm<DType>(Order, CblasRight, CblasUpper, CblasTrans, CblasUnit, NRHS, N, ONE, A, lda, B, ldb);
1413
+ nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, -1);
1414
+ } else {
1415
+ nm::math::laswp<DType>(NRHS, B, ldb, 0, N, ipiv, 1);
1416
+ nm::math::trsm<DType>(Order, CblasRight, CblasUpper, CblasNoTrans, CblasUnit, NRHS, N, ONE, A, lda, B, ldb);
1417
+ nm::math::trsm<DType>(Order, CblasRight, CblasLower, CblasNoTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
1418
+ }
1419
+ }
1420
+ return 0;
1421
+ }
1422
+
1423
+
1424
+ /*
1425
+ * Solves a system of linear equations A*X = B with a symmetric positive definite matrix A using the Cholesky factorization computed by POTRF.
1426
+ *
1427
+ * From ATLAS 3.8.0.
1428
+ */
1429
+ template <typename DType, bool is_complex>
1430
+ int potrs(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, const int NRHS, const DType* A,
1431
+ const int lda, DType* B, const int ldb)
1432
+ {
1433
+ // enum CBLAS_DIAG Lunit, Uunit; // These aren't used. Not sure why they're declared in ATLAS' src.
1434
+
1435
+ CBLAS_TRANSPOSE MyTrans = is_complex ? CblasConjTrans : CblasTrans;
1436
+
1437
+ if (!N || !NRHS) return 0;
1438
+
1439
+ const DType ONE = 1;
1440
+
1441
+ if (Order == CblasColMajor) {
1442
+ if (Uplo == CblasUpper) {
1443
+ nm::math::trsm<DType>(Order, CblasLeft, CblasUpper, MyTrans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
1444
+ nm::math::trsm<DType>(Order, CblasLeft, CblasUpper, CblasNoTrans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
1445
+ } else {
1446
+ nm::math::trsm<DType>(Order, CblasLeft, CblasLower, CblasNoTrans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
1447
+ nm::math::trsm<DType>(Order, CblasLeft, CblasLower, MyTrans, CblasNonUnit, N, NRHS, ONE, A, lda, B, ldb);
1448
+ }
1449
+ } else {
1450
+ // There's some kind of scaling operation that normally happens here in ATLAS. Not sure what it does, so we'll only
1451
+ // worry if something breaks. It probably has to do with their non-templated code and doesn't apply to us.
1452
+
1453
+ if (Uplo == CblasUpper) {
1454
+ nm::math::trsm<DType>(Order, CblasRight, CblasUpper, CblasNoTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
1455
+ nm::math::trsm<DType>(Order, CblasRight, CblasUpper, MyTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
1456
+ } else {
1457
+ nm::math::trsm<DType>(Order, CblasRight, CblasLower, MyTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
1458
+ nm::math::trsm<DType>(Order, CblasRight, CblasLower, CblasNoTrans, CblasNonUnit, NRHS, N, ONE, A, lda, B, ldb);
1459
+ }
1460
+ }
1461
+ return 0;
1462
+ }
1463
+
1464
+
1264
1465
 
1265
1466
  /*
1266
1467
  * From ATLAS 3.8.0:
@@ -1317,6 +1518,617 @@ inline int getrf(const enum CBLAS_ORDER Order, const int M, const int N, DType*
1317
1518
  }
1318
1519
 
1319
1520
 
1521
+ /*
1522
+ * From ATLAS 3.8.0:
1523
+ *
1524
+ * Computes one of two LU factorizations based on the setting of the Order
1525
+ * parameter, as follows:
1526
+ * ----------------------------------------------------------------------------
1527
+ * Order == CblasColMajor
1528
+ * Column-major factorization of form
1529
+ * A = P * L * U
1530
+ * where P is a row-permutation matrix, L is lower triangular with unit
1531
+ * diagonal elements (lower trapazoidal if M > N), and U is upper triangular
1532
+ * (upper trapazoidal if M < N).
1533
+ *
1534
+ * ----------------------------------------------------------------------------
1535
+ * Order == CblasRowMajor
1536
+ * Row-major factorization of form
1537
+ * A = P * L * U
1538
+ * where P is a column-permutation matrix, L is lower triangular (lower
1539
+ * trapazoidal if M > N), and U is upper triangular with unit diagonals (upper
1540
+ * trapazoidal if M < N).
1541
+ *
1542
+ * ============================================================================
1543
+ * Let IERR be the return value of the function:
1544
+ * If IERR == 0, successful exit.
1545
+ * If (IERR < 0) the -IERR argument had an illegal value
1546
+ * If (IERR > 0 && Order == CblasColMajor)
1547
+ * U(i-1,i-1) is exactly zero. The factorization has been completed,
1548
+ * but the factor U is exactly singular, and division by zero will
1549
+ * occur if it is used to solve a system of equations.
1550
+ * If (IERR > 0 && Order == CblasRowMajor)
1551
+ * L(i-1,i-1) is exactly zero. The factorization has been completed,
1552
+ * but the factor L is exactly singular, and division by zero will
1553
+ * occur if it is used to solve a system of equations.
1554
+ */
1555
+ template <typename DType>
1556
+ inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, DType* A, const int lda) {
1557
+ #ifdef HAVE_CLAPACK_H
1558
+ rb_raise(rb_eNotImpError, "not yet implemented for non-BLAS dtypes");
1559
+ #else
1560
+ rb_raise(rb_eNotImpError, "only LAPACK version implemented thus far");
1561
+ #endif
1562
+ return 0;
1563
+ }
1564
+
1565
+ #ifdef HAVE_CLAPACK_H
1566
+ template <>
1567
+ inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, float* A, const int lda) {
1568
+ return clapack_spotrf(order, uplo, N, A, lda);
1569
+ }
1570
+
1571
+ template <>
1572
+ inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, double* A, const int lda) {
1573
+ return clapack_dpotrf(order, uplo, N, A, lda);
1574
+ }
1575
+
1576
+ template <>
1577
+ inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, Complex64* A, const int lda) {
1578
+ return clapack_cpotrf(order, uplo, N, reinterpret_cast<void*>(A), lda);
1579
+ }
1580
+
1581
+ template <>
1582
+ inline int potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, Complex128* A, const int lda) {
1583
+ return clapack_zpotrf(order, uplo, N, reinterpret_cast<void*>(A), lda);
1584
+ }
1585
+ #endif
1586
+
1587
+
1588
+ // This is the old BLAS version of this function. ATLAS has an optimized version, but
1589
+ // it's going to be tough to translate.
1590
+ template <typename DType>
1591
+ static void swap(const int N, DType* X, const int incX, DType* Y, const int incY) {
1592
+ if (N > 0) {
1593
+ int ix = 0, iy = 0;
1594
+ for (int i = 0; i < N; ++i) {
1595
+ DType temp = X[i];
1596
+ X[i] = Y[i];
1597
+ Y[i] = temp;
1598
+
1599
+ ix += incX;
1600
+ iy += incY;
1601
+ }
1602
+ }
1603
+ }
1604
+
1605
+
1606
+ // Copies an upper row-major array from U, zeroing U; U is unit, so diagonal is not copied.
1607
+ //
1608
+ // From ATLAS 3.8.0.
1609
+ template <typename DType>
1610
+ static inline void trcpzeroU(const int M, const int N, DType* U, const int ldu, DType* C, const int ldc) {
1611
+
1612
+ for (int i = 0; i != M; ++i) {
1613
+ for (int j = i+1; j < N; ++j) {
1614
+ C[j] = U[j];
1615
+ U[j] = 0;
1616
+ }
1617
+
1618
+ C += ldc;
1619
+ U += ldu;
1620
+ }
1621
+ }
1622
+
1623
+
1624
+ /*
1625
+ * Un-comment the following lines when we figure out how to calculate NB for each of the ATLAS-derived
1626
+ * functions. This is probably really complicated.
1627
+ *
1628
+ * Also needed: ATL_MulByNB, ATL_DivByNB (both defined in the build process for ATLAS), and ATL_mmMU.
1629
+ *
1630
+ */
1631
+
1632
+ /*
1633
+
1634
+ template <bool RowMajor, bool Upper, typename DType>
1635
+ static int trtri_4(const enum CBLAS_DIAG Diag, DType* A, const int lda) {
1636
+
1637
+ if (RowMajor) {
1638
+ DType *pA0 = A, *pA1 = A+lda, *pA2 = A+2*lda, *pA3 = A+3*lda;
1639
+ DType tmp;
1640
+ if (Upper) {
1641
+ DType A01 = pA0[1], A02 = pA0[2], A03 = pA0[3],
1642
+ A12 = pA1[2], A13 = pA1[3],
1643
+ A23 = pA2[3];
1644
+
1645
+ if (Diag == CblasNonUnit) {
1646
+ pA0->inverse();
1647
+ (pA1+1)->inverse();
1648
+ (pA2+2)->inverse();
1649
+ (pA3+3)->inverse();
1650
+
1651
+ pA0[1] = -A01 * pA1[1] * pA0[0];
1652
+ pA1[2] = -A12 * pA2[2] * pA1[1];
1653
+ pA2[3] = -A23 * pA3[3] * pA2[2];
1654
+
1655
+ pA0[2] = -(A01 * pA1[2] + A02 * pA2[2]) * pA0[0];
1656
+ pA1[3] = -(A12 * pA2[3] + A13 * pA3[3]) * pA1[1];
1657
+
1658
+ pA0[3] = -(A01 * pA1[3] + A02 * pA2[3] + A03 * pA3[3]) * pA0[0];
1659
+
1660
+ } else {
1661
+
1662
+ pA0[1] = -A01;
1663
+ pA1[2] = -A12;
1664
+ pA2[3] = -A23;
1665
+
1666
+ pA0[2] = -(A01 * pA1[2] + A02);
1667
+ pA1[3] = -(A12 * pA2[3] + A13);
1668
+
1669
+ pA0[3] = -(A01 * pA1[3] + A02 * pA2[3] + A03);
1670
+ }
1671
+
1672
+ } else { // Lower
1673
+ DType A10 = pA1[0],
1674
+ A20 = pA2[0], A21 = pA2[1],
1675
+ A30 = PA3[0], A31 = pA3[1], A32 = pA3[2];
1676
+ DType *B10 = pA1,
1677
+ *B20 = pA2,
1678
+ *B30 = pA3,
1679
+ *B21 = pA2+1,
1680
+ *B31 = pA3+1,
1681
+ *B32 = pA3+2;
1682
+
1683
+
1684
+ if (Diag == CblasNonUnit) {
1685
+ pA0->inverse();
1686
+ (pA1+1)->inverse();
1687
+ (pA2+2)->inverse();
1688
+ (pA3+3)->inverse();
1689
+
1690
+ *B10 = -A10 * pA0[0] * pA1[1];
1691
+ *B21 = -A21 * pA1[1] * pA2[2];
1692
+ *B32 = -A32 * pA2[2] * pA3[3];
1693
+ *B20 = -(A20 * pA0[0] + A21 * (*B10)) * pA2[2];
1694
+ *B31 = -(A31 * pA1[1] + A32 * (*B21)) * pA3[3];
1695
+ *B30 = -(A30 * pA0[0] + A31 * (*B10) + A32 * (*B20)) * pA3;
1696
+ } else {
1697
+ *B10 = -A10;
1698
+ *B21 = -A21;
1699
+ *B32 = -A32;
1700
+ *B20 = -(A20 + A21 * (*B10));
1701
+ *B31 = -(A31 + A32 * (*B21));
1702
+ *B30 = -(A30 + A31 * (*B10) + A32 * (*B20));
1703
+ }
1704
+ }
1705
+
1706
+ } else {
1707
+ rb_raise(rb_eNotImpError, "only row-major implemented at this time");
1708
+ }
1709
+
1710
+ return 0;
1711
+
1712
+ }
1713
+
1714
+
1715
+ template <bool RowMajor, bool Upper, typename DType>
1716
+ static int trtri_3(const enum CBLAS_DIAG Diag, DType* A, const int lda) {
1717
+
1718
+ if (RowMajor) {
1719
+
1720
+ DType tmp;
1721
+
1722
+ if (Upper) {
1723
+ DType A01 = pA0[1], A02 = pA0[2], A03 = pA0[3],
1724
+ A12 = pA1[2], A13 = pA1[3];
1725
+
1726
+ DType *B01 = pA0 + 1,
1727
+ *B02 = pA0 + 2,
1728
+ *B12 = pA1 + 2;
1729
+
1730
+ if (Diag == CblasNonUnit) {
1731
+ pA0->inverse();
1732
+ (pA1+1)->inverse();
1733
+ (pA2+2)->inverse();
1734
+
1735
+ *B01 = -A01 * pA1[1] * pA0[0];
1736
+ *B12 = -A12 * pA2[2] * pA1[1];
1737
+ *B02 = -(A01 * (*B12) + A02 * pA2[2]) * pA0[0];
1738
+ } else {
1739
+ *B01 = -A01;
1740
+ *B12 = -A12;
1741
+ *B02 = -(A01 * (*B12) + A02);
1742
+ }
1743
+
1744
+ } else { // Lower
1745
+ DType *pA0=A, *pA1=A+lda, *pA2=A+2*lda;
1746
+ DType A10=pA1[0],
1747
+ A20=pA2[0], A21=pA2[1];
1748
+
1749
+ DType *B10 = pA1,
1750
+ *B20 = pA2;
1751
+ *B21 = pA2+1;
1752
+
1753
+ if (Diag == CblasNonUnit) {
1754
+ pA0->inverse();
1755
+ (pA1+1)->inverse();
1756
+ (pA2+2)->inverse();
1757
+ *B10 = -A10 * pA0[0] * pA1[1];
1758
+ *B21 = -A21 * pA1[1] * pA2[2];
1759
+ *B20 = -(A20 * pA0[0] + A21 * (*B10)) * pA2[2];
1760
+ } else {
1761
+ *B10 = -A10;
1762
+ *B21 = -A21;
1763
+ *B20 = -(A20 + A21 * (*B10));
1764
+ }
1765
+ }
1766
+
1767
+
1768
+ } else {
1769
+ rb_raise(rb_eNotImpError, "only row-major implemented at this time");
1770
+ }
1771
+
1772
+ return 0;
1773
+
1774
+ }
1775
+
1776
+ template <bool RowMajor, bool Upper, bool Real, typename DType>
1777
+ static void trtri(const enum CBLAS_DIAG Diag, const int N, DType* A, const int lda) {
1778
+ DType *Age, *Atr;
1779
+ DType tmp;
1780
+ int Nleft, Nright;
1781
+
1782
+ int ierr = 0;
1783
+
1784
+ static const DType ONE = 1;
1785
+ static const DType MONE -1;
1786
+ static const DType NONE = -1;
1787
+
1788
+ if (RowMajor) {
1789
+
1790
+ // FIXME: Use REAL_RECURSE_LIMIT here for float32 and float64 (instead of 1)
1791
+ if ((Real && N > REAL_RECURSE_LIMIT) || (N > 1)) {
1792
+ Nleft = N >> 1;
1793
+ #ifdef NB
1794
+ if (Nleft > NB) NLeft = ATL_MulByNB(ATL_DivByNB(Nleft));
1795
+ #endif
1796
+
1797
+ Nright = N - Nleft;
1798
+
1799
+ if (Upper) {
1800
+ Age = A + Nleft;
1801
+ Atr = A + (Nleft * (lda+1));
1802
+
1803
+ nm::math::trsm<DType>(CblasRowMajor, CblasRight, CblasUpper, CblasNoTrans, Diag,
1804
+ Nleft, Nright, ONE, Atr, lda, Age, lda);
1805
+
1806
+ nm::math::trsm<DType>(CblasRowMajor, CblasLeft, CblasUpper, CblasNoTrans, Diag,
1807
+ Nleft, Nright, MONE, A, lda, Age, lda);
1808
+
1809
+ } else { // Lower
1810
+ Age = A + ((Nleft*lda));
1811
+ Atr = A + (Nleft * (lda+1));
1812
+
1813
+ nm::math::trsm<DType>(CblasRowMajor, CblasRight, CblasLower, CblasNoTrans, Diag,
1814
+ Nright, Nleft, ONE, A, lda, Age, lda);
1815
+ nm::math::trsm<DType>(CblasRowMajor, CblasLeft, CblasLower, CblasNoTrans, Diag,
1816
+ Nright, Nleft, MONE, Atr, lda, Age, lda);
1817
+ }
1818
+
1819
+ ierr = trtri<RowMajor,Upper,Real,DType>(Diag, Nleft, A, lda);
1820
+ if (ierr) return ierr;
1821
+
1822
+ ierr = trtri<RowMajor,Upper,Real,DType>(Diag, Nright, Atr, lda);
1823
+ if (ierr) return ierr + Nleft;
1824
+
1825
+ } else {
1826
+ if (Real) {
1827
+ if (N == 4) {
1828
+ return trtri_4<RowMajor,Upper,Real,DType>(Diag, A, lda);
1829
+ } else if (N == 3) {
1830
+ return trtri_3<RowMajor,Upper,Real,DType>(Diag, A, lda);
1831
+ } else if (N == 2) {
1832
+ if (Diag == CblasNonUnit) {
1833
+ A->inverse();
1834
+ (A+(lda+1))->inverse();
1835
+
1836
+ if (Upper) {
1837
+ *(A+1) *= *A; // TRI_MUL
1838
+ *(A+1) *= *(A+lda+1); // TRI_MUL
1839
+ } else {
1840
+ *(A+lda) *= *A; // TRI_MUL
1841
+ *(A+lda) *= *(A+lda+1); // TRI_MUL
1842
+ }
1843
+ }
1844
+
1845
+ if (Upper) *(A+1) = -*(A+1); // TRI_NEG
1846
+ else *(A+lda) = -*(A+lda); // TRI_NEG
1847
+ } else if (Diag == CblasNonUnit) A->inverse();
1848
+ } else { // not real
1849
+ if (Diag == CblasNonUnit) A->inverse();
1850
+ }
1851
+ }
1852
+
1853
+ } else {
1854
+ rb_raise(rb_eNotImpError, "only row-major implemented at this time");
1855
+ }
1856
+
1857
+ return ierr;
1858
+ }
1859
+
1860
+
1861
+ template <bool RowMajor, bool Real, typename DType>
1862
+ int getri(const int N, DType* A, const int lda, const int* ipiv, DType* wrk, const int lwrk) {
1863
+
1864
+ if (!RowMajor) rb_raise(rb_eNotImpError, "only row-major implemented at this time");
1865
+
1866
+ int jb, nb, I, ndown, iret;
1867
+
1868
+ const DType ONE = 1, NONE = -1;
1869
+
1870
+ int iret = trtri<RowMajor,false,Real,DType>(CblasNonUnit, N, A, lda);
1871
+ if (!iret && N > 1) {
1872
+ jb = lwrk / N;
1873
+ if (jb >= NB) nb = ATL_MulByNB(ATL_DivByNB(jb));
1874
+ else if (jb >= ATL_mmMU) nb = (jb/ATL_mmMU)*ATL_mmMU;
1875
+ else nb = jb;
1876
+ if (!nb) return -6; // need at least 1 row of workspace
1877
+
1878
+ // only first iteration will have partial block, unroll it
1879
+
1880
+ jb = N - (N/nb) * nb;
1881
+ if (!jb) jb = nb;
1882
+ I = N - jb;
1883
+ A += lda * I;
1884
+ trcpzeroU<DType>(jb, jb, A+I, lda, wrk, jb);
1885
+ nm::math::trsm<DType>(CblasRowMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasUnit,
1886
+ jb, N, ONE, wrk, jb, A, lda);
1887
+
1888
+ if (I) {
1889
+ do {
1890
+ I -= nb;
1891
+ A -= nb * lda;
1892
+ ndown = N-I;
1893
+ trcpzeroU<DType>(nb, ndown, A+I, lda, wrk, ndown);
1894
+ nm::math::gemm<DType>(CblasRowMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasUnit,
1895
+ nb, N, ONE, wrk, ndown, A, lda);
1896
+ } while (I);
1897
+ }
1898
+
1899
+ // Apply row interchanges
1900
+
1901
+ for (I = N - 2; I >= 0; --I) {
1902
+ jb = ipiv[I];
1903
+ if (jb != I) nm::math::swap<DType>(N, A+I*lda, 1, A+jb*lda, 1);
1904
+ }
1905
+ }
1906
+
1907
+ return iret;
1908
+ }
1909
+ */
1910
+
1911
+
1912
+ // TODO: Test this to see if it works properly on complex. ATLAS has a separate algorithm for complex, which looks like
1913
+ // TODO: it may actually be the same one.
1914
+ //
1915
+ // This function is called ATL_rot in ATLAS 3.8.4.
1916
+ template <typename DType>
1917
+ inline void rot_helper(const int N, DType* X, const int incX, DType* Y, const int incY, const DType c, const DType s) {
1918
+ if (c != 1 || s != 0) {
1919
+ if (incX == 1 && incY == 1) {
1920
+ for (int i = 0; i != N; ++i) {
1921
+ DType tmp = X[i] * c + Y[i] * s;
1922
+ Y[i] = Y[i] * c - X[i] * s;
1923
+ X[i] = tmp;
1924
+ }
1925
+ } else {
1926
+ for (int i = N; i > 0; --i, Y += incY, X += incX) {
1927
+ DType tmp = *X * c + *Y * s;
1928
+ *Y = *Y * c - *X * s;
1929
+ *X = tmp;
1930
+ }
1931
+ }
1932
+ }
1933
+ }
1934
+
1935
+
1936
+ /* Givens plane rotation. From ATLAS 3.8.4. */
1937
+ // FIXME: Need a specialized algorithm for Rationals. BLAS' algorithm simply will not work for most values due to the
1938
+ // FIXME: sqrt.
1939
+ template <typename DType>
1940
+ inline void rotg(DType* a, DType* b, DType* c, DType* s) {
1941
+ DType aa = std::abs(*a), ab = std::abs(*b);
1942
+ DType roe = aa > ab ? *a : *b;
1943
+ DType scal = aa + ab;
1944
+
1945
+ if (scal == 0) {
1946
+ *c = 1;
1947
+ *s = *a = *b = 0;
1948
+ } else {
1949
+ DType t0 = aa / scal, t1 = ab / scal;
1950
+ DType r = scal * std::sqrt(t0 * t0 + t1 * t1);
1951
+ if (roe < 0) r = -r;
1952
+ *c = *a / r;
1953
+ *s = *b / r;
1954
+ DType z = (*c != 0) ? (1 / *c) : DType(1);
1955
+ *a = r;
1956
+ *b = z;
1957
+ }
1958
+ }
1959
+
1960
+ template <>
1961
+ inline void rotg(float* a, float* b, float* c, float* s) {
1962
+ cblas_srotg(a, b, c, s);
1963
+ }
1964
+
1965
+ template <>
1966
+ inline void rotg(double* a, double* b, double* c, double* s) {
1967
+ cblas_drotg(a, b, c, s);
1968
+ }
1969
+
1970
+ template <>
1971
+ inline void rotg(Complex64* a, Complex64* b, Complex64* c, Complex64* s) {
1972
+ cblas_crotg(reinterpret_cast<void*>(a), reinterpret_cast<void*>(b), reinterpret_cast<void*>(c), reinterpret_cast<void*>(s));
1973
+ }
1974
+
1975
+ template <>
1976
+ inline void rotg(Complex128* a, Complex128* b, Complex128* c, Complex128* s) {
1977
+ cblas_zrotg(reinterpret_cast<void*>(a), reinterpret_cast<void*>(b), reinterpret_cast<void*>(c), reinterpret_cast<void*>(s));
1978
+ }
1979
+
1980
+ template <typename DType>
1981
+ inline void cblas_rotg(void* a, void* b, void* c, void* s) {
1982
+ rotg<DType>(reinterpret_cast<DType*>(a), reinterpret_cast<DType*>(b), reinterpret_cast<DType*>(c), reinterpret_cast<DType*>(s));
1983
+ }
1984
+
1985
+
1986
+ /* Applies a plane rotation. From ATLAS 3.8.4. */
1987
+ template <typename DType, typename CSDType>
1988
+ inline void rot(const int N, DType* X, const int incX, DType* Y, const int incY, const CSDType c, const CSDType s) {
1989
+ DType *x = X, *y = Y;
1990
+ int incx = incX, incy = incY;
1991
+
1992
+ if (N > 0) {
1993
+ if (incX < 0) {
1994
+ if (incY < 0) { incx = -incx; incy = -incy; }
1995
+ else x += -incX * (N-1);
1996
+ } else if (incY < 0) {
1997
+ incy = -incy;
1998
+ incx = -incx;
1999
+ x += (N-1) * incX;
2000
+ }
2001
+ rot_helper<DType>(N, x, incx, y, incy, c, s);
2002
+ }
2003
+ }
2004
+
2005
+ template <>
2006
+ inline void rot(const int N, float* X, const int incX, float* Y, const int incY, const float c, const float s) {
2007
+ cblas_srot(N, X, incX, Y, incY, (float)c, (float)s);
2008
+ }
2009
+
2010
+ template <>
2011
+ inline void rot(const int N, double* X, const int incX, double* Y, const int incY, const double c, const double s) {
2012
+ cblas_drot(N, X, incX, Y, incY, c, s);
2013
+ }
2014
+
2015
+ template <>
2016
+ inline void rot(const int N, Complex64* X, const int incX, Complex64* Y, const int incY, const float c, const float s) {
2017
+ cblas_csrot(N, X, incX, Y, incY, c, s);
2018
+ }
2019
+
2020
+ template <>
2021
+ inline void rot(const int N, Complex128* X, const int incX, Complex128* Y, const int incY, const double c, const double s) {
2022
+ cblas_zdrot(N, X, incX, Y, incY, c, s);
2023
+ }
2024
+
2025
+
2026
+ template <typename DType, typename CSDType>
2027
+ inline void cblas_rot(const int N, void* X, const int incX, void* Y, const int incY, const void* c, const void* s) {
2028
+ rot<DType,CSDType>(N, reinterpret_cast<DType*>(X), incX, reinterpret_cast<DType*>(Y), incY, *reinterpret_cast<const CSDType*>(c), *reinterpret_cast<const CSDType*>(s));
2029
+ }
2030
+
2031
+
2032
+ template <bool is_complex, typename DType>
2033
+ inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, DType* A, const int lda) {
2034
+
2035
+ int Nleft, Nright;
2036
+ const DType ONE = 1;
2037
+ DType *G, *U0 = A, *U1;
2038
+
2039
+ if (N > 1) {
2040
+ Nleft = N >> 1;
2041
+ #ifdef NB
2042
+ if (Nleft > NB) Nleft = ATL_MulByNB(ATL_DivByNB(Nleft));
2043
+ #endif
2044
+
2045
+ Nright = N - Nleft;
2046
+
2047
+ // FIXME: There's a simpler way to write this next block, but I'm way too tired to work it out right now.
2048
+ if (uplo == CblasUpper) {
2049
+ if (order == CblasRowMajor) {
2050
+ G = A + Nleft;
2051
+ U1 = G + Nleft * lda;
2052
+ } else {
2053
+ G = A + Nleft * lda;
2054
+ U1 = G + Nleft;
2055
+ }
2056
+ } else {
2057
+ if (order == CblasRowMajor) {
2058
+ G = A + Nleft * lda;
2059
+ U1 = G + Nleft;
2060
+ } else {
2061
+ G = A + Nleft;
2062
+ U1 = G + Nleft * lda;
2063
+ }
2064
+ }
2065
+
2066
+ lauum<is_complex, DType>(order, uplo, Nleft, U0, lda);
2067
+
2068
+ if (is_complex) {
2069
+
2070
+ nm::math::herk<DType>(order, uplo,
2071
+ uplo == CblasLower ? CblasConjTrans : CblasNoTrans,
2072
+ Nleft, Nright, &ONE, G, lda, &ONE, U0, lda);
2073
+
2074
+ nm::math::trmm<DType>(order, CblasLeft, uplo, CblasConjTrans, CblasNonUnit, Nright, Nleft, &ONE, U1, lda, G, lda);
2075
+ } else {
2076
+ nm::math::syrk<DType>(order, uplo,
2077
+ uplo == CblasLower ? CblasTrans : CblasNoTrans,
2078
+ Nleft, Nright, &ONE, G, lda, &ONE, U0, lda);
2079
+
2080
+ nm::math::trmm<DType>(order, CblasLeft, uplo, CblasTrans, CblasNonUnit, Nright, Nleft, &ONE, U1, lda, G, lda);
2081
+ }
2082
+ lauum<is_complex, DType>(order, uplo, Nright, U1, lda);
2083
+
2084
+ } else {
2085
+ *A = *A * *A;
2086
+ }
2087
+ }
2088
+
2089
+
2090
+ #ifdef HAVE_CLAPACK_H
2091
+ template <bool is_complex>
2092
+ inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, float* A, const int lda) {
2093
+ clapack_slauum(order, uplo, N, A, lda);
2094
+ }
2095
+
2096
+ template <bool is_complex>
2097
+ inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, double* A, const int lda) {
2098
+ clapack_dlauum(order, uplo, N, A, lda);
2099
+ }
2100
+
2101
+ template <bool is_complex>
2102
+ inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, Complex64* A, const int lda) {
2103
+ clapack_clauum(order, uplo, N, A, lda);
2104
+ }
2105
+
2106
+ template <bool is_complex>
2107
+ inline void lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int N, Complex128* A, const int lda) {
2108
+ clapack_zlauum(order, uplo, N, A, lda);
2109
+ }
2110
+ #endif
2111
+
2112
+
2113
+ /*
2114
+ * Function signature conversion for calling LAPACK's lauum functions as directly as possible.
2115
+ *
2116
+ * For documentation: http://www.netlib.org/lapack/double/dlauum.f
2117
+ *
2118
+ * This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
2119
+ */
2120
+ template <bool is_complex, typename DType>
2121
+ inline int clapack_lauum(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, void* a, const int lda) {
2122
+ if (n < 0) rb_raise(rb_eArgError, "n cannot be less than zero, is set to %d", n);
2123
+ if (lda < n || lda < 1) rb_raise(rb_eArgError, "lda must be >= max(n,1); lda=%d, n=%d\n", lda, n);
2124
+
2125
+ if (uplo == CblasUpper) lauum<is_complex, DType>(order, uplo, n, reinterpret_cast<DType*>(a), lda);
2126
+ else lauum<is_complex, DType>(order, uplo, n, reinterpret_cast<DType*>(a), lda);
2127
+
2128
+ return 0;
2129
+ }
2130
+
2131
+
1320
2132
 
1321
2133
 
1322
2134
  /*
@@ -1357,6 +2169,143 @@ inline int clapack_getrf(const enum CBLAS_ORDER order, const int m, const int n,
1357
2169
  }
1358
2170
 
1359
2171
 
2172
+ /*
2173
+ * Function signature conversion for calling LAPACK's potrf functions as directly as possible.
2174
+ *
2175
+ * For documentation: http://www.netlib.org/lapack/double/dpotrf.f
2176
+ *
2177
+ * This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
2178
+ */
2179
+ template <typename DType>
2180
+ inline int clapack_potrf(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, void* a, const int lda) {
2181
+ return potrf<DType>(order, uplo, n, reinterpret_cast<DType*>(a), lda);
2182
+ }
2183
+
2184
+
2185
+ /*
2186
+ * Function signature conversion for calling LAPACK's getrs functions as directly as possible.
2187
+ *
2188
+ * For documentation: http://www.netlib.org/lapack/double/dgetrs.f
2189
+ *
2190
+ * This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
2191
+ */
2192
+ template <typename DType>
2193
+ inline int clapack_getrs(const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans, const int n, const int nrhs,
2194
+ const void* a, const int lda, const int* ipiv, void* b, const int ldb) {
2195
+ return getrs<DType>(order, trans, n, nrhs, reinterpret_cast<const DType*>(a), lda, ipiv, reinterpret_cast<DType*>(b), ldb);
2196
+ }
2197
+
2198
+ /*
2199
+ * Function signature conversion for calling LAPACK's potrs functions as directly as possible.
2200
+ *
2201
+ * For documentation: http://www.netlib.org/lapack/double/dpotrs.f
2202
+ *
2203
+ * This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
2204
+ */
2205
+ template <typename DType, bool is_complex>
2206
+ inline int clapack_potrs(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, const int nrhs,
2207
+ const void* a, const int lda, void* b, const int ldb) {
2208
+ return potrs<DType,is_complex>(order, uplo, n, nrhs, reinterpret_cast<const DType*>(a), lda, reinterpret_cast<DType*>(b), ldb);
2209
+ }
2210
+
2211
+ template <typename DType>
2212
+ inline int getri(const enum CBLAS_ORDER order, const int n, DType* a, const int lda, const int* ipiv) {
2213
+ rb_raise(rb_eNotImpError, "getri not yet implemented for non-BLAS dtypes");
2214
+ return 0;
2215
+ }
2216
+
2217
+ #ifdef HAVE_CLAPACK_H
2218
+ template <>
2219
+ inline int getri(const enum CBLAS_ORDER order, const int n, float* a, const int lda, const int* ipiv) {
2220
+ return clapack_sgetri(order, n, a, lda, ipiv);
2221
+ }
2222
+
2223
+ template <>
2224
+ inline int getri(const enum CBLAS_ORDER order, const int n, double* a, const int lda, const int* ipiv) {
2225
+ return clapack_dgetri(order, n, a, lda, ipiv);
2226
+ }
2227
+
2228
+ template <>
2229
+ inline int getri(const enum CBLAS_ORDER order, const int n, Complex64* a, const int lda, const int* ipiv) {
2230
+ return clapack_cgetri(order, n, reinterpret_cast<void*>(a), lda, ipiv);
2231
+ }
2232
+
2233
+ template <>
2234
+ inline int getri(const enum CBLAS_ORDER order, const int n, Complex128* a, const int lda, const int* ipiv) {
2235
+ return clapack_zgetri(order, n, reinterpret_cast<void*>(a), lda, ipiv);
2236
+ }
2237
+ #endif
2238
+
2239
+
2240
+ template <typename DType>
2241
+ inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, DType* a, const int lda) {
2242
+ rb_raise(rb_eNotImpError, "potri not yet implemented for non-BLAS dtypes");
2243
+ return 0;
2244
+ }
2245
+
2246
+
2247
+ #ifdef HAVE_CLAPACK_H
2248
+ template <>
2249
+ inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, float* a, const int lda) {
2250
+ return clapack_spotri(order, uplo, n, a, lda);
2251
+ }
2252
+
2253
+ template <>
2254
+ inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, double* a, const int lda) {
2255
+ return clapack_dpotri(order, uplo, n, a, lda);
2256
+ }
2257
+
2258
+ template <>
2259
+ inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, Complex64* a, const int lda) {
2260
+ return clapack_cpotri(order, uplo, n, reinterpret_cast<void*>(a), lda);
2261
+ }
2262
+
2263
+ template <>
2264
+ inline int potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, Complex128* a, const int lda) {
2265
+ return clapack_zpotri(order, uplo, n, reinterpret_cast<void*>(a), lda);
2266
+ }
2267
+ #endif
2268
+
2269
+ /*
2270
+ * Function signature conversion for calling LAPACK's getri functions as directly as possible.
2271
+ *
2272
+ * For documentation: http://www.netlib.org/lapack/double/dgetri.f
2273
+ *
2274
+ * This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
2275
+ */
2276
+ template <typename DType>
2277
+ inline int clapack_getri(const enum CBLAS_ORDER order, const int n, void* a, const int lda, const int* ipiv) {
2278
+ return getri<DType>(order, n, reinterpret_cast<DType*>(a), lda, ipiv);
2279
+ }
2280
+
2281
+
2282
+ /*
2283
+ * Function signature conversion for calling LAPACK's potri functions as directly as possible.
2284
+ *
2285
+ * For documentation: http://www.netlib.org/lapack/double/dpotri.f
2286
+ *
2287
+ * This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
2288
+ */
2289
+ template <typename DType>
2290
+ inline int clapack_potri(const enum CBLAS_ORDER order, const enum CBLAS_UPLO uplo, const int n, void* a, const int lda) {
2291
+ return potri<DType>(order, uplo, n, reinterpret_cast<DType*>(a), lda);
2292
+ }
2293
+
2294
+
2295
+ /*
2296
+ * Function signature conversion for calling LAPACK's laswp functions as directly as possible.
2297
+ *
2298
+ * For documentation: http://www.netlib.org/lapack/double/dlaswp.f
2299
+ *
2300
+ * This function should normally go in math.cpp, but we need it to be available to nmatrix.cpp.
2301
+ */
2302
+ template <typename DType>
2303
+ inline void clapack_laswp(const int n, void* a, const int lda, const int k1, const int k2, const int* ipiv, const int incx) {
2304
+ laswp<DType>(n, reinterpret_cast<DType*>(a), lda, k1, k2, ipiv, incx);
2305
+ }
2306
+
2307
+
2308
+
1360
2309
  }} // end namespace nm::math
1361
2310
 
1362
2311