memory-view-test-helper 0.0.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/Gemfile +3 -0
- data/LICENSE.txt +7 -0
- data/README.md +33 -0
- data/Rakefile +58 -0
- data/ext/memory-view-test-helper/extconf.rb +2 -0
- data/ext/memory-view-test-helper/memory-view-test-helper.c +832 -0
- data/lib/memory-view-test-helper.rb +162 -0
- data/lib/memory-view-test-helper/version.rb +3 -0
- data/memory-view-test-helper.gemspec +37 -0
- metadata +97 -0
checksums.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
1
|
+
---
|
2
|
+
SHA256:
|
3
|
+
metadata.gz: f2c2efc75bd337da021a3e0121ad01983e8c8ffd62c98338dbfe09c80b07cffb
|
4
|
+
data.tar.gz: 1a7e4655c7b78bbf5053fd106328abf25731afd9a1385fde36a7f7f6a67a1481
|
5
|
+
SHA512:
|
6
|
+
metadata.gz: f491b150d3467a8469e6660834d0617aa10261030b53dd7f4ea87638b34023c2d2d9b4cf8afb92dd326fcebe4c65ccee91ecb8c490ae05b1628308ec22ae69c7
|
7
|
+
data.tar.gz: 3a1940c8c6e9478eaa5a84f08b0aa244bcc100ce076a5a6d9c0a6c4bd50735f2e1bae411fa19b982f15c1fb1ea1768336f39d054ac86ba9cea4a3a3cb93bc66a
|
data/Gemfile
ADDED
data/LICENSE.txt
ADDED
@@ -0,0 +1,7 @@
|
|
1
|
+
Copyright 2021 Kenta Murata <mrkn@mrkn.jp>
|
2
|
+
|
3
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
4
|
+
|
5
|
+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
6
|
+
|
7
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
data/README.md
ADDED
@@ -0,0 +1,33 @@
|
|
1
|
+
# MemoryViewTestHelper
|
2
|
+
|
3
|
+
## Description
|
4
|
+
|
5
|
+
MemoryViewTestHelper provides features that help to test libraries that include MemoryView support.
|
6
|
+
|
7
|
+
`MemoryViewTestHelper::NDArray` provides simple multi-dimensional numeric array that can export MemoryView.
|
8
|
+
|
9
|
+
## Install
|
10
|
+
|
11
|
+
```console
|
12
|
+
$ gem install memory-view-test-helper
|
13
|
+
```
|
14
|
+
|
15
|
+
## Usage
|
16
|
+
|
17
|
+
First you need to require `memory-view-test-helper` library.
|
18
|
+
|
19
|
+
```ruby
|
20
|
+
require "memory-view-test-helper"
|
21
|
+
```
|
22
|
+
|
23
|
+
You can create a multi-dimensional numeric array by `MemoryViewTestHelper.new`.
|
24
|
+
|
25
|
+
```
|
26
|
+
x = MemoryViewTestHelper::NDArray.new([[1, 2, 3], [4, 5, 6]], dtype: :float64)
|
27
|
+
```
|
28
|
+
|
29
|
+
By this expression, `x` refers a 2x3 matrix of 64-bit floating point numbers.
|
30
|
+
|
31
|
+
## License
|
32
|
+
|
33
|
+
The MIT license. See [`LICENSE.txt`](LICENSE.txt) for details.
|
data/Rakefile
ADDED
@@ -0,0 +1,58 @@
|
|
1
|
+
require "bundler/gem_helper"
|
2
|
+
require "rake/clean"
|
3
|
+
|
4
|
+
base_dir = File.join(File.dirname(__FILE__))
|
5
|
+
|
6
|
+
helper = Bundler::GemHelper.new(base_dir)
|
7
|
+
helper.install
|
8
|
+
spec = helper.gemspec
|
9
|
+
|
10
|
+
def run_extconf(build_dir, extension_dir, *arguments)
|
11
|
+
cd(build_dir) do
|
12
|
+
ruby(File.join(extension_dir, "extconf.rb"), *arguments)
|
13
|
+
end
|
14
|
+
end
|
15
|
+
|
16
|
+
spec.extensions.each do |extension|
|
17
|
+
extension_dir = File.join(base_dir, File.dirname(extension))
|
18
|
+
build_dir = ENV["BUILD_DIR"]
|
19
|
+
if build_dir
|
20
|
+
build_dir = File.join(build_dir, "memory-view-test-helper")
|
21
|
+
directory build_dir
|
22
|
+
else
|
23
|
+
build_dir = extension_dir
|
24
|
+
end
|
25
|
+
|
26
|
+
makefile = File.join(build_dir, "Makefile")
|
27
|
+
file makefile => build_dir do
|
28
|
+
run_extconf(build_dir, extension_dir)
|
29
|
+
end
|
30
|
+
|
31
|
+
CLOBBER << makefile
|
32
|
+
CLOBBER << File.join(build_dir, "mkmf.log")
|
33
|
+
|
34
|
+
desc "Configure"
|
35
|
+
task configure: makefile
|
36
|
+
|
37
|
+
desc "Compile"
|
38
|
+
task compile: makefile do
|
39
|
+
cd(build_dir) do
|
40
|
+
sh("make")
|
41
|
+
end
|
42
|
+
end
|
43
|
+
|
44
|
+
task :clean do
|
45
|
+
cd(build_dir) do
|
46
|
+
sh("make", "clean") if File.exist?("Makefile")
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|
50
|
+
|
51
|
+
desc "Run tests"
|
52
|
+
task :test do
|
53
|
+
cd(base_dir) do
|
54
|
+
ruby("test/run-test.rb")
|
55
|
+
end
|
56
|
+
end
|
57
|
+
|
58
|
+
task default: :test
|
@@ -0,0 +1,832 @@
|
|
1
|
+
#include <ruby.h>
|
2
|
+
|
3
|
+
#include <float.h>
|
4
|
+
#include <limits.h>
|
5
|
+
|
6
|
+
#define NUM2INT8(num) num2int8(num)
|
7
|
+
#define NUM2UINT8(num) num2uint8(num)
|
8
|
+
#define NUM2INT16(num) num2int16(num)
|
9
|
+
#define NUM2UINT16(num) num2uint16(num)
|
10
|
+
#if SIZEOF_INT32_T == SIZEOF_INT
|
11
|
+
# define NUM2INT32(num) ((int32_t)NUM2INT(num))
|
12
|
+
# define NUM2UINT32(num) ((uint32_t)NUM2UINT(num))
|
13
|
+
#elif SIZEOF_INT32_T == SIZEOF_LONG
|
14
|
+
# define NUM2INT32(num) ((int32_t)NUM2LONG(num))
|
15
|
+
# define NUM2UINT32(num) ((uint32_t)NUM2LLONG(num))
|
16
|
+
#else
|
17
|
+
# define NUM2INT32(num) num2int32(num)
|
18
|
+
# define NUM2UINT32(num) num2uint32(num)
|
19
|
+
static int32_t
|
20
|
+
num2int32(VALUE num)
|
21
|
+
{
|
22
|
+
return (int32_t)int_range_check(NUM2LONG(num), INT32_MIN, INT32_MAX, "int32_t");
|
23
|
+
}
|
24
|
+
|
25
|
+
static uint32_t
|
26
|
+
num2uint32(VALUE num)
|
27
|
+
{
|
28
|
+
return (uint32_t)uint_range_check(NUM2ULONG(num), UINT32_MAX, "uint32_t");
|
29
|
+
}
|
30
|
+
#endif
|
31
|
+
#if SIZEOF_INT64_T == SIZEOF_INT
|
32
|
+
# define NUM2INT64(num) ((int64_t)NUM2INT(num))
|
33
|
+
# define NUM2UINT64(num) ((uint64_t)NUM2UINT(num))
|
34
|
+
#elif SIZEOF_INT64_T == SIZEOF_LONG
|
35
|
+
# define NUM2INT64(num) ((int64_t)NUM2LONG(num))
|
36
|
+
# define NUM2UINT64(num) ((uint64_t)NUM2ULONG(num))
|
37
|
+
#elif SIZEOF_INT64_T == SIZEOF_LONG_LONG
|
38
|
+
# define NUM2INT64(num) ((int64_t)NUM2LL(num))
|
39
|
+
# define NUM2UINT64(num) ((uint64_t)NUM2ULL(num))
|
40
|
+
#else
|
41
|
+
# error Unable to define NUM2INT64 and NUM2UINT64
|
42
|
+
#endif
|
43
|
+
#define NUM2FLT(num) num2flt(num)
|
44
|
+
|
45
|
+
static long
|
46
|
+
int_range_check(long num, long min, long max, const char *type)
|
47
|
+
{
|
48
|
+
if (min <= num && num <= max) return num;
|
49
|
+
rb_raise(rb_eRangeError, "integer %ld too %s to convert to `%s'",
|
50
|
+
num, num < 0 ? "small" : "big", type);
|
51
|
+
}
|
52
|
+
|
53
|
+
static unsigned long
|
54
|
+
uint_range_check(unsigned long num, unsigned long max, const char *type)
|
55
|
+
{
|
56
|
+
if (num > max) {
|
57
|
+
rb_raise(rb_eRangeError, "integer %lu too big to convert to `%s'", num, type);
|
58
|
+
}
|
59
|
+
return num;
|
60
|
+
}
|
61
|
+
|
62
|
+
static int8_t
|
63
|
+
num2int8(VALUE num)
|
64
|
+
{
|
65
|
+
return (int8_t)int_range_check(NUM2LONG(num), INT8_MIN, INT8_MAX, "int8_t");
|
66
|
+
}
|
67
|
+
|
68
|
+
static uint8_t
|
69
|
+
num2uint8(VALUE num)
|
70
|
+
{
|
71
|
+
return (uint8_t)uint_range_check(NUM2ULONG(num), UINT8_MAX, "uint8_t");
|
72
|
+
}
|
73
|
+
|
74
|
+
static int16_t
|
75
|
+
num2int16(VALUE num)
|
76
|
+
{
|
77
|
+
return (int16_t)int_range_check(NUM2LONG(num), INT16_MIN, INT16_MAX, "int16_t");
|
78
|
+
}
|
79
|
+
|
80
|
+
static uint16_t
|
81
|
+
num2uint16(VALUE num)
|
82
|
+
{
|
83
|
+
return (uint16_t)uint_range_check(NUM2ULONG(num), UINT16_MAX, "uint16_t");
|
84
|
+
}
|
85
|
+
|
86
|
+
static float
|
87
|
+
num2flt(VALUE num)
|
88
|
+
{
|
89
|
+
double dbl = NUM2DBL(num);
|
90
|
+
if (dbl < FLT_MIN || FLT_MAX < dbl) {
|
91
|
+
rb_raise(rb_eRangeError, "float %lf too %s to convert to `float'",
|
92
|
+
dbl, dbl < 0 ? "small" : "big");
|
93
|
+
}
|
94
|
+
return (float)dbl;
|
95
|
+
}
|
96
|
+
|
97
|
+
VALUE mMemoryViewTestHelper;
|
98
|
+
VALUE cNDArray;
|
99
|
+
|
100
|
+
static VALUE sym_row_major;
|
101
|
+
static VALUE sym_column_major;
|
102
|
+
static VALUE sym_auto;
|
103
|
+
|
104
|
+
#define MAX_INLINE_DIM 32
|
105
|
+
|
106
|
+
typedef enum {
|
107
|
+
ndarray_dtype_none = 0,
|
108
|
+
ndarray_dtype_int8,
|
109
|
+
ndarray_dtype_uint8,
|
110
|
+
ndarray_dtype_int16,
|
111
|
+
ndarray_dtype_uint16,
|
112
|
+
ndarray_dtype_int32,
|
113
|
+
ndarray_dtype_uint32,
|
114
|
+
ndarray_dtype_int64,
|
115
|
+
ndarray_dtype_uint64,
|
116
|
+
ndarray_dtype_float32,
|
117
|
+
ndarray_dtype_float64,
|
118
|
+
|
119
|
+
___ndarray_dtype_sentinel___
|
120
|
+
} ndarray_dtype_t;
|
121
|
+
|
122
|
+
#define NDARRAY_NUM_DTYPES ((int)___ndarray_dtype_sentinel___)
|
123
|
+
|
124
|
+
static const int ndarray_dtype_sizes[] = {
|
125
|
+
0,
|
126
|
+
sizeof(int8_t),
|
127
|
+
sizeof(uint8_t),
|
128
|
+
sizeof(int16_t),
|
129
|
+
sizeof(uint16_t),
|
130
|
+
sizeof(int32_t),
|
131
|
+
sizeof(uint32_t),
|
132
|
+
sizeof(int64_t),
|
133
|
+
sizeof(uint64_t),
|
134
|
+
sizeof(float),
|
135
|
+
sizeof(double),
|
136
|
+
};
|
137
|
+
|
138
|
+
#define SIZEOF_DTYPE(type) (*(const int *)(&ndarray_dtype_sizes[type]))
|
139
|
+
|
140
|
+
static ID ndarray_dtype_ids[NDARRAY_NUM_DTYPES];
|
141
|
+
|
142
|
+
#define DTYPE_ID(type) (*(const ID *)(&ndarray_dtype_ids[type]))
|
143
|
+
|
144
|
+
static ndarray_dtype_t
|
145
|
+
ndarray_id_to_dtype_t(ID id)
|
146
|
+
{
|
147
|
+
int i;
|
148
|
+
for (i = 0; i < NDARRAY_NUM_DTYPES; ++i) {
|
149
|
+
if (ndarray_dtype_ids[i] == id) {
|
150
|
+
return (ndarray_dtype_t)i;
|
151
|
+
}
|
152
|
+
}
|
153
|
+
rb_raise(rb_eArgError, "unknown dtype: %"PRIsVALUE, ID2SYM(id));
|
154
|
+
}
|
155
|
+
|
156
|
+
static ndarray_dtype_t
|
157
|
+
ndarray_sym_to_dtype_t(VALUE sym)
|
158
|
+
{
|
159
|
+
assert(RB_TYPE_P(sym, T_SYMBOL));
|
160
|
+
ID id = SYM2ID(sym);
|
161
|
+
return ndarray_id_to_dtype_t(id);
|
162
|
+
}
|
163
|
+
|
164
|
+
static ndarray_dtype_t
|
165
|
+
ndarray_obj_to_dtype_t(VALUE obj)
|
166
|
+
{
|
167
|
+
while (!RB_TYPE_P(obj, T_SYMBOL)) {
|
168
|
+
if (RB_TYPE_P(obj, T_STRING) || rb_respond_to(obj, rb_intern("to_sym"))) {
|
169
|
+
obj = rb_funcallv(obj, rb_intern("to_sym"), 0, NULL);
|
170
|
+
if (!RB_TYPE_P(obj, T_SYMBOL)) {
|
171
|
+
goto type_error;
|
172
|
+
}
|
173
|
+
}
|
174
|
+
else if (!RB_TYPE_P(obj, T_STRING)) {
|
175
|
+
if (!rb_respond_to(obj, rb_intern("to_str"))) {
|
176
|
+
goto type_error;
|
177
|
+
}
|
178
|
+
else {
|
179
|
+
obj = rb_funcallv(obj, rb_intern("to_str"), 0, NULL);
|
180
|
+
}
|
181
|
+
}
|
182
|
+
}
|
183
|
+
return ndarray_sym_to_dtype_t(obj);
|
184
|
+
|
185
|
+
type_error:
|
186
|
+
rb_raise(rb_eTypeError, "dtype must be a symbol");
|
187
|
+
}
|
188
|
+
|
189
|
+
typedef struct {
|
190
|
+
void *data;
|
191
|
+
ssize_t byte_size;
|
192
|
+
|
193
|
+
ndarray_dtype_t dtype;
|
194
|
+
ssize_t ndim;
|
195
|
+
ssize_t *shape;
|
196
|
+
ssize_t *strides;
|
197
|
+
|
198
|
+
VALUE base;
|
199
|
+
} ndarray_t;
|
200
|
+
|
201
|
+
static void ndarray_mark(void *);
|
202
|
+
static void ndarray_free(void *);
|
203
|
+
static size_t ndarray_memsize(const void *);
|
204
|
+
|
205
|
+
static const rb_data_type_t ndarray_data_type = {
|
206
|
+
"memory-view-test-helper/ndarray",
|
207
|
+
{
|
208
|
+
ndarray_mark,
|
209
|
+
ndarray_free,
|
210
|
+
ndarray_memsize,
|
211
|
+
},
|
212
|
+
0, 0, RUBY_TYPED_FREE_IMMEDIATELY
|
213
|
+
};
|
214
|
+
|
215
|
+
static void
|
216
|
+
ndarray_mark(void *ptr)
|
217
|
+
{
|
218
|
+
ndarray_t *nar = (ndarray_t *)ptr;
|
219
|
+
if (nar->base)
|
220
|
+
rb_gc_mark(nar->base);
|
221
|
+
}
|
222
|
+
|
223
|
+
static void
|
224
|
+
ndarray_free(void *ptr)
|
225
|
+
{
|
226
|
+
ndarray_t *nar = (ndarray_t *)ptr;
|
227
|
+
if (!nar->base && nar->data) xfree(nar->data);
|
228
|
+
if (nar->shape) xfree(nar->shape);
|
229
|
+
if (nar->strides) xfree(nar->strides);
|
230
|
+
xfree(nar);
|
231
|
+
}
|
232
|
+
|
233
|
+
static size_t
|
234
|
+
ndarray_memsize(const void *ptr)
|
235
|
+
{
|
236
|
+
ndarray_t *nar = (ndarray_t *)ptr;
|
237
|
+
size_t size = sizeof(ndarray_t);
|
238
|
+
if (nar->data) size += nar->byte_size;
|
239
|
+
if (nar->shape) size += sizeof(ssize_t) * nar->ndim;
|
240
|
+
if (nar->strides) size += sizeof(ssize_t) * nar->ndim;
|
241
|
+
return size;
|
242
|
+
}
|
243
|
+
|
244
|
+
static VALUE
|
245
|
+
ndarray_s_allocate(VALUE klass)
|
246
|
+
{
|
247
|
+
ndarray_t *nar;
|
248
|
+
VALUE obj = TypedData_Make_Struct(klass, ndarray_t, &ndarray_data_type, nar);
|
249
|
+
nar->data = NULL;
|
250
|
+
nar->byte_size = 0;
|
251
|
+
nar->dtype = ndarray_dtype_none;
|
252
|
+
nar->ndim = 0;
|
253
|
+
nar->shape = NULL;
|
254
|
+
nar->strides = NULL;
|
255
|
+
nar->base = Qfalse;
|
256
|
+
return obj;
|
257
|
+
}
|
258
|
+
|
259
|
+
static void
|
260
|
+
ndarray_init_row_major_strides(const ndarray_dtype_t dtype, const ssize_t ndim,
|
261
|
+
const ssize_t *shape, ssize_t *out_strides)
|
262
|
+
{
|
263
|
+
const ssize_t item_size = SIZEOF_DTYPE(dtype);
|
264
|
+
out_strides[ndim - 1] = item_size;
|
265
|
+
|
266
|
+
int i;
|
267
|
+
for (i = ndim - 1; i > 0; --i) {
|
268
|
+
out_strides[i - 1] = out_strides[i] * shape[i];
|
269
|
+
}
|
270
|
+
}
|
271
|
+
|
272
|
+
static VALUE
|
273
|
+
ndarray_initialize(VALUE obj, VALUE shape_ary, VALUE dtype_name)
|
274
|
+
{
|
275
|
+
int i;
|
276
|
+
|
277
|
+
Check_Type(shape_ary, T_ARRAY);
|
278
|
+
|
279
|
+
const ssize_t ndim = (ssize_t)RARRAY_LEN(shape_ary);
|
280
|
+
for (i = 0; i < ndim; ++i) {
|
281
|
+
VALUE si = RARRAY_AREF(shape_ary, i);
|
282
|
+
Check_Type(si, T_FIXNUM);
|
283
|
+
}
|
284
|
+
|
285
|
+
ssize_t *shape = ALLOC_N(ssize_t, ndim);
|
286
|
+
for (i = 0; i < ndim; ++i) {
|
287
|
+
VALUE si = RARRAY_AREF(shape_ary, i);
|
288
|
+
shape[i] = NUM2SSIZET(si);
|
289
|
+
}
|
290
|
+
|
291
|
+
ndarray_dtype_t dtype = ndarray_obj_to_dtype_t(dtype_name);
|
292
|
+
|
293
|
+
ssize_t *strides = ALLOC_N(ssize_t, ndim);
|
294
|
+
ndarray_init_row_major_strides(dtype, ndim, shape, strides);
|
295
|
+
|
296
|
+
ndarray_t *nar;
|
297
|
+
TypedData_Get_Struct(obj, ndarray_t, &ndarray_data_type, nar);
|
298
|
+
|
299
|
+
ssize_t byte_size = strides[0] * shape[0];
|
300
|
+
nar->data = ALLOC_N(uint8_t, byte_size);
|
301
|
+
nar->byte_size = byte_size;
|
302
|
+
nar->dtype = dtype;
|
303
|
+
nar->ndim = ndim;
|
304
|
+
nar->shape = shape;
|
305
|
+
nar->strides = strides;
|
306
|
+
|
307
|
+
return Qnil;
|
308
|
+
}
|
309
|
+
|
310
|
+
static VALUE
|
311
|
+
ndarray_get_byte_size(VALUE obj)
|
312
|
+
{
|
313
|
+
ndarray_t *nar;
|
314
|
+
TypedData_Get_Struct(obj, ndarray_t, &ndarray_data_type, nar);
|
315
|
+
|
316
|
+
return SSIZET2NUM(nar->byte_size);
|
317
|
+
}
|
318
|
+
|
319
|
+
static VALUE
|
320
|
+
ndarray_get_dtype(VALUE obj)
|
321
|
+
{
|
322
|
+
ndarray_t *nar;
|
323
|
+
TypedData_Get_Struct(obj, ndarray_t, &ndarray_data_type, nar);
|
324
|
+
|
325
|
+
if (ndarray_dtype_none < nar->dtype && nar->dtype < NDARRAY_NUM_DTYPES) {
|
326
|
+
return ID2SYM(DTYPE_ID(nar->dtype));
|
327
|
+
}
|
328
|
+
return Qnil;
|
329
|
+
}
|
330
|
+
|
331
|
+
static VALUE
|
332
|
+
ndarray_get_ndim(VALUE obj)
|
333
|
+
{
|
334
|
+
ndarray_t *nar;
|
335
|
+
TypedData_Get_Struct(obj, ndarray_t, &ndarray_data_type, nar);
|
336
|
+
|
337
|
+
return SSIZET2NUM(nar->ndim);
|
338
|
+
}
|
339
|
+
|
340
|
+
static VALUE
|
341
|
+
ndarray_get_shape(VALUE obj)
|
342
|
+
{
|
343
|
+
ndarray_t *nar;
|
344
|
+
TypedData_Get_Struct(obj, ndarray_t, &ndarray_data_type, nar);
|
345
|
+
|
346
|
+
VALUE ary = rb_ary_new_capa(nar->ndim);
|
347
|
+
int i;
|
348
|
+
for (i = 0; i < nar->ndim; ++i) {
|
349
|
+
rb_ary_push(ary, SSIZET2NUM(nar->shape[i]));
|
350
|
+
}
|
351
|
+
|
352
|
+
return ary;
|
353
|
+
}
|
354
|
+
|
355
|
+
static VALUE
|
356
|
+
ndarray_get_strides(VALUE obj)
|
357
|
+
{
|
358
|
+
ndarray_t *nar;
|
359
|
+
TypedData_Get_Struct(obj, ndarray_t, &ndarray_data_type, nar);
|
360
|
+
|
361
|
+
if (nar->strides == NULL) {
|
362
|
+
return rb_ary_new_capa(0);
|
363
|
+
}
|
364
|
+
|
365
|
+
VALUE ary = rb_ary_new_capa(nar->ndim);
|
366
|
+
int i;
|
367
|
+
for (i = 0; i < nar->ndim; ++i) {
|
368
|
+
rb_ary_push(ary, SSIZET2NUM(nar->strides[i]));
|
369
|
+
}
|
370
|
+
|
371
|
+
return ary;
|
372
|
+
}
|
373
|
+
|
374
|
+
static VALUE
|
375
|
+
ndarray_get_value(const uint8_t *value_ptr, const ndarray_dtype_t dtype)
|
376
|
+
{
|
377
|
+
assert(value_ptr != NULL);
|
378
|
+
switch (dtype) {
|
379
|
+
case ndarray_dtype_int8:
|
380
|
+
return INT2NUM(*(int8_t *)value_ptr);
|
381
|
+
case ndarray_dtype_uint8:
|
382
|
+
return UINT2NUM(*(uint8_t *)value_ptr);
|
383
|
+
|
384
|
+
case ndarray_dtype_int16:
|
385
|
+
return INT2NUM(*(int16_t *)value_ptr);
|
386
|
+
case ndarray_dtype_uint16:
|
387
|
+
return UINT2NUM(*(uint16_t *)value_ptr);
|
388
|
+
|
389
|
+
case ndarray_dtype_int32:
|
390
|
+
return LONG2NUM(*(int32_t *)value_ptr);
|
391
|
+
case ndarray_dtype_uint32:
|
392
|
+
return ULONG2NUM(*(uint32_t *)value_ptr);
|
393
|
+
|
394
|
+
case ndarray_dtype_int64:
|
395
|
+
return LL2NUM(*(int64_t *)value_ptr);
|
396
|
+
case ndarray_dtype_uint64:
|
397
|
+
return ULL2NUM(*(uint64_t *)value_ptr);
|
398
|
+
|
399
|
+
case ndarray_dtype_float32:
|
400
|
+
return DBL2NUM(*(float *)value_ptr);
|
401
|
+
case ndarray_dtype_float64:
|
402
|
+
return DBL2NUM(*(double *)value_ptr);
|
403
|
+
|
404
|
+
default:
|
405
|
+
return Qnil;
|
406
|
+
}
|
407
|
+
}
|
408
|
+
|
409
|
+
static VALUE
|
410
|
+
ndarray_1d_aref(const ndarray_t *nar, ssize_t i)
|
411
|
+
{
|
412
|
+
assert(nar != NULL);
|
413
|
+
assert(nar->ndim == 1);
|
414
|
+
assert(0 <= i);
|
415
|
+
assert(i < nar->shape[0]);
|
416
|
+
|
417
|
+
uint8_t *p = ((uint8_t *)nar->data) + i * nar->strides[0];
|
418
|
+
return ndarray_get_value(p, nar->dtype);
|
419
|
+
}
|
420
|
+
|
421
|
+
static VALUE
|
422
|
+
ndarray_md_aref(const ndarray_t *nar, ssize_t *indices)
|
423
|
+
{
|
424
|
+
assert(nar != NULL);
|
425
|
+
assert(indices != NULL);
|
426
|
+
|
427
|
+
/* assume the size of indices equals to nar->ndim */
|
428
|
+
const ssize_t ndim = nar->ndim;
|
429
|
+
|
430
|
+
uint8_t *value_ptr = nar->data;
|
431
|
+
ssize_t i;
|
432
|
+
for (i = 0; i < ndim; ++i) {
|
433
|
+
value_ptr += indices[i] * nar->strides[i];
|
434
|
+
}
|
435
|
+
|
436
|
+
return ndarray_get_value(value_ptr, nar->dtype);
|
437
|
+
}
|
438
|
+
|
439
|
+
static VALUE
|
440
|
+
ndarray_aref(int argc, VALUE *argv, VALUE obj)
|
441
|
+
{
|
442
|
+
ndarray_t *nar;
|
443
|
+
TypedData_Get_Struct(obj, ndarray_t, &ndarray_data_type, nar);
|
444
|
+
|
445
|
+
if (nar->ndim != argc) {
|
446
|
+
rb_raise(rb_eIndexError, "index dimension mismatched (%d for %"PRIdSIZE")", argc, nar->ndim);
|
447
|
+
}
|
448
|
+
|
449
|
+
const ssize_t ndim = nar->ndim;
|
450
|
+
if (ndim == 1) {
|
451
|
+
const ssize_t i = NUM2SSIZET(argv[0]);
|
452
|
+
return ndarray_1d_aref(nar, i);
|
453
|
+
}
|
454
|
+
else {
|
455
|
+
ssize_t inline_indices_buf[MAX_INLINE_DIM] = { 0, };
|
456
|
+
ssize_t *indices = inline_indices_buf;
|
457
|
+
|
458
|
+
VALUE heap_indices_buf = 0;
|
459
|
+
if (ndim > MAX_INLINE_DIM) {
|
460
|
+
indices = RB_ALLOCV_N(ssize_t, heap_indices_buf, ndim);
|
461
|
+
}
|
462
|
+
|
463
|
+
ssize_t i;
|
464
|
+
for (i = 0; i < ndim; ++i) {
|
465
|
+
indices[i] = NUM2SSIZET(argv[i]);
|
466
|
+
}
|
467
|
+
|
468
|
+
VALUE res = ndarray_md_aref(nar, indices);
|
469
|
+
RB_ALLOCV_END(heap_indices_buf);
|
470
|
+
return res;
|
471
|
+
}
|
472
|
+
}
|
473
|
+
|
474
|
+
static VALUE
|
475
|
+
ndarray_set_value(uint8_t *value_ptr, const ndarray_dtype_t dtype, const VALUE val)
|
476
|
+
{
|
477
|
+
assert(value_ptr != NULL);
|
478
|
+
switch (dtype) {
|
479
|
+
case ndarray_dtype_int8:
|
480
|
+
*(int8_t *)value_ptr = NUM2INT8(val);
|
481
|
+
break;
|
482
|
+
case ndarray_dtype_uint8:
|
483
|
+
*(uint8_t *)value_ptr = NUM2UINT8(val);
|
484
|
+
break;
|
485
|
+
|
486
|
+
case ndarray_dtype_int16:
|
487
|
+
*(int16_t *)value_ptr = NUM2INT16(val);
|
488
|
+
break;
|
489
|
+
case ndarray_dtype_uint16:
|
490
|
+
*(uint16_t *)value_ptr = NUM2UINT16(val);
|
491
|
+
break;
|
492
|
+
|
493
|
+
case ndarray_dtype_int32:
|
494
|
+
*(int32_t *)value_ptr = NUM2INT32(val);
|
495
|
+
case ndarray_dtype_uint32:
|
496
|
+
*(uint32_t *)value_ptr = NUM2UINT32(val);
|
497
|
+
|
498
|
+
case ndarray_dtype_int64:
|
499
|
+
*(int64_t *)value_ptr = NUM2INT64(val);
|
500
|
+
break;
|
501
|
+
case ndarray_dtype_uint64:
|
502
|
+
*(uint64_t *)value_ptr = NUM2UINT64(val);
|
503
|
+
break;
|
504
|
+
|
505
|
+
case ndarray_dtype_float32:
|
506
|
+
*(float *)value_ptr = NUM2FLT(val);
|
507
|
+
break;
|
508
|
+
case ndarray_dtype_float64:
|
509
|
+
*(double *)value_ptr = NUM2DBL(val);
|
510
|
+
break;
|
511
|
+
|
512
|
+
default:
|
513
|
+
return Qnil;
|
514
|
+
}
|
515
|
+
|
516
|
+
return val;
|
517
|
+
}
|
518
|
+
|
519
|
+
static VALUE
|
520
|
+
ndarray_md_aset(ndarray_t *nar, ssize_t *indices, VALUE val)
|
521
|
+
{
|
522
|
+
assert(nar != NULL);
|
523
|
+
assert(indices != NULL);
|
524
|
+
|
525
|
+
/* assume the size of indices equals to nar->ndim */
|
526
|
+
const ssize_t ndim = nar->ndim;
|
527
|
+
|
528
|
+
uint8_t *value_ptr = nar->data;
|
529
|
+
ssize_t i;
|
530
|
+
for (i = 0; i < ndim; ++i) {
|
531
|
+
value_ptr += indices[i] * nar->strides[i];
|
532
|
+
}
|
533
|
+
|
534
|
+
return ndarray_set_value(value_ptr, nar->dtype, val);
|
535
|
+
}
|
536
|
+
|
537
|
+
static VALUE
|
538
|
+
ndarray_aset(int argc, VALUE *argv, VALUE obj)
|
539
|
+
{
|
540
|
+
ndarray_t *nar;
|
541
|
+
TypedData_Get_Struct(obj, ndarray_t, &ndarray_data_type, nar);
|
542
|
+
|
543
|
+
rb_check_frozen(obj);
|
544
|
+
|
545
|
+
if (nar->ndim != argc - 1) {
|
546
|
+
rb_raise(rb_eIndexError, "index dimension mismatched (%d for %"PRIdSIZE")", argc - 1, nar->ndim);
|
547
|
+
}
|
548
|
+
|
549
|
+
const VALUE val = argv[argc-1];
|
550
|
+
const int item_size = SIZEOF_DTYPE(nar->dtype);
|
551
|
+
|
552
|
+
const ssize_t ndim = nar->ndim;
|
553
|
+
if (ndim == 1) {
|
554
|
+
/* special case for 1-D array */
|
555
|
+
ssize_t i = NUM2SSIZET(argv[0]);
|
556
|
+
uint8_t *p = ((uint8_t *)nar->data) + i * item_size;
|
557
|
+
return ndarray_set_value(p, nar->dtype, val);
|
558
|
+
}
|
559
|
+
else {
|
560
|
+
ssize_t inline_indices_buf[MAX_INLINE_DIM] = { 0, };
|
561
|
+
ssize_t *indices = inline_indices_buf;
|
562
|
+
|
563
|
+
VALUE heap_indices_buf = 0;
|
564
|
+
if (ndim > MAX_INLINE_DIM) {
|
565
|
+
indices = RB_ALLOCV_N(ssize_t, heap_indices_buf, ndim);
|
566
|
+
}
|
567
|
+
|
568
|
+
ssize_t i;
|
569
|
+
for (i = 0; i < ndim; ++i) {
|
570
|
+
indices[i] = NUM2SSIZET(argv[i]);
|
571
|
+
}
|
572
|
+
|
573
|
+
VALUE res = ndarray_md_aset(nar, indices, val);
|
574
|
+
RB_ALLOCV_END(heap_indices_buf);
|
575
|
+
return res;
|
576
|
+
}
|
577
|
+
}
|
578
|
+
|
579
|
+
static int
|
580
|
+
increment_indices(const ndarray_t *nar, ssize_t *indices)
|
581
|
+
{
|
582
|
+
assert(nar != NULL);
|
583
|
+
assert(indices != NULL);
|
584
|
+
|
585
|
+
ssize_t i = nar->ndim - 1;
|
586
|
+
if (indices[i] + 1 < nar->shape[i]) {
|
587
|
+
++indices[i];
|
588
|
+
return 0;
|
589
|
+
}
|
590
|
+
else {
|
591
|
+
indices[i] = 0;
|
592
|
+
for (--i; i >= 0; --i) {
|
593
|
+
if (indices[i] + 1 == nar->shape[i]) {
|
594
|
+
indices[i] = 0;
|
595
|
+
}
|
596
|
+
else {
|
597
|
+
++indices[i];
|
598
|
+
return 0;
|
599
|
+
}
|
600
|
+
}
|
601
|
+
return 1; /* overflow */
|
602
|
+
}
|
603
|
+
}
|
604
|
+
|
605
|
+
static VALUE
|
606
|
+
ndarray_md_eq(const ndarray_t *nar1, const ndarray_t *nar2)
|
607
|
+
{
|
608
|
+
assert(nar1 != NULL);
|
609
|
+
assert(nar2 != NULL);
|
610
|
+
assert(nar1->ndim == nar2->ndim);
|
611
|
+
|
612
|
+
const ssize_t ndim = nar1->ndim;
|
613
|
+
ssize_t n_items = 1;
|
614
|
+
ssize_t i;
|
615
|
+
for (i = 0; i < ndim; ++i) {
|
616
|
+
if (nar1->shape[i] != nar2->shape[i])
|
617
|
+
return Qfalse;
|
618
|
+
|
619
|
+
n_items *= nar1->shape[i];
|
620
|
+
}
|
621
|
+
|
622
|
+
ssize_t inline_indices_buf[MAX_INLINE_DIM] = { 0, };
|
623
|
+
ssize_t *indices = inline_indices_buf;
|
624
|
+
|
625
|
+
VALUE heap_indices_buf = 0;
|
626
|
+
if (ndim > MAX_INLINE_DIM) {
|
627
|
+
indices = RB_ALLOCV_N(ssize_t, heap_indices_buf, ndim);
|
628
|
+
MEMZERO(indices, ssize_t, ndim);
|
629
|
+
}
|
630
|
+
|
631
|
+
ssize_t n = 0;
|
632
|
+
VALUE res = Qtrue;
|
633
|
+
for (; n < n_items; ++n) {
|
634
|
+
VALUE v1 = ndarray_md_aref(nar1, indices);
|
635
|
+
VALUE v2 = ndarray_md_aref(nar2, indices);
|
636
|
+
if (!rb_equal(v1, v2)) {
|
637
|
+
res = Qfalse;
|
638
|
+
break;
|
639
|
+
}
|
640
|
+
|
641
|
+
increment_indices(nar1, indices);
|
642
|
+
}
|
643
|
+
|
644
|
+
RB_ALLOCV_END(heap_indices_buf);
|
645
|
+
return res;
|
646
|
+
}
|
647
|
+
|
648
|
+
static VALUE
|
649
|
+
ndarray_eq(VALUE obj, VALUE other)
|
650
|
+
{
|
651
|
+
if (obj == other)
|
652
|
+
return Qtrue;
|
653
|
+
else if (!rb_typeddata_is_kind_of(other, &ndarray_data_type)) {
|
654
|
+
return Qfalse;
|
655
|
+
}
|
656
|
+
|
657
|
+
ndarray_t *nar1, *nar2;
|
658
|
+
TypedData_Get_Struct(obj, ndarray_t, &ndarray_data_type, nar1);
|
659
|
+
TypedData_Get_Struct(other, ndarray_t, &ndarray_data_type, nar2);
|
660
|
+
|
661
|
+
const ssize_t ndim = nar1->ndim;
|
662
|
+
if (ndim != nar2->ndim)
|
663
|
+
return Qfalse;
|
664
|
+
|
665
|
+
if (ndim == 1) {
|
666
|
+
const ssize_t n = nar1->shape[0];
|
667
|
+
if (n != nar2->shape[0])
|
668
|
+
return Qfalse;
|
669
|
+
|
670
|
+
ssize_t i;
|
671
|
+
for (i = 0; i < n; ++i) {
|
672
|
+
VALUE v1 = ndarray_1d_aref(nar1, i);
|
673
|
+
VALUE v2 = ndarray_1d_aref(nar2, i);
|
674
|
+
if (!rb_equal(v1, v2))
|
675
|
+
return Qfalse;
|
676
|
+
}
|
677
|
+
|
678
|
+
return Qtrue;
|
679
|
+
}
|
680
|
+
else {
|
681
|
+
return ndarray_md_eq(nar1, nar2);
|
682
|
+
}
|
683
|
+
}
|
684
|
+
|
685
|
+
static void
|
686
|
+
check_order(VALUE order)
|
687
|
+
{
|
688
|
+
if (order != sym_row_major && order != sym_column_major && order != sym_auto) {
|
689
|
+
rb_raise(rb_eArgError,
|
690
|
+
"order must be either :row_major, :column_major, or :auto (%"PRIsVALUE" given)",
|
691
|
+
order);
|
692
|
+
}
|
693
|
+
}
|
694
|
+
|
695
|
+
static VALUE
|
696
|
+
ndarray_reshape_impl(VALUE base, VALUE new_shape_v, VALUE order)
|
697
|
+
{
|
698
|
+
enum {
|
699
|
+
nothing,
|
700
|
+
zero_or_negative_size_in_shape,
|
701
|
+
incompatible_new_shape,
|
702
|
+
} failure_reason = nothing;
|
703
|
+
|
704
|
+
ndarray_t *nar_base;
|
705
|
+
TypedData_Get_Struct(base, ndarray_t, &ndarray_data_type, nar_base);
|
706
|
+
|
707
|
+
Check_Type(new_shape_v, T_ARRAY);
|
708
|
+
check_order(order);
|
709
|
+
|
710
|
+
if (order == sym_auto) {
|
711
|
+
rb_raise(rb_eNotImpError, ":auto order is not implemented");
|
712
|
+
}
|
713
|
+
else if (order == sym_column_major) {
|
714
|
+
rb_raise(rb_eNotImpError, ":column_major order is not implemented");
|
715
|
+
}
|
716
|
+
|
717
|
+
const ssize_t new_ndim = RARRAY_LEN(new_shape_v);
|
718
|
+
|
719
|
+
/* preparing the buffer for new_shape */
|
720
|
+
|
721
|
+
ssize_t inline_new_shape_buf[MAX_INLINE_DIM] = { 0, };
|
722
|
+
ssize_t *new_shape = inline_new_shape_buf;
|
723
|
+
|
724
|
+
if (new_ndim > MAX_INLINE_DIM) {
|
725
|
+
new_shape = ALLOC_N(ssize_t, new_ndim);
|
726
|
+
}
|
727
|
+
|
728
|
+
/* extracting new_shape */
|
729
|
+
|
730
|
+
ssize_t byte_size = SIZEOF_DTYPE(nar_base->dtype);
|
731
|
+
ssize_t i;
|
732
|
+
for (i = 0; i < new_ndim; ++i) {
|
733
|
+
ssize_t dim_size = NUM2SSIZET(RARRAY_AREF(new_shape_v, i));
|
734
|
+
if (dim_size <= 0) {
|
735
|
+
failure_reason = zero_or_negative_size_in_shape;
|
736
|
+
goto finish;
|
737
|
+
}
|
738
|
+
new_shape[i] = dim_size;
|
739
|
+
byte_size *= dim_size;
|
740
|
+
}
|
741
|
+
|
742
|
+
if (byte_size != nar_base->byte_size) {
|
743
|
+
failure_reason = incompatible_new_shape;
|
744
|
+
goto finish;
|
745
|
+
}
|
746
|
+
|
747
|
+
/* preparing view array */
|
748
|
+
|
749
|
+
VALUE view = ndarray_s_allocate(CLASS_OF(base));
|
750
|
+
|
751
|
+
ndarray_t *nar;
|
752
|
+
TypedData_Get_Struct(view, ndarray_t, &ndarray_data_type, nar);
|
753
|
+
|
754
|
+
nar->data = nar_base->data;
|
755
|
+
nar->byte_size = nar_base->byte_size;
|
756
|
+
nar->dtype = nar_base->dtype;
|
757
|
+
nar->base = base;
|
758
|
+
nar->ndim = new_ndim;
|
759
|
+
|
760
|
+
if (new_shape == inline_new_shape_buf) {
|
761
|
+
nar->shape = ALLOC_N(ssize_t, new_ndim);
|
762
|
+
MEMCPY(nar->shape, new_shape, ssize_t, new_ndim);
|
763
|
+
}
|
764
|
+
else {
|
765
|
+
nar->shape = new_shape;
|
766
|
+
}
|
767
|
+
|
768
|
+
nar->strides = ALLOC_N(ssize_t, new_ndim);
|
769
|
+
|
770
|
+
if (order == sym_row_major) {
|
771
|
+
ndarray_init_row_major_strides(nar->dtype, new_ndim, nar->shape, nar->strides);
|
772
|
+
}
|
773
|
+
|
774
|
+
finish:
|
775
|
+
if (failure_reason != nothing) {
|
776
|
+
if (new_shape && new_shape != inline_new_shape_buf) {
|
777
|
+
xfree(new_shape);
|
778
|
+
}
|
779
|
+
}
|
780
|
+
|
781
|
+
switch (failure_reason) {
|
782
|
+
case zero_or_negative_size_in_shape:
|
783
|
+
rb_raise(rb_eArgError, "zero or negative size is given in new_shape");
|
784
|
+
|
785
|
+
case incompatible_new_shape:
|
786
|
+
rb_raise(rb_eArgError,
|
787
|
+
"new_shape is incompatible with the base shape (%"PRIsVALUE" for %"PRIsVALUE")",
|
788
|
+
new_shape_v, ndarray_get_shape(base));
|
789
|
+
|
790
|
+
default:
|
791
|
+
break;
|
792
|
+
}
|
793
|
+
|
794
|
+
return view;
|
795
|
+
}
|
796
|
+
|
797
|
+
void
|
798
|
+
Init_memory_view_test_helper(void)
|
799
|
+
{
|
800
|
+
mMemoryViewTestHelper = rb_define_module("MemoryViewTestHelper");
|
801
|
+
cNDArray = rb_define_class_under(mMemoryViewTestHelper, "NDArray", rb_cObject);
|
802
|
+
|
803
|
+
rb_define_alloc_func(cNDArray, ndarray_s_allocate);
|
804
|
+
rb_define_method(cNDArray, "initialize", ndarray_initialize, 2);
|
805
|
+
rb_define_method(cNDArray, "byte_size", ndarray_get_byte_size, 0);
|
806
|
+
rb_define_method(cNDArray, "dtype", ndarray_get_dtype, 0);
|
807
|
+
rb_define_method(cNDArray, "ndim", ndarray_get_ndim, 0);
|
808
|
+
rb_define_method(cNDArray, "shape", ndarray_get_shape, 0);
|
809
|
+
rb_define_method(cNDArray, "strides", ndarray_get_strides, 0);
|
810
|
+
rb_define_method(cNDArray, "[]", ndarray_aref, -1);
|
811
|
+
rb_define_method(cNDArray, "[]=", ndarray_aset, -1);
|
812
|
+
rb_define_method(cNDArray, "==", ndarray_eq, 1);
|
813
|
+
|
814
|
+
rb_define_private_method(cNDArray, "reshape_impl", ndarray_reshape_impl, 2);
|
815
|
+
|
816
|
+
ndarray_dtype_ids[ndarray_dtype_int8] = rb_intern("int8");
|
817
|
+
ndarray_dtype_ids[ndarray_dtype_uint8] = rb_intern("uint8");
|
818
|
+
ndarray_dtype_ids[ndarray_dtype_int16] = rb_intern("int16");
|
819
|
+
ndarray_dtype_ids[ndarray_dtype_uint16] = rb_intern("uint16");
|
820
|
+
ndarray_dtype_ids[ndarray_dtype_int32] = rb_intern("int32");
|
821
|
+
ndarray_dtype_ids[ndarray_dtype_uint32] = rb_intern("uint32");
|
822
|
+
ndarray_dtype_ids[ndarray_dtype_int64] = rb_intern("int64");
|
823
|
+
ndarray_dtype_ids[ndarray_dtype_uint64] = rb_intern("uint64");
|
824
|
+
ndarray_dtype_ids[ndarray_dtype_float32] = rb_intern("float32");
|
825
|
+
ndarray_dtype_ids[ndarray_dtype_float64] = rb_intern("float64");
|
826
|
+
|
827
|
+
sym_row_major = ID2SYM(rb_intern("row_major"));
|
828
|
+
sym_column_major = ID2SYM(rb_intern("column_major"));
|
829
|
+
sym_auto = ID2SYM(rb_intern("auto"));
|
830
|
+
|
831
|
+
(void)ndarray_dtype_sizes; /* TODO: to be deleted */
|
832
|
+
}
|
@@ -0,0 +1,162 @@
|
|
1
|
+
require "memory_view_test_helper.so"
|
2
|
+
require "memory-view-test-helper/version"
|
3
|
+
require "set"
|
4
|
+
|
5
|
+
module MemoryViewTestHelper
|
6
|
+
class NDArray
|
7
|
+
def self.try_convert(obj, dtype: nil, order: :row_major)
|
8
|
+
begin
|
9
|
+
ary = obj.to_ary
|
10
|
+
rescue TypeError
|
11
|
+
raise ArgumentError, "the argument must be converted to an Array by to_ary (#{obj.class} given)"
|
12
|
+
end
|
13
|
+
|
14
|
+
dtype, shape, cache = detect_dtype_and_shape(ary, dtype)
|
15
|
+
nar = new(shape, dtype)
|
16
|
+
assign_cache(nar, cache)
|
17
|
+
return nar
|
18
|
+
end
|
19
|
+
|
20
|
+
private_class_method def self.assign_cache(nar, cache)
|
21
|
+
if nar.ndim == 1
|
22
|
+
src = cache[0][:ary]
|
23
|
+
src.each_with_index do |x, i|
|
24
|
+
nar[i] = x
|
25
|
+
end
|
26
|
+
else
|
27
|
+
assign_cache_recursive(nar, [], cache, 0)
|
28
|
+
end
|
29
|
+
end
|
30
|
+
|
31
|
+
private_class_method def self.assign_cache_recursive(nar, idx, cache, k)
|
32
|
+
if cache[k][:dim]+1 != nar.ndim
|
33
|
+
cache[k][:ary].each_index do |i|
|
34
|
+
k = assign_cache_recursive(nar, [*idx, i], cache, k+1)
|
35
|
+
end
|
36
|
+
else
|
37
|
+
cache[k][:ary].each_with_index do |x, i|
|
38
|
+
nar[*idx, i] = x
|
39
|
+
end
|
40
|
+
end
|
41
|
+
k
|
42
|
+
end
|
43
|
+
|
44
|
+
private_class_method def self.detect_dtype_and_shape(ary, dtype)
|
45
|
+
current_dim = ary.length
|
46
|
+
shape = []
|
47
|
+
cache = []
|
48
|
+
_, dtype, shape, cache = detect_dtype_and_shape_recursive(ary, 0, nil, dtype, shape, cache)
|
49
|
+
return dtype, shape, cache
|
50
|
+
end
|
51
|
+
|
52
|
+
private_class_method def self.detect_dtype_and_shape_recursive(obj, dim, max_dim, fixed_dtype, out_shape, conversion_cache)
|
53
|
+
dtype = detect_dtype(obj)
|
54
|
+
unless dtype.nil?
|
55
|
+
# obj is scalar
|
56
|
+
# TODO handle scalar object
|
57
|
+
if max_dim.nil?
|
58
|
+
max_dim = dim # update max_dim
|
59
|
+
elsif dim != max_dim
|
60
|
+
dim_failed = [dim, max_dim].min
|
61
|
+
raise ArgumentError, "inhomogeneous array detected at the the #{dim_failed}#{ordinal(dim_failed)} dimension"
|
62
|
+
end
|
63
|
+
return max_dim, dtype, out_shape, conversion_cache
|
64
|
+
end
|
65
|
+
|
66
|
+
# obj is array-like
|
67
|
+
ary = Array(obj)
|
68
|
+
conversion_cache << {obj:obj, ary:ary, dim:dim}
|
69
|
+
|
70
|
+
dim_size = ary.length
|
71
|
+
if out_shape.length <= dim
|
72
|
+
# update_shape
|
73
|
+
out_shape[dim] = dim_size
|
74
|
+
elsif out_shape[dim] != dim_size
|
75
|
+
raise ArgumentError, "size mismatch at the #{dim}#{ordinal(dim)} dimension (#{dim_size} for #{out_shape[dim]})"
|
76
|
+
end
|
77
|
+
|
78
|
+
# recursive detection
|
79
|
+
ary.each do |sub|
|
80
|
+
max_dim, dtype_sub, = detect_dtype_and_shape_recursive(sub, dim + 1, max_dim, fixed_dtype, out_shape, conversion_cache)
|
81
|
+
dtype = promote_dtype(dtype, dtype_sub) unless fixed_dtype
|
82
|
+
end
|
83
|
+
|
84
|
+
return max_dim, (fixed_dtype || dtype), out_shape, conversion_cache
|
85
|
+
end
|
86
|
+
|
87
|
+
private_class_method def self.detect_dtype(obj)
|
88
|
+
case obj
|
89
|
+
when Integer
|
90
|
+
:int64
|
91
|
+
when Float, Rational
|
92
|
+
:float64
|
93
|
+
when ->(x) { x.is_a?(Complex) && x.imag == 0 }
|
94
|
+
detect_dtype(x.real)
|
95
|
+
when Enumerable, proc { obj.respond_to?(:to_ary) }
|
96
|
+
nil
|
97
|
+
else
|
98
|
+
raise TypeError, "#{obj.class} is unsupported"
|
99
|
+
end
|
100
|
+
end
|
101
|
+
|
102
|
+
INTEGER_TYPES = Set[:int8, :uint8, :int16, :uint16, :int32, :uint32, :int64, :uint64].freeze
|
103
|
+
|
104
|
+
SIZEOF_DTYPE = {
|
105
|
+
int8: 1, uint8: 1,
|
106
|
+
int16: 2, uint16: 2,
|
107
|
+
int32: 4, uint32: 4,
|
108
|
+
int64: 8, uint64: 8,
|
109
|
+
float32: 4,
|
110
|
+
float64: 8
|
111
|
+
}.freeze
|
112
|
+
|
113
|
+
private_class_method def self.promote_dtype(dtype_a, dtype_b)
|
114
|
+
# TODO: use sizeof
|
115
|
+
if dtype_a == dtype_b
|
116
|
+
dtype_a
|
117
|
+
elsif dtype_a.nil? || dtype_b.nil?
|
118
|
+
dtype_a || dtype_b
|
119
|
+
else
|
120
|
+
sizeof_a = SIZEOF_DTYPE[dtype_a]
|
121
|
+
sizeof_b = SIZEOF_DTYPE[dtype_b]
|
122
|
+
|
123
|
+
if INTEGER_TYPES.include?(dtype_a) && INTEGER_TYPES.include?(dtype_b)
|
124
|
+
# both are integer
|
125
|
+
if sizeof_a > sizeof_b
|
126
|
+
dtype_a
|
127
|
+
elsif sizeof_b > sizeof_a
|
128
|
+
dtype_b
|
129
|
+
else
|
130
|
+
raise TypeError, "auto promotion between signed and unsigned is not supported"
|
131
|
+
end
|
132
|
+
elsif INTEGER_TYPES.include?(dtype_a)
|
133
|
+
# b is float
|
134
|
+
dtype_b
|
135
|
+
elsif INTEGER_TYPES.include?(dtype_b)
|
136
|
+
# a is float
|
137
|
+
dtype_a
|
138
|
+
else
|
139
|
+
# both are float
|
140
|
+
sizeof_a > sizeof_b ? dtype_a : dtype_b
|
141
|
+
end
|
142
|
+
end
|
143
|
+
end
|
144
|
+
|
145
|
+
private_class_method def self.ordinal(n)
|
146
|
+
case n % 10
|
147
|
+
when 1
|
148
|
+
n != 11 ? "st" : "th"
|
149
|
+
when 2
|
150
|
+
n != 12 ? "nd" : "th"
|
151
|
+
when 3
|
152
|
+
n != 13 ? "rd" : "th"
|
153
|
+
else
|
154
|
+
"th"
|
155
|
+
end
|
156
|
+
end
|
157
|
+
|
158
|
+
def reshape(new_shape, order: :row_major)
|
159
|
+
reshape_impl(new_shape.to_ary, order.to_sym)
|
160
|
+
end
|
161
|
+
end
|
162
|
+
end
|
@@ -0,0 +1,37 @@
|
|
1
|
+
lib = File.expand_path("../lib", __FILE__)
|
2
|
+
$LOAD_PATH.unshift(lib) unless $LOAD_PATH.include?(lib)
|
3
|
+
require "memory-view-test-helper/version"
|
4
|
+
|
5
|
+
clean_white_space = lambda do |entry|
|
6
|
+
entry.gsub(/(\A\n+|\n+\z)/, "") + "\n"
|
7
|
+
end
|
8
|
+
|
9
|
+
Gem::Specification.new do |spec|
|
10
|
+
spec.name = "memory-view-test-helper"
|
11
|
+
spec.version = MemoryViewTestHelper::VERSION
|
12
|
+
spec.homepage = "https://github.com/mrkn/memory-view-test-helper"
|
13
|
+
spec.authors = ["Ketna Murata"]
|
14
|
+
spec.email = ["mrkn@mrkn.jp"]
|
15
|
+
|
16
|
+
readme = File.read("README.md")
|
17
|
+
readme.force_encoding("UTF-8")
|
18
|
+
entries = readme.split(/^\#\#\s(.*)$/)
|
19
|
+
description = clean_white_space.call(entries[entries.index("Description") + 1])
|
20
|
+
spec.summary, spec.description, = description.split(/\n\n+/, 3)
|
21
|
+
spec.license = "MIT"
|
22
|
+
spec.files = [
|
23
|
+
"README.md",
|
24
|
+
"LICENSE.txt",
|
25
|
+
"Rakefile",
|
26
|
+
"Gemfile",
|
27
|
+
"#{spec.name}.gemspec",
|
28
|
+
]
|
29
|
+
spec.files += Dir.glob("lib/**/*.rb")
|
30
|
+
spec.files += Dir.glob("ext/**/*.{c,h,rb}")
|
31
|
+
spec.files += Dir.glob("ext/**/depend")
|
32
|
+
spec.extensions = ["ext/memory-view-test-helper/extconf.rb"]
|
33
|
+
|
34
|
+
spec.add_development_dependency("bundler")
|
35
|
+
spec.add_development_dependency("rake")
|
36
|
+
spec.add_development_dependency("test-unit")
|
37
|
+
end
|
metadata
ADDED
@@ -0,0 +1,97 @@
|
|
1
|
+
--- !ruby/object:Gem::Specification
|
2
|
+
name: memory-view-test-helper
|
3
|
+
version: !ruby/object:Gem::Version
|
4
|
+
version: 0.0.1
|
5
|
+
platform: ruby
|
6
|
+
authors:
|
7
|
+
- Ketna Murata
|
8
|
+
autorequire:
|
9
|
+
bindir: bin
|
10
|
+
cert_chain: []
|
11
|
+
date: 2021-03-17 00:00:00.000000000 Z
|
12
|
+
dependencies:
|
13
|
+
- !ruby/object:Gem::Dependency
|
14
|
+
name: bundler
|
15
|
+
requirement: !ruby/object:Gem::Requirement
|
16
|
+
requirements:
|
17
|
+
- - ">="
|
18
|
+
- !ruby/object:Gem::Version
|
19
|
+
version: '0'
|
20
|
+
type: :development
|
21
|
+
prerelease: false
|
22
|
+
version_requirements: !ruby/object:Gem::Requirement
|
23
|
+
requirements:
|
24
|
+
- - ">="
|
25
|
+
- !ruby/object:Gem::Version
|
26
|
+
version: '0'
|
27
|
+
- !ruby/object:Gem::Dependency
|
28
|
+
name: rake
|
29
|
+
requirement: !ruby/object:Gem::Requirement
|
30
|
+
requirements:
|
31
|
+
- - ">="
|
32
|
+
- !ruby/object:Gem::Version
|
33
|
+
version: '0'
|
34
|
+
type: :development
|
35
|
+
prerelease: false
|
36
|
+
version_requirements: !ruby/object:Gem::Requirement
|
37
|
+
requirements:
|
38
|
+
- - ">="
|
39
|
+
- !ruby/object:Gem::Version
|
40
|
+
version: '0'
|
41
|
+
- !ruby/object:Gem::Dependency
|
42
|
+
name: test-unit
|
43
|
+
requirement: !ruby/object:Gem::Requirement
|
44
|
+
requirements:
|
45
|
+
- - ">="
|
46
|
+
- !ruby/object:Gem::Version
|
47
|
+
version: '0'
|
48
|
+
type: :development
|
49
|
+
prerelease: false
|
50
|
+
version_requirements: !ruby/object:Gem::Requirement
|
51
|
+
requirements:
|
52
|
+
- - ">="
|
53
|
+
- !ruby/object:Gem::Version
|
54
|
+
version: '0'
|
55
|
+
description: "`MemoryViewTestHelper::NDArray` provides simple multi-dimensional numeric
|
56
|
+
array that can export MemoryView.\n"
|
57
|
+
email:
|
58
|
+
- mrkn@mrkn.jp
|
59
|
+
executables: []
|
60
|
+
extensions:
|
61
|
+
- ext/memory-view-test-helper/extconf.rb
|
62
|
+
extra_rdoc_files: []
|
63
|
+
files:
|
64
|
+
- Gemfile
|
65
|
+
- LICENSE.txt
|
66
|
+
- README.md
|
67
|
+
- Rakefile
|
68
|
+
- ext/memory-view-test-helper/extconf.rb
|
69
|
+
- ext/memory-view-test-helper/memory-view-test-helper.c
|
70
|
+
- lib/memory-view-test-helper.rb
|
71
|
+
- lib/memory-view-test-helper/version.rb
|
72
|
+
- memory-view-test-helper.gemspec
|
73
|
+
homepage: https://github.com/mrkn/memory-view-test-helper
|
74
|
+
licenses:
|
75
|
+
- MIT
|
76
|
+
metadata: {}
|
77
|
+
post_install_message:
|
78
|
+
rdoc_options: []
|
79
|
+
require_paths:
|
80
|
+
- lib
|
81
|
+
required_ruby_version: !ruby/object:Gem::Requirement
|
82
|
+
requirements:
|
83
|
+
- - ">="
|
84
|
+
- !ruby/object:Gem::Version
|
85
|
+
version: '0'
|
86
|
+
required_rubygems_version: !ruby/object:Gem::Requirement
|
87
|
+
requirements:
|
88
|
+
- - ">="
|
89
|
+
- !ruby/object:Gem::Version
|
90
|
+
version: '0'
|
91
|
+
requirements: []
|
92
|
+
rubygems_version: 3.2.3
|
93
|
+
signing_key:
|
94
|
+
specification_version: 4
|
95
|
+
summary: MemoryViewTestHelper provides features that help to test libraries that include
|
96
|
+
MemoryView support.
|
97
|
+
test_files: []
|