nmatrix 0.1.0.rc5 → 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (44) hide show
  1. checksums.yaml +4 -4
  2. data/.travis.yml +0 -1
  3. data/Gemfile +0 -2
  4. data/History.txt +39 -4
  5. data/LICENSE.txt +3 -1
  6. data/Manifest.txt +2 -0
  7. data/README.rdoc +6 -14
  8. data/Rakefile +4 -1
  9. data/ext/nmatrix/data/data.cpp +1 -1
  10. data/ext/nmatrix/data/data.h +2 -1
  11. data/ext/nmatrix/data/rational.h +230 -226
  12. data/ext/nmatrix/extconf.rb +7 -4
  13. data/ext/nmatrix/math.cpp +259 -172
  14. data/ext/nmatrix/math/getri.h +2 -2
  15. data/ext/nmatrix/math/math.h +1 -1
  16. data/ext/nmatrix/ruby_constants.cpp +0 -1
  17. data/ext/nmatrix/ruby_nmatrix.c +55 -32
  18. data/ext/nmatrix/storage/dense/dense.cpp +1 -0
  19. data/ext/nmatrix/storage/yale/yale.cpp +12 -14
  20. data/ext/nmatrix/ttable_helper.rb +0 -1
  21. data/lib/nmatrix.rb +5 -0
  22. data/lib/nmatrix/homogeneous.rb +98 -0
  23. data/lib/nmatrix/io/fortran_format.rb +135 -0
  24. data/lib/nmatrix/io/harwell_boeing.rb +220 -0
  25. data/lib/nmatrix/io/market.rb +18 -8
  26. data/lib/nmatrix/io/mat5_reader.rb +16 -111
  27. data/lib/nmatrix/io/mat_reader.rb +3 -5
  28. data/lib/nmatrix/io/point_cloud.rb +27 -28
  29. data/lib/nmatrix/lapack.rb +3 -1
  30. data/lib/nmatrix/math.rb +112 -43
  31. data/lib/nmatrix/monkeys.rb +67 -11
  32. data/lib/nmatrix/nmatrix.rb +56 -33
  33. data/lib/nmatrix/rspec.rb +2 -2
  34. data/lib/nmatrix/shortcuts.rb +42 -25
  35. data/lib/nmatrix/version.rb +4 -4
  36. data/nmatrix.gemspec +4 -3
  37. data/spec/03_nmatrix_monkeys_spec.rb +72 -0
  38. data/spec/blas_spec.rb +4 -0
  39. data/spec/homogeneous_spec.rb +12 -4
  40. data/spec/io/fortran_format_spec.rb +88 -0
  41. data/spec/io/harwell_boeing_spec.rb +98 -0
  42. data/spec/io/test.rua +9 -0
  43. data/spec/math_spec.rb +51 -9
  44. metadata +38 -9
@@ -68,7 +68,7 @@ inline int getri(const enum CBLAS_ORDER order, const int n, DType* a, const int
68
68
  return 0;
69
69
  }
70
70
 
71
- #ifdef HAVE_CLAPACK_H
71
+ #if defined (HAVE_CLAPACK_H) || defined (HAVE_ATLAS_CLAPACK_H)
72
72
  template <>
73
73
  inline int getri(const enum CBLAS_ORDER order, const int n, float* a, const int lda, const int* ipiv) {
74
74
  return clapack_sgetri(order, n, a, lda, ipiv);
@@ -105,4 +105,4 @@ inline int clapack_getri(const enum CBLAS_ORDER order, const int n, void* a, con
105
105
 
106
106
  } } // end nm::math
107
107
 
108
- #endif // GETRI_H
108
+ #endif // GETRI_H
@@ -104,10 +104,10 @@ extern "C" {
104
104
  * C accessors.
105
105
  */
106
106
  void nm_math_det_exact(const int M, const void* elements, const int lda, nm::dtype_t dtype, void* result);
107
+ void nm_math_inverse(const int M, void* A_elements, nm::dtype_t dtype);
107
108
  void nm_math_inverse_exact(const int M, const void* A_elements, const int lda, void* B_elements, const int ldb, nm::dtype_t dtype);
108
109
  void nm_math_transpose_generic(const size_t M, const size_t N, const void* A, const int lda, void* B, const int ldb, size_t element_size);
109
110
  void nm_math_init_blas(void);
110
-
111
111
  }
112
112
 
113
113
 
@@ -148,7 +148,6 @@ void nm_init_ruby_constants(void) {
148
148
  nm_rb_column = rb_intern("column");
149
149
  nm_rb_row = rb_intern("row");
150
150
 
151
- //Added by Ryan
152
151
  nm_rb_both = rb_intern("both");
153
152
  nm_rb_none = rb_intern("none");
154
153
  }
@@ -134,6 +134,7 @@ DECL_UNARY_RUBY_ACCESSOR(gamma)
134
134
  DECL_UNARY_RUBY_ACCESSOR(negate)
135
135
  DECL_UNARY_RUBY_ACCESSOR(floor)
136
136
  DECL_UNARY_RUBY_ACCESSOR(ceil)
137
+ DECL_UNARY_RUBY_ACCESSOR(round)
137
138
  DECL_NONCOM_ELEMENTWISE_RUBY_ACCESSOR(atan2)
138
139
  DECL_NONCOM_ELEMENTWISE_RUBY_ACCESSOR(ldexp)
139
140
  DECL_NONCOM_ELEMENTWISE_RUBY_ACCESSOR(hypot)
@@ -154,7 +155,8 @@ static VALUE matrix_multiply_scalar(NMATRIX* left, VALUE scalar);
154
155
  static VALUE matrix_multiply(NMATRIX* left, NMATRIX* right);
155
156
  static VALUE nm_multiply(VALUE left_v, VALUE right_v);
156
157
  static VALUE nm_det_exact(VALUE self);
157
- static VALUE nm_inverse_exact(VALUE self, VALUE inverse);
158
+ static VALUE nm_inverse(VALUE self, VALUE inverse, VALUE bang);
159
+ static VALUE nm_inverse_exact(VALUE self, VALUE inverse, VALUE lda, VALUE ldb);
158
160
  static VALUE nm_complex_conjugate_bang(VALUE self);
159
161
  static VALUE nm_complex_conjugate(VALUE self);
160
162
  static VALUE nm_reshape_bang(VALUE self, VALUE arg);
@@ -208,6 +210,7 @@ void Init_nmatrix() {
208
210
  nm_eNotInvertibleError = rb_define_class("NotInvertibleError", rb_eStandardError);
209
211
 
210
212
  /*
213
+ * :nodoc:
211
214
  * Class that holds values in use by the C code.
212
215
  */
213
216
  cNMatrix_GC_holder = rb_define_class("NMGCHolder", rb_cObject);
@@ -260,7 +263,8 @@ void Init_nmatrix() {
260
263
  rb_define_method(cNMatrix, "supershape", (METHOD)nm_supershape, 0);
261
264
  rb_define_method(cNMatrix, "offset", (METHOD)nm_offset, 0);
262
265
  rb_define_method(cNMatrix, "det_exact", (METHOD)nm_det_exact, 0);
263
- rb_define_protected_method(cNMatrix, "__inverse_exact__", (METHOD)nm_inverse_exact, 1);
266
+ rb_define_protected_method(cNMatrix, "__inverse__", (METHOD)nm_inverse, 2);
267
+ rb_define_protected_method(cNMatrix, "__inverse_exact__", (METHOD)nm_inverse_exact, 3);
264
268
  rb_define_method(cNMatrix, "complex_conjugate!", (METHOD)nm_complex_conjugate_bang, 0);
265
269
  rb_define_method(cNMatrix, "complex_conjugate", (METHOD)nm_complex_conjugate, 0);
266
270
  rb_define_protected_method(cNMatrix, "reshape_bang", (METHOD)nm_reshape_bang, 1);
@@ -316,6 +320,8 @@ void Init_nmatrix() {
316
320
  rb_define_method(cNMatrix, "-@", (METHOD)nm_unary_negate,0);
317
321
  rb_define_method(cNMatrix, "floor", (METHOD)nm_unary_floor, 0);
318
322
  rb_define_method(cNMatrix, "ceil", (METHOD)nm_unary_ceil, 0);
323
+ rb_define_method(cNMatrix, "round", (METHOD)nm_unary_round, 0);
324
+
319
325
 
320
326
  rb_define_method(cNMatrix, "=~", (METHOD)nm_ew_eqeq, 1);
321
327
  rb_define_method(cNMatrix, "!~", (METHOD)nm_ew_neq, 1);
@@ -445,11 +451,11 @@ static VALUE nm_capacity(VALUE self) {
445
451
  break;
446
452
 
447
453
  default:
448
- NM_CONSERVATIVE(nm_unregister_value(self));
454
+ NM_CONSERVATIVE(nm_unregister_value(&self));
449
455
  rb_raise(nm_eStorageTypeError, "unrecognized stype in nm_capacity()");
450
456
  }
451
457
 
452
- NM_CONSERVATIVE(nm_unregister_value(self));
458
+ NM_CONSERVATIVE(nm_unregister_value(&self));
453
459
  return cap;
454
460
  }
455
461
 
@@ -740,11 +746,11 @@ static VALUE nm_each_with_indices(VALUE nmatrix) {
740
746
  to_return = nm_list_each_with_indices(nmatrix, false);
741
747
  break;
742
748
  default:
743
- NM_CONSERVATIVE(nm_unregister_value(nmatrix));
749
+ NM_CONSERVATIVE(nm_unregister_value(&nmatrix));
744
750
  rb_raise(nm_eDataTypeError, "Not a proper storage type");
745
751
  }
746
752
 
747
- NM_CONSERVATIVE(nm_unregister_value(nmatrix));
753
+ NM_CONSERVATIVE(nm_unregister_value(&nmatrix));
748
754
  return to_return;
749
755
  }
750
756
 
@@ -771,11 +777,11 @@ static VALUE nm_each_stored_with_indices(VALUE nmatrix) {
771
777
  to_return = nm_list_each_with_indices(nmatrix, true);
772
778
  break;
773
779
  default:
774
- NM_CONSERVATIVE(nm_unregister_value(nmatrix));
780
+ NM_CONSERVATIVE(nm_unregister_value(&nmatrix));
775
781
  rb_raise(nm_eDataTypeError, "Not a proper storage type");
776
782
  }
777
783
 
778
- NM_CONSERVATIVE(nm_unregister_value(nmatrix));
784
+ NM_CONSERVATIVE(nm_unregister_value(&nmatrix));
779
785
  return to_return;
780
786
  }
781
787
 
@@ -803,11 +809,11 @@ static VALUE nm_map_stored(VALUE nmatrix) {
803
809
  to_return = nm_list_map_stored(nmatrix, Qnil);
804
810
  break;
805
811
  default:
806
- NM_CONSERVATIVE(nm_unregister_value(nmatrix));
812
+ NM_CONSERVATIVE(nm_unregister_value(&nmatrix));
807
813
  rb_raise(nm_eDataTypeError, "Not a proper storage type");
808
814
  }
809
815
 
810
- NM_CONSERVATIVE(nm_unregister_value(nmatrix));
816
+ NM_CONSERVATIVE(nm_unregister_value(&nmatrix));
811
817
  return to_return;
812
818
  }
813
819
 
@@ -833,11 +839,11 @@ static VALUE nm_each_ordered_stored_with_indices(VALUE nmatrix) {
833
839
  to_return = nm_list_each_with_indices(nmatrix, true);
834
840
  break;
835
841
  default:
836
- NM_CONSERVATIVE(nm_unregister_value(nmatrix));
842
+ NM_CONSERVATIVE(nm_unregister_value(&nmatrix));
837
843
  rb_raise(nm_eDataTypeError, "Not a proper storage type");
838
844
  }
839
845
 
840
- NM_CONSERVATIVE(nm_unregister_value(nmatrix));
846
+ NM_CONSERVATIVE(nm_unregister_value(&nmatrix));
841
847
  return to_return;
842
848
  }
843
849
 
@@ -932,6 +938,7 @@ DEF_UNARY_RUBY_ACCESSOR(GAMMA, gamma)
932
938
  DEF_UNARY_RUBY_ACCESSOR(NEGATE, negate)
933
939
  DEF_UNARY_RUBY_ACCESSOR(FLOOR, floor)
934
940
  DEF_UNARY_RUBY_ACCESSOR(CEIL, ceil)
941
+ DEF_UNARY_RUBY_ACCESSOR(ROUND, round)
935
942
 
936
943
  DEF_NONCOM_ELEMENTWISE_RUBY_ACCESSOR(ATAN2, atan2)
937
944
  DEF_NONCOM_ELEMENTWISE_RUBY_ACCESSOR(LDEXP, ldexp)
@@ -1109,24 +1116,8 @@ static VALUE nm_init_new_version(int argc, VALUE* argv, VALUE self) {
1109
1116
  VALUE shape_ary, initial_ary, hash;
1110
1117
  //VALUE shape_ary, default_val, capacity, initial_ary, dtype_sym, stype_sym;
1111
1118
  // Mandatory args: shape, dtype, stype
1112
- // FIXME: This is the one line of code standing between Ruby 1.9.2 and 1.9.3.
1113
- #ifndef OLD_RB_SCAN_ARGS // Ruby 1.9.3 and higher
1114
1119
  rb_scan_args(argc, argv, "11:", &shape_ary, &initial_ary, &hash); // &stype_sym, &dtype_sym, &default_val, &capacity);
1115
- #else // Ruby 1.9.2 and lower
1116
- if (argc == 3)
1117
- rb_scan_args(argc, argv, "12", &shape_ary, &initial_ary, &hash);
1118
- else if (argc == 2) {
1119
- VALUE unknown_arg;
1120
- rb_scan_args(argc, argv, "11", &shape_ary, &unknown_arg);
1121
- if (!NIL_P(unknown_arg) && TYPE(unknown_arg) == T_HASH) {
1122
- hash = unknown_arg;
1123
- initial_ary = Qnil;
1124
- } else {
1125
- initial_ary = unknown_arg;
1126
- hash = Qnil;
1127
- }
1128
- }
1129
- #endif
1120
+
1130
1121
  NM_CONSERVATIVE(nm_register_value(&shape_ary));
1131
1122
  NM_CONSERVATIVE(nm_register_value(&initial_ary));
1132
1123
  NM_CONSERVATIVE(nm_register_value(&hash));
@@ -2947,12 +2938,43 @@ static VALUE matrix_multiply(NMATRIX* left, NMATRIX* right) {
2947
2938
  }
2948
2939
 
2949
2940
 
2941
+ /*
2942
+ * Calculate the inverse of a matrix with in-place Gauss-Jordan elimination.
2943
+ * Inverse will fail if the largest element in any column in zero.
2944
+ *
2945
+ * LAPACK free.
2946
+ */
2947
+ static VALUE nm_inverse(VALUE self, VALUE inverse, VALUE bang) {
2948
+
2949
+ if (NM_STYPE(self) != nm::DENSE_STORE) {
2950
+ rb_raise(rb_eNotImpError, "needs exact determinant implementation for this matrix stype");
2951
+ return Qnil;
2952
+ }
2953
+
2954
+ if (NM_DIM(self) != 2 || NM_SHAPE0(self) != NM_SHAPE1(self)) {
2955
+ rb_raise(nm_eShapeError, "matrices must be square to have an inverse defined");
2956
+ return Qnil;
2957
+ }
2958
+
2959
+ if (bang == Qtrue) {
2960
+ nm_math_inverse(NM_SHAPE0(self), NM_STORAGE_DENSE(self)->elements,
2961
+ NM_DTYPE(self));
2962
+
2963
+ return self;
2964
+ }
2965
+
2966
+ nm_math_inverse(NM_SHAPE0(inverse), NM_STORAGE_DENSE(inverse)->elements,
2967
+ NM_DTYPE(inverse));
2968
+
2969
+ return inverse;
2970
+ }
2971
+
2950
2972
  /*
2951
2973
  * Calculate the exact inverse of a 2x2 or 3x3 matrix.
2952
2974
  *
2953
2975
  * Does not test for invertibility!
2954
2976
  */
2955
- static VALUE nm_inverse_exact(VALUE self, VALUE inverse) {
2977
+ static VALUE nm_inverse_exact(VALUE self, VALUE inverse, VALUE lda, VALUE ldb) {
2956
2978
 
2957
2979
  if (NM_STYPE(self) != nm::DENSE_STORE) {
2958
2980
  rb_raise(rb_eNotImpError, "needs exact determinant implementation for this matrix stype");
@@ -2964,8 +2986,9 @@ static VALUE nm_inverse_exact(VALUE self, VALUE inverse) {
2964
2986
  return Qnil;
2965
2987
  }
2966
2988
 
2967
- // Calculate the exact inverse.
2968
- nm_math_inverse_exact(NM_SHAPE0(self), NM_STORAGE_DENSE(self)->elements, NM_SHAPE0(self), NM_STORAGE_DENSE(inverse)->elements, NM_SHAPE0(inverse), NM_DTYPE(self));
2989
+ nm_math_inverse_exact(NM_SHAPE0(self),
2990
+ NM_STORAGE_DENSE(self)->elements, FIX2INT(lda),
2991
+ NM_STORAGE_DENSE(inverse)->elements, FIX2INT(ldb), NM_DTYPE(self));
2969
2992
 
2970
2993
  return inverse;
2971
2994
  }
@@ -979,6 +979,7 @@ bool eqeq(const DENSE_STORAGE* left, const DENSE_STORAGE* right) {
979
979
  if (left->dim != right->dim) {
980
980
  nm_dense_storage_unregister(right);
981
981
  nm_dense_storage_unregister(left);
982
+
982
983
  return false;
983
984
  }
984
985
 
@@ -549,7 +549,7 @@ static char vector_insert_resize(YALE_STORAGE* s, size_t current_size, size_t po
549
549
  NM_FREE(s->ija);
550
550
  nm_yale_storage_unregister(s);
551
551
  NM_FREE(s->a);
552
-
552
+
553
553
  if (s->dtype == nm::RUBYOBJ)
554
554
  nm_yale_storage_unregister_a(new_a, new_capacity);
555
555
 
@@ -943,7 +943,7 @@ static VALUE map_stored(VALUE self) {
943
943
  NM_CONSERVATIVE(nm_register_value(&self));
944
944
  YALE_STORAGE* s = NM_STORAGE_YALE(self);
945
945
  YaleStorage<D> y(s);
946
-
946
+
947
947
  RETURN_SIZED_ENUMERATOR_PRE
948
948
  NM_CONSERVATIVE(nm_unregister_value(&self));
949
949
  RETURN_SIZED_ENUMERATOR(self, 0, 0, nm_yale_stored_enumerator_length);
@@ -1014,7 +1014,7 @@ static VALUE stored_diagonal_each_with_indices(VALUE nm) {
1014
1014
  RETURN_SIZED_ENUMERATOR_PRE
1015
1015
  NM_CONSERVATIVE(nm_unregister_value(&nm));
1016
1016
  RETURN_SIZED_ENUMERATOR(nm, 0, 0, nm_yale_stored_diagonal_length); // FIXME: need diagonal length
1017
-
1017
+
1018
1018
  for (typename YaleStorage<DType>::const_stored_diagonal_iterator d = y.csdbegin(); d != y.csdend(); ++d) {
1019
1019
  rb_yield_values(3, ~d, d.rb_i(), d.rb_j());
1020
1020
  }
@@ -1106,10 +1106,8 @@ static bool is_pos_default_value(YALE_STORAGE* s, size_t apos) {
1106
1106
  return y.is_pos_default_value(apos);
1107
1107
  }
1108
1108
 
1109
-
1110
1109
  } // end of namespace nm::yale_storage
1111
1110
 
1112
-
1113
1111
  } // end of namespace nm.
1114
1112
 
1115
1113
  ///////////////////
@@ -1123,7 +1121,7 @@ extern "C" {
1123
1121
  void nm_init_yale_functions() {
1124
1122
  /*
1125
1123
  * This module stores methods that are useful for debugging Yale matrices,
1126
- * i.e. the ones with +:yale+ stype.
1124
+ * i.e. the ones with +:yale+ stype.
1127
1125
  */
1128
1126
  cNMatrix_YaleFunctions = rb_define_module_under(cNMatrix, "YaleFunctions");
1129
1127
 
@@ -1141,10 +1139,13 @@ void nm_init_yale_functions() {
1141
1139
 
1142
1140
  rb_define_method(cNMatrix_YaleFunctions, "yale_nd_row", (METHOD)nm_nd_row, -1);
1143
1141
 
1142
+ /* Document-const:
1143
+ * Defines the growth rate of the sparse NMatrix's size. Default is 1.5.
1144
+ */
1144
1145
  rb_define_const(cNMatrix_YaleFunctions, "YALE_GROWTH_CONSTANT", rb_float_new(nm::yale_storage::GROWTH_CONSTANT));
1145
1146
 
1146
1147
  // This is so the user can easily check the IType size, mostly for debugging.
1147
- size_t itype_size = sizeof(IType);
1148
+ size_t itype_size = sizeof(IType);
1148
1149
  VALUE itype_dtype;
1149
1150
  if (itype_size == sizeof(uint64_t)) {
1150
1151
  itype_dtype = ID2SYM(rb_intern("int64"));
@@ -1158,12 +1159,10 @@ void nm_init_yale_functions() {
1158
1159
  rb_define_const(cNMatrix, "INDEX_DTYPE", itype_dtype);
1159
1160
  }
1160
1161
 
1161
-
1162
1162
  /////////////////
1163
1163
  // C ACCESSORS //
1164
1164
  /////////////////
1165
1165
 
1166
-
1167
1166
  /* C interface for NMatrix#each_with_indices (Yale) */
1168
1167
  VALUE nm_yale_each_with_indices(VALUE nmatrix) {
1169
1168
  NAMED_DTYPE_TEMPLATE_TABLE(ttable, nm::yale_storage::each_with_indices, VALUE, VALUE)
@@ -1555,7 +1554,7 @@ static bool is_pos_default_value(YALE_STORAGE* s, size_t apos) {
1555
1554
  * Only checks the stored indices; does not care about matrix default value.
1556
1555
  */
1557
1556
  static VALUE nm_row_keys_intersection(VALUE m1, VALUE ii1, VALUE m2, VALUE ii2) {
1558
-
1557
+
1559
1558
  NM_CONSERVATIVE(nm_register_value(&m1));
1560
1559
  NM_CONSERVATIVE(nm_register_value(&m2));
1561
1560
 
@@ -1658,7 +1657,7 @@ static VALUE nm_a(int argc, VALUE* argv, VALUE self) {
1658
1657
  VALUE* vals = NM_ALLOCA_N(VALUE, size);
1659
1658
 
1660
1659
  nm_register_values(vals, size);
1661
-
1660
+
1662
1661
  if (NM_DTYPE(self) == nm::RUBYOBJ) {
1663
1662
  for (size_t i = 0; i < size; ++i) {
1664
1663
  vals[i] = reinterpret_cast<VALUE*>(s->a)[i];
@@ -1786,7 +1785,7 @@ static VALUE nm_ia(VALUE self) {
1786
1785
  vals[i] = INT2FIX(s->ija[i]);
1787
1786
  }
1788
1787
 
1789
- NM_CONSERVATIVE(nm_unregister_value(&self));
1788
+ NM_CONSERVATIVE(nm_unregister_value(&self));
1790
1789
 
1791
1790
  return rb_ary_new4(s->shape[0]+1, vals);
1792
1791
  }
@@ -1887,11 +1886,10 @@ static VALUE nm_ija(int argc, VALUE* argv, VALUE self) {
1887
1886
  static VALUE nm_nd_row(int argc, VALUE* argv, VALUE self) {
1888
1887
 
1889
1888
  NM_CONSERVATIVE(nm_register_value(&self));
1890
-
1891
1889
  if (NM_SRC(self) != NM_STORAGE(self)) {
1892
1890
  NM_CONSERVATIVE(nm_unregister_value(&self));
1893
1891
  rb_raise(rb_eNotImpError, "must be called on a real matrix and not a slice");
1894
- }
1892
+ }
1895
1893
 
1896
1894
  VALUE i_, as;
1897
1895
  rb_scan_args(argc, argv, "11", &i_, &as);
@@ -2,7 +2,6 @@
2
2
 
3
3
  # A helper file for generating and maintaining template tables.
4
4
 
5
-
6
5
  DTYPES = [
7
6
  :uint8_t,
8
7
  :int8_t,
@@ -34,6 +34,11 @@ else
34
34
  require "nmatrix.so"
35
35
  end
36
36
 
37
+ require 'nmatrix/io/mat_reader'
38
+ require 'nmatrix/io/mat5_reader'
39
+ require 'nmatrix/io/market'
40
+ require 'nmatrix/io/point_cloud'
41
+
37
42
  require 'nmatrix/nmatrix.rb'
38
43
  require 'nmatrix/version.rb'
39
44
  require 'nmatrix/blas.rb'
@@ -140,4 +140,102 @@ class NMatrix
140
140
  n
141
141
  end
142
142
  end
143
+
144
+ #
145
+ # call-seq:
146
+ # quaternion -> NMatrix
147
+ #
148
+ # Find the quaternion for a 3D rotation matrix.
149
+ #
150
+ # Code borrowed from: http://courses.cms.caltech.edu/cs171/quatut.pdf
151
+ #
152
+ # * *Returns* :
153
+ # - A length-4 NMatrix representing the corresponding quaternion.
154
+ #
155
+ # Examples:
156
+ #
157
+ # n.quaternion # => [1, 0, 0, 0]
158
+ #
159
+ def quaternion
160
+ raise(ShapeError, "Expected square matrix") if self.shape[0] != self.shape[1]
161
+ raise(ShapeError, "Expected 3x3 rotation (or 4x4 homogeneous) matrix") if self.shape[0] > 4 || self.shape[0] < 3
162
+
163
+ q = NMatrix.new([4], dtype: self.dtype == :float32 ? :float32: :float64)
164
+ rotation_trace = self[0,0] + self[1,1] + self[2,2]
165
+ if rotation_trace >= 0
166
+ self_w = self.shape[0] == 4 ? self[3,3] : 1.0
167
+ root_of_homogeneous_trace = Math.sqrt(rotation_trace + self_w)
168
+ q[0] = root_of_homogeneous_trace * 0.5
169
+ s = 0.5 / root_of_homogeneous_trace
170
+ q[1] = (self[2,1] - self[1,2]) * s
171
+ q[2] = (self[0,2] - self[2,0]) * s
172
+ q[3] = (self[1,0] - self[0,1]) * s
173
+ else
174
+ h = 0
175
+ h = 1 if self[1,1] > self[0,0]
176
+ h = 2 if self[2,2] > self[h,h]
177
+
178
+ case_macro = Proc.new do |i,j,k,ii,jj,kk|
179
+ qq = NMatrix.new([4], dtype: :float64)
180
+ self_w = self.shape[0] == 4 ? self[3,3] : 1.0
181
+ s = Math.sqrt( (self[ii,ii] - (self[jj,jj] + self[kk,kk])) + self_w)
182
+ qq[i] = s*0.5
183
+ s = 0.5 / s
184
+ qq[j] = (self[ii,jj] + self[jj,ii]) * s
185
+ qq[k] = (self[kk,ii] + self[ii,kk]) * s
186
+ qq[0] = (self[kk,jj] - self[jj,kk]) * s
187
+ qq
188
+ end
189
+
190
+ case h
191
+ when 0
192
+ q = case_macro.call(1,2,3, 0,1,2)
193
+ when 1
194
+ q = case_macro.call(2,3,1, 1,2,0)
195
+ when 2
196
+ q = case_macro.call(3,1,2, 2,0,1)
197
+ end
198
+
199
+ self_w = self.shape[0] == 4 ? self[3,3] : 1.0
200
+ if self_w != 1
201
+ s = 1.0 / Math.sqrt(self_w)
202
+ q[0] *= s
203
+ q[1] *= s
204
+ q[2] *= s
205
+ q[3] *= s
206
+ end
207
+ end
208
+
209
+ q
210
+ end
211
+
212
+ #
213
+ # call-seq:
214
+ # angle_vector -> [angle, about_vector]
215
+ #
216
+ # Find the angle vector for a quaternion. Assumes the quaternion has unit length.
217
+ #
218
+ # Source: http://www.euclideanspace.com/maths/geometry/rotations/conversions/quaternionToAngle/
219
+ #
220
+ # * *Returns* :
221
+ # - An angle (in radians) describing the rotation about the +about_vector+.
222
+ # - A length-3 NMatrix representing the corresponding quaternion.
223
+ #
224
+ # Examples:
225
+ #
226
+ # q.angle_vector # => [1, 0, 0, 0]
227
+ #
228
+ def angle_vector
229
+ raise(ShapeError, "Expected length-4 vector or matrix (quaternion)") if self.shape[0] != 4
230
+ raise("Expected unit quaternion") if self[0] > 1
231
+
232
+ xyz = NMatrix.new([3], dtype: self.dtype)
233
+
234
+ angle = 2 * Math.acos(self[0])
235
+ s = Math.sqrt(1.0 - self[0]*self[0])
236
+
237
+ xyz[0..2] = self[1..3]
238
+ xyz /= s if s >= 0.001 # avoid divide by zero
239
+ return [angle, xyz]
240
+ end
143
241
  end