nmatrix 0.2.0 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. checksums.yaml +4 -4
  2. data/ext/nmatrix/data/complex.h +183 -159
  3. data/ext/nmatrix/data/data.cpp +113 -112
  4. data/ext/nmatrix/data/data.h +306 -292
  5. data/ext/nmatrix/data/ruby_object.h +193 -193
  6. data/ext/nmatrix/extconf.rb +11 -9
  7. data/ext/nmatrix/math.cpp +9 -11
  8. data/ext/nmatrix/math/math.h +3 -2
  9. data/ext/nmatrix/math/trsm.h +152 -152
  10. data/ext/nmatrix/nmatrix.h +30 -0
  11. data/ext/nmatrix/ruby_constants.cpp +67 -67
  12. data/ext/nmatrix/ruby_constants.h +35 -35
  13. data/ext/nmatrix/ruby_nmatrix.c +168 -183
  14. data/ext/nmatrix/storage/common.h +4 -3
  15. data/ext/nmatrix/storage/dense/dense.cpp +50 -50
  16. data/ext/nmatrix/storage/dense/dense.h +8 -7
  17. data/ext/nmatrix/storage/list/list.cpp +16 -16
  18. data/ext/nmatrix/storage/list/list.h +7 -6
  19. data/ext/nmatrix/storage/storage.cpp +32 -32
  20. data/ext/nmatrix/storage/storage.h +12 -11
  21. data/ext/nmatrix/storage/yale/class.h +2 -2
  22. data/ext/nmatrix/storage/yale/iterators/base.h +2 -1
  23. data/ext/nmatrix/storage/yale/iterators/iterator.h +2 -1
  24. data/ext/nmatrix/storage/yale/iterators/row.h +2 -1
  25. data/ext/nmatrix/storage/yale/iterators/row_stored.h +2 -1
  26. data/ext/nmatrix/storage/yale/iterators/row_stored_nd.h +1 -0
  27. data/ext/nmatrix/storage/yale/iterators/stored_diagonal.h +2 -1
  28. data/ext/nmatrix/storage/yale/yale.cpp +27 -27
  29. data/ext/nmatrix/storage/yale/yale.h +7 -6
  30. data/ext/nmatrix/ttable_helper.rb +10 -10
  31. data/ext/nmatrix/types.h +3 -2
  32. data/ext/nmatrix/util/io.cpp +7 -7
  33. data/ext/nmatrix/util/sl_list.cpp +26 -26
  34. data/ext/nmatrix/util/sl_list.h +19 -18
  35. data/lib/nmatrix/blas.rb +7 -7
  36. data/lib/nmatrix/io/mat5_reader.rb +30 -30
  37. data/lib/nmatrix/math.rb +73 -17
  38. data/lib/nmatrix/nmatrix.rb +10 -8
  39. data/lib/nmatrix/shortcuts.rb +3 -3
  40. data/lib/nmatrix/version.rb +3 -3
  41. data/spec/00_nmatrix_spec.rb +6 -0
  42. data/spec/math_spec.rb +77 -0
  43. data/spec/spec_helper.rb +9 -0
  44. metadata +2 -2
@@ -119,10 +119,10 @@ void mark(LIST* list, size_t recursions) {
119
119
  next = curr->next;
120
120
 
121
121
  if (recursions == 0) {
122
- rb_gc_mark(*((VALUE*)(curr->val)));
123
-
122
+ rb_gc_mark(*((VALUE*)(curr->val)));
123
+
124
124
  } else {
125
- mark((LIST*)curr->val, recursions - 1);
125
+ mark((LIST*)curr->val, recursions - 1);
126
126
  }
127
127
 
128
128
  curr = next;
@@ -174,8 +174,8 @@ NODE* insert(LIST* list, bool replace, size_t key, void* val) {
174
174
  NODE *ins;
175
175
 
176
176
  if (list->first == NULL) {
177
- // List is empty
178
-
177
+ // List is empty
178
+
179
179
  //if (!(ins = malloc(sizeof(NODE)))) return NULL;
180
180
  ins = NM_ALLOC(NODE);
181
181
  ins->next = NULL;
@@ -186,8 +186,8 @@ NODE* insert(LIST* list, bool replace, size_t key, void* val) {
186
186
  return ins;
187
187
 
188
188
  } else if (key < list->first->key) {
189
- // Goes at the beginning of the list
190
-
189
+ // Goes at the beginning of the list
190
+
191
191
  //if (!(ins = malloc(sizeof(NODE)))) return NULL;
192
192
  ins = NM_ALLOC(NODE);
193
193
  ins->next = list->first;
@@ -214,7 +214,7 @@ NODE* insert(LIST* list, bool replace, size_t key, void* val) {
214
214
  return ins;
215
215
 
216
216
  } else {
217
- return insert_after(ins, key, val);
217
+ return insert_after(ins, key, val);
218
218
  }
219
219
  }
220
220
 
@@ -305,7 +305,7 @@ void* remove_by_key(LIST* list, size_t key) {
305
305
  void* val;
306
306
 
307
307
  if (!list->first || list->first->key > key) { // empty list or def. not present
308
- return NULL;
308
+ return NULL;
309
309
  }
310
310
 
311
311
  if (list->first->key == key) {
@@ -320,7 +320,7 @@ void* remove_by_key(LIST* list, size_t key) {
320
320
 
321
321
  f = find_preceding_from_node(list->first, key);
322
322
  if (!f || !f->next) { // not found, end of list
323
- return NULL;
323
+ return NULL;
324
324
  }
325
325
 
326
326
  if (f->next->key == key) {
@@ -411,15 +411,15 @@ bool remove_recursive(LIST* list, const size_t* coords, const size_t* offsets, c
411
411
  NODE* find(LIST* list, size_t key) {
412
412
  NODE* f;
413
413
  if (!list->first) {
414
- // empty list -- does not exist
415
- return NULL;
414
+ // empty list -- does not exist
415
+ return NULL;
416
416
  }
417
417
 
418
418
  // see if we can find it.
419
419
  f = find_nearest_from(list->first, key);
420
420
 
421
421
  if (!f || f->key == key) {
422
- return f;
422
+ return f;
423
423
  }
424
424
 
425
425
  return NULL;
@@ -458,10 +458,10 @@ NODE* find_preceding_from_node(NODE* prev, size_t key) {
458
458
  NODE* curr = prev->next;
459
459
 
460
460
  if (!curr || key <= curr->key) {
461
- return prev;
462
-
461
+ return prev;
462
+
463
463
  } else {
464
- return find_preceding_from_node(curr, key);
464
+ return find_preceding_from_node(curr, key);
465
465
  }
466
466
  }
467
467
 
@@ -491,19 +491,19 @@ NODE* find_nearest_from(NODE* prev, size_t key) {
491
491
  NODE* f;
492
492
 
493
493
  if (prev && prev->key == key) {
494
- return prev;
494
+ return prev;
495
495
  }
496
496
 
497
497
  f = find_preceding_from_node(prev, key);
498
498
 
499
499
  if (!f->next) { // key exceeds final node; return final node.
500
- return f;
501
-
500
+ return f;
501
+
502
502
  } else if (key == f->next->key) { // node already present; return location
503
- return f->next;
503
+ return f->next;
504
504
 
505
505
  } else {
506
- return f;
506
+ return f;
507
507
  }
508
508
  }
509
509
 
@@ -528,14 +528,14 @@ void cast_copy_contents(LIST* lhs, const LIST* rhs, size_t recursions) {
528
528
  lcurr->key = rcurr->key;
529
529
 
530
530
  if (recursions == 0) {
531
- // contents is some kind of value
531
+ // contents is some kind of value
532
532
 
533
533
  lcurr->val = NM_ALLOC( LDType );
534
534
 
535
535
  *reinterpret_cast<LDType*>(lcurr->val) = *reinterpret_cast<RDType*>( rcurr->val );
536
536
 
537
537
  } else {
538
- // contents is a list
538
+ // contents is a list
539
539
 
540
540
  lcurr->val = NM_ALLOC( LIST );
541
541
 
@@ -547,10 +547,10 @@ void cast_copy_contents(LIST* lhs, const LIST* rhs, size_t recursions) {
547
547
  }
548
548
 
549
549
  if (rcurr->next) {
550
- lcurr->next = NM_ALLOC( NODE );
550
+ lcurr->next = NM_ALLOC( NODE );
551
551
 
552
552
  } else {
553
- lcurr->next = NULL;
553
+ lcurr->next = NULL;
554
554
  }
555
555
 
556
556
  lcurr = lcurr->next;
@@ -608,7 +608,7 @@ extern "C" {
608
608
  size_t key = curr->key;
609
609
 
610
610
  if (recursions == 0) { // content is some kind of value
611
- rb_hash_aset(h, INT2FIX(key), rubyobj_from_cval(curr->val, dtype).rval);
611
+ rb_hash_aset(h, INT2FIX(key), nm::rubyobj_from_cval(curr->val, dtype).rval);
612
612
  } else { // content is a list
613
613
  rb_hash_aset(h, INT2FIX(key), nm_list_copy_to_hash(reinterpret_cast<const LIST*>(curr->val), dtype, recursions-1, default_value));
614
614
  }
@@ -33,6 +33,7 @@
33
33
  * Standard Includes
34
34
  */
35
35
 
36
+ #include <ruby.h>
36
37
  #include <type_traits>
37
38
  #include <cstdlib>
38
39
 
@@ -69,9 +70,9 @@ namespace nm { namespace list {
69
70
  // Lifecycle //
70
71
  ///////////////
71
72
 
72
- LIST* create(void);
73
- void del(LIST* list, size_t recursions);
74
- void mark(LIST* list, size_t recursions);
73
+ LIST* create(void);
74
+ void del(LIST* list, size_t recursions);
75
+ void mark(LIST* list, size_t recursions);
75
76
 
76
77
  ///////////////
77
78
  // Accessors //
@@ -90,25 +91,25 @@ bool node_is_within_slice(NODE* n, size_t coord, size_t len);
90
91
 
91
92
  template <typename Type>
92
93
  inline NODE* insert_helper(LIST* list, NODE* node, size_t key, Type val) {
93
- Type* val_mem = NM_ALLOC(Type);
94
- *val_mem = val;
95
-
96
- if (node == NULL) {
97
- return insert(list, false, key, val_mem);
98
-
99
- } else {
100
- return insert_after(node, key, val_mem);
101
- }
94
+ Type* val_mem = NM_ALLOC(Type);
95
+ *val_mem = val;
96
+
97
+ if (node == NULL) {
98
+ return insert(list, false, key, val_mem);
99
+
100
+ } else {
101
+ return insert_after(node, key, val_mem);
102
+ }
102
103
  }
103
104
 
104
105
  template <typename Type>
105
106
  inline NODE* insert_helper(LIST* list, NODE* node, size_t key, Type* ptr) {
106
- if (node == NULL) {
107
- return insert(list, false, key, ptr);
108
-
109
- } else {
110
- return insert_after(node, key, ptr);
111
- }
107
+ if (node == NULL) {
108
+ return insert(list, false, key, ptr);
109
+
110
+ } else {
111
+ return insert_after(node, key, ptr);
112
+ }
112
113
  }
113
114
 
114
115
  ///////////
@@ -71,7 +71,7 @@ module NMatrix::BLAS
71
71
  def gemm(a, b, c = nil, alpha = 1.0, beta = 0.0, transpose_a = false, transpose_b = false, m = nil, n = nil, k = nil, lda = nil, ldb = nil, ldc = nil)
72
72
  raise(ArgumentError, 'Expected dense NMatrices as first two arguments.') unless a.is_a?(NMatrix) and b.is_a?(NMatrix) and a.stype == :dense and b.stype == :dense
73
73
  raise(ArgumentError, 'Expected nil or dense NMatrix as third argument.') unless c.nil? or (c.is_a?(NMatrix) and c.stype == :dense)
74
- raise(ArgumentError, 'NMatrix dtype mismatch.') unless a.dtype == b.dtype and (c ? a.dtype == c.dtype : true)
74
+ raise(ArgumentError, 'NMatrix dtype mismatch.') unless a.dtype == b.dtype and (c ? a.dtype == c.dtype : true)
75
75
 
76
76
  # First, set m, n, and k, which depend on whether we're taking the
77
77
  # transpose of a and b.
@@ -93,7 +93,7 @@ module NMatrix::BLAS
93
93
  end
94
94
 
95
95
  n ||= transpose_b ? b.shape[0] : b.shape[1]
96
- c = NMatrix.new([m, n], dtype: a.dtype)
96
+ c = NMatrix.new([m, n], dtype: a.dtype)
97
97
  end
98
98
 
99
99
  # I think these are independent of whether or not a transpose occurs.
@@ -143,7 +143,7 @@ module NMatrix::BLAS
143
143
  def gemv(a, x, y = nil, alpha = 1.0, beta = 0.0, transpose_a = false, m = nil, n = nil, lda = nil, incx = nil, incy = nil)
144
144
  raise(ArgumentError, 'Expected dense NMatrices as first two arguments.') unless a.is_a?(NMatrix) and x.is_a?(NMatrix) and a.stype == :dense and x.stype == :dense
145
145
  raise(ArgumentError, 'Expected nil or dense NMatrix as third argument.') unless y.nil? or (y.is_a?(NMatrix) and y.stype == :dense)
146
- raise(ArgumentError, 'NMatrix dtype mismatch.') unless a.dtype == x.dtype and (y ? a.dtype == y.dtype : true)
146
+ raise(ArgumentError, 'NMatrix dtype mismatch.') unless a.dtype == x.dtype and (y ? a.dtype == y.dtype : true)
147
147
 
148
148
  m ||= transpose_a == :transpose ? a.shape[1] : a.shape[0]
149
149
  n ||= transpose_a == :transpose ? a.shape[0] : a.shape[1]
@@ -155,9 +155,9 @@ module NMatrix::BLAS
155
155
  y = NMatrix.new([m,1], dtype: a.dtype)
156
156
  end
157
157
 
158
- lda ||= a.shape[1]
159
- incx ||= 1
160
- incy ||= 1
158
+ lda ||= a.shape[1]
159
+ incx ||= 1
160
+ incy ||= 1
161
161
 
162
162
  ::NMatrix::BLAS.cblas_gemv(transpose_a, m, n, alpha, a, lda, x, incx, beta, y, incy)
163
163
 
@@ -188,7 +188,7 @@ module NMatrix::BLAS
188
188
  #
189
189
  def rot(x, y, c, s, incx = 1, incy = 1, n = nil, in_place=false)
190
190
  raise(ArgumentError, 'Expected dense NMatrices as first two arguments.') unless x.is_a?(NMatrix) and y.is_a?(NMatrix) and x.stype == :dense and y.stype == :dense
191
- raise(ArgumentError, 'NMatrix dtype mismatch.') unless x.dtype == y.dtype
191
+ raise(ArgumentError, 'NMatrix dtype mismatch.') unless x.dtype == y.dtype
192
192
  raise(ArgumentError, 'Need to supply n for non-standard incx, incy values') if n.nil? && incx != 1 && incx != -1 && incy != 1 && incy != -1
193
193
 
194
194
  n ||= [x.size/incx.abs, y.size/incy.abs].min
@@ -41,8 +41,8 @@ module NMatrix::IO::Matlab
41
41
  attr_reader :byte_order
42
42
 
43
43
  def initialize(stream = nil, byte_order = nil, content_or_bytes = nil)
44
- @stream = stream
45
- @byte_order = byte_order
44
+ @stream = stream
45
+ @byte_order = byte_order
46
46
 
47
47
  if content_or_bytes.is_a?(String)
48
48
  @content = content_or_bytes
@@ -118,9 +118,9 @@ module NMatrix::IO::Matlab
118
118
  # See also to_nm, which is responsible for NMatrix instantiation.
119
119
  def to_ruby
120
120
  case matlab_class
121
- when :mxSPARSE then return to_nm
122
- when :mxCELL then return self.cells.collect { |c| c.to_ruby }
123
- else return to_nm
121
+ when :mxSPARSE then return to_nm
122
+ when :mxCELL then return self.cells.collect { |c| c.to_ruby }
123
+ else return to_nm
124
124
  end
125
125
  end
126
126
 
@@ -262,8 +262,8 @@ module NMatrix::IO::Matlab
262
262
  #
263
263
  def to_nm(dtype = nil)
264
264
  # Hardest part is figuring out from_dtype, from_index_dtype, and dtype.
265
- dtype ||= guess_dtype_from_mdtype
266
- from_dtype = MatReader::MDTYPE_TO_DTYPE[self.real_part.tag.data_type]
265
+ dtype ||= guess_dtype_from_mdtype
266
+ from_dtype = MatReader::MDTYPE_TO_DTYPE[self.real_part.tag.data_type]
267
267
 
268
268
  # Create the same kind of matrix that MATLAB saved.
269
269
  case matlab_class
@@ -297,8 +297,8 @@ module NMatrix::IO::Matlab
297
297
  self.dimensions = dimensions_tag_data.data
298
298
 
299
299
  begin
300
- name_tag_data = packedio.read([Element, options])
301
- self.matlab_name = name_tag_data.data.is_a?(Array) ? name_tag_data.data.collect { |i| i.chr }.join('') : name_tag_data.data.chr
300
+ name_tag_data = packedio.read([Element, options])
301
+ self.matlab_name = name_tag_data.data.is_a?(Array) ? name_tag_data.data.collect { |i| i.chr }.join('') : name_tag_data.data.chr
302
302
 
303
303
  rescue ElementDataIOError => e
304
304
  STDERR.puts "ERROR: Failure while trying to read Matlab variable name: #{name_tag_data.inspect}"
@@ -325,8 +325,8 @@ module NMatrix::IO::Matlab
325
325
  self.row_index = packedio.read(read_opts)
326
326
  end
327
327
 
328
- self.real_part = packedio.read(read_opts)
329
- self.imaginary_part = packedio.read(read_opts) if self.complex
328
+ self.real_part = packedio.read(read_opts)
329
+ self.imaginary_part = packedio.read(read_opts) if self.complex
330
330
  end
331
331
  end
332
332
 
@@ -338,8 +338,8 @@ module NMatrix::IO::Matlab
338
338
 
339
339
  MDTYPE_UNPACK_ARGS =
340
340
  MatReader::MDTYPE_UNPACK_ARGS.merge({
341
- :miCOMPRESSED => [Compressed, {}],
342
- :miMATRIX => [MatrixData, {}]
341
+ :miCOMPRESSED => [Compressed, {}],
342
+ :miMATRIX => [MatrixData, {}]
343
343
  })
344
344
 
345
345
  FIRST_TAG_FIELD_POS = 128
@@ -407,26 +407,26 @@ module NMatrix::IO::Matlab
407
407
 
408
408
  include Packable
409
409
 
410
- BYTE_ORDER_LENGTH = 2
411
- DESC_LENGTH = 116
412
- DATA_OFFSET_LENGTH = 8
413
- VERSION_LENGTH = 2
414
- BYTE_ORDER_POS = 126
410
+ BYTE_ORDER_LENGTH = 2
411
+ DESC_LENGTH = 116
412
+ DATA_OFFSET_LENGTH = 8
413
+ VERSION_LENGTH = 2
414
+ BYTE_ORDER_POS = 126
415
415
 
416
416
  # TODO: TEST WRITE.
417
417
  def write_packed(packedio, options)
418
- packedio << [desc, {:bytes => DESC_LENGTH }] <<
419
- [data_offset, {:bytes => DATA_OFFSET_LENGTH }] <<
420
- [version, {:bytes => VERSION_LENGTH }] <<
421
- [byte_order, {:bytes => BYTE_ORDER_LENGTH }]
418
+ packedio << [desc, {:bytes => DESC_LENGTH }] <<
419
+ [data_offset, {:bytes => DATA_OFFSET_LENGTH }] <<
420
+ [version, {:bytes => VERSION_LENGTH }] <<
421
+ [byte_order, {:bytes => BYTE_ORDER_LENGTH }]
422
422
  end
423
423
 
424
424
  def read_packed(packedio, options)
425
425
  self.desc, self.data_offset, self.version, self.endian = packedio >>
426
- [String, {:bytes => DESC_LENGTH }] >>
427
- [String, {:bytes => DATA_OFFSET_LENGTH }] >>
428
- [Integer, {:bytes => VERSION_LENGTH, :endian => options[:endian] }] >>
429
- [String, {:bytes => 2 }]
426
+ [String, {:bytes => DESC_LENGTH }] >>
427
+ [String, {:bytes => DATA_OFFSET_LENGTH }] >>
428
+ [Integer, {:bytes => VERSION_LENGTH, :endian => options[:endian] }] >>
429
+ [String, {:bytes => 2 }]
430
430
 
431
431
  self.desc.strip!
432
432
  self.data_offset.strip!
@@ -466,8 +466,8 @@ module NMatrix::IO::Matlab
466
466
  # Small data element format
467
467
  raise IOError, 'Small data element format indicated, but length is more than 4 bytes!' if upper > 4
468
468
 
469
- self.bytes = upper
470
- self.raw_data_type = lower
469
+ self.bytes = upper
470
+ self.raw_data_type = lower
471
471
 
472
472
  else
473
473
  self.bytes = packedio.read([Integer, BYTES_OPTS.merge(options)])
@@ -560,8 +560,8 @@ module NMatrix::IO::Matlab
560
560
  def read_packed(packedio, options)
561
561
  raise(ArgumentError, 'Missing mandatory option :endian.') unless options.has_key?(:endian)
562
562
 
563
- self.tag = packedio.read([Tag, {:endian => options[:endian] }])
564
- self.data = packedio.read([String, {:endian => options[:endian], :bytes => tag.bytes }])
563
+ self.tag = packedio.read([Tag, {:endian => options[:endian] }])
564
+ self.data = packedio.read([String, {:endian => options[:endian], :bytes => tag.bytes }])
565
565
 
566
566
  begin
567
567
  ignore_padding(packedio, (tag.bytes + tag.size) % 8) unless [:miMATRIX, :miCOMPRESSED].include?(tag.data_type)
@@ -258,36 +258,92 @@ class NMatrix
258
258
 
259
259
  # Solve the matrix equation AX = B, where A is +self+, B is the first
260
260
  # argument, and X is returned. A must be a nxn square matrix, while B must be
261
- # nxm. Only works with dense
262
- # matrices and non-integer, non-object data types.
261
+ # nxm. Only works with dense matrices and non-integer, non-object data types.
263
262
  #
263
+ # == Arguments
264
+ #
265
+ # * +b+ - the right hand side
266
+ #
267
+ # == Options
268
+ #
269
+ # * +form+ - Signifies the form of the matrix A in the linear system AX=B.
270
+ # If not set then it defaults to +:general+, which uses an LU solver.
271
+ # Other possible values are +:lower_tri+, +:upper_tri+ and +:pos_def+ (alternatively,
272
+ # non-abbreviated symbols +:lower_triangular+, +:upper_triangular+,
273
+ # and +:positive_definite+ can be used.
274
+ # If +:lower_tri+ or +:upper_tri+ is set, then a specialized linear solver for linear
275
+ # systems AX=B with a lower or upper triangular matrix A is used. If +:pos_def+ is chosen,
276
+ # then the linear system is solved via the Cholesky factorization.
277
+ # Note that when +:lower_tri+ or +:upper_tri+ is used, then the algorithm just assumes that
278
+ # all entries in the lower/upper triangle of the matrix are zeros without checking (which
279
+ # can be useful in certain applications).
280
+ #
281
+ #
264
282
  # == Usage
265
283
  #
266
284
  # a = NMatrix.new [2,2], [3,1,1,2], dtype: dtype
267
285
  # b = NMatrix.new [2,1], [9,8], dtype: dtype
268
286
  # a.solve(b)
269
- def solve b
287
+ #
288
+ # # solve an upper triangular linear system more efficiently:
289
+ # require 'benchmark'
290
+ # require 'nmatrix/lapacke'
291
+ # rand_mat = NMatrix.random([10000, 10000], dtype: :float64)
292
+ # a = rand_mat.triu
293
+ # b = NMatrix.random([10000, 10], dtype: :float64)
294
+ # Benchmark.bm(10) do |bm|
295
+ # bm.report('general') { a.solve(b) }
296
+ # bm.report('upper_tri') { a.solve(b, form: :upper_tri) }
297
+ # end
298
+ # # user system total real
299
+ # # general 73.170000 0.670000 73.840000 ( 73.810086)
300
+ # # upper_tri 0.180000 0.000000 0.180000 ( 0.182491)
301
+ #
302
+ def solve(b, opts = {})
270
303
  raise(ShapeError, "Must be called on square matrix") unless self.dim == 2 && self.shape[0] == self.shape[1]
271
304
  raise(ShapeError, "number of rows of b must equal number of cols of self") if
272
305
  self.shape[1] != b.shape[0]
273
- raise ArgumentError, "only works with dense matrices" if self.stype != :dense
274
- raise ArgumentError, "only works for non-integer, non-object dtypes" if
306
+ raise(ArgumentError, "only works with dense matrices") if self.stype != :dense
307
+ raise(ArgumentError, "only works for non-integer, non-object dtypes") if
275
308
  integer_dtype? or object_dtype? or b.integer_dtype? or b.object_dtype?
276
309
 
277
- x = b.clone
278
- clone = self.clone
279
- n = self.shape[0]
310
+ opts = { form: :general }.merge(opts)
311
+ x = b.clone
312
+ n = self.shape[0]
280
313
  nrhs = b.shape[1]
281
314
 
282
- ipiv = NMatrix::LAPACK.clapack_getrf(:row, n, n, clone, n)
283
- # When we call clapack_getrs with :row, actually only the first matrix
284
- # (i.e. clone) is interpreted as row-major, while the other matrix (x)
285
- # is interpreted as column-major. See here: http://math-atlas.sourceforge.net/faq.html#RowSolve
286
- # So we must transpose x before and after
287
- # calling it.
288
- x = x.transpose
289
- NMatrix::LAPACK.clapack_getrs(:row, :no_transpose, n, nrhs, clone, n, ipiv, x, n)
290
- x.transpose
315
+ case opts[:form]
316
+ when :general
317
+ clone = self.clone
318
+ ipiv = NMatrix::LAPACK.clapack_getrf(:row, n, n, clone, n)
319
+ # When we call clapack_getrs with :row, actually only the first matrix
320
+ # (i.e. clone) is interpreted as row-major, while the other matrix (x)
321
+ # is interpreted as column-major. See here: http://math-atlas.sourceforge.net/faq.html#RowSolve
322
+ # So we must transpose x before and after
323
+ # calling it.
324
+ x = x.transpose
325
+ NMatrix::LAPACK.clapack_getrs(:row, :no_transpose, n, nrhs, clone, n, ipiv, x, n)
326
+ x.transpose
327
+ when :upper_tri, :upper_triangular
328
+ raise(ArgumentError, "upper triangular solver does not work with complex dtypes") if
329
+ complex_dtype? or b.complex_dtype?
330
+ # this is the correct function call; see https://github.com/SciRuby/nmatrix/issues/374
331
+ NMatrix::BLAS::cblas_trsm(:row, :left, :upper, false, :nounit, n, nrhs, 1.0, self, n, x, nrhs)
332
+ x
333
+ when :lower_tri, :lower_triangular
334
+ raise(ArgumentError, "lower triangular solver does not work with complex dtypes") if
335
+ complex_dtype? or b.complex_dtype?
336
+ # this is a workaround; see https://github.com/SciRuby/nmatrix/issues/422
337
+ x = x.transpose
338
+ NMatrix::BLAS::cblas_trsm(:row, :right, :lower, :transpose, :nounit, nrhs, n, 1.0, self, n, x, n)
339
+ x.transpose
340
+ when :pos_def, :positive_definite
341
+ u, l = self.factorize_cholesky
342
+ z = l.solve(b, form: :lower_tri)
343
+ u.solve(z, form: :upper_tri)
344
+ else
345
+ raise(ArgumentError, "#{opts[:form]} is not a valid form option")
346
+ end
291
347
  end
292
348
 
293
349
  #