nmatrix 0.2.0 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (44) hide show
  1. checksums.yaml +4 -4
  2. data/ext/nmatrix/data/complex.h +183 -159
  3. data/ext/nmatrix/data/data.cpp +113 -112
  4. data/ext/nmatrix/data/data.h +306 -292
  5. data/ext/nmatrix/data/ruby_object.h +193 -193
  6. data/ext/nmatrix/extconf.rb +11 -9
  7. data/ext/nmatrix/math.cpp +9 -11
  8. data/ext/nmatrix/math/math.h +3 -2
  9. data/ext/nmatrix/math/trsm.h +152 -152
  10. data/ext/nmatrix/nmatrix.h +30 -0
  11. data/ext/nmatrix/ruby_constants.cpp +67 -67
  12. data/ext/nmatrix/ruby_constants.h +35 -35
  13. data/ext/nmatrix/ruby_nmatrix.c +168 -183
  14. data/ext/nmatrix/storage/common.h +4 -3
  15. data/ext/nmatrix/storage/dense/dense.cpp +50 -50
  16. data/ext/nmatrix/storage/dense/dense.h +8 -7
  17. data/ext/nmatrix/storage/list/list.cpp +16 -16
  18. data/ext/nmatrix/storage/list/list.h +7 -6
  19. data/ext/nmatrix/storage/storage.cpp +32 -32
  20. data/ext/nmatrix/storage/storage.h +12 -11
  21. data/ext/nmatrix/storage/yale/class.h +2 -2
  22. data/ext/nmatrix/storage/yale/iterators/base.h +2 -1
  23. data/ext/nmatrix/storage/yale/iterators/iterator.h +2 -1
  24. data/ext/nmatrix/storage/yale/iterators/row.h +2 -1
  25. data/ext/nmatrix/storage/yale/iterators/row_stored.h +2 -1
  26. data/ext/nmatrix/storage/yale/iterators/row_stored_nd.h +1 -0
  27. data/ext/nmatrix/storage/yale/iterators/stored_diagonal.h +2 -1
  28. data/ext/nmatrix/storage/yale/yale.cpp +27 -27
  29. data/ext/nmatrix/storage/yale/yale.h +7 -6
  30. data/ext/nmatrix/ttable_helper.rb +10 -10
  31. data/ext/nmatrix/types.h +3 -2
  32. data/ext/nmatrix/util/io.cpp +7 -7
  33. data/ext/nmatrix/util/sl_list.cpp +26 -26
  34. data/ext/nmatrix/util/sl_list.h +19 -18
  35. data/lib/nmatrix/blas.rb +7 -7
  36. data/lib/nmatrix/io/mat5_reader.rb +30 -30
  37. data/lib/nmatrix/math.rb +73 -17
  38. data/lib/nmatrix/nmatrix.rb +10 -8
  39. data/lib/nmatrix/shortcuts.rb +3 -3
  40. data/lib/nmatrix/version.rb +3 -3
  41. data/spec/00_nmatrix_spec.rb +6 -0
  42. data/spec/math_spec.rb +77 -0
  43. data/spec/spec_helper.rb +9 -0
  44. metadata +2 -2
@@ -119,10 +119,10 @@ void mark(LIST* list, size_t recursions) {
119
119
  next = curr->next;
120
120
 
121
121
  if (recursions == 0) {
122
- rb_gc_mark(*((VALUE*)(curr->val)));
123
-
122
+ rb_gc_mark(*((VALUE*)(curr->val)));
123
+
124
124
  } else {
125
- mark((LIST*)curr->val, recursions - 1);
125
+ mark((LIST*)curr->val, recursions - 1);
126
126
  }
127
127
 
128
128
  curr = next;
@@ -174,8 +174,8 @@ NODE* insert(LIST* list, bool replace, size_t key, void* val) {
174
174
  NODE *ins;
175
175
 
176
176
  if (list->first == NULL) {
177
- // List is empty
178
-
177
+ // List is empty
178
+
179
179
  //if (!(ins = malloc(sizeof(NODE)))) return NULL;
180
180
  ins = NM_ALLOC(NODE);
181
181
  ins->next = NULL;
@@ -186,8 +186,8 @@ NODE* insert(LIST* list, bool replace, size_t key, void* val) {
186
186
  return ins;
187
187
 
188
188
  } else if (key < list->first->key) {
189
- // Goes at the beginning of the list
190
-
189
+ // Goes at the beginning of the list
190
+
191
191
  //if (!(ins = malloc(sizeof(NODE)))) return NULL;
192
192
  ins = NM_ALLOC(NODE);
193
193
  ins->next = list->first;
@@ -214,7 +214,7 @@ NODE* insert(LIST* list, bool replace, size_t key, void* val) {
214
214
  return ins;
215
215
 
216
216
  } else {
217
- return insert_after(ins, key, val);
217
+ return insert_after(ins, key, val);
218
218
  }
219
219
  }
220
220
 
@@ -305,7 +305,7 @@ void* remove_by_key(LIST* list, size_t key) {
305
305
  void* val;
306
306
 
307
307
  if (!list->first || list->first->key > key) { // empty list or def. not present
308
- return NULL;
308
+ return NULL;
309
309
  }
310
310
 
311
311
  if (list->first->key == key) {
@@ -320,7 +320,7 @@ void* remove_by_key(LIST* list, size_t key) {
320
320
 
321
321
  f = find_preceding_from_node(list->first, key);
322
322
  if (!f || !f->next) { // not found, end of list
323
- return NULL;
323
+ return NULL;
324
324
  }
325
325
 
326
326
  if (f->next->key == key) {
@@ -411,15 +411,15 @@ bool remove_recursive(LIST* list, const size_t* coords, const size_t* offsets, c
411
411
  NODE* find(LIST* list, size_t key) {
412
412
  NODE* f;
413
413
  if (!list->first) {
414
- // empty list -- does not exist
415
- return NULL;
414
+ // empty list -- does not exist
415
+ return NULL;
416
416
  }
417
417
 
418
418
  // see if we can find it.
419
419
  f = find_nearest_from(list->first, key);
420
420
 
421
421
  if (!f || f->key == key) {
422
- return f;
422
+ return f;
423
423
  }
424
424
 
425
425
  return NULL;
@@ -458,10 +458,10 @@ NODE* find_preceding_from_node(NODE* prev, size_t key) {
458
458
  NODE* curr = prev->next;
459
459
 
460
460
  if (!curr || key <= curr->key) {
461
- return prev;
462
-
461
+ return prev;
462
+
463
463
  } else {
464
- return find_preceding_from_node(curr, key);
464
+ return find_preceding_from_node(curr, key);
465
465
  }
466
466
  }
467
467
 
@@ -491,19 +491,19 @@ NODE* find_nearest_from(NODE* prev, size_t key) {
491
491
  NODE* f;
492
492
 
493
493
  if (prev && prev->key == key) {
494
- return prev;
494
+ return prev;
495
495
  }
496
496
 
497
497
  f = find_preceding_from_node(prev, key);
498
498
 
499
499
  if (!f->next) { // key exceeds final node; return final node.
500
- return f;
501
-
500
+ return f;
501
+
502
502
  } else if (key == f->next->key) { // node already present; return location
503
- return f->next;
503
+ return f->next;
504
504
 
505
505
  } else {
506
- return f;
506
+ return f;
507
507
  }
508
508
  }
509
509
 
@@ -528,14 +528,14 @@ void cast_copy_contents(LIST* lhs, const LIST* rhs, size_t recursions) {
528
528
  lcurr->key = rcurr->key;
529
529
 
530
530
  if (recursions == 0) {
531
- // contents is some kind of value
531
+ // contents is some kind of value
532
532
 
533
533
  lcurr->val = NM_ALLOC( LDType );
534
534
 
535
535
  *reinterpret_cast<LDType*>(lcurr->val) = *reinterpret_cast<RDType*>( rcurr->val );
536
536
 
537
537
  } else {
538
- // contents is a list
538
+ // contents is a list
539
539
 
540
540
  lcurr->val = NM_ALLOC( LIST );
541
541
 
@@ -547,10 +547,10 @@ void cast_copy_contents(LIST* lhs, const LIST* rhs, size_t recursions) {
547
547
  }
548
548
 
549
549
  if (rcurr->next) {
550
- lcurr->next = NM_ALLOC( NODE );
550
+ lcurr->next = NM_ALLOC( NODE );
551
551
 
552
552
  } else {
553
- lcurr->next = NULL;
553
+ lcurr->next = NULL;
554
554
  }
555
555
 
556
556
  lcurr = lcurr->next;
@@ -608,7 +608,7 @@ extern "C" {
608
608
  size_t key = curr->key;
609
609
 
610
610
  if (recursions == 0) { // content is some kind of value
611
- rb_hash_aset(h, INT2FIX(key), rubyobj_from_cval(curr->val, dtype).rval);
611
+ rb_hash_aset(h, INT2FIX(key), nm::rubyobj_from_cval(curr->val, dtype).rval);
612
612
  } else { // content is a list
613
613
  rb_hash_aset(h, INT2FIX(key), nm_list_copy_to_hash(reinterpret_cast<const LIST*>(curr->val), dtype, recursions-1, default_value));
614
614
  }
@@ -33,6 +33,7 @@
33
33
  * Standard Includes
34
34
  */
35
35
 
36
+ #include <ruby.h>
36
37
  #include <type_traits>
37
38
  #include <cstdlib>
38
39
 
@@ -69,9 +70,9 @@ namespace nm { namespace list {
69
70
  // Lifecycle //
70
71
  ///////////////
71
72
 
72
- LIST* create(void);
73
- void del(LIST* list, size_t recursions);
74
- void mark(LIST* list, size_t recursions);
73
+ LIST* create(void);
74
+ void del(LIST* list, size_t recursions);
75
+ void mark(LIST* list, size_t recursions);
75
76
 
76
77
  ///////////////
77
78
  // Accessors //
@@ -90,25 +91,25 @@ bool node_is_within_slice(NODE* n, size_t coord, size_t len);
90
91
 
91
92
  template <typename Type>
92
93
  inline NODE* insert_helper(LIST* list, NODE* node, size_t key, Type val) {
93
- Type* val_mem = NM_ALLOC(Type);
94
- *val_mem = val;
95
-
96
- if (node == NULL) {
97
- return insert(list, false, key, val_mem);
98
-
99
- } else {
100
- return insert_after(node, key, val_mem);
101
- }
94
+ Type* val_mem = NM_ALLOC(Type);
95
+ *val_mem = val;
96
+
97
+ if (node == NULL) {
98
+ return insert(list, false, key, val_mem);
99
+
100
+ } else {
101
+ return insert_after(node, key, val_mem);
102
+ }
102
103
  }
103
104
 
104
105
  template <typename Type>
105
106
  inline NODE* insert_helper(LIST* list, NODE* node, size_t key, Type* ptr) {
106
- if (node == NULL) {
107
- return insert(list, false, key, ptr);
108
-
109
- } else {
110
- return insert_after(node, key, ptr);
111
- }
107
+ if (node == NULL) {
108
+ return insert(list, false, key, ptr);
109
+
110
+ } else {
111
+ return insert_after(node, key, ptr);
112
+ }
112
113
  }
113
114
 
114
115
  ///////////
@@ -71,7 +71,7 @@ module NMatrix::BLAS
71
71
  def gemm(a, b, c = nil, alpha = 1.0, beta = 0.0, transpose_a = false, transpose_b = false, m = nil, n = nil, k = nil, lda = nil, ldb = nil, ldc = nil)
72
72
  raise(ArgumentError, 'Expected dense NMatrices as first two arguments.') unless a.is_a?(NMatrix) and b.is_a?(NMatrix) and a.stype == :dense and b.stype == :dense
73
73
  raise(ArgumentError, 'Expected nil or dense NMatrix as third argument.') unless c.nil? or (c.is_a?(NMatrix) and c.stype == :dense)
74
- raise(ArgumentError, 'NMatrix dtype mismatch.') unless a.dtype == b.dtype and (c ? a.dtype == c.dtype : true)
74
+ raise(ArgumentError, 'NMatrix dtype mismatch.') unless a.dtype == b.dtype and (c ? a.dtype == c.dtype : true)
75
75
 
76
76
  # First, set m, n, and k, which depend on whether we're taking the
77
77
  # transpose of a and b.
@@ -93,7 +93,7 @@ module NMatrix::BLAS
93
93
  end
94
94
 
95
95
  n ||= transpose_b ? b.shape[0] : b.shape[1]
96
- c = NMatrix.new([m, n], dtype: a.dtype)
96
+ c = NMatrix.new([m, n], dtype: a.dtype)
97
97
  end
98
98
 
99
99
  # I think these are independent of whether or not a transpose occurs.
@@ -143,7 +143,7 @@ module NMatrix::BLAS
143
143
  def gemv(a, x, y = nil, alpha = 1.0, beta = 0.0, transpose_a = false, m = nil, n = nil, lda = nil, incx = nil, incy = nil)
144
144
  raise(ArgumentError, 'Expected dense NMatrices as first two arguments.') unless a.is_a?(NMatrix) and x.is_a?(NMatrix) and a.stype == :dense and x.stype == :dense
145
145
  raise(ArgumentError, 'Expected nil or dense NMatrix as third argument.') unless y.nil? or (y.is_a?(NMatrix) and y.stype == :dense)
146
- raise(ArgumentError, 'NMatrix dtype mismatch.') unless a.dtype == x.dtype and (y ? a.dtype == y.dtype : true)
146
+ raise(ArgumentError, 'NMatrix dtype mismatch.') unless a.dtype == x.dtype and (y ? a.dtype == y.dtype : true)
147
147
 
148
148
  m ||= transpose_a == :transpose ? a.shape[1] : a.shape[0]
149
149
  n ||= transpose_a == :transpose ? a.shape[0] : a.shape[1]
@@ -155,9 +155,9 @@ module NMatrix::BLAS
155
155
  y = NMatrix.new([m,1], dtype: a.dtype)
156
156
  end
157
157
 
158
- lda ||= a.shape[1]
159
- incx ||= 1
160
- incy ||= 1
158
+ lda ||= a.shape[1]
159
+ incx ||= 1
160
+ incy ||= 1
161
161
 
162
162
  ::NMatrix::BLAS.cblas_gemv(transpose_a, m, n, alpha, a, lda, x, incx, beta, y, incy)
163
163
 
@@ -188,7 +188,7 @@ module NMatrix::BLAS
188
188
  #
189
189
  def rot(x, y, c, s, incx = 1, incy = 1, n = nil, in_place=false)
190
190
  raise(ArgumentError, 'Expected dense NMatrices as first two arguments.') unless x.is_a?(NMatrix) and y.is_a?(NMatrix) and x.stype == :dense and y.stype == :dense
191
- raise(ArgumentError, 'NMatrix dtype mismatch.') unless x.dtype == y.dtype
191
+ raise(ArgumentError, 'NMatrix dtype mismatch.') unless x.dtype == y.dtype
192
192
  raise(ArgumentError, 'Need to supply n for non-standard incx, incy values') if n.nil? && incx != 1 && incx != -1 && incy != 1 && incy != -1
193
193
 
194
194
  n ||= [x.size/incx.abs, y.size/incy.abs].min
@@ -41,8 +41,8 @@ module NMatrix::IO::Matlab
41
41
  attr_reader :byte_order
42
42
 
43
43
  def initialize(stream = nil, byte_order = nil, content_or_bytes = nil)
44
- @stream = stream
45
- @byte_order = byte_order
44
+ @stream = stream
45
+ @byte_order = byte_order
46
46
 
47
47
  if content_or_bytes.is_a?(String)
48
48
  @content = content_or_bytes
@@ -118,9 +118,9 @@ module NMatrix::IO::Matlab
118
118
  # See also to_nm, which is responsible for NMatrix instantiation.
119
119
  def to_ruby
120
120
  case matlab_class
121
- when :mxSPARSE then return to_nm
122
- when :mxCELL then return self.cells.collect { |c| c.to_ruby }
123
- else return to_nm
121
+ when :mxSPARSE then return to_nm
122
+ when :mxCELL then return self.cells.collect { |c| c.to_ruby }
123
+ else return to_nm
124
124
  end
125
125
  end
126
126
 
@@ -262,8 +262,8 @@ module NMatrix::IO::Matlab
262
262
  #
263
263
  def to_nm(dtype = nil)
264
264
  # Hardest part is figuring out from_dtype, from_index_dtype, and dtype.
265
- dtype ||= guess_dtype_from_mdtype
266
- from_dtype = MatReader::MDTYPE_TO_DTYPE[self.real_part.tag.data_type]
265
+ dtype ||= guess_dtype_from_mdtype
266
+ from_dtype = MatReader::MDTYPE_TO_DTYPE[self.real_part.tag.data_type]
267
267
 
268
268
  # Create the same kind of matrix that MATLAB saved.
269
269
  case matlab_class
@@ -297,8 +297,8 @@ module NMatrix::IO::Matlab
297
297
  self.dimensions = dimensions_tag_data.data
298
298
 
299
299
  begin
300
- name_tag_data = packedio.read([Element, options])
301
- self.matlab_name = name_tag_data.data.is_a?(Array) ? name_tag_data.data.collect { |i| i.chr }.join('') : name_tag_data.data.chr
300
+ name_tag_data = packedio.read([Element, options])
301
+ self.matlab_name = name_tag_data.data.is_a?(Array) ? name_tag_data.data.collect { |i| i.chr }.join('') : name_tag_data.data.chr
302
302
 
303
303
  rescue ElementDataIOError => e
304
304
  STDERR.puts "ERROR: Failure while trying to read Matlab variable name: #{name_tag_data.inspect}"
@@ -325,8 +325,8 @@ module NMatrix::IO::Matlab
325
325
  self.row_index = packedio.read(read_opts)
326
326
  end
327
327
 
328
- self.real_part = packedio.read(read_opts)
329
- self.imaginary_part = packedio.read(read_opts) if self.complex
328
+ self.real_part = packedio.read(read_opts)
329
+ self.imaginary_part = packedio.read(read_opts) if self.complex
330
330
  end
331
331
  end
332
332
 
@@ -338,8 +338,8 @@ module NMatrix::IO::Matlab
338
338
 
339
339
  MDTYPE_UNPACK_ARGS =
340
340
  MatReader::MDTYPE_UNPACK_ARGS.merge({
341
- :miCOMPRESSED => [Compressed, {}],
342
- :miMATRIX => [MatrixData, {}]
341
+ :miCOMPRESSED => [Compressed, {}],
342
+ :miMATRIX => [MatrixData, {}]
343
343
  })
344
344
 
345
345
  FIRST_TAG_FIELD_POS = 128
@@ -407,26 +407,26 @@ module NMatrix::IO::Matlab
407
407
 
408
408
  include Packable
409
409
 
410
- BYTE_ORDER_LENGTH = 2
411
- DESC_LENGTH = 116
412
- DATA_OFFSET_LENGTH = 8
413
- VERSION_LENGTH = 2
414
- BYTE_ORDER_POS = 126
410
+ BYTE_ORDER_LENGTH = 2
411
+ DESC_LENGTH = 116
412
+ DATA_OFFSET_LENGTH = 8
413
+ VERSION_LENGTH = 2
414
+ BYTE_ORDER_POS = 126
415
415
 
416
416
  # TODO: TEST WRITE.
417
417
  def write_packed(packedio, options)
418
- packedio << [desc, {:bytes => DESC_LENGTH }] <<
419
- [data_offset, {:bytes => DATA_OFFSET_LENGTH }] <<
420
- [version, {:bytes => VERSION_LENGTH }] <<
421
- [byte_order, {:bytes => BYTE_ORDER_LENGTH }]
418
+ packedio << [desc, {:bytes => DESC_LENGTH }] <<
419
+ [data_offset, {:bytes => DATA_OFFSET_LENGTH }] <<
420
+ [version, {:bytes => VERSION_LENGTH }] <<
421
+ [byte_order, {:bytes => BYTE_ORDER_LENGTH }]
422
422
  end
423
423
 
424
424
  def read_packed(packedio, options)
425
425
  self.desc, self.data_offset, self.version, self.endian = packedio >>
426
- [String, {:bytes => DESC_LENGTH }] >>
427
- [String, {:bytes => DATA_OFFSET_LENGTH }] >>
428
- [Integer, {:bytes => VERSION_LENGTH, :endian => options[:endian] }] >>
429
- [String, {:bytes => 2 }]
426
+ [String, {:bytes => DESC_LENGTH }] >>
427
+ [String, {:bytes => DATA_OFFSET_LENGTH }] >>
428
+ [Integer, {:bytes => VERSION_LENGTH, :endian => options[:endian] }] >>
429
+ [String, {:bytes => 2 }]
430
430
 
431
431
  self.desc.strip!
432
432
  self.data_offset.strip!
@@ -466,8 +466,8 @@ module NMatrix::IO::Matlab
466
466
  # Small data element format
467
467
  raise IOError, 'Small data element format indicated, but length is more than 4 bytes!' if upper > 4
468
468
 
469
- self.bytes = upper
470
- self.raw_data_type = lower
469
+ self.bytes = upper
470
+ self.raw_data_type = lower
471
471
 
472
472
  else
473
473
  self.bytes = packedio.read([Integer, BYTES_OPTS.merge(options)])
@@ -560,8 +560,8 @@ module NMatrix::IO::Matlab
560
560
  def read_packed(packedio, options)
561
561
  raise(ArgumentError, 'Missing mandatory option :endian.') unless options.has_key?(:endian)
562
562
 
563
- self.tag = packedio.read([Tag, {:endian => options[:endian] }])
564
- self.data = packedio.read([String, {:endian => options[:endian], :bytes => tag.bytes }])
563
+ self.tag = packedio.read([Tag, {:endian => options[:endian] }])
564
+ self.data = packedio.read([String, {:endian => options[:endian], :bytes => tag.bytes }])
565
565
 
566
566
  begin
567
567
  ignore_padding(packedio, (tag.bytes + tag.size) % 8) unless [:miMATRIX, :miCOMPRESSED].include?(tag.data_type)
@@ -258,36 +258,92 @@ class NMatrix
258
258
 
259
259
  # Solve the matrix equation AX = B, where A is +self+, B is the first
260
260
  # argument, and X is returned. A must be a nxn square matrix, while B must be
261
- # nxm. Only works with dense
262
- # matrices and non-integer, non-object data types.
261
+ # nxm. Only works with dense matrices and non-integer, non-object data types.
263
262
  #
263
+ # == Arguments
264
+ #
265
+ # * +b+ - the right hand side
266
+ #
267
+ # == Options
268
+ #
269
+ # * +form+ - Signifies the form of the matrix A in the linear system AX=B.
270
+ # If not set then it defaults to +:general+, which uses an LU solver.
271
+ # Other possible values are +:lower_tri+, +:upper_tri+ and +:pos_def+ (alternatively,
272
+ # non-abbreviated symbols +:lower_triangular+, +:upper_triangular+,
273
+ # and +:positive_definite+ can be used.
274
+ # If +:lower_tri+ or +:upper_tri+ is set, then a specialized linear solver for linear
275
+ # systems AX=B with a lower or upper triangular matrix A is used. If +:pos_def+ is chosen,
276
+ # then the linear system is solved via the Cholesky factorization.
277
+ # Note that when +:lower_tri+ or +:upper_tri+ is used, then the algorithm just assumes that
278
+ # all entries in the lower/upper triangle of the matrix are zeros without checking (which
279
+ # can be useful in certain applications).
280
+ #
281
+ #
264
282
  # == Usage
265
283
  #
266
284
  # a = NMatrix.new [2,2], [3,1,1,2], dtype: dtype
267
285
  # b = NMatrix.new [2,1], [9,8], dtype: dtype
268
286
  # a.solve(b)
269
- def solve b
287
+ #
288
+ # # solve an upper triangular linear system more efficiently:
289
+ # require 'benchmark'
290
+ # require 'nmatrix/lapacke'
291
+ # rand_mat = NMatrix.random([10000, 10000], dtype: :float64)
292
+ # a = rand_mat.triu
293
+ # b = NMatrix.random([10000, 10], dtype: :float64)
294
+ # Benchmark.bm(10) do |bm|
295
+ # bm.report('general') { a.solve(b) }
296
+ # bm.report('upper_tri') { a.solve(b, form: :upper_tri) }
297
+ # end
298
+ # # user system total real
299
+ # # general 73.170000 0.670000 73.840000 ( 73.810086)
300
+ # # upper_tri 0.180000 0.000000 0.180000 ( 0.182491)
301
+ #
302
+ def solve(b, opts = {})
270
303
  raise(ShapeError, "Must be called on square matrix") unless self.dim == 2 && self.shape[0] == self.shape[1]
271
304
  raise(ShapeError, "number of rows of b must equal number of cols of self") if
272
305
  self.shape[1] != b.shape[0]
273
- raise ArgumentError, "only works with dense matrices" if self.stype != :dense
274
- raise ArgumentError, "only works for non-integer, non-object dtypes" if
306
+ raise(ArgumentError, "only works with dense matrices") if self.stype != :dense
307
+ raise(ArgumentError, "only works for non-integer, non-object dtypes") if
275
308
  integer_dtype? or object_dtype? or b.integer_dtype? or b.object_dtype?
276
309
 
277
- x = b.clone
278
- clone = self.clone
279
- n = self.shape[0]
310
+ opts = { form: :general }.merge(opts)
311
+ x = b.clone
312
+ n = self.shape[0]
280
313
  nrhs = b.shape[1]
281
314
 
282
- ipiv = NMatrix::LAPACK.clapack_getrf(:row, n, n, clone, n)
283
- # When we call clapack_getrs with :row, actually only the first matrix
284
- # (i.e. clone) is interpreted as row-major, while the other matrix (x)
285
- # is interpreted as column-major. See here: http://math-atlas.sourceforge.net/faq.html#RowSolve
286
- # So we must transpose x before and after
287
- # calling it.
288
- x = x.transpose
289
- NMatrix::LAPACK.clapack_getrs(:row, :no_transpose, n, nrhs, clone, n, ipiv, x, n)
290
- x.transpose
315
+ case opts[:form]
316
+ when :general
317
+ clone = self.clone
318
+ ipiv = NMatrix::LAPACK.clapack_getrf(:row, n, n, clone, n)
319
+ # When we call clapack_getrs with :row, actually only the first matrix
320
+ # (i.e. clone) is interpreted as row-major, while the other matrix (x)
321
+ # is interpreted as column-major. See here: http://math-atlas.sourceforge.net/faq.html#RowSolve
322
+ # So we must transpose x before and after
323
+ # calling it.
324
+ x = x.transpose
325
+ NMatrix::LAPACK.clapack_getrs(:row, :no_transpose, n, nrhs, clone, n, ipiv, x, n)
326
+ x.transpose
327
+ when :upper_tri, :upper_triangular
328
+ raise(ArgumentError, "upper triangular solver does not work with complex dtypes") if
329
+ complex_dtype? or b.complex_dtype?
330
+ # this is the correct function call; see https://github.com/SciRuby/nmatrix/issues/374
331
+ NMatrix::BLAS::cblas_trsm(:row, :left, :upper, false, :nounit, n, nrhs, 1.0, self, n, x, nrhs)
332
+ x
333
+ when :lower_tri, :lower_triangular
334
+ raise(ArgumentError, "lower triangular solver does not work with complex dtypes") if
335
+ complex_dtype? or b.complex_dtype?
336
+ # this is a workaround; see https://github.com/SciRuby/nmatrix/issues/422
337
+ x = x.transpose
338
+ NMatrix::BLAS::cblas_trsm(:row, :right, :lower, :transpose, :nounit, nrhs, n, 1.0, self, n, x, n)
339
+ x.transpose
340
+ when :pos_def, :positive_definite
341
+ u, l = self.factorize_cholesky
342
+ z = l.solve(b, form: :lower_tri)
343
+ u.solve(z, form: :upper_tri)
344
+ else
345
+ raise(ArgumentError, "#{opts[:form]} is not a valid form option")
346
+ end
291
347
  end
292
348
 
293
349
  #