nmatrix-fftw 0.2.1 → 0.2.3

Sign up to get free protection for your applications and to get access to all the features.
@@ -81,10 +81,14 @@ inline void trsm_nothrow(const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
81
81
  // (row-major) trsm: left upper trans nonunit m=3 n=1 1/1 a 3 b 3
82
82
 
83
83
  if (m == 0 || n == 0) return; /* Quick return if possible. */
84
+
85
+ // Apply necessary offset
86
+ a -= 1 + lda;
87
+ b -= 1 + ldb;
84
88
 
85
89
  if (alpha == 0) { // Handle alpha == 0
86
- for (int j = 0; j < n; ++j) {
87
- for (int i = 0; i < m; ++i) {
90
+ for (int j = 1; j <= n; ++j) {
91
+ for (int i = 1; i <= m; ++i) {
88
92
  b[i + j * ldb] = 0;
89
93
  }
90
94
  }
@@ -96,37 +100,37 @@ inline void trsm_nothrow(const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
96
100
 
97
101
  /* Form B := alpha*inv( A )*B. */
98
102
  if (uplo == CblasUpper) {
99
- for (int j = 0; j < n; ++j) {
103
+ for (int j = 1; j <= n; ++j) {
100
104
  if (alpha != 1) {
101
- for (int i = 0; i < m; ++i) {
105
+ for (int i = 1; i <= m; ++i) {
102
106
  b[i + j * ldb] = alpha * b[i + j * ldb];
103
107
  }
104
108
  }
105
- for (int k = m-1; k >= 0; --k) {
109
+ for (int k = m; k >= 1; --k) {
106
110
  if (b[k + j * ldb] != 0) {
107
111
  if (diag == CblasNonUnit) {
108
112
  b[k + j * ldb] /= a[k + k * lda];
109
113
  }
110
114
 
111
- for (int i = 0; i < k-1; ++i) {
115
+ for (int i = 1; i <= k-1; ++i) {
112
116
  b[i + j * ldb] -= b[k + j * ldb] * a[i + k * lda];
113
117
  }
114
118
  }
115
119
  }
116
120
  }
117
121
  } else {
118
- for (int j = 0; j < n; ++j) {
122
+ for (int j = 1; j <= n; ++j) {
119
123
  if (alpha != 1) {
120
- for (int i = 0; i < m; ++i) {
124
+ for (int i = 1; i <= m; ++i) {
121
125
  b[i + j * ldb] = alpha * b[i + j * ldb];
122
126
  }
123
127
  }
124
- for (int k = 0; k < m; ++k) {
128
+ for (int k = 1; k <= m; ++k) {
125
129
  if (b[k + j * ldb] != 0.) {
126
130
  if (diag == CblasNonUnit) {
127
131
  b[k + j * ldb] /= a[k + k * lda];
128
132
  }
129
- for (int i = k+1; i < m; ++i) {
133
+ for (int i = k+1; i <= m; ++i) {
130
134
  b[i + j * ldb] -= b[k + j * ldb] * a[i + k * lda];
131
135
  }
132
136
  }
@@ -137,10 +141,10 @@ inline void trsm_nothrow(const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
137
141
 
138
142
  /* Form B := alpha*inv( A**T )*B. */
139
143
  if (uplo == CblasUpper) {
140
- for (int j = 0; j < n; ++j) {
141
- for (int i = 0; i < m; ++i) {
144
+ for (int j = 1; j <= n; ++j) {
145
+ for (int i = 1; i <= m; ++i) {
142
146
  DType temp = alpha * b[i + j * ldb];
143
- for (int k = 0; k < i; ++k) { // limit was i-1. Lots of similar bugs in this code, probably.
147
+ for (int k = 1; k <= i-1; ++k) { // limit was i-1. Lots of similar bugs in this code, probably.
144
148
  temp -= a[k + i * lda] * b[k + j * ldb];
145
149
  }
146
150
  if (diag == CblasNonUnit) {
@@ -150,10 +154,10 @@ inline void trsm_nothrow(const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
150
154
  }
151
155
  }
152
156
  } else {
153
- for (int j = 0; j < n; ++j) {
154
- for (int i = m-1; i >= 0; --i) {
157
+ for (int j = 1; j <= n; ++j) {
158
+ for (int i = m; i >= 1; --i) {
155
159
  DType temp= alpha * b[i + j * ldb];
156
- for (int k = i+1; k < m; ++k) {
160
+ for (int k = i+1; k <= m; ++k) {
157
161
  temp -= a[k + i * lda] * b[k + j * ldb];
158
162
  }
159
163
  if (diag == CblasNonUnit) {
@@ -171,37 +175,37 @@ inline void trsm_nothrow(const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
171
175
  /* Form B := alpha*B*inv( A ). */
172
176
 
173
177
  if (uplo == CblasUpper) {
174
- for (int j = 0; j < n; ++j) {
178
+ for (int j = 1; j <= n; ++j) {
175
179
  if (alpha != 1) {
176
- for (int i = 0; i < m; ++i) {
180
+ for (int i = 1; i <= m; ++i) {
177
181
  b[i + j * ldb] = alpha * b[i + j * ldb];
178
182
  }
179
183
  }
180
- for (int k = 0; k < j-1; ++k) {
184
+ for (int k = 1; k <= j-1; ++k) {
181
185
  if (a[k + j * lda] != 0) {
182
- for (int i = 0; i < m; ++i) {
186
+ for (int i = 1; i <= m; ++i) {
183
187
  b[i + j * ldb] -= a[k + j * lda] * b[i + k * ldb];
184
188
  }
185
189
  }
186
190
  }
187
191
  if (diag == CblasNonUnit) {
188
192
  DType temp = 1 / a[j + j * lda];
189
- for (int i = 0; i < m; ++i) {
193
+ for (int i = 1; i <= m; ++i) {
190
194
  b[i + j * ldb] = temp * b[i + j * ldb];
191
195
  }
192
196
  }
193
197
  }
194
198
  } else {
195
- for (int j = n-1; j >= 0; --j) {
199
+ for (int j = n; j >= 1; --j) {
196
200
  if (alpha != 1) {
197
- for (int i = 0; i < m; ++i) {
201
+ for (int i = 1; i <= m; ++i) {
198
202
  b[i + j * ldb] = alpha * b[i + j * ldb];
199
203
  }
200
204
  }
201
205
 
202
- for (int k = j+1; k < n; ++k) {
206
+ for (int k = j+1; k <= n; ++k) {
203
207
  if (a[k + j * lda] != 0.) {
204
- for (int i = 0; i < m; ++i) {
208
+ for (int i = 1; i <= m; ++i) {
205
209
  b[i + j * ldb] -= a[k + j * lda] * b[i + k * ldb];
206
210
  }
207
211
  }
@@ -209,7 +213,7 @@ inline void trsm_nothrow(const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
209
213
  if (diag == CblasNonUnit) {
210
214
  DType temp = 1 / a[j + j * lda];
211
215
 
212
- for (int i = 0; i < m; ++i) {
216
+ for (int i = 1; i <= m; ++i) {
213
217
  b[i + j * ldb] = temp * b[i + j * ldb];
214
218
  }
215
219
  }
@@ -220,45 +224,45 @@ inline void trsm_nothrow(const enum CBLAS_SIDE side, const enum CBLAS_UPLO uplo,
220
224
  /* Form B := alpha*B*inv( A**T ). */
221
225
 
222
226
  if (uplo == CblasUpper) {
223
- for (int k = n-1; k >= 0; --k) {
227
+ for (int k = n; k >= 1; --k) {
224
228
  if (diag == CblasNonUnit) {
225
229
  DType temp= 1 / a[k + k * lda];
226
- for (int i = 0; i < m; ++i) {
230
+ for (int i = 1; i <= m; ++i) {
227
231
  b[i + k * ldb] = temp * b[i + k * ldb];
228
232
  }
229
233
  }
230
- for (int j = 0; j < k-1; ++j) {
234
+ for (int j = 1; j <= k-1; ++j) {
231
235
  if (a[j + k * lda] != 0.) {
232
236
  DType temp= a[j + k * lda];
233
- for (int i = 0; i < m; ++i) {
237
+ for (int i = 1; i <= m; ++i) {
234
238
  b[i + j * ldb] -= temp * b[i + k * ldb];
235
239
  }
236
240
  }
237
241
  }
238
242
  if (alpha != 1) {
239
- for (int i = 0; i < m; ++i) {
243
+ for (int i = 1; i <= m; ++i) {
240
244
  b[i + k * ldb] = alpha * b[i + k * ldb];
241
245
  }
242
246
  }
243
247
  }
244
248
  } else {
245
- for (int k = 0; k < n; ++k) {
249
+ for (int k = 1; k <= n; ++k) {
246
250
  if (diag == CblasNonUnit) {
247
251
  DType temp = 1 / a[k + k * lda];
248
- for (int i = 0; i < m; ++i) {
252
+ for (int i = 1; i <= m; ++i) {
249
253
  b[i + k * ldb] = temp * b[i + k * ldb];
250
254
  }
251
255
  }
252
- for (int j = k+1; j < n; ++j) {
256
+ for (int j = k+1; j <= n; ++j) {
253
257
  if (a[j + k * lda] != 0.) {
254
258
  DType temp = a[j + k * lda];
255
- for (int i = 0; i < m; ++i) {
259
+ for (int i = 1; i <= m; ++i) {
256
260
  b[i + j * ldb] -= temp * b[i + k * ldb];
257
261
  }
258
262
  }
259
263
  }
260
264
  if (alpha != 1) {
261
- for (int i = 0; i < m; ++i) {
265
+ for (int i = 1; i <= m; ++i) {
262
266
  b[i + k * ldb] = alpha * b[i + k * ldb];
263
267
  }
264
268
  }
@@ -70,6 +70,20 @@ static inline enum CBLAS_SIDE blas_side_sym(VALUE op) {
70
70
  return CblasLeft;
71
71
  }
72
72
 
73
+ /*
74
+ * Interprets the LAPACK side argument which could be :left or :right
75
+ *
76
+ * Related to obtaining Q in QR factorization after calling lapack_geqrf
77
+ */
78
+
79
+ static inline char lapacke_side_sym(VALUE op) {
80
+ ID op_id = rb_to_id(op);
81
+ if (op_id == nm_rb_left) return 'L';
82
+ if (op_id == nm_rb_right) return 'R';
83
+ else rb_raise(rb_eArgError, "Expected :left or :right for side argument");
84
+ return 'L';
85
+ }
86
+
73
87
  /*
74
88
  * Interprets cblas argument which could be :upper or :lower
75
89
  *
@@ -33,6 +33,7 @@
33
33
  */
34
34
 
35
35
  #include <ruby.h>
36
+ #include "ruby_constants.h"
36
37
 
37
38
  #ifdef __cplusplus
38
39
  #include <cmath>
@@ -57,6 +58,28 @@
57
58
  #include "nm_memory.h"
58
59
  #endif
59
60
 
61
+ #ifndef RB_BUILTIN_TYPE
62
+ # define RB_BUILTIN_TYPE(obj) BUILTIN_TYPE(obj)
63
+ #endif
64
+
65
+ #ifndef RB_FLOAT_TYPE_P
66
+ /* NOTE: assume flonum doesn't exist */
67
+ # define RB_FLOAT_TYPE_P(obj) ( \
68
+ (!SPECIAL_CONST_P(obj) && BUILTIN_TYPE(obj) == T_FLOAT))
69
+ #endif
70
+
71
+ #ifndef RB_TYPE_P
72
+ # define RB_TYPE_P(obj, type) ( \
73
+ ((type) == T_FIXNUM) ? FIXNUM_P(obj) : \
74
+ ((type) == T_TRUE) ? ((obj) == Qtrue) : \
75
+ ((type) == T_FALSE) ? ((obj) == Qfalse) : \
76
+ ((type) == T_NIL) ? ((obj) == Qnil) : \
77
+ ((type) == T_UNDEF) ? ((obj) == Qundef) : \
78
+ ((type) == T_SYMBOL) ? SYMBOL_P(obj) : \
79
+ ((type) == T_FLOAT) ? RB_FLOAT_TYPE_P(obj) : \
80
+ (!SPECIAL_CONST_P(obj) && BUILTIN_TYPE(obj) == (type)))
81
+ #endif
82
+
60
83
  #ifndef FIX_CONST_VALUE_PTR
61
84
  # if defined(__fcc__) || defined(__fcc_version) || \
62
85
  defined(__FCC__) || defined(__FCC_VERSION)
@@ -343,11 +366,25 @@ NM_DEF_STRUCT_POST(NM_GC_HOLDER); // };
343
366
 
344
367
  #define NM_SRC(val) (NM_STORAGE(val)->src)
345
368
  #define NM_DIM(val) (NM_STORAGE(val)->dim)
369
+
370
+ // Returns an int corresponding the data type of the nmatrix. See the dtype_t
371
+ // enum for a list of possible data types.
346
372
  #define NM_DTYPE(val) (NM_STORAGE(val)->dtype)
373
+
374
+ // Returns a number corresponding the storage type of the nmatrix. See the stype_t
375
+ // enum for a list of possible storage types.
347
376
  #define NM_STYPE(val) (NM_STRUCT(val)->stype)
377
+
378
+ // Get the shape of the ith dimension (int)
348
379
  #define NM_SHAPE(val,i) (NM_STORAGE(val)->shape[(i)])
380
+
381
+ // Get the shape of the 0th dimension (int)
349
382
  #define NM_SHAPE0(val) (NM_STORAGE(val)->shape[0])
383
+
384
+ // Get the shape of the 1st dimenension (int)
350
385
  #define NM_SHAPE1(val) (NM_STORAGE(val)->shape[1])
386
+
387
+ // Get the default value assigned to the nmatrix.
351
388
  #define NM_DEFAULT_VAL(val) (NM_STORAGE_LIST(val)->default_val)
352
389
 
353
390
  // Number of elements in a dense nmatrix.
@@ -366,7 +403,8 @@ NM_DEF_STRUCT_POST(NM_GC_HOLDER); // };
366
403
 
367
404
  #define RB_FILE_EXISTS(fn) (rb_funcall(rb_const_get(rb_cObject, rb_intern("File")), rb_intern("exists?"), 1, (fn)) == Qtrue)
368
405
 
369
- #define CheckNMatrixType(v) if (TYPE(v) != T_DATA || (RDATA(v)->dfree != (RUBY_DATA_FUNC)nm_delete && RDATA(v)->dfree != (RUBY_DATA_FUNC)nm_delete_ref)) rb_raise(rb_eTypeError, "expected NMatrix on left-hand side of operation");
406
+ #define IsNMatrixType(v) (RB_TYPE_P(v, T_DATA) && (RDATA(v)->dfree == (RUBY_DATA_FUNC)nm_delete || RDATA(v)->dfree == (RUBY_DATA_FUNC)nm_delete_ref))
407
+ #define CheckNMatrixType(v) if (!IsNMatrixType(v)) rb_raise(rb_eTypeError, "expected NMatrix on left-hand side of operation");
370
408
 
371
409
  #define NM_IsNMatrix(obj) \
372
410
  (rb_obj_is_kind_of(obj, cNMatrix) == Qtrue)
@@ -34,6 +34,7 @@
34
34
 
35
35
  #include <ruby.h>
36
36
  #include <cmath> // pow().
37
+ #include <type_traits>
37
38
 
38
39
  /*
39
40
  * Project Includes
@@ -45,6 +46,11 @@
45
46
  * Macros
46
47
  */
47
48
 
49
+ #define u_int8_t static_assert(false, "Please use uint8_t for cross-platform support and consistency."); uint8_t
50
+ #define u_int16_t static_assert(false, "Please use uint16_t for cross-platform support and consistency."); uint16_t
51
+ #define u_int32_t static_assert(false, "Please use uint32_t for cross-platform support and consistency."); uint32_t
52
+ #define u_int64_t static_assert(false, "Please use uint64_t for cross-platform support and consistency."); uint64_t
53
+
48
54
  extern "C" {
49
55
 
50
56
  /*
@@ -152,7 +158,7 @@ namespace nm {
152
158
  EWOP_INT_INT_DIV(int16_t, int32_t)
153
159
  EWOP_INT_INT_DIV(int16_t, int64_t)
154
160
  EWOP_INT_INT_DIV(int8_t, int8_t)
155
- EWOP_INT_UINT_DIV(int8_t, u_int8_t)
161
+ EWOP_INT_UINT_DIV(int8_t, uint8_t)
156
162
  EWOP_INT_INT_DIV(int8_t, int16_t)
157
163
  EWOP_INT_INT_DIV(int8_t, int32_t)
158
164
  EWOP_INT_INT_DIV(int8_t, int64_t)
@@ -162,12 +168,12 @@ namespace nm {
162
168
  EWOP_UINT_INT_DIV(uint8_t, int32_t)
163
169
  EWOP_UINT_INT_DIV(uint8_t, int64_t)
164
170
  EWOP_FLOAT_INT_DIV(float, int8_t)
165
- EWOP_FLOAT_INT_DIV(float, u_int8_t)
171
+ EWOP_FLOAT_INT_DIV(float, uint8_t)
166
172
  EWOP_FLOAT_INT_DIV(float, int16_t)
167
173
  EWOP_FLOAT_INT_DIV(float, int32_t)
168
174
  EWOP_FLOAT_INT_DIV(float, int64_t)
169
175
  EWOP_FLOAT_INT_DIV(double, int8_t)
170
- EWOP_FLOAT_INT_DIV(double, u_int8_t)
176
+ EWOP_FLOAT_INT_DIV(double, uint8_t)
171
177
  EWOP_FLOAT_INT_DIV(double, int16_t)
172
178
  EWOP_FLOAT_INT_DIV(double, int32_t)
173
179
  EWOP_FLOAT_INT_DIV(double, int64_t)
@@ -376,7 +376,7 @@ public:
376
376
  v = reinterpret_cast<D*>(s->elements);
377
377
  v_size = nm_storage_count_max_elements(s);
378
378
 
379
- } else if (TYPE(right) == T_ARRAY) {
379
+ } else if (RB_TYPE_P(right, T_ARRAY)) {
380
380
  v_size = RARRAY_LEN(right);
381
381
  v = NM_ALLOC_N(D, v_size);
382
382
  if (dtype() == nm::RUBYOBJ) {
@@ -24,55 +24,7 @@
24
24
  #
25
25
  # This file checks FFTW3 and other necessary headers/shared objects.
26
26
 
27
- require 'mkmf'
28
-
29
- # Function derived from NArray's extconf.rb.
30
- def create_conf_h(file) #:nodoc:
31
- print "creating #{file}\n"
32
- File.open(file, 'w') do |hfile|
33
- header_guard = file.upcase.sub(/\s|\./, '_')
34
-
35
- hfile.puts "#ifndef #{header_guard}"
36
- hfile.puts "#define #{header_guard}"
37
- hfile.puts
38
-
39
- # FIXME: Find a better way to do this:
40
- hfile.puts "#define RUBY_2 1" if RUBY_VERSION >= '2.0'
41
-
42
- for line in $defs
43
- line =~ /^-D(.*)/
44
- hfile.printf "#define %s 1\n", $1
45
- end
46
-
47
- hfile.puts
48
- hfile.puts "#endif"
49
- end
50
- end
51
-
52
- def find_newer_gplusplus #:nodoc:
53
- print "checking for apparent GNU g++ binary with C++0x/C++11 support... "
54
- [9,8,7,6,5,4,3].each do |minor|
55
- ver = "4.#{minor}"
56
- gpp = "g++-#{ver}"
57
- result = `which #{gpp}`
58
- next if result.empty?
59
- CONFIG['CXX'] = gpp
60
- puts ver
61
- return CONFIG['CXX']
62
- end
63
- false
64
- end
65
-
66
- def gplusplus_version
67
- cxxvar = proc { |n| `#{CONFIG['CXX']} -E -dM - </dev/null | grep #{n}`.chomp.split(' ')[2] }
68
- major = cxxvar.call('__GNUC__')
69
- minor = cxxvar.call('__GNUC_MINOR__')
70
- patch = cxxvar.call('__GNUC_PATCHLEVEL__')
71
-
72
- raise("unable to determine g++ version (match to get version was nil)") if major.nil? || minor.nil? || patch.nil?
73
-
74
- "#{major}.#{minor}.#{patch}"
75
- end
27
+ require 'nmatrix/mkmf'
76
28
 
77
29
  fftw_libdir = RbConfig::CONFIG['libdir']
78
30
  fftw_incdir = RbConfig::CONFIG['includedir']
@@ -82,28 +34,6 @@ $CFLAGS = ["-Wall -Werror=return-type -I$(srcdir)/../nmatrix -I$(srcdir)/lapacke
82
34
  $CXXFLAGS = ["-Wall -Werror=return-type -I$(srcdir)/../nmatrix -I$(srcdir)/lapacke/include -std=c++11",$CXXFLAGS].join(" ")
83
35
  $CPPFLAGS = ["-Wall -Werror=return-type -I$(srcdir)/../nmatrix -I$(srcdir)/lapacke/include -std=c++11",$CPPFLAGS].join(" ")
84
36
 
85
- CONFIG['CXX'] = 'g++'
86
-
87
- if CONFIG['CXX'] == 'clang++'
88
- $CPP_STANDARD = 'c++11'
89
- else
90
- version = gplusplus_version
91
- if version < '4.3.0' && CONFIG['CXX'] == 'g++' # see if we can find a newer G++, unless it's been overridden by user
92
- if !find_newer_gplusplus
93
- raise("You need a version of g++ which supports -std=c++0x or -std=c++11. If you're on a Mac and using Homebrew, we recommend using mac-brew-gcc.sh to install a more recent g++.")
94
- end
95
- version = gplusplus_version
96
- end
97
-
98
- if version < '4.7.0'
99
- $CPP_STANDARD = 'c++0x'
100
- else
101
- $CPP_STANDARD = 'c++11'
102
- end
103
- puts "using C++ standard... #{$CPP_STANDARD}"
104
- puts "g++ reports version... " + `#{CONFIG['CXX']} --version|head -n 1|cut -f 3 -d " "`
105
- end
106
-
107
37
  flags = " --include=#{fftw_incdir} --libdir=#{fftw_libdir}"
108
38
 
109
39
  if have_library("fftw3")
@@ -424,6 +424,13 @@ describe 'NMatrix' do
424
424
  expect(n.reshape!([8,2]).eql?(n)).to eq(true) # because n itself changes
425
425
  end
426
426
 
427
+ it "should do the reshape operation in place, changing dimension" do
428
+ n = NMatrix.seq(4)
429
+ a = n.reshape!([4,2,2])
430
+ expect(n).to eq(NMatrix.seq([4,2,2]))
431
+ expect(a).to eq(NMatrix.seq([4,2,2]))
432
+ end
433
+
427
434
  it "reshape and reshape! must produce same result" do
428
435
  n = NMatrix.seq(4)+1
429
436
  a = NMatrix.seq(4)+1
@@ -432,7 +439,7 @@ describe 'NMatrix' do
432
439
 
433
440
  it "should prevent a resize in place" do
434
441
  n = NMatrix.seq(4)+1
435
- expect { n.reshape([5,2]) }.to raise_error(ArgumentError)
442
+ expect { n.reshape!([5,2]) }.to raise_error(ArgumentError)
436
443
  end
437
444
  end
438
445
 
@@ -534,6 +541,26 @@ describe 'NMatrix' do
534
541
  n = NMatrix.new([1,3,1], [1,2,3])
535
542
  expect(n.dconcat(n)).to eq(NMatrix.new([1,3,2], [1,1,2,2,3,3]))
536
543
  end
544
+
545
+ it "should work on matrices with different size along concat dim" do
546
+ n = N[[1, 2, 3],
547
+ [4, 5, 6]]
548
+ m = N[[7],
549
+ [8]]
550
+
551
+ expect(n.hconcat(m)).to eq N[[1, 2, 3, 7], [4, 5, 6, 8]]
552
+ expect(m.hconcat(n)).to eq N[[7, 1, 2, 3], [8, 4, 5, 6]]
553
+ end
554
+
555
+ it "should work on matrices with different size along concat dim" do
556
+ n = N[[1, 2, 3],
557
+ [4, 5, 6]]
558
+
559
+ m = N[[7, 8, 9]]
560
+
561
+ expect(n.vconcat(m)).to eq N[[1, 2, 3], [4, 5, 6], [7, 8, 9]]
562
+ expect(m.vconcat(n)).to eq N[[7, 8, 9], [1, 2, 3], [4, 5, 6]]
563
+ end
537
564
  end
538
565
 
539
566
  context "#[]" do
@@ -631,6 +658,23 @@ describe 'NMatrix' do
631
658
  end
632
659
  end
633
660
 
661
+ context "#last" do
662
+ it "returns the last element of a 1-dimensional NMatrix" do
663
+ n = NMatrix.new([1,4], [1,2,3,4])
664
+ expect(n.last).to eq(4)
665
+ end
666
+
667
+ it "returns the last element of a 2-dimensional NMatrix" do
668
+ n = NMatrix.new([2,2], [4,8,12,16])
669
+ expect(n.last).to eq(16)
670
+ end
671
+
672
+ it "returns the last element of a 3-dimensional NMatrix" do
673
+ n = NMatrix.new([2,2,2], [1,2,3,4,5,6,7,8])
674
+ expect(n.last).to eq(8)
675
+ end
676
+ end
677
+
634
678
  context "#diagonal" do
635
679
  ALL_DTYPES.each do |dtype|
636
680
  before do
@@ -682,6 +726,11 @@ describe 'NMatrix' do
682
726
  expect(@sample_matrix.repeat(2, 0)).to eq(NMatrix.new([4, 2], [1, 2, 3, 4, 1, 2, 3, 4]))
683
727
  expect(@sample_matrix.repeat(2, 1)).to eq(NMatrix.new([2, 4], [1, 2, 1, 2, 3, 4, 3, 4]))
684
728
  end
729
+
730
+ it "preserves dtype" do
731
+ expect(@sample_matrix.repeat(2, 0).dtype).to eq(@sample_matrix.dtype)
732
+ expect(@sample_matrix.repeat(2, 1).dtype).to eq(@sample_matrix.dtype)
733
+ end
685
734
  end
686
735
 
687
736
  context "#meshgrid" do