pnmatrix 1.2.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (111) hide show
  1. checksums.yaml +7 -0
  2. data/ext/nmatrix/binary_format.txt +53 -0
  3. data/ext/nmatrix/data/complex.h +388 -0
  4. data/ext/nmatrix/data/data.cpp +274 -0
  5. data/ext/nmatrix/data/data.h +651 -0
  6. data/ext/nmatrix/data/meta.h +64 -0
  7. data/ext/nmatrix/data/ruby_object.h +386 -0
  8. data/ext/nmatrix/extconf.rb +70 -0
  9. data/ext/nmatrix/math/asum.h +99 -0
  10. data/ext/nmatrix/math/cblas_enums.h +36 -0
  11. data/ext/nmatrix/math/cblas_templates_core.h +507 -0
  12. data/ext/nmatrix/math/gemm.h +241 -0
  13. data/ext/nmatrix/math/gemv.h +178 -0
  14. data/ext/nmatrix/math/getrf.h +255 -0
  15. data/ext/nmatrix/math/getrs.h +121 -0
  16. data/ext/nmatrix/math/imax.h +82 -0
  17. data/ext/nmatrix/math/laswp.h +165 -0
  18. data/ext/nmatrix/math/long_dtype.h +62 -0
  19. data/ext/nmatrix/math/magnitude.h +54 -0
  20. data/ext/nmatrix/math/math.h +751 -0
  21. data/ext/nmatrix/math/nrm2.h +165 -0
  22. data/ext/nmatrix/math/rot.h +117 -0
  23. data/ext/nmatrix/math/rotg.h +106 -0
  24. data/ext/nmatrix/math/scal.h +71 -0
  25. data/ext/nmatrix/math/trsm.h +336 -0
  26. data/ext/nmatrix/math/util.h +162 -0
  27. data/ext/nmatrix/math.cpp +1368 -0
  28. data/ext/nmatrix/nm_memory.h +60 -0
  29. data/ext/nmatrix/nmatrix.cpp +285 -0
  30. data/ext/nmatrix/nmatrix.h +476 -0
  31. data/ext/nmatrix/ruby_constants.cpp +151 -0
  32. data/ext/nmatrix/ruby_constants.h +106 -0
  33. data/ext/nmatrix/ruby_nmatrix.c +3130 -0
  34. data/ext/nmatrix/storage/common.cpp +77 -0
  35. data/ext/nmatrix/storage/common.h +183 -0
  36. data/ext/nmatrix/storage/dense/dense.cpp +1096 -0
  37. data/ext/nmatrix/storage/dense/dense.h +129 -0
  38. data/ext/nmatrix/storage/list/list.cpp +1628 -0
  39. data/ext/nmatrix/storage/list/list.h +138 -0
  40. data/ext/nmatrix/storage/storage.cpp +730 -0
  41. data/ext/nmatrix/storage/storage.h +99 -0
  42. data/ext/nmatrix/storage/yale/class.h +1139 -0
  43. data/ext/nmatrix/storage/yale/iterators/base.h +143 -0
  44. data/ext/nmatrix/storage/yale/iterators/iterator.h +131 -0
  45. data/ext/nmatrix/storage/yale/iterators/row.h +450 -0
  46. data/ext/nmatrix/storage/yale/iterators/row_stored.h +140 -0
  47. data/ext/nmatrix/storage/yale/iterators/row_stored_nd.h +169 -0
  48. data/ext/nmatrix/storage/yale/iterators/stored_diagonal.h +124 -0
  49. data/ext/nmatrix/storage/yale/math/transpose.h +110 -0
  50. data/ext/nmatrix/storage/yale/yale.cpp +2074 -0
  51. data/ext/nmatrix/storage/yale/yale.h +203 -0
  52. data/ext/nmatrix/types.h +55 -0
  53. data/ext/nmatrix/util/io.cpp +279 -0
  54. data/ext/nmatrix/util/io.h +115 -0
  55. data/ext/nmatrix/util/sl_list.cpp +627 -0
  56. data/ext/nmatrix/util/sl_list.h +144 -0
  57. data/ext/nmatrix/util/util.h +78 -0
  58. data/lib/nmatrix/blas.rb +378 -0
  59. data/lib/nmatrix/cruby/math.rb +744 -0
  60. data/lib/nmatrix/enumerate.rb +253 -0
  61. data/lib/nmatrix/homogeneous.rb +241 -0
  62. data/lib/nmatrix/io/fortran_format.rb +138 -0
  63. data/lib/nmatrix/io/harwell_boeing.rb +221 -0
  64. data/lib/nmatrix/io/market.rb +263 -0
  65. data/lib/nmatrix/io/point_cloud.rb +189 -0
  66. data/lib/nmatrix/jruby/decomposition.rb +24 -0
  67. data/lib/nmatrix/jruby/enumerable.rb +13 -0
  68. data/lib/nmatrix/jruby/error.rb +4 -0
  69. data/lib/nmatrix/jruby/math.rb +501 -0
  70. data/lib/nmatrix/jruby/nmatrix_java.rb +840 -0
  71. data/lib/nmatrix/jruby/operators.rb +283 -0
  72. data/lib/nmatrix/jruby/slice.rb +264 -0
  73. data/lib/nmatrix/lapack_core.rb +181 -0
  74. data/lib/nmatrix/lapack_plugin.rb +44 -0
  75. data/lib/nmatrix/math.rb +953 -0
  76. data/lib/nmatrix/mkmf.rb +100 -0
  77. data/lib/nmatrix/monkeys.rb +137 -0
  78. data/lib/nmatrix/nmatrix.rb +1172 -0
  79. data/lib/nmatrix/rspec.rb +75 -0
  80. data/lib/nmatrix/shortcuts.rb +1163 -0
  81. data/lib/nmatrix/version.rb +39 -0
  82. data/lib/nmatrix/yale_functions.rb +118 -0
  83. data/lib/nmatrix.rb +28 -0
  84. data/spec/00_nmatrix_spec.rb +892 -0
  85. data/spec/01_enum_spec.rb +196 -0
  86. data/spec/02_slice_spec.rb +407 -0
  87. data/spec/03_nmatrix_monkeys_spec.rb +80 -0
  88. data/spec/2x2_dense_double.mat +0 -0
  89. data/spec/4x4_sparse.mat +0 -0
  90. data/spec/4x5_dense.mat +0 -0
  91. data/spec/blas_spec.rb +215 -0
  92. data/spec/elementwise_spec.rb +311 -0
  93. data/spec/homogeneous_spec.rb +100 -0
  94. data/spec/io/fortran_format_spec.rb +88 -0
  95. data/spec/io/harwell_boeing_spec.rb +98 -0
  96. data/spec/io/test.rua +9 -0
  97. data/spec/io_spec.rb +159 -0
  98. data/spec/lapack_core_spec.rb +482 -0
  99. data/spec/leakcheck.rb +16 -0
  100. data/spec/math_spec.rb +1363 -0
  101. data/spec/nmatrix_yale_resize_test_associations.yaml +2802 -0
  102. data/spec/nmatrix_yale_spec.rb +286 -0
  103. data/spec/rspec_monkeys.rb +56 -0
  104. data/spec/rspec_spec.rb +35 -0
  105. data/spec/shortcuts_spec.rb +474 -0
  106. data/spec/slice_set_spec.rb +162 -0
  107. data/spec/spec_helper.rb +172 -0
  108. data/spec/stat_spec.rb +214 -0
  109. data/spec/test.pcd +20 -0
  110. data/spec/utm5940.mtx +83844 -0
  111. metadata +295 -0
@@ -0,0 +1,751 @@
1
+ /////////////////////////////////////////////////////////////////////
2
+ // = NMatrix
3
+ //
4
+ // A linear algebra library for scientific computation in Ruby.
5
+ // NMatrix is part of SciRuby.
6
+ //
7
+ // NMatrix was originally inspired by and derived from NArray, by
8
+ // Masahiro Tanaka: http://narray.rubyforge.org
9
+ //
10
+ // == Copyright Information
11
+ //
12
+ // SciRuby is Copyright (c) 2010 - 2014, Ruby Science Foundation
13
+ // NMatrix is Copyright (c) 2012 - 2014, John Woods and the Ruby Science Foundation
14
+ //
15
+ // Please see LICENSE.txt for additional copyright notices.
16
+ //
17
+ // == Contributing
18
+ //
19
+ // By contributing source code to SciRuby, you agree to be bound by
20
+ // our Contributor Agreement:
21
+ //
22
+ // * https://github.com/SciRuby/sciruby/wiki/Contributor-Agreement
23
+ //
24
+ // == math.h
25
+ //
26
+ // Header file for math functions, interfacing with BLAS, etc.
27
+ //
28
+ // For instructions on adding CBLAS and CLAPACK functions, see the
29
+ // beginning of math.cpp.
30
+ //
31
+ // Some of these functions are from ATLAS. Here is the license for
32
+ // ATLAS:
33
+ //
34
+ /*
35
+ * Automatically Tuned Linear Algebra Software v3.8.4
36
+ * (C) Copyright 1999 R. Clint Whaley
37
+ *
38
+ * Redistribution and use in source and binary forms, with or without
39
+ * modification, are permitted provided that the following conditions
40
+ * are met:
41
+ * 1. Redistributions of source code must retain the above copyright
42
+ * notice, this list of conditions and the following disclaimer.
43
+ * 2. Redistributions in binary form must reproduce the above copyright
44
+ * notice, this list of conditions, and the following disclaimer in the
45
+ * documentation and/or other materials provided with the distribution.
46
+ * 3. The name of the ATLAS group or the names of its contributers may
47
+ * not be used to endorse or promote products derived from this
48
+ * software without specific written permission.
49
+ *
50
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
51
+ * ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
52
+ * TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
53
+ * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE ATLAS GROUP OR ITS CONTRIBUTORS
54
+ * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
55
+ * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
56
+ * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
57
+ * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
58
+ * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
59
+ * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
60
+ * POSSIBILITY OF SUCH DAMAGE.
61
+ *
62
+ */
63
+
64
+ #ifndef MATH_H
65
+ #define MATH_H
66
+
67
+ /*
68
+ * Standard Includes
69
+ */
70
+
71
+ #include "cblas_enums.h"
72
+
73
+ #include <algorithm> // std::min, std::max
74
+ #include <limits> // std::numeric_limits
75
+ #include <memory> // std::unique_ptr
76
+
77
+ /*
78
+ * Project Includes
79
+ */
80
+
81
+ /*
82
+ * Macros
83
+ */
84
+ #define REAL_RECURSE_LIMIT 4
85
+
86
+ /*
87
+ * Data
88
+ */
89
+
90
+
91
+ extern "C" {
92
+ /*
93
+ * C accessors.
94
+ */
95
+
96
+ void nm_math_transpose_generic(const size_t M, const size_t N, const void* A, const int lda, void* B, const int ldb, size_t element_size);
97
+ void nm_math_init_blas(void);
98
+
99
+ /*
100
+ * Pure math implementations.
101
+ */
102
+ void nm_math_solve(VALUE lu, VALUE b, VALUE x, VALUE ipiv);
103
+ void nm_math_inverse(const int M, void* A_elements, nm::dtype_t dtype);
104
+ void nm_math_hessenberg(VALUE a);
105
+ void nm_math_det_exact_from_dense(const int M, const void* elements,
106
+ const int lda, nm::dtype_t dtype, void* result);
107
+ void nm_math_det_exact_from_yale(const int M, const YALE_STORAGE* storage,
108
+ const int lda, nm::dtype_t dtype, void* result);
109
+ void nm_math_inverse_exact_from_dense(const int M, const void* A_elements,
110
+ const int lda, void* B_elements, const int ldb, nm::dtype_t dtype);
111
+ void nm_math_inverse_exact_from_yale(const int M, const YALE_STORAGE* storage,
112
+ const int lda, YALE_STORAGE* inverse, const int ldb, nm::dtype_t dtype);
113
+ }
114
+
115
+
116
+ namespace nm {
117
+ namespace math {
118
+
119
+ /*
120
+ * Types
121
+ */
122
+
123
+
124
+ /*
125
+ * Functions
126
+ */
127
+
128
+ // Yale: numeric matrix multiply c=a*b
129
+ template <typename DType>
130
+ inline void numbmm(const unsigned int n, const unsigned int m, const unsigned int l, const IType* ia, const IType* ja, const DType* a, const bool diaga,
131
+ const IType* ib, const IType* jb, const DType* b, const bool diagb, IType* ic, IType* jc, DType* c, const bool diagc) {
132
+ const unsigned int max_lmn = std::max(std::max(m, n), l);
133
+ std::unique_ptr<IType[]> next(new IType[max_lmn]);
134
+ std::unique_ptr<DType[]> sums(new DType[max_lmn]);
135
+
136
+ DType v;
137
+
138
+ IType head, length, temp, ndnz = 0;
139
+ IType minmn = std::min(m,n);
140
+ IType minlm = std::min(l,m);
141
+
142
+ for (IType idx = 0; idx < max_lmn; ++idx) { // initialize scratch arrays
143
+ next[idx] = std::numeric_limits<IType>::max();
144
+ sums[idx] = 0;
145
+ }
146
+
147
+ for (IType i = 0; i < n; ++i) { // walk down the rows
148
+ head = std::numeric_limits<IType>::max()-1; // head gets assigned as whichever column of B's row j we last visited
149
+ length = 0;
150
+
151
+ for (IType jj = ia[i]; jj <= ia[i+1]; ++jj) { // walk through entries in each row
152
+ IType j;
153
+
154
+ if (jj == ia[i+1]) { // if we're in the last entry for this row:
155
+ if (!diaga || i >= minmn) continue;
156
+ j = i; // if it's a new Yale matrix, and last entry, get the diagonal position (j) and entry (ajj)
157
+ v = a[i];
158
+ } else {
159
+ j = ja[jj]; // if it's not the last entry for this row, get the column (j) and entry (ajj)
160
+ v = a[jj];
161
+ }
162
+
163
+ for (IType kk = ib[j]; kk <= ib[j+1]; ++kk) {
164
+
165
+ IType k;
166
+
167
+ if (kk == ib[j+1]) { // Get the column id for that entry
168
+ if (!diagb || j >= minlm) continue;
169
+ k = j;
170
+ sums[k] += v*b[k];
171
+ } else {
172
+ k = jb[kk];
173
+ sums[k] += v*b[kk];
174
+ }
175
+
176
+ if (next[k] == std::numeric_limits<IType>::max()) {
177
+ next[k] = head;
178
+ head = k;
179
+ ++length;
180
+ }
181
+ } // end of kk loop
182
+ } // end of jj loop
183
+
184
+ for (IType jj = 0; jj < length; ++jj) {
185
+ if (sums[head] != 0) {
186
+ if (diagc && head == i) {
187
+ c[head] = sums[head];
188
+ } else {
189
+ jc[n+1+ndnz] = head;
190
+ c[n+1+ndnz] = sums[head];
191
+ ++ndnz;
192
+ }
193
+ }
194
+
195
+ temp = head;
196
+ head = next[head];
197
+
198
+ next[temp] = std::numeric_limits<IType>::max();
199
+ sums[temp] = 0;
200
+ }
201
+
202
+ ic[i+1] = n+1+ndnz;
203
+ }
204
+ } /* numbmm_ */
205
+
206
+
207
+ /*
208
+ template <typename DType, typename IType>
209
+ inline void new_yale_matrix_multiply(const unsigned int m, const IType* ija, const DType* a, const IType* ijb, const DType* b, YALE_STORAGE* c_storage) {
210
+ unsigned int n = c_storage->shape[0],
211
+ l = c_storage->shape[1];
212
+
213
+ // Create a working vector of dimension max(m,l,n) and initial value IType::max():
214
+ std::vector<IType> mask(std::max(std::max(m,l),n), std::numeric_limits<IType>::max());
215
+
216
+ for (IType i = 0; i < n; ++i) { // A.rows.each_index do |i|
217
+
218
+ IType j, k;
219
+ size_t ndnz;
220
+
221
+ for (IType jj = ija[i]; jj <= ija[i+1]; ++jj) { // walk through column pointers for row i of A
222
+ j = (jj == ija[i+1]) ? i : ija[jj]; // Get the current column index (handle diagonals last)
223
+
224
+ if (j >= m) {
225
+ if (j == ija[jj]) rb_raise(rb_eIndexError, "ija array for left-hand matrix contains an out-of-bounds column index %u at position %u", jj, j);
226
+ else break;
227
+ }
228
+
229
+ for (IType kk = ijb[j]; kk <= ijb[j+1]; ++kk) { // walk through column pointers for row j of B
230
+ if (j >= m) continue; // first of all, does B *have* a row j?
231
+ k = (kk == ijb[j+1]) ? j : ijb[kk]; // Get the current column index (handle diagonals last)
232
+
233
+ if (k >= l) {
234
+ if (k == ijb[kk]) rb_raise(rb_eIndexError, "ija array for right-hand matrix contains an out-of-bounds column index %u at position %u", kk, k);
235
+ else break;
236
+ }
237
+
238
+ if (mask[k] == )
239
+ }
240
+
241
+ }
242
+ }
243
+ }
244
+ */
245
+
246
+ // Yale: Symbolic matrix multiply c=a*b
247
+ inline size_t symbmm(const unsigned int n, const unsigned int m, const unsigned int l, const IType* ia, const IType* ja, const bool diaga,
248
+ const IType* ib, const IType* jb, const bool diagb, IType* ic, const bool diagc) {
249
+ unsigned int max_lmn = std::max(std::max(m,n), l);
250
+ IType mask[max_lmn]; // INDEX in the SMMP paper.
251
+ IType j, k; /* Local variables */
252
+ size_t ndnz = n;
253
+
254
+ for (IType idx = 0; idx < max_lmn; ++idx)
255
+ mask[idx] = std::numeric_limits<IType>::max();
256
+
257
+ if (ic) { // Only write to ic if it's supplied; otherwise, we're just counting.
258
+ if (diagc) ic[0] = n+1;
259
+ else ic[0] = 0;
260
+ }
261
+
262
+ IType minmn = std::min(m,n);
263
+ IType minlm = std::min(l,m);
264
+
265
+ for (IType i = 0; i < n; ++i) { // MAIN LOOP: through rows
266
+
267
+ for (IType jj = ia[i]; jj <= ia[i+1]; ++jj) { // merge row lists, walking through columns in each row
268
+
269
+ // j <- column index given by JA[jj], or handle diagonal.
270
+ if (jj == ia[i+1]) { // Don't really do it the last time -- just handle diagonals in a new yale matrix.
271
+ if (!diaga || i >= minmn) continue;
272
+ j = i;
273
+ } else j = ja[jj];
274
+
275
+ for (IType kk = ib[j]; kk <= ib[j+1]; ++kk) { // Now walk through columns K of row J in matrix B.
276
+ if (kk == ib[j+1]) {
277
+ if (!diagb || j >= minlm) continue;
278
+ k = j;
279
+ } else k = jb[kk];
280
+
281
+ if (mask[k] != i) {
282
+ mask[k] = i;
283
+ ++ndnz;
284
+ }
285
+ }
286
+ }
287
+
288
+ if (diagc && mask[i] == std::numeric_limits<IType>::max()) --ndnz;
289
+
290
+ if (ic) ic[i+1] = ndnz;
291
+ }
292
+
293
+ return ndnz;
294
+ } /* symbmm_ */
295
+
296
+
297
+ // In-place quicksort (from Wikipedia) -- called by smmp_sort_columns, below. All functions are inclusive of left, right.
298
+ namespace smmp_sort {
299
+ const size_t THRESHOLD = 4; // switch to insertion sort for 4 elements or fewer
300
+
301
+ template <typename DType>
302
+ void print_array(DType* vals, IType* array, IType left, IType right) {
303
+ for (IType i = left; i <= right; ++i) {
304
+ std::cerr << array[i] << ":" << vals[i] << " ";
305
+ }
306
+ std::cerr << std::endl;
307
+ }
308
+
309
+ template <typename DType>
310
+ IType partition(DType* vals, IType* array, IType left, IType right, IType pivot) {
311
+ IType pivotJ = array[pivot];
312
+ DType pivotV = vals[pivot];
313
+
314
+ // Swap pivot and right
315
+ array[pivot] = array[right];
316
+ vals[pivot] = vals[right];
317
+ array[right] = pivotJ;
318
+ vals[right] = pivotV;
319
+
320
+ IType store = left;
321
+ for (IType idx = left; idx < right; ++idx) {
322
+ if (array[idx] <= pivotJ) {
323
+ // Swap i and store
324
+ std::swap(array[idx], array[store]);
325
+ std::swap(vals[idx], vals[store]);
326
+ ++store;
327
+ }
328
+ }
329
+
330
+ std::swap(array[store], array[right]);
331
+ std::swap(vals[store], vals[right]);
332
+
333
+ return store;
334
+ }
335
+
336
+ // Recommended to use the median of left, right, and mid for the pivot.
337
+ template <typename I>
338
+ inline I median(I a, I b, I c) {
339
+ if (a < b) {
340
+ if (b < c) return b; // a b c
341
+ if (a < c) return c; // a c b
342
+ return a; // c a b
343
+
344
+ } else { // a > b
345
+ if (a < c) return a; // b a c
346
+ if (b < c) return c; // b c a
347
+ return b; // c b a
348
+ }
349
+ }
350
+
351
+
352
+ // Insertion sort is more efficient than quicksort for small N
353
+ template <typename DType>
354
+ void insertion_sort(DType* vals, IType* array, IType left, IType right) {
355
+ for (IType idx = left; idx <= right; ++idx) {
356
+ IType col_to_insert = array[idx];
357
+ DType val_to_insert = vals[idx];
358
+
359
+ IType hole_pos = idx;
360
+ for (; hole_pos > left && col_to_insert < array[hole_pos-1]; --hole_pos) {
361
+ array[hole_pos] = array[hole_pos - 1]; // shift the larger column index up
362
+ vals[hole_pos] = vals[hole_pos - 1]; // value goes along with it
363
+ }
364
+
365
+ array[hole_pos] = col_to_insert;
366
+ vals[hole_pos] = val_to_insert;
367
+ }
368
+ }
369
+
370
+
371
+ template <typename DType>
372
+ void quicksort(DType* vals, IType* array, IType left, IType right) {
373
+
374
+ if (left < right) {
375
+ if (right - left < THRESHOLD) {
376
+ insertion_sort(vals, array, left, right);
377
+ } else {
378
+ // choose any pivot such that left < pivot < right
379
+ IType pivot = median<IType>(left, right, (IType)(((unsigned long)left + (unsigned long)right) / 2));
380
+ pivot = partition(vals, array, left, right, pivot);
381
+
382
+ // recursively sort elements smaller than the pivot
383
+ quicksort<DType>(vals, array, left, pivot-1);
384
+
385
+ // recursively sort elements at least as big as the pivot
386
+ quicksort<DType>(vals, array, pivot+1, right);
387
+ }
388
+ }
389
+ }
390
+
391
+
392
+ }; // end of namespace smmp_sort
393
+
394
+
395
+ /*
396
+ * For use following symbmm and numbmm. Sorts the matrix entries in each row according to the column index.
397
+ * This utilizes quicksort, which is an in-place unstable sort (since there are no duplicate entries, we don't care
398
+ * about stability).
399
+ *
400
+ * TODO: It might be worthwhile to do a test for free memory, and if available, use an unstable sort that isn't in-place.
401
+ *
402
+ * TODO: It's actually probably possible to write an even faster sort, since symbmm/numbmm are not producing a random
403
+ * ordering. If someone is doing a lot of Yale matrix multiplication, it might benefit them to consider even insertion
404
+ * sort.
405
+ */
406
+ template <typename DType>
407
+ inline void smmp_sort_columns(const size_t n, const IType* ia, IType* ja, DType* a) {
408
+ for (size_t i = 0; i < n; ++i) {
409
+ if (ia[i+1] - ia[i] < 2) continue; // no need to sort rows containing only one or two elements.
410
+ else if (ia[i+1] - ia[i] <= smmp_sort::THRESHOLD) {
411
+ smmp_sort::insertion_sort<DType>(a, ja, ia[i], ia[i+1]-1); // faster for small rows
412
+ } else {
413
+ smmp_sort::quicksort<DType>(a, ja, ia[i], ia[i+1]-1); // faster for large rows (and may call insertion_sort as well)
414
+ }
415
+ }
416
+ }
417
+
418
+
419
+ // Copies an upper row-major array from U, zeroing U; U is unit, so diagonal is not copied.
420
+ //
421
+ // From ATLAS 3.8.0.
422
+ template <typename DType>
423
+ static inline void trcpzeroU(const int M, const int N, DType* U, const int ldu, DType* C, const int ldc) {
424
+
425
+ for (int i = 0; i != M; ++i) {
426
+ for (int j = i+1; j < N; ++j) {
427
+ C[j] = U[j];
428
+ U[j] = 0;
429
+ }
430
+
431
+ C += ldc;
432
+ U += ldu;
433
+ }
434
+ }
435
+
436
+
437
+ /*
438
+ * Un-comment the following lines when we figure out how to calculate NB for each of the ATLAS-derived
439
+ * functions. This is probably really complicated.
440
+ *
441
+ * Also needed: ATL_MulByNB, ATL_DivByNB (both defined in the build process for ATLAS), and ATL_mmMU.
442
+ *
443
+ */
444
+
445
+ /*
446
+
447
+ template <bool RowMajor, bool Upper, typename DType>
448
+ static int trtri_4(const enum CBLAS_DIAG Diag, DType* A, const int lda) {
449
+
450
+ if (RowMajor) {
451
+ DType *pA0 = A, *pA1 = A+lda, *pA2 = A+2*lda, *pA3 = A+3*lda;
452
+ DType tmp;
453
+ if (Upper) {
454
+ DType A01 = pA0[1], A02 = pA0[2], A03 = pA0[3],
455
+ A12 = pA1[2], A13 = pA1[3],
456
+ A23 = pA2[3];
457
+
458
+ if (Diag == CblasNonUnit) {
459
+ pA0->inverse();
460
+ (pA1+1)->inverse();
461
+ (pA2+2)->inverse();
462
+ (pA3+3)->inverse();
463
+
464
+ pA0[1] = -A01 * pA1[1] * pA0[0];
465
+ pA1[2] = -A12 * pA2[2] * pA1[1];
466
+ pA2[3] = -A23 * pA3[3] * pA2[2];
467
+
468
+ pA0[2] = -(A01 * pA1[2] + A02 * pA2[2]) * pA0[0];
469
+ pA1[3] = -(A12 * pA2[3] + A13 * pA3[3]) * pA1[1];
470
+
471
+ pA0[3] = -(A01 * pA1[3] + A02 * pA2[3] + A03 * pA3[3]) * pA0[0];
472
+
473
+ } else {
474
+
475
+ pA0[1] = -A01;
476
+ pA1[2] = -A12;
477
+ pA2[3] = -A23;
478
+
479
+ pA0[2] = -(A01 * pA1[2] + A02);
480
+ pA1[3] = -(A12 * pA2[3] + A13);
481
+
482
+ pA0[3] = -(A01 * pA1[3] + A02 * pA2[3] + A03);
483
+ }
484
+
485
+ } else { // Lower
486
+ DType A10 = pA1[0],
487
+ A20 = pA2[0], A21 = pA2[1],
488
+ A30 = PA3[0], A31 = pA3[1], A32 = pA3[2];
489
+ DType *B10 = pA1,
490
+ *B20 = pA2,
491
+ *B30 = pA3,
492
+ *B21 = pA2+1,
493
+ *B31 = pA3+1,
494
+ *B32 = pA3+2;
495
+
496
+
497
+ if (Diag == CblasNonUnit) {
498
+ pA0->inverse();
499
+ (pA1+1)->inverse();
500
+ (pA2+2)->inverse();
501
+ (pA3+3)->inverse();
502
+
503
+ *B10 = -A10 * pA0[0] * pA1[1];
504
+ *B21 = -A21 * pA1[1] * pA2[2];
505
+ *B32 = -A32 * pA2[2] * pA3[3];
506
+ *B20 = -(A20 * pA0[0] + A21 * (*B10)) * pA2[2];
507
+ *B31 = -(A31 * pA1[1] + A32 * (*B21)) * pA3[3];
508
+ *B30 = -(A30 * pA0[0] + A31 * (*B10) + A32 * (*B20)) * pA3;
509
+ } else {
510
+ *B10 = -A10;
511
+ *B21 = -A21;
512
+ *B32 = -A32;
513
+ *B20 = -(A20 + A21 * (*B10));
514
+ *B31 = -(A31 + A32 * (*B21));
515
+ *B30 = -(A30 + A31 * (*B10) + A32 * (*B20));
516
+ }
517
+ }
518
+
519
+ } else {
520
+ rb_raise(rb_eNotImpError, "only row-major implemented at this time");
521
+ }
522
+
523
+ return 0;
524
+
525
+ }
526
+
527
+
528
+ template <bool RowMajor, bool Upper, typename DType>
529
+ static int trtri_3(const enum CBLAS_DIAG Diag, DType* A, const int lda) {
530
+
531
+ if (RowMajor) {
532
+
533
+ DType tmp;
534
+
535
+ if (Upper) {
536
+ DType A01 = pA0[1], A02 = pA0[2], A03 = pA0[3],
537
+ A12 = pA1[2], A13 = pA1[3];
538
+
539
+ DType *B01 = pA0 + 1,
540
+ *B02 = pA0 + 2,
541
+ *B12 = pA1 + 2;
542
+
543
+ if (Diag == CblasNonUnit) {
544
+ pA0->inverse();
545
+ (pA1+1)->inverse();
546
+ (pA2+2)->inverse();
547
+
548
+ *B01 = -A01 * pA1[1] * pA0[0];
549
+ *B12 = -A12 * pA2[2] * pA1[1];
550
+ *B02 = -(A01 * (*B12) + A02 * pA2[2]) * pA0[0];
551
+ } else {
552
+ *B01 = -A01;
553
+ *B12 = -A12;
554
+ *B02 = -(A01 * (*B12) + A02);
555
+ }
556
+
557
+ } else { // Lower
558
+ DType *pA0=A, *pA1=A+lda, *pA2=A+2*lda;
559
+ DType A10=pA1[0],
560
+ A20=pA2[0], A21=pA2[1];
561
+
562
+ DType *B10 = pA1,
563
+ *B20 = pA2;
564
+ *B21 = pA2+1;
565
+
566
+ if (Diag == CblasNonUnit) {
567
+ pA0->inverse();
568
+ (pA1+1)->inverse();
569
+ (pA2+2)->inverse();
570
+ *B10 = -A10 * pA0[0] * pA1[1];
571
+ *B21 = -A21 * pA1[1] * pA2[2];
572
+ *B20 = -(A20 * pA0[0] + A21 * (*B10)) * pA2[2];
573
+ } else {
574
+ *B10 = -A10;
575
+ *B21 = -A21;
576
+ *B20 = -(A20 + A21 * (*B10));
577
+ }
578
+ }
579
+
580
+
581
+ } else {
582
+ rb_raise(rb_eNotImpError, "only row-major implemented at this time");
583
+ }
584
+
585
+ return 0;
586
+
587
+ }
588
+
589
+ template <bool RowMajor, bool Upper, bool Real, typename DType>
590
+ static void trtri(const enum CBLAS_DIAG Diag, const int N, DType* A, const int lda) {
591
+ DType *Age, *Atr;
592
+ DType tmp;
593
+ int Nleft, Nright;
594
+
595
+ int ierr = 0;
596
+
597
+ static const DType ONE = 1;
598
+ static const DType MONE -1;
599
+ static const DType NONE = -1;
600
+
601
+ if (RowMajor) {
602
+
603
+ // FIXME: Use REAL_RECURSE_LIMIT here for float32 and float64 (instead of 1)
604
+ if ((Real && N > REAL_RECURSE_LIMIT) || (N > 1)) {
605
+ Nleft = N >> 1;
606
+ #ifdef NB
607
+ if (Nleft > NB) NLeft = ATL_MulByNB(ATL_DivByNB(Nleft));
608
+ #endif
609
+
610
+ Nright = N - Nleft;
611
+
612
+ if (Upper) {
613
+ Age = A + Nleft;
614
+ Atr = A + (Nleft * (lda+1));
615
+
616
+ nm::math::trsm<DType>(CblasRowMajor, CblasRight, CblasUpper, CblasNoTrans, Diag,
617
+ Nleft, Nright, ONE, Atr, lda, Age, lda);
618
+
619
+ nm::math::trsm<DType>(CblasRowMajor, CblasLeft, CblasUpper, CblasNoTrans, Diag,
620
+ Nleft, Nright, MONE, A, lda, Age, lda);
621
+
622
+ } else { // Lower
623
+ Age = A + ((Nleft*lda));
624
+ Atr = A + (Nleft * (lda+1));
625
+
626
+ nm::math::trsm<DType>(CblasRowMajor, CblasRight, CblasLower, CblasNoTrans, Diag,
627
+ Nright, Nleft, ONE, A, lda, Age, lda);
628
+ nm::math::trsm<DType>(CblasRowMajor, CblasLeft, CblasLower, CblasNoTrans, Diag,
629
+ Nright, Nleft, MONE, Atr, lda, Age, lda);
630
+ }
631
+
632
+ ierr = trtri<RowMajor,Upper,Real,DType>(Diag, Nleft, A, lda);
633
+ if (ierr) return ierr;
634
+
635
+ ierr = trtri<RowMajor,Upper,Real,DType>(Diag, Nright, Atr, lda);
636
+ if (ierr) return ierr + Nleft;
637
+
638
+ } else {
639
+ if (Real) {
640
+ if (N == 4) {
641
+ return trtri_4<RowMajor,Upper,Real,DType>(Diag, A, lda);
642
+ } else if (N == 3) {
643
+ return trtri_3<RowMajor,Upper,Real,DType>(Diag, A, lda);
644
+ } else if (N == 2) {
645
+ if (Diag == CblasNonUnit) {
646
+ A->inverse();
647
+ (A+(lda+1))->inverse();
648
+
649
+ if (Upper) {
650
+ *(A+1) *= *A; // TRI_MUL
651
+ *(A+1) *= *(A+lda+1); // TRI_MUL
652
+ } else {
653
+ *(A+lda) *= *A; // TRI_MUL
654
+ *(A+lda) *= *(A+lda+1); // TRI_MUL
655
+ }
656
+ }
657
+
658
+ if (Upper) *(A+1) = -*(A+1); // TRI_NEG
659
+ else *(A+lda) = -*(A+lda); // TRI_NEG
660
+ } else if (Diag == CblasNonUnit) A->inverse();
661
+ } else { // not real
662
+ if (Diag == CblasNonUnit) A->inverse();
663
+ }
664
+ }
665
+
666
+ } else {
667
+ rb_raise(rb_eNotImpError, "only row-major implemented at this time");
668
+ }
669
+
670
+ return ierr;
671
+ }
672
+
673
+
674
+ template <bool RowMajor, bool Real, typename DType>
675
+ int getri(const int N, DType* A, const int lda, const int* ipiv, DType* wrk, const int lwrk) {
676
+
677
+ if (!RowMajor) rb_raise(rb_eNotImpError, "only row-major implemented at this time");
678
+
679
+ int jb, nb, I, ndown, iret;
680
+
681
+ const DType ONE = 1, NONE = -1;
682
+
683
+ int iret = trtri<RowMajor,false,Real,DType>(CblasNonUnit, N, A, lda);
684
+ if (!iret && N > 1) {
685
+ jb = lwrk / N;
686
+ if (jb >= NB) nb = ATL_MulByNB(ATL_DivByNB(jb));
687
+ else if (jb >= ATL_mmMU) nb = (jb/ATL_mmMU)*ATL_mmMU;
688
+ else nb = jb;
689
+ if (!nb) return -6; // need at least 1 row of workspace
690
+
691
+ // only first iteration will have partial block, unroll it
692
+
693
+ jb = N - (N/nb) * nb;
694
+ if (!jb) jb = nb;
695
+ I = N - jb;
696
+ A += lda * I;
697
+ trcpzeroU<DType>(jb, jb, A+I, lda, wrk, jb);
698
+ nm::math::trsm<DType>(CblasRowMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasUnit,
699
+ jb, N, ONE, wrk, jb, A, lda);
700
+
701
+ if (I) {
702
+ do {
703
+ I -= nb;
704
+ A -= nb * lda;
705
+ ndown = N-I;
706
+ trcpzeroU<DType>(nb, ndown, A+I, lda, wrk, ndown);
707
+ nm::math::gemm<DType>(CblasRowMajor, CblasLeft, CblasUpper, CblasNoTrans, CblasUnit,
708
+ nb, N, ONE, wrk, ndown, A, lda);
709
+ } while (I);
710
+ }
711
+
712
+ // Apply row interchanges
713
+
714
+ for (I = N - 2; I >= 0; --I) {
715
+ jb = ipiv[I];
716
+ if (jb != I) nm::math::swap<DType>(N, A+I*lda, 1, A+jb*lda, 1);
717
+ }
718
+ }
719
+
720
+ return iret;
721
+ }
722
+ */
723
+
724
+ /*
725
+ * Macro for declaring LAPACK specializations of the getrf function.
726
+ *
727
+ * type is the DType; call is the specific function to call; cast_as is what the DType* should be
728
+ * cast to in order to pass it to LAPACK.
729
+ */
730
+ #define LAPACK_GETRF(type, call, cast_as) \
731
+ template <> \
732
+ inline int getrf(const enum CBLAS_ORDER Order, const int M, const int N, type * A, const int lda, int* ipiv) { \
733
+ int info = call(Order, M, N, reinterpret_cast<cast_as *>(A), lda, ipiv); \
734
+ if (!info) return info; \
735
+ else { \
736
+ rb_raise(rb_eArgError, "getrf: problem with argument %d\n", info); \
737
+ return info; \
738
+ } \
739
+ }
740
+
741
+ /* Specialize for ATLAS types */
742
+ /*LAPACK_GETRF(float, clapack_sgetrf, float)
743
+ LAPACK_GETRF(double, clapack_dgetrf, double)
744
+ LAPACK_GETRF(Complex64, clapack_cgetrf, void)
745
+ LAPACK_GETRF(Complex128, clapack_zgetrf, void)
746
+ */
747
+
748
+ }} // end namespace nm::math
749
+
750
+
751
+ #endif // MATH_H