nmatrix 0.2.0 → 0.2.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/ext/nmatrix/data/complex.h +183 -159
- data/ext/nmatrix/data/data.cpp +113 -112
- data/ext/nmatrix/data/data.h +306 -292
- data/ext/nmatrix/data/ruby_object.h +193 -193
- data/ext/nmatrix/extconf.rb +11 -9
- data/ext/nmatrix/math.cpp +9 -11
- data/ext/nmatrix/math/math.h +3 -2
- data/ext/nmatrix/math/trsm.h +152 -152
- data/ext/nmatrix/nmatrix.h +30 -0
- data/ext/nmatrix/ruby_constants.cpp +67 -67
- data/ext/nmatrix/ruby_constants.h +35 -35
- data/ext/nmatrix/ruby_nmatrix.c +168 -183
- data/ext/nmatrix/storage/common.h +4 -3
- data/ext/nmatrix/storage/dense/dense.cpp +50 -50
- data/ext/nmatrix/storage/dense/dense.h +8 -7
- data/ext/nmatrix/storage/list/list.cpp +16 -16
- data/ext/nmatrix/storage/list/list.h +7 -6
- data/ext/nmatrix/storage/storage.cpp +32 -32
- data/ext/nmatrix/storage/storage.h +12 -11
- data/ext/nmatrix/storage/yale/class.h +2 -2
- data/ext/nmatrix/storage/yale/iterators/base.h +2 -1
- data/ext/nmatrix/storage/yale/iterators/iterator.h +2 -1
- data/ext/nmatrix/storage/yale/iterators/row.h +2 -1
- data/ext/nmatrix/storage/yale/iterators/row_stored.h +2 -1
- data/ext/nmatrix/storage/yale/iterators/row_stored_nd.h +1 -0
- data/ext/nmatrix/storage/yale/iterators/stored_diagonal.h +2 -1
- data/ext/nmatrix/storage/yale/yale.cpp +27 -27
- data/ext/nmatrix/storage/yale/yale.h +7 -6
- data/ext/nmatrix/ttable_helper.rb +10 -10
- data/ext/nmatrix/types.h +3 -2
- data/ext/nmatrix/util/io.cpp +7 -7
- data/ext/nmatrix/util/sl_list.cpp +26 -26
- data/ext/nmatrix/util/sl_list.h +19 -18
- data/lib/nmatrix/blas.rb +7 -7
- data/lib/nmatrix/io/mat5_reader.rb +30 -30
- data/lib/nmatrix/math.rb +73 -17
- data/lib/nmatrix/nmatrix.rb +10 -8
- data/lib/nmatrix/shortcuts.rb +3 -3
- data/lib/nmatrix/version.rb +3 -3
- data/spec/00_nmatrix_spec.rb +6 -0
- data/spec/math_spec.rb +77 -0
- data/spec/spec_helper.rb +9 -0
- metadata +2 -2
@@ -119,10 +119,10 @@ void mark(LIST* list, size_t recursions) {
|
|
119
119
|
next = curr->next;
|
120
120
|
|
121
121
|
if (recursions == 0) {
|
122
|
-
|
123
|
-
|
122
|
+
rb_gc_mark(*((VALUE*)(curr->val)));
|
123
|
+
|
124
124
|
} else {
|
125
|
-
|
125
|
+
mark((LIST*)curr->val, recursions - 1);
|
126
126
|
}
|
127
127
|
|
128
128
|
curr = next;
|
@@ -174,8 +174,8 @@ NODE* insert(LIST* list, bool replace, size_t key, void* val) {
|
|
174
174
|
NODE *ins;
|
175
175
|
|
176
176
|
if (list->first == NULL) {
|
177
|
-
|
178
|
-
|
177
|
+
// List is empty
|
178
|
+
|
179
179
|
//if (!(ins = malloc(sizeof(NODE)))) return NULL;
|
180
180
|
ins = NM_ALLOC(NODE);
|
181
181
|
ins->next = NULL;
|
@@ -186,8 +186,8 @@ NODE* insert(LIST* list, bool replace, size_t key, void* val) {
|
|
186
186
|
return ins;
|
187
187
|
|
188
188
|
} else if (key < list->first->key) {
|
189
|
-
|
190
|
-
|
189
|
+
// Goes at the beginning of the list
|
190
|
+
|
191
191
|
//if (!(ins = malloc(sizeof(NODE)))) return NULL;
|
192
192
|
ins = NM_ALLOC(NODE);
|
193
193
|
ins->next = list->first;
|
@@ -214,7 +214,7 @@ NODE* insert(LIST* list, bool replace, size_t key, void* val) {
|
|
214
214
|
return ins;
|
215
215
|
|
216
216
|
} else {
|
217
|
-
|
217
|
+
return insert_after(ins, key, val);
|
218
218
|
}
|
219
219
|
}
|
220
220
|
|
@@ -305,7 +305,7 @@ void* remove_by_key(LIST* list, size_t key) {
|
|
305
305
|
void* val;
|
306
306
|
|
307
307
|
if (!list->first || list->first->key > key) { // empty list or def. not present
|
308
|
-
|
308
|
+
return NULL;
|
309
309
|
}
|
310
310
|
|
311
311
|
if (list->first->key == key) {
|
@@ -320,7 +320,7 @@ void* remove_by_key(LIST* list, size_t key) {
|
|
320
320
|
|
321
321
|
f = find_preceding_from_node(list->first, key);
|
322
322
|
if (!f || !f->next) { // not found, end of list
|
323
|
-
|
323
|
+
return NULL;
|
324
324
|
}
|
325
325
|
|
326
326
|
if (f->next->key == key) {
|
@@ -411,15 +411,15 @@ bool remove_recursive(LIST* list, const size_t* coords, const size_t* offsets, c
|
|
411
411
|
NODE* find(LIST* list, size_t key) {
|
412
412
|
NODE* f;
|
413
413
|
if (!list->first) {
|
414
|
-
|
415
|
-
|
414
|
+
// empty list -- does not exist
|
415
|
+
return NULL;
|
416
416
|
}
|
417
417
|
|
418
418
|
// see if we can find it.
|
419
419
|
f = find_nearest_from(list->first, key);
|
420
420
|
|
421
421
|
if (!f || f->key == key) {
|
422
|
-
|
422
|
+
return f;
|
423
423
|
}
|
424
424
|
|
425
425
|
return NULL;
|
@@ -458,10 +458,10 @@ NODE* find_preceding_from_node(NODE* prev, size_t key) {
|
|
458
458
|
NODE* curr = prev->next;
|
459
459
|
|
460
460
|
if (!curr || key <= curr->key) {
|
461
|
-
|
462
|
-
|
461
|
+
return prev;
|
462
|
+
|
463
463
|
} else {
|
464
|
-
|
464
|
+
return find_preceding_from_node(curr, key);
|
465
465
|
}
|
466
466
|
}
|
467
467
|
|
@@ -491,19 +491,19 @@ NODE* find_nearest_from(NODE* prev, size_t key) {
|
|
491
491
|
NODE* f;
|
492
492
|
|
493
493
|
if (prev && prev->key == key) {
|
494
|
-
|
494
|
+
return prev;
|
495
495
|
}
|
496
496
|
|
497
497
|
f = find_preceding_from_node(prev, key);
|
498
498
|
|
499
499
|
if (!f->next) { // key exceeds final node; return final node.
|
500
|
-
|
501
|
-
|
500
|
+
return f;
|
501
|
+
|
502
502
|
} else if (key == f->next->key) { // node already present; return location
|
503
|
-
|
503
|
+
return f->next;
|
504
504
|
|
505
505
|
} else {
|
506
|
-
|
506
|
+
return f;
|
507
507
|
}
|
508
508
|
}
|
509
509
|
|
@@ -528,14 +528,14 @@ void cast_copy_contents(LIST* lhs, const LIST* rhs, size_t recursions) {
|
|
528
528
|
lcurr->key = rcurr->key;
|
529
529
|
|
530
530
|
if (recursions == 0) {
|
531
|
-
|
531
|
+
// contents is some kind of value
|
532
532
|
|
533
533
|
lcurr->val = NM_ALLOC( LDType );
|
534
534
|
|
535
535
|
*reinterpret_cast<LDType*>(lcurr->val) = *reinterpret_cast<RDType*>( rcurr->val );
|
536
536
|
|
537
537
|
} else {
|
538
|
-
|
538
|
+
// contents is a list
|
539
539
|
|
540
540
|
lcurr->val = NM_ALLOC( LIST );
|
541
541
|
|
@@ -547,10 +547,10 @@ void cast_copy_contents(LIST* lhs, const LIST* rhs, size_t recursions) {
|
|
547
547
|
}
|
548
548
|
|
549
549
|
if (rcurr->next) {
|
550
|
-
|
550
|
+
lcurr->next = NM_ALLOC( NODE );
|
551
551
|
|
552
552
|
} else {
|
553
|
-
|
553
|
+
lcurr->next = NULL;
|
554
554
|
}
|
555
555
|
|
556
556
|
lcurr = lcurr->next;
|
@@ -608,7 +608,7 @@ extern "C" {
|
|
608
608
|
size_t key = curr->key;
|
609
609
|
|
610
610
|
if (recursions == 0) { // content is some kind of value
|
611
|
-
rb_hash_aset(h, INT2FIX(key), rubyobj_from_cval(curr->val, dtype).rval);
|
611
|
+
rb_hash_aset(h, INT2FIX(key), nm::rubyobj_from_cval(curr->val, dtype).rval);
|
612
612
|
} else { // content is a list
|
613
613
|
rb_hash_aset(h, INT2FIX(key), nm_list_copy_to_hash(reinterpret_cast<const LIST*>(curr->val), dtype, recursions-1, default_value));
|
614
614
|
}
|
data/ext/nmatrix/util/sl_list.h
CHANGED
@@ -33,6 +33,7 @@
|
|
33
33
|
* Standard Includes
|
34
34
|
*/
|
35
35
|
|
36
|
+
#include <ruby.h>
|
36
37
|
#include <type_traits>
|
37
38
|
#include <cstdlib>
|
38
39
|
|
@@ -69,9 +70,9 @@ namespace nm { namespace list {
|
|
69
70
|
// Lifecycle //
|
70
71
|
///////////////
|
71
72
|
|
72
|
-
LIST*
|
73
|
-
void
|
74
|
-
void
|
73
|
+
LIST* create(void);
|
74
|
+
void del(LIST* list, size_t recursions);
|
75
|
+
void mark(LIST* list, size_t recursions);
|
75
76
|
|
76
77
|
///////////////
|
77
78
|
// Accessors //
|
@@ -90,25 +91,25 @@ bool node_is_within_slice(NODE* n, size_t coord, size_t len);
|
|
90
91
|
|
91
92
|
template <typename Type>
|
92
93
|
inline NODE* insert_helper(LIST* list, NODE* node, size_t key, Type val) {
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
94
|
+
Type* val_mem = NM_ALLOC(Type);
|
95
|
+
*val_mem = val;
|
96
|
+
|
97
|
+
if (node == NULL) {
|
98
|
+
return insert(list, false, key, val_mem);
|
99
|
+
|
100
|
+
} else {
|
101
|
+
return insert_after(node, key, val_mem);
|
102
|
+
}
|
102
103
|
}
|
103
104
|
|
104
105
|
template <typename Type>
|
105
106
|
inline NODE* insert_helper(LIST* list, NODE* node, size_t key, Type* ptr) {
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
107
|
+
if (node == NULL) {
|
108
|
+
return insert(list, false, key, ptr);
|
109
|
+
|
110
|
+
} else {
|
111
|
+
return insert_after(node, key, ptr);
|
112
|
+
}
|
112
113
|
}
|
113
114
|
|
114
115
|
///////////
|
data/lib/nmatrix/blas.rb
CHANGED
@@ -71,7 +71,7 @@ module NMatrix::BLAS
|
|
71
71
|
def gemm(a, b, c = nil, alpha = 1.0, beta = 0.0, transpose_a = false, transpose_b = false, m = nil, n = nil, k = nil, lda = nil, ldb = nil, ldc = nil)
|
72
72
|
raise(ArgumentError, 'Expected dense NMatrices as first two arguments.') unless a.is_a?(NMatrix) and b.is_a?(NMatrix) and a.stype == :dense and b.stype == :dense
|
73
73
|
raise(ArgumentError, 'Expected nil or dense NMatrix as third argument.') unless c.nil? or (c.is_a?(NMatrix) and c.stype == :dense)
|
74
|
-
raise(ArgumentError, 'NMatrix dtype mismatch.')
|
74
|
+
raise(ArgumentError, 'NMatrix dtype mismatch.') unless a.dtype == b.dtype and (c ? a.dtype == c.dtype : true)
|
75
75
|
|
76
76
|
# First, set m, n, and k, which depend on whether we're taking the
|
77
77
|
# transpose of a and b.
|
@@ -93,7 +93,7 @@ module NMatrix::BLAS
|
|
93
93
|
end
|
94
94
|
|
95
95
|
n ||= transpose_b ? b.shape[0] : b.shape[1]
|
96
|
-
c
|
96
|
+
c = NMatrix.new([m, n], dtype: a.dtype)
|
97
97
|
end
|
98
98
|
|
99
99
|
# I think these are independent of whether or not a transpose occurs.
|
@@ -143,7 +143,7 @@ module NMatrix::BLAS
|
|
143
143
|
def gemv(a, x, y = nil, alpha = 1.0, beta = 0.0, transpose_a = false, m = nil, n = nil, lda = nil, incx = nil, incy = nil)
|
144
144
|
raise(ArgumentError, 'Expected dense NMatrices as first two arguments.') unless a.is_a?(NMatrix) and x.is_a?(NMatrix) and a.stype == :dense and x.stype == :dense
|
145
145
|
raise(ArgumentError, 'Expected nil or dense NMatrix as third argument.') unless y.nil? or (y.is_a?(NMatrix) and y.stype == :dense)
|
146
|
-
raise(ArgumentError, 'NMatrix dtype mismatch.')
|
146
|
+
raise(ArgumentError, 'NMatrix dtype mismatch.') unless a.dtype == x.dtype and (y ? a.dtype == y.dtype : true)
|
147
147
|
|
148
148
|
m ||= transpose_a == :transpose ? a.shape[1] : a.shape[0]
|
149
149
|
n ||= transpose_a == :transpose ? a.shape[0] : a.shape[1]
|
@@ -155,9 +155,9 @@ module NMatrix::BLAS
|
|
155
155
|
y = NMatrix.new([m,1], dtype: a.dtype)
|
156
156
|
end
|
157
157
|
|
158
|
-
lda
|
159
|
-
incx
|
160
|
-
incy
|
158
|
+
lda ||= a.shape[1]
|
159
|
+
incx ||= 1
|
160
|
+
incy ||= 1
|
161
161
|
|
162
162
|
::NMatrix::BLAS.cblas_gemv(transpose_a, m, n, alpha, a, lda, x, incx, beta, y, incy)
|
163
163
|
|
@@ -188,7 +188,7 @@ module NMatrix::BLAS
|
|
188
188
|
#
|
189
189
|
def rot(x, y, c, s, incx = 1, incy = 1, n = nil, in_place=false)
|
190
190
|
raise(ArgumentError, 'Expected dense NMatrices as first two arguments.') unless x.is_a?(NMatrix) and y.is_a?(NMatrix) and x.stype == :dense and y.stype == :dense
|
191
|
-
raise(ArgumentError, 'NMatrix dtype mismatch.')
|
191
|
+
raise(ArgumentError, 'NMatrix dtype mismatch.') unless x.dtype == y.dtype
|
192
192
|
raise(ArgumentError, 'Need to supply n for non-standard incx, incy values') if n.nil? && incx != 1 && incx != -1 && incy != 1 && incy != -1
|
193
193
|
|
194
194
|
n ||= [x.size/incx.abs, y.size/incy.abs].min
|
@@ -41,8 +41,8 @@ module NMatrix::IO::Matlab
|
|
41
41
|
attr_reader :byte_order
|
42
42
|
|
43
43
|
def initialize(stream = nil, byte_order = nil, content_or_bytes = nil)
|
44
|
-
@stream
|
45
|
-
@byte_order
|
44
|
+
@stream = stream
|
45
|
+
@byte_order = byte_order
|
46
46
|
|
47
47
|
if content_or_bytes.is_a?(String)
|
48
48
|
@content = content_or_bytes
|
@@ -118,9 +118,9 @@ module NMatrix::IO::Matlab
|
|
118
118
|
# See also to_nm, which is responsible for NMatrix instantiation.
|
119
119
|
def to_ruby
|
120
120
|
case matlab_class
|
121
|
-
when :mxSPARSE
|
122
|
-
when :mxCELL
|
123
|
-
else
|
121
|
+
when :mxSPARSE then return to_nm
|
122
|
+
when :mxCELL then return self.cells.collect { |c| c.to_ruby }
|
123
|
+
else return to_nm
|
124
124
|
end
|
125
125
|
end
|
126
126
|
|
@@ -262,8 +262,8 @@ module NMatrix::IO::Matlab
|
|
262
262
|
#
|
263
263
|
def to_nm(dtype = nil)
|
264
264
|
# Hardest part is figuring out from_dtype, from_index_dtype, and dtype.
|
265
|
-
dtype
|
266
|
-
from_dtype
|
265
|
+
dtype ||= guess_dtype_from_mdtype
|
266
|
+
from_dtype = MatReader::MDTYPE_TO_DTYPE[self.real_part.tag.data_type]
|
267
267
|
|
268
268
|
# Create the same kind of matrix that MATLAB saved.
|
269
269
|
case matlab_class
|
@@ -297,8 +297,8 @@ module NMatrix::IO::Matlab
|
|
297
297
|
self.dimensions = dimensions_tag_data.data
|
298
298
|
|
299
299
|
begin
|
300
|
-
name_tag_data
|
301
|
-
self.matlab_name
|
300
|
+
name_tag_data = packedio.read([Element, options])
|
301
|
+
self.matlab_name = name_tag_data.data.is_a?(Array) ? name_tag_data.data.collect { |i| i.chr }.join('') : name_tag_data.data.chr
|
302
302
|
|
303
303
|
rescue ElementDataIOError => e
|
304
304
|
STDERR.puts "ERROR: Failure while trying to read Matlab variable name: #{name_tag_data.inspect}"
|
@@ -325,8 +325,8 @@ module NMatrix::IO::Matlab
|
|
325
325
|
self.row_index = packedio.read(read_opts)
|
326
326
|
end
|
327
327
|
|
328
|
-
self.real_part
|
329
|
-
self.imaginary_part
|
328
|
+
self.real_part = packedio.read(read_opts)
|
329
|
+
self.imaginary_part = packedio.read(read_opts) if self.complex
|
330
330
|
end
|
331
331
|
end
|
332
332
|
|
@@ -338,8 +338,8 @@ module NMatrix::IO::Matlab
|
|
338
338
|
|
339
339
|
MDTYPE_UNPACK_ARGS =
|
340
340
|
MatReader::MDTYPE_UNPACK_ARGS.merge({
|
341
|
-
:miCOMPRESSED
|
342
|
-
:miMATRIX
|
341
|
+
:miCOMPRESSED => [Compressed, {}],
|
342
|
+
:miMATRIX => [MatrixData, {}]
|
343
343
|
})
|
344
344
|
|
345
345
|
FIRST_TAG_FIELD_POS = 128
|
@@ -407,26 +407,26 @@ module NMatrix::IO::Matlab
|
|
407
407
|
|
408
408
|
include Packable
|
409
409
|
|
410
|
-
BYTE_ORDER_LENGTH
|
411
|
-
DESC_LENGTH
|
412
|
-
DATA_OFFSET_LENGTH
|
413
|
-
VERSION_LENGTH
|
414
|
-
BYTE_ORDER_POS
|
410
|
+
BYTE_ORDER_LENGTH = 2
|
411
|
+
DESC_LENGTH = 116
|
412
|
+
DATA_OFFSET_LENGTH = 8
|
413
|
+
VERSION_LENGTH = 2
|
414
|
+
BYTE_ORDER_POS = 126
|
415
415
|
|
416
416
|
# TODO: TEST WRITE.
|
417
417
|
def write_packed(packedio, options)
|
418
|
-
packedio <<
|
419
|
-
[data_offset,
|
420
|
-
[version,
|
421
|
-
[byte_order,
|
418
|
+
packedio << [desc, {:bytes => DESC_LENGTH }] <<
|
419
|
+
[data_offset, {:bytes => DATA_OFFSET_LENGTH }] <<
|
420
|
+
[version, {:bytes => VERSION_LENGTH }] <<
|
421
|
+
[byte_order, {:bytes => BYTE_ORDER_LENGTH }]
|
422
422
|
end
|
423
423
|
|
424
424
|
def read_packed(packedio, options)
|
425
425
|
self.desc, self.data_offset, self.version, self.endian = packedio >>
|
426
|
-
[String,
|
427
|
-
[String,
|
428
|
-
[Integer,
|
429
|
-
[String,
|
426
|
+
[String, {:bytes => DESC_LENGTH }] >>
|
427
|
+
[String, {:bytes => DATA_OFFSET_LENGTH }] >>
|
428
|
+
[Integer, {:bytes => VERSION_LENGTH, :endian => options[:endian] }] >>
|
429
|
+
[String, {:bytes => 2 }]
|
430
430
|
|
431
431
|
self.desc.strip!
|
432
432
|
self.data_offset.strip!
|
@@ -466,8 +466,8 @@ module NMatrix::IO::Matlab
|
|
466
466
|
# Small data element format
|
467
467
|
raise IOError, 'Small data element format indicated, but length is more than 4 bytes!' if upper > 4
|
468
468
|
|
469
|
-
self.bytes
|
470
|
-
self.raw_data_type
|
469
|
+
self.bytes = upper
|
470
|
+
self.raw_data_type = lower
|
471
471
|
|
472
472
|
else
|
473
473
|
self.bytes = packedio.read([Integer, BYTES_OPTS.merge(options)])
|
@@ -560,8 +560,8 @@ module NMatrix::IO::Matlab
|
|
560
560
|
def read_packed(packedio, options)
|
561
561
|
raise(ArgumentError, 'Missing mandatory option :endian.') unless options.has_key?(:endian)
|
562
562
|
|
563
|
-
self.tag
|
564
|
-
self.data
|
563
|
+
self.tag = packedio.read([Tag, {:endian => options[:endian] }])
|
564
|
+
self.data = packedio.read([String, {:endian => options[:endian], :bytes => tag.bytes }])
|
565
565
|
|
566
566
|
begin
|
567
567
|
ignore_padding(packedio, (tag.bytes + tag.size) % 8) unless [:miMATRIX, :miCOMPRESSED].include?(tag.data_type)
|
data/lib/nmatrix/math.rb
CHANGED
@@ -258,36 +258,92 @@ class NMatrix
|
|
258
258
|
|
259
259
|
# Solve the matrix equation AX = B, where A is +self+, B is the first
|
260
260
|
# argument, and X is returned. A must be a nxn square matrix, while B must be
|
261
|
-
# nxm. Only works with dense
|
262
|
-
# matrices and non-integer, non-object data types.
|
261
|
+
# nxm. Only works with dense matrices and non-integer, non-object data types.
|
263
262
|
#
|
263
|
+
# == Arguments
|
264
|
+
#
|
265
|
+
# * +b+ - the right hand side
|
266
|
+
#
|
267
|
+
# == Options
|
268
|
+
#
|
269
|
+
# * +form+ - Signifies the form of the matrix A in the linear system AX=B.
|
270
|
+
# If not set then it defaults to +:general+, which uses an LU solver.
|
271
|
+
# Other possible values are +:lower_tri+, +:upper_tri+ and +:pos_def+ (alternatively,
|
272
|
+
# non-abbreviated symbols +:lower_triangular+, +:upper_triangular+,
|
273
|
+
# and +:positive_definite+ can be used.
|
274
|
+
# If +:lower_tri+ or +:upper_tri+ is set, then a specialized linear solver for linear
|
275
|
+
# systems AX=B with a lower or upper triangular matrix A is used. If +:pos_def+ is chosen,
|
276
|
+
# then the linear system is solved via the Cholesky factorization.
|
277
|
+
# Note that when +:lower_tri+ or +:upper_tri+ is used, then the algorithm just assumes that
|
278
|
+
# all entries in the lower/upper triangle of the matrix are zeros without checking (which
|
279
|
+
# can be useful in certain applications).
|
280
|
+
#
|
281
|
+
#
|
264
282
|
# == Usage
|
265
283
|
#
|
266
284
|
# a = NMatrix.new [2,2], [3,1,1,2], dtype: dtype
|
267
285
|
# b = NMatrix.new [2,1], [9,8], dtype: dtype
|
268
286
|
# a.solve(b)
|
269
|
-
|
287
|
+
#
|
288
|
+
# # solve an upper triangular linear system more efficiently:
|
289
|
+
# require 'benchmark'
|
290
|
+
# require 'nmatrix/lapacke'
|
291
|
+
# rand_mat = NMatrix.random([10000, 10000], dtype: :float64)
|
292
|
+
# a = rand_mat.triu
|
293
|
+
# b = NMatrix.random([10000, 10], dtype: :float64)
|
294
|
+
# Benchmark.bm(10) do |bm|
|
295
|
+
# bm.report('general') { a.solve(b) }
|
296
|
+
# bm.report('upper_tri') { a.solve(b, form: :upper_tri) }
|
297
|
+
# end
|
298
|
+
# # user system total real
|
299
|
+
# # general 73.170000 0.670000 73.840000 ( 73.810086)
|
300
|
+
# # upper_tri 0.180000 0.000000 0.180000 ( 0.182491)
|
301
|
+
#
|
302
|
+
def solve(b, opts = {})
|
270
303
|
raise(ShapeError, "Must be called on square matrix") unless self.dim == 2 && self.shape[0] == self.shape[1]
|
271
304
|
raise(ShapeError, "number of rows of b must equal number of cols of self") if
|
272
305
|
self.shape[1] != b.shape[0]
|
273
|
-
raise
|
274
|
-
raise
|
306
|
+
raise(ArgumentError, "only works with dense matrices") if self.stype != :dense
|
307
|
+
raise(ArgumentError, "only works for non-integer, non-object dtypes") if
|
275
308
|
integer_dtype? or object_dtype? or b.integer_dtype? or b.object_dtype?
|
276
309
|
|
277
|
-
|
278
|
-
|
279
|
-
n
|
310
|
+
opts = { form: :general }.merge(opts)
|
311
|
+
x = b.clone
|
312
|
+
n = self.shape[0]
|
280
313
|
nrhs = b.shape[1]
|
281
314
|
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
315
|
+
case opts[:form]
|
316
|
+
when :general
|
317
|
+
clone = self.clone
|
318
|
+
ipiv = NMatrix::LAPACK.clapack_getrf(:row, n, n, clone, n)
|
319
|
+
# When we call clapack_getrs with :row, actually only the first matrix
|
320
|
+
# (i.e. clone) is interpreted as row-major, while the other matrix (x)
|
321
|
+
# is interpreted as column-major. See here: http://math-atlas.sourceforge.net/faq.html#RowSolve
|
322
|
+
# So we must transpose x before and after
|
323
|
+
# calling it.
|
324
|
+
x = x.transpose
|
325
|
+
NMatrix::LAPACK.clapack_getrs(:row, :no_transpose, n, nrhs, clone, n, ipiv, x, n)
|
326
|
+
x.transpose
|
327
|
+
when :upper_tri, :upper_triangular
|
328
|
+
raise(ArgumentError, "upper triangular solver does not work with complex dtypes") if
|
329
|
+
complex_dtype? or b.complex_dtype?
|
330
|
+
# this is the correct function call; see https://github.com/SciRuby/nmatrix/issues/374
|
331
|
+
NMatrix::BLAS::cblas_trsm(:row, :left, :upper, false, :nounit, n, nrhs, 1.0, self, n, x, nrhs)
|
332
|
+
x
|
333
|
+
when :lower_tri, :lower_triangular
|
334
|
+
raise(ArgumentError, "lower triangular solver does not work with complex dtypes") if
|
335
|
+
complex_dtype? or b.complex_dtype?
|
336
|
+
# this is a workaround; see https://github.com/SciRuby/nmatrix/issues/422
|
337
|
+
x = x.transpose
|
338
|
+
NMatrix::BLAS::cblas_trsm(:row, :right, :lower, :transpose, :nounit, nrhs, n, 1.0, self, n, x, n)
|
339
|
+
x.transpose
|
340
|
+
when :pos_def, :positive_definite
|
341
|
+
u, l = self.factorize_cholesky
|
342
|
+
z = l.solve(b, form: :lower_tri)
|
343
|
+
u.solve(z, form: :upper_tri)
|
344
|
+
else
|
345
|
+
raise(ArgumentError, "#{opts[:form]} is not a valid form option")
|
346
|
+
end
|
291
347
|
end
|
292
348
|
|
293
349
|
#
|