nmatrix 0.2.0 → 0.2.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/ext/nmatrix/data/complex.h +183 -159
- data/ext/nmatrix/data/data.cpp +113 -112
- data/ext/nmatrix/data/data.h +306 -292
- data/ext/nmatrix/data/ruby_object.h +193 -193
- data/ext/nmatrix/extconf.rb +11 -9
- data/ext/nmatrix/math.cpp +9 -11
- data/ext/nmatrix/math/math.h +3 -2
- data/ext/nmatrix/math/trsm.h +152 -152
- data/ext/nmatrix/nmatrix.h +30 -0
- data/ext/nmatrix/ruby_constants.cpp +67 -67
- data/ext/nmatrix/ruby_constants.h +35 -35
- data/ext/nmatrix/ruby_nmatrix.c +168 -183
- data/ext/nmatrix/storage/common.h +4 -3
- data/ext/nmatrix/storage/dense/dense.cpp +50 -50
- data/ext/nmatrix/storage/dense/dense.h +8 -7
- data/ext/nmatrix/storage/list/list.cpp +16 -16
- data/ext/nmatrix/storage/list/list.h +7 -6
- data/ext/nmatrix/storage/storage.cpp +32 -32
- data/ext/nmatrix/storage/storage.h +12 -11
- data/ext/nmatrix/storage/yale/class.h +2 -2
- data/ext/nmatrix/storage/yale/iterators/base.h +2 -1
- data/ext/nmatrix/storage/yale/iterators/iterator.h +2 -1
- data/ext/nmatrix/storage/yale/iterators/row.h +2 -1
- data/ext/nmatrix/storage/yale/iterators/row_stored.h +2 -1
- data/ext/nmatrix/storage/yale/iterators/row_stored_nd.h +1 -0
- data/ext/nmatrix/storage/yale/iterators/stored_diagonal.h +2 -1
- data/ext/nmatrix/storage/yale/yale.cpp +27 -27
- data/ext/nmatrix/storage/yale/yale.h +7 -6
- data/ext/nmatrix/ttable_helper.rb +10 -10
- data/ext/nmatrix/types.h +3 -2
- data/ext/nmatrix/util/io.cpp +7 -7
- data/ext/nmatrix/util/sl_list.cpp +26 -26
- data/ext/nmatrix/util/sl_list.h +19 -18
- data/lib/nmatrix/blas.rb +7 -7
- data/lib/nmatrix/io/mat5_reader.rb +30 -30
- data/lib/nmatrix/math.rb +73 -17
- data/lib/nmatrix/nmatrix.rb +10 -8
- data/lib/nmatrix/shortcuts.rb +3 -3
- data/lib/nmatrix/version.rb +3 -3
- data/spec/00_nmatrix_spec.rb +6 -0
- data/spec/math_spec.rb +77 -0
- data/spec/spec_helper.rb +9 -0
- metadata +2 -2
@@ -119,10 +119,10 @@ void mark(LIST* list, size_t recursions) {
|
|
119
119
|
next = curr->next;
|
120
120
|
|
121
121
|
if (recursions == 0) {
|
122
|
-
|
123
|
-
|
122
|
+
rb_gc_mark(*((VALUE*)(curr->val)));
|
123
|
+
|
124
124
|
} else {
|
125
|
-
|
125
|
+
mark((LIST*)curr->val, recursions - 1);
|
126
126
|
}
|
127
127
|
|
128
128
|
curr = next;
|
@@ -174,8 +174,8 @@ NODE* insert(LIST* list, bool replace, size_t key, void* val) {
|
|
174
174
|
NODE *ins;
|
175
175
|
|
176
176
|
if (list->first == NULL) {
|
177
|
-
|
178
|
-
|
177
|
+
// List is empty
|
178
|
+
|
179
179
|
//if (!(ins = malloc(sizeof(NODE)))) return NULL;
|
180
180
|
ins = NM_ALLOC(NODE);
|
181
181
|
ins->next = NULL;
|
@@ -186,8 +186,8 @@ NODE* insert(LIST* list, bool replace, size_t key, void* val) {
|
|
186
186
|
return ins;
|
187
187
|
|
188
188
|
} else if (key < list->first->key) {
|
189
|
-
|
190
|
-
|
189
|
+
// Goes at the beginning of the list
|
190
|
+
|
191
191
|
//if (!(ins = malloc(sizeof(NODE)))) return NULL;
|
192
192
|
ins = NM_ALLOC(NODE);
|
193
193
|
ins->next = list->first;
|
@@ -214,7 +214,7 @@ NODE* insert(LIST* list, bool replace, size_t key, void* val) {
|
|
214
214
|
return ins;
|
215
215
|
|
216
216
|
} else {
|
217
|
-
|
217
|
+
return insert_after(ins, key, val);
|
218
218
|
}
|
219
219
|
}
|
220
220
|
|
@@ -305,7 +305,7 @@ void* remove_by_key(LIST* list, size_t key) {
|
|
305
305
|
void* val;
|
306
306
|
|
307
307
|
if (!list->first || list->first->key > key) { // empty list or def. not present
|
308
|
-
|
308
|
+
return NULL;
|
309
309
|
}
|
310
310
|
|
311
311
|
if (list->first->key == key) {
|
@@ -320,7 +320,7 @@ void* remove_by_key(LIST* list, size_t key) {
|
|
320
320
|
|
321
321
|
f = find_preceding_from_node(list->first, key);
|
322
322
|
if (!f || !f->next) { // not found, end of list
|
323
|
-
|
323
|
+
return NULL;
|
324
324
|
}
|
325
325
|
|
326
326
|
if (f->next->key == key) {
|
@@ -411,15 +411,15 @@ bool remove_recursive(LIST* list, const size_t* coords, const size_t* offsets, c
|
|
411
411
|
NODE* find(LIST* list, size_t key) {
|
412
412
|
NODE* f;
|
413
413
|
if (!list->first) {
|
414
|
-
|
415
|
-
|
414
|
+
// empty list -- does not exist
|
415
|
+
return NULL;
|
416
416
|
}
|
417
417
|
|
418
418
|
// see if we can find it.
|
419
419
|
f = find_nearest_from(list->first, key);
|
420
420
|
|
421
421
|
if (!f || f->key == key) {
|
422
|
-
|
422
|
+
return f;
|
423
423
|
}
|
424
424
|
|
425
425
|
return NULL;
|
@@ -458,10 +458,10 @@ NODE* find_preceding_from_node(NODE* prev, size_t key) {
|
|
458
458
|
NODE* curr = prev->next;
|
459
459
|
|
460
460
|
if (!curr || key <= curr->key) {
|
461
|
-
|
462
|
-
|
461
|
+
return prev;
|
462
|
+
|
463
463
|
} else {
|
464
|
-
|
464
|
+
return find_preceding_from_node(curr, key);
|
465
465
|
}
|
466
466
|
}
|
467
467
|
|
@@ -491,19 +491,19 @@ NODE* find_nearest_from(NODE* prev, size_t key) {
|
|
491
491
|
NODE* f;
|
492
492
|
|
493
493
|
if (prev && prev->key == key) {
|
494
|
-
|
494
|
+
return prev;
|
495
495
|
}
|
496
496
|
|
497
497
|
f = find_preceding_from_node(prev, key);
|
498
498
|
|
499
499
|
if (!f->next) { // key exceeds final node; return final node.
|
500
|
-
|
501
|
-
|
500
|
+
return f;
|
501
|
+
|
502
502
|
} else if (key == f->next->key) { // node already present; return location
|
503
|
-
|
503
|
+
return f->next;
|
504
504
|
|
505
505
|
} else {
|
506
|
-
|
506
|
+
return f;
|
507
507
|
}
|
508
508
|
}
|
509
509
|
|
@@ -528,14 +528,14 @@ void cast_copy_contents(LIST* lhs, const LIST* rhs, size_t recursions) {
|
|
528
528
|
lcurr->key = rcurr->key;
|
529
529
|
|
530
530
|
if (recursions == 0) {
|
531
|
-
|
531
|
+
// contents is some kind of value
|
532
532
|
|
533
533
|
lcurr->val = NM_ALLOC( LDType );
|
534
534
|
|
535
535
|
*reinterpret_cast<LDType*>(lcurr->val) = *reinterpret_cast<RDType*>( rcurr->val );
|
536
536
|
|
537
537
|
} else {
|
538
|
-
|
538
|
+
// contents is a list
|
539
539
|
|
540
540
|
lcurr->val = NM_ALLOC( LIST );
|
541
541
|
|
@@ -547,10 +547,10 @@ void cast_copy_contents(LIST* lhs, const LIST* rhs, size_t recursions) {
|
|
547
547
|
}
|
548
548
|
|
549
549
|
if (rcurr->next) {
|
550
|
-
|
550
|
+
lcurr->next = NM_ALLOC( NODE );
|
551
551
|
|
552
552
|
} else {
|
553
|
-
|
553
|
+
lcurr->next = NULL;
|
554
554
|
}
|
555
555
|
|
556
556
|
lcurr = lcurr->next;
|
@@ -608,7 +608,7 @@ extern "C" {
|
|
608
608
|
size_t key = curr->key;
|
609
609
|
|
610
610
|
if (recursions == 0) { // content is some kind of value
|
611
|
-
rb_hash_aset(h, INT2FIX(key), rubyobj_from_cval(curr->val, dtype).rval);
|
611
|
+
rb_hash_aset(h, INT2FIX(key), nm::rubyobj_from_cval(curr->val, dtype).rval);
|
612
612
|
} else { // content is a list
|
613
613
|
rb_hash_aset(h, INT2FIX(key), nm_list_copy_to_hash(reinterpret_cast<const LIST*>(curr->val), dtype, recursions-1, default_value));
|
614
614
|
}
|
data/ext/nmatrix/util/sl_list.h
CHANGED
@@ -33,6 +33,7 @@
|
|
33
33
|
* Standard Includes
|
34
34
|
*/
|
35
35
|
|
36
|
+
#include <ruby.h>
|
36
37
|
#include <type_traits>
|
37
38
|
#include <cstdlib>
|
38
39
|
|
@@ -69,9 +70,9 @@ namespace nm { namespace list {
|
|
69
70
|
// Lifecycle //
|
70
71
|
///////////////
|
71
72
|
|
72
|
-
LIST*
|
73
|
-
void
|
74
|
-
void
|
73
|
+
LIST* create(void);
|
74
|
+
void del(LIST* list, size_t recursions);
|
75
|
+
void mark(LIST* list, size_t recursions);
|
75
76
|
|
76
77
|
///////////////
|
77
78
|
// Accessors //
|
@@ -90,25 +91,25 @@ bool node_is_within_slice(NODE* n, size_t coord, size_t len);
|
|
90
91
|
|
91
92
|
template <typename Type>
|
92
93
|
inline NODE* insert_helper(LIST* list, NODE* node, size_t key, Type val) {
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
94
|
+
Type* val_mem = NM_ALLOC(Type);
|
95
|
+
*val_mem = val;
|
96
|
+
|
97
|
+
if (node == NULL) {
|
98
|
+
return insert(list, false, key, val_mem);
|
99
|
+
|
100
|
+
} else {
|
101
|
+
return insert_after(node, key, val_mem);
|
102
|
+
}
|
102
103
|
}
|
103
104
|
|
104
105
|
template <typename Type>
|
105
106
|
inline NODE* insert_helper(LIST* list, NODE* node, size_t key, Type* ptr) {
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
107
|
+
if (node == NULL) {
|
108
|
+
return insert(list, false, key, ptr);
|
109
|
+
|
110
|
+
} else {
|
111
|
+
return insert_after(node, key, ptr);
|
112
|
+
}
|
112
113
|
}
|
113
114
|
|
114
115
|
///////////
|
data/lib/nmatrix/blas.rb
CHANGED
@@ -71,7 +71,7 @@ module NMatrix::BLAS
|
|
71
71
|
def gemm(a, b, c = nil, alpha = 1.0, beta = 0.0, transpose_a = false, transpose_b = false, m = nil, n = nil, k = nil, lda = nil, ldb = nil, ldc = nil)
|
72
72
|
raise(ArgumentError, 'Expected dense NMatrices as first two arguments.') unless a.is_a?(NMatrix) and b.is_a?(NMatrix) and a.stype == :dense and b.stype == :dense
|
73
73
|
raise(ArgumentError, 'Expected nil or dense NMatrix as third argument.') unless c.nil? or (c.is_a?(NMatrix) and c.stype == :dense)
|
74
|
-
raise(ArgumentError, 'NMatrix dtype mismatch.')
|
74
|
+
raise(ArgumentError, 'NMatrix dtype mismatch.') unless a.dtype == b.dtype and (c ? a.dtype == c.dtype : true)
|
75
75
|
|
76
76
|
# First, set m, n, and k, which depend on whether we're taking the
|
77
77
|
# transpose of a and b.
|
@@ -93,7 +93,7 @@ module NMatrix::BLAS
|
|
93
93
|
end
|
94
94
|
|
95
95
|
n ||= transpose_b ? b.shape[0] : b.shape[1]
|
96
|
-
c
|
96
|
+
c = NMatrix.new([m, n], dtype: a.dtype)
|
97
97
|
end
|
98
98
|
|
99
99
|
# I think these are independent of whether or not a transpose occurs.
|
@@ -143,7 +143,7 @@ module NMatrix::BLAS
|
|
143
143
|
def gemv(a, x, y = nil, alpha = 1.0, beta = 0.0, transpose_a = false, m = nil, n = nil, lda = nil, incx = nil, incy = nil)
|
144
144
|
raise(ArgumentError, 'Expected dense NMatrices as first two arguments.') unless a.is_a?(NMatrix) and x.is_a?(NMatrix) and a.stype == :dense and x.stype == :dense
|
145
145
|
raise(ArgumentError, 'Expected nil or dense NMatrix as third argument.') unless y.nil? or (y.is_a?(NMatrix) and y.stype == :dense)
|
146
|
-
raise(ArgumentError, 'NMatrix dtype mismatch.')
|
146
|
+
raise(ArgumentError, 'NMatrix dtype mismatch.') unless a.dtype == x.dtype and (y ? a.dtype == y.dtype : true)
|
147
147
|
|
148
148
|
m ||= transpose_a == :transpose ? a.shape[1] : a.shape[0]
|
149
149
|
n ||= transpose_a == :transpose ? a.shape[0] : a.shape[1]
|
@@ -155,9 +155,9 @@ module NMatrix::BLAS
|
|
155
155
|
y = NMatrix.new([m,1], dtype: a.dtype)
|
156
156
|
end
|
157
157
|
|
158
|
-
lda
|
159
|
-
incx
|
160
|
-
incy
|
158
|
+
lda ||= a.shape[1]
|
159
|
+
incx ||= 1
|
160
|
+
incy ||= 1
|
161
161
|
|
162
162
|
::NMatrix::BLAS.cblas_gemv(transpose_a, m, n, alpha, a, lda, x, incx, beta, y, incy)
|
163
163
|
|
@@ -188,7 +188,7 @@ module NMatrix::BLAS
|
|
188
188
|
#
|
189
189
|
def rot(x, y, c, s, incx = 1, incy = 1, n = nil, in_place=false)
|
190
190
|
raise(ArgumentError, 'Expected dense NMatrices as first two arguments.') unless x.is_a?(NMatrix) and y.is_a?(NMatrix) and x.stype == :dense and y.stype == :dense
|
191
|
-
raise(ArgumentError, 'NMatrix dtype mismatch.')
|
191
|
+
raise(ArgumentError, 'NMatrix dtype mismatch.') unless x.dtype == y.dtype
|
192
192
|
raise(ArgumentError, 'Need to supply n for non-standard incx, incy values') if n.nil? && incx != 1 && incx != -1 && incy != 1 && incy != -1
|
193
193
|
|
194
194
|
n ||= [x.size/incx.abs, y.size/incy.abs].min
|
@@ -41,8 +41,8 @@ module NMatrix::IO::Matlab
|
|
41
41
|
attr_reader :byte_order
|
42
42
|
|
43
43
|
def initialize(stream = nil, byte_order = nil, content_or_bytes = nil)
|
44
|
-
@stream
|
45
|
-
@byte_order
|
44
|
+
@stream = stream
|
45
|
+
@byte_order = byte_order
|
46
46
|
|
47
47
|
if content_or_bytes.is_a?(String)
|
48
48
|
@content = content_or_bytes
|
@@ -118,9 +118,9 @@ module NMatrix::IO::Matlab
|
|
118
118
|
# See also to_nm, which is responsible for NMatrix instantiation.
|
119
119
|
def to_ruby
|
120
120
|
case matlab_class
|
121
|
-
when :mxSPARSE
|
122
|
-
when :mxCELL
|
123
|
-
else
|
121
|
+
when :mxSPARSE then return to_nm
|
122
|
+
when :mxCELL then return self.cells.collect { |c| c.to_ruby }
|
123
|
+
else return to_nm
|
124
124
|
end
|
125
125
|
end
|
126
126
|
|
@@ -262,8 +262,8 @@ module NMatrix::IO::Matlab
|
|
262
262
|
#
|
263
263
|
def to_nm(dtype = nil)
|
264
264
|
# Hardest part is figuring out from_dtype, from_index_dtype, and dtype.
|
265
|
-
dtype
|
266
|
-
from_dtype
|
265
|
+
dtype ||= guess_dtype_from_mdtype
|
266
|
+
from_dtype = MatReader::MDTYPE_TO_DTYPE[self.real_part.tag.data_type]
|
267
267
|
|
268
268
|
# Create the same kind of matrix that MATLAB saved.
|
269
269
|
case matlab_class
|
@@ -297,8 +297,8 @@ module NMatrix::IO::Matlab
|
|
297
297
|
self.dimensions = dimensions_tag_data.data
|
298
298
|
|
299
299
|
begin
|
300
|
-
name_tag_data
|
301
|
-
self.matlab_name
|
300
|
+
name_tag_data = packedio.read([Element, options])
|
301
|
+
self.matlab_name = name_tag_data.data.is_a?(Array) ? name_tag_data.data.collect { |i| i.chr }.join('') : name_tag_data.data.chr
|
302
302
|
|
303
303
|
rescue ElementDataIOError => e
|
304
304
|
STDERR.puts "ERROR: Failure while trying to read Matlab variable name: #{name_tag_data.inspect}"
|
@@ -325,8 +325,8 @@ module NMatrix::IO::Matlab
|
|
325
325
|
self.row_index = packedio.read(read_opts)
|
326
326
|
end
|
327
327
|
|
328
|
-
self.real_part
|
329
|
-
self.imaginary_part
|
328
|
+
self.real_part = packedio.read(read_opts)
|
329
|
+
self.imaginary_part = packedio.read(read_opts) if self.complex
|
330
330
|
end
|
331
331
|
end
|
332
332
|
|
@@ -338,8 +338,8 @@ module NMatrix::IO::Matlab
|
|
338
338
|
|
339
339
|
MDTYPE_UNPACK_ARGS =
|
340
340
|
MatReader::MDTYPE_UNPACK_ARGS.merge({
|
341
|
-
:miCOMPRESSED
|
342
|
-
:miMATRIX
|
341
|
+
:miCOMPRESSED => [Compressed, {}],
|
342
|
+
:miMATRIX => [MatrixData, {}]
|
343
343
|
})
|
344
344
|
|
345
345
|
FIRST_TAG_FIELD_POS = 128
|
@@ -407,26 +407,26 @@ module NMatrix::IO::Matlab
|
|
407
407
|
|
408
408
|
include Packable
|
409
409
|
|
410
|
-
BYTE_ORDER_LENGTH
|
411
|
-
DESC_LENGTH
|
412
|
-
DATA_OFFSET_LENGTH
|
413
|
-
VERSION_LENGTH
|
414
|
-
BYTE_ORDER_POS
|
410
|
+
BYTE_ORDER_LENGTH = 2
|
411
|
+
DESC_LENGTH = 116
|
412
|
+
DATA_OFFSET_LENGTH = 8
|
413
|
+
VERSION_LENGTH = 2
|
414
|
+
BYTE_ORDER_POS = 126
|
415
415
|
|
416
416
|
# TODO: TEST WRITE.
|
417
417
|
def write_packed(packedio, options)
|
418
|
-
packedio <<
|
419
|
-
[data_offset,
|
420
|
-
[version,
|
421
|
-
[byte_order,
|
418
|
+
packedio << [desc, {:bytes => DESC_LENGTH }] <<
|
419
|
+
[data_offset, {:bytes => DATA_OFFSET_LENGTH }] <<
|
420
|
+
[version, {:bytes => VERSION_LENGTH }] <<
|
421
|
+
[byte_order, {:bytes => BYTE_ORDER_LENGTH }]
|
422
422
|
end
|
423
423
|
|
424
424
|
def read_packed(packedio, options)
|
425
425
|
self.desc, self.data_offset, self.version, self.endian = packedio >>
|
426
|
-
[String,
|
427
|
-
[String,
|
428
|
-
[Integer,
|
429
|
-
[String,
|
426
|
+
[String, {:bytes => DESC_LENGTH }] >>
|
427
|
+
[String, {:bytes => DATA_OFFSET_LENGTH }] >>
|
428
|
+
[Integer, {:bytes => VERSION_LENGTH, :endian => options[:endian] }] >>
|
429
|
+
[String, {:bytes => 2 }]
|
430
430
|
|
431
431
|
self.desc.strip!
|
432
432
|
self.data_offset.strip!
|
@@ -466,8 +466,8 @@ module NMatrix::IO::Matlab
|
|
466
466
|
# Small data element format
|
467
467
|
raise IOError, 'Small data element format indicated, but length is more than 4 bytes!' if upper > 4
|
468
468
|
|
469
|
-
self.bytes
|
470
|
-
self.raw_data_type
|
469
|
+
self.bytes = upper
|
470
|
+
self.raw_data_type = lower
|
471
471
|
|
472
472
|
else
|
473
473
|
self.bytes = packedio.read([Integer, BYTES_OPTS.merge(options)])
|
@@ -560,8 +560,8 @@ module NMatrix::IO::Matlab
|
|
560
560
|
def read_packed(packedio, options)
|
561
561
|
raise(ArgumentError, 'Missing mandatory option :endian.') unless options.has_key?(:endian)
|
562
562
|
|
563
|
-
self.tag
|
564
|
-
self.data
|
563
|
+
self.tag = packedio.read([Tag, {:endian => options[:endian] }])
|
564
|
+
self.data = packedio.read([String, {:endian => options[:endian], :bytes => tag.bytes }])
|
565
565
|
|
566
566
|
begin
|
567
567
|
ignore_padding(packedio, (tag.bytes + tag.size) % 8) unless [:miMATRIX, :miCOMPRESSED].include?(tag.data_type)
|
data/lib/nmatrix/math.rb
CHANGED
@@ -258,36 +258,92 @@ class NMatrix
|
|
258
258
|
|
259
259
|
# Solve the matrix equation AX = B, where A is +self+, B is the first
|
260
260
|
# argument, and X is returned. A must be a nxn square matrix, while B must be
|
261
|
-
# nxm. Only works with dense
|
262
|
-
# matrices and non-integer, non-object data types.
|
261
|
+
# nxm. Only works with dense matrices and non-integer, non-object data types.
|
263
262
|
#
|
263
|
+
# == Arguments
|
264
|
+
#
|
265
|
+
# * +b+ - the right hand side
|
266
|
+
#
|
267
|
+
# == Options
|
268
|
+
#
|
269
|
+
# * +form+ - Signifies the form of the matrix A in the linear system AX=B.
|
270
|
+
# If not set then it defaults to +:general+, which uses an LU solver.
|
271
|
+
# Other possible values are +:lower_tri+, +:upper_tri+ and +:pos_def+ (alternatively,
|
272
|
+
# non-abbreviated symbols +:lower_triangular+, +:upper_triangular+,
|
273
|
+
# and +:positive_definite+ can be used.
|
274
|
+
# If +:lower_tri+ or +:upper_tri+ is set, then a specialized linear solver for linear
|
275
|
+
# systems AX=B with a lower or upper triangular matrix A is used. If +:pos_def+ is chosen,
|
276
|
+
# then the linear system is solved via the Cholesky factorization.
|
277
|
+
# Note that when +:lower_tri+ or +:upper_tri+ is used, then the algorithm just assumes that
|
278
|
+
# all entries in the lower/upper triangle of the matrix are zeros without checking (which
|
279
|
+
# can be useful in certain applications).
|
280
|
+
#
|
281
|
+
#
|
264
282
|
# == Usage
|
265
283
|
#
|
266
284
|
# a = NMatrix.new [2,2], [3,1,1,2], dtype: dtype
|
267
285
|
# b = NMatrix.new [2,1], [9,8], dtype: dtype
|
268
286
|
# a.solve(b)
|
269
|
-
|
287
|
+
#
|
288
|
+
# # solve an upper triangular linear system more efficiently:
|
289
|
+
# require 'benchmark'
|
290
|
+
# require 'nmatrix/lapacke'
|
291
|
+
# rand_mat = NMatrix.random([10000, 10000], dtype: :float64)
|
292
|
+
# a = rand_mat.triu
|
293
|
+
# b = NMatrix.random([10000, 10], dtype: :float64)
|
294
|
+
# Benchmark.bm(10) do |bm|
|
295
|
+
# bm.report('general') { a.solve(b) }
|
296
|
+
# bm.report('upper_tri') { a.solve(b, form: :upper_tri) }
|
297
|
+
# end
|
298
|
+
# # user system total real
|
299
|
+
# # general 73.170000 0.670000 73.840000 ( 73.810086)
|
300
|
+
# # upper_tri 0.180000 0.000000 0.180000 ( 0.182491)
|
301
|
+
#
|
302
|
+
def solve(b, opts = {})
|
270
303
|
raise(ShapeError, "Must be called on square matrix") unless self.dim == 2 && self.shape[0] == self.shape[1]
|
271
304
|
raise(ShapeError, "number of rows of b must equal number of cols of self") if
|
272
305
|
self.shape[1] != b.shape[0]
|
273
|
-
raise
|
274
|
-
raise
|
306
|
+
raise(ArgumentError, "only works with dense matrices") if self.stype != :dense
|
307
|
+
raise(ArgumentError, "only works for non-integer, non-object dtypes") if
|
275
308
|
integer_dtype? or object_dtype? or b.integer_dtype? or b.object_dtype?
|
276
309
|
|
277
|
-
|
278
|
-
|
279
|
-
n
|
310
|
+
opts = { form: :general }.merge(opts)
|
311
|
+
x = b.clone
|
312
|
+
n = self.shape[0]
|
280
313
|
nrhs = b.shape[1]
|
281
314
|
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
315
|
+
case opts[:form]
|
316
|
+
when :general
|
317
|
+
clone = self.clone
|
318
|
+
ipiv = NMatrix::LAPACK.clapack_getrf(:row, n, n, clone, n)
|
319
|
+
# When we call clapack_getrs with :row, actually only the first matrix
|
320
|
+
# (i.e. clone) is interpreted as row-major, while the other matrix (x)
|
321
|
+
# is interpreted as column-major. See here: http://math-atlas.sourceforge.net/faq.html#RowSolve
|
322
|
+
# So we must transpose x before and after
|
323
|
+
# calling it.
|
324
|
+
x = x.transpose
|
325
|
+
NMatrix::LAPACK.clapack_getrs(:row, :no_transpose, n, nrhs, clone, n, ipiv, x, n)
|
326
|
+
x.transpose
|
327
|
+
when :upper_tri, :upper_triangular
|
328
|
+
raise(ArgumentError, "upper triangular solver does not work with complex dtypes") if
|
329
|
+
complex_dtype? or b.complex_dtype?
|
330
|
+
# this is the correct function call; see https://github.com/SciRuby/nmatrix/issues/374
|
331
|
+
NMatrix::BLAS::cblas_trsm(:row, :left, :upper, false, :nounit, n, nrhs, 1.0, self, n, x, nrhs)
|
332
|
+
x
|
333
|
+
when :lower_tri, :lower_triangular
|
334
|
+
raise(ArgumentError, "lower triangular solver does not work with complex dtypes") if
|
335
|
+
complex_dtype? or b.complex_dtype?
|
336
|
+
# this is a workaround; see https://github.com/SciRuby/nmatrix/issues/422
|
337
|
+
x = x.transpose
|
338
|
+
NMatrix::BLAS::cblas_trsm(:row, :right, :lower, :transpose, :nounit, nrhs, n, 1.0, self, n, x, n)
|
339
|
+
x.transpose
|
340
|
+
when :pos_def, :positive_definite
|
341
|
+
u, l = self.factorize_cholesky
|
342
|
+
z = l.solve(b, form: :lower_tri)
|
343
|
+
u.solve(z, form: :upper_tri)
|
344
|
+
else
|
345
|
+
raise(ArgumentError, "#{opts[:form]} is not a valid form option")
|
346
|
+
end
|
291
347
|
end
|
292
348
|
|
293
349
|
#
|