cumo 0.4.3 → 0.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.gitignore +3 -0
- data/.rubocop.yml +15 -0
- data/.rubocop_todo.yml +1272 -0
- data/3rd_party/mkmf-cu/Gemfile +2 -0
- data/3rd_party/mkmf-cu/Rakefile +2 -1
- data/3rd_party/mkmf-cu/bin/mkmf-cu-nvcc +2 -0
- data/3rd_party/mkmf-cu/lib/mkmf-cu/cli.rb +36 -7
- data/3rd_party/mkmf-cu/lib/mkmf-cu/nvcc.rb +51 -45
- data/3rd_party/mkmf-cu/lib/mkmf-cu.rb +2 -0
- data/3rd_party/mkmf-cu/mkmf-cu.gemspec +3 -1
- data/3rd_party/mkmf-cu/test/test_mkmf-cu.rb +5 -3
- data/CHANGELOG.md +69 -0
- data/Gemfile +6 -1
- data/README.md +2 -10
- data/Rakefile +8 -11
- data/bench/broadcast_fp32.rb +28 -26
- data/bench/cumo_bench.rb +18 -16
- data/bench/numo_bench.rb +18 -16
- data/bench/reduction_fp32.rb +14 -12
- data/bin/console +1 -0
- data/cumo.gemspec +5 -8
- data/ext/cumo/cuda/cudnn.c +2 -2
- data/ext/cumo/cumo.c +7 -3
- data/ext/cumo/depend.erb +15 -13
- data/ext/cumo/extconf.rb +32 -46
- data/ext/cumo/include/cumo/cuda/cudnn.h +3 -1
- data/ext/cumo/include/cumo/intern.h +1 -0
- data/ext/cumo/include/cumo/narray.h +13 -1
- data/ext/cumo/include/cumo/template.h +2 -4
- data/ext/cumo/include/cumo/types/complex_macro.h +1 -1
- data/ext/cumo/include/cumo/types/float_macro.h +2 -2
- data/ext/cumo/include/cumo/types/xint_macro.h +3 -2
- data/ext/cumo/include/cumo.h +2 -2
- data/ext/cumo/narray/array.c +3 -3
- data/ext/cumo/narray/data.c +23 -2
- data/ext/cumo/narray/gen/cogen.rb +8 -7
- data/ext/cumo/narray/gen/cogen_kernel.rb +8 -7
- data/ext/cumo/narray/gen/def/bit.rb +3 -1
- data/ext/cumo/narray/gen/def/dcomplex.rb +2 -0
- data/ext/cumo/narray/gen/def/dfloat.rb +2 -0
- data/ext/cumo/narray/gen/def/int16.rb +2 -0
- data/ext/cumo/narray/gen/def/int32.rb +2 -0
- data/ext/cumo/narray/gen/def/int64.rb +2 -0
- data/ext/cumo/narray/gen/def/int8.rb +2 -0
- data/ext/cumo/narray/gen/def/robject.rb +2 -0
- data/ext/cumo/narray/gen/def/scomplex.rb +2 -0
- data/ext/cumo/narray/gen/def/sfloat.rb +2 -0
- data/ext/cumo/narray/gen/def/uint16.rb +2 -0
- data/ext/cumo/narray/gen/def/uint32.rb +2 -0
- data/ext/cumo/narray/gen/def/uint64.rb +2 -0
- data/ext/cumo/narray/gen/def/uint8.rb +2 -0
- data/ext/cumo/narray/gen/erbln.rb +9 -7
- data/ext/cumo/narray/gen/erbpp2.rb +26 -24
- data/ext/cumo/narray/gen/narray_def.rb +13 -11
- data/ext/cumo/narray/gen/spec.rb +58 -55
- data/ext/cumo/narray/gen/tmpl/alloc_func.c +1 -1
- data/ext/cumo/narray/gen/tmpl/at.c +34 -0
- data/ext/cumo/narray/gen/tmpl/batch_norm.c +1 -1
- data/ext/cumo/narray/gen/tmpl/batch_norm_backward.c +2 -2
- data/ext/cumo/narray/gen/tmpl/conv.c +1 -1
- data/ext/cumo/narray/gen/tmpl/conv_grad_w.c +3 -1
- data/ext/cumo/narray/gen/tmpl/conv_transpose.c +1 -1
- data/ext/cumo/narray/gen/tmpl/fixed_batch_norm.c +1 -1
- data/ext/cumo/narray/gen/tmpl/init_class.c +1 -0
- data/ext/cumo/narray/gen/tmpl/pooling_backward.c +1 -1
- data/ext/cumo/narray/gen/tmpl/pooling_forward.c +1 -1
- data/ext/cumo/narray/gen/tmpl/qsort.c +1 -5
- data/ext/cumo/narray/gen/tmpl/sort.c +1 -1
- data/ext/cumo/narray/gen/tmpl_bit/binary.c +42 -14
- data/ext/cumo/narray/gen/tmpl_bit/bit_count.c +5 -0
- data/ext/cumo/narray/gen/tmpl_bit/bit_reduce.c +5 -0
- data/ext/cumo/narray/gen/tmpl_bit/mask.c +27 -7
- data/ext/cumo/narray/gen/tmpl_bit/store_bit.c +21 -7
- data/ext/cumo/narray/gen/tmpl_bit/unary.c +21 -7
- data/ext/cumo/narray/index.c +243 -39
- data/ext/cumo/narray/index_kernel.cu +84 -0
- data/ext/cumo/narray/narray.c +38 -1
- data/ext/cumo/narray/ndloop.c +1 -1
- data/ext/cumo/narray/struct.c +1 -1
- data/lib/cumo/cuda/compile_error.rb +1 -1
- data/lib/cumo/cuda/compiler.rb +23 -22
- data/lib/cumo/cuda/cudnn.rb +1 -1
- data/lib/cumo/cuda/device.rb +1 -1
- data/lib/cumo/cuda/link_state.rb +2 -2
- data/lib/cumo/cuda/module.rb +1 -2
- data/lib/cumo/cuda/nvrtc_program.rb +3 -2
- data/lib/cumo/cuda.rb +2 -0
- data/lib/cumo/linalg.rb +2 -0
- data/lib/cumo/narray/extra.rb +137 -185
- data/lib/cumo/narray.rb +2 -0
- data/lib/cumo.rb +3 -1
- data/test/bit_test.rb +157 -0
- data/test/cuda/compiler_test.rb +69 -0
- data/test/cuda/device_test.rb +30 -0
- data/test/cuda/memory_pool_test.rb +45 -0
- data/test/cuda/nvrtc_test.rb +51 -0
- data/test/cuda/runtime_test.rb +28 -0
- data/test/cudnn_test.rb +498 -0
- data/test/cumo_test.rb +27 -0
- data/test/narray_test.rb +745 -0
- data/test/ractor_test.rb +52 -0
- data/test/test_helper.rb +31 -0
- metadata +31 -54
- data/.travis.yml +0 -5
- data/numo-narray-version +0 -1
data/lib/cumo/narray/extra.rb
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
# frozen_string_literal: true
|
|
2
|
+
|
|
1
3
|
module Cumo
|
|
2
4
|
class NArray
|
|
3
5
|
|
|
@@ -23,12 +25,12 @@ module Cumo
|
|
|
23
25
|
|
|
24
26
|
# Convert angles from radians to degrees.
|
|
25
27
|
def rad2deg
|
|
26
|
-
self * (180/Math::PI)
|
|
28
|
+
self * (180 / Math::PI)
|
|
27
29
|
end
|
|
28
30
|
|
|
29
31
|
# Convert angles from degrees to radians.
|
|
30
32
|
def deg2rad
|
|
31
|
-
self * (Math::PI/180)
|
|
33
|
+
self * (Math::PI / 180)
|
|
32
34
|
end
|
|
33
35
|
|
|
34
36
|
# Flip each row in the left/right direction.
|
|
@@ -43,56 +45,6 @@ module Cumo
|
|
|
43
45
|
reverse(0)
|
|
44
46
|
end
|
|
45
47
|
|
|
46
|
-
# Multi-dimensional array indexing.
|
|
47
|
-
# Same as [] for one-dimensional NArray.
|
|
48
|
-
# Similar to numpy's tuple indexing, i.e., `a[[1,2,..],[3,4,..]]`
|
|
49
|
-
# (This method will be rewritten in C)
|
|
50
|
-
# @return [Cumo::NArray] one-dimensional view of self.
|
|
51
|
-
# @example
|
|
52
|
-
# p x = Cumo::DFloat.new(3,3,3).seq
|
|
53
|
-
# # Cumo::DFloat#shape=[3,3,3]
|
|
54
|
-
# # [[[0, 1, 2],
|
|
55
|
-
# # [3, 4, 5],
|
|
56
|
-
# # [6, 7, 8]],
|
|
57
|
-
# # [[9, 10, 11],
|
|
58
|
-
# # [12, 13, 14],
|
|
59
|
-
# # [15, 16, 17]],
|
|
60
|
-
# # [[18, 19, 20],
|
|
61
|
-
# # [21, 22, 23],
|
|
62
|
-
# # [24, 25, 26]]]
|
|
63
|
-
#
|
|
64
|
-
# p x.at([0,1,2],[0,1,2],[-1,-2,-3])
|
|
65
|
-
# # Cumo::DFloat(view)#shape=[3]
|
|
66
|
-
# # [2, 13, 24]
|
|
67
|
-
def at(*indices)
|
|
68
|
-
if indices.size != ndim
|
|
69
|
-
raise DimensionError, "argument length does not match dimension size"
|
|
70
|
-
end
|
|
71
|
-
idx = nil
|
|
72
|
-
stride = 1
|
|
73
|
-
(indices.size-1).downto(0) do |i|
|
|
74
|
-
ix = Int64.cast(indices[i])
|
|
75
|
-
if ix.ndim != 1
|
|
76
|
-
raise DimensionError, "index array is not one-dimensional"
|
|
77
|
-
end
|
|
78
|
-
ix[ix < 0] += shape[i]
|
|
79
|
-
if ((ix < 0) & (ix >= shape[i])).any?
|
|
80
|
-
raise IndexError, "index array is out of range"
|
|
81
|
-
end
|
|
82
|
-
if idx
|
|
83
|
-
if idx.size != ix.size
|
|
84
|
-
raise ShapeError, "index array sizes mismatch"
|
|
85
|
-
end
|
|
86
|
-
idx += ix * stride
|
|
87
|
-
stride *= shape[i]
|
|
88
|
-
else
|
|
89
|
-
idx = ix
|
|
90
|
-
stride = shape[i]
|
|
91
|
-
end
|
|
92
|
-
end
|
|
93
|
-
self[idx]
|
|
94
|
-
end
|
|
95
|
-
|
|
96
48
|
# Rotate in the plane specified by axes.
|
|
97
49
|
# @example
|
|
98
50
|
# p a = Cumo::Int32.new(2,2).seq
|
|
@@ -114,7 +66,7 @@ module Cumo
|
|
|
114
66
|
# # Cumo::Int32(view)#shape=[2,2]
|
|
115
67
|
# # [[2, 0],
|
|
116
68
|
# # [3, 1]]
|
|
117
|
-
def rot90(k=1,axes=[0,1])
|
|
69
|
+
def rot90(k=1, axes=[0, 1])
|
|
118
70
|
case k % 4
|
|
119
71
|
when 0
|
|
120
72
|
view
|
|
@@ -128,7 +80,7 @@ module Cumo
|
|
|
128
80
|
end
|
|
129
81
|
|
|
130
82
|
def to_i
|
|
131
|
-
if size==1
|
|
83
|
+
if size == 1
|
|
132
84
|
self.extract_cpu.to_i
|
|
133
85
|
else
|
|
134
86
|
# convert to Int?
|
|
@@ -137,7 +89,7 @@ module Cumo
|
|
|
137
89
|
end
|
|
138
90
|
|
|
139
91
|
def to_f
|
|
140
|
-
if size==1
|
|
92
|
+
if size == 1
|
|
141
93
|
self.extract_cpu.to_f
|
|
142
94
|
else
|
|
143
95
|
# convert to DFloat?
|
|
@@ -146,7 +98,7 @@ module Cumo
|
|
|
146
98
|
end
|
|
147
99
|
|
|
148
100
|
def to_c
|
|
149
|
-
if size==1
|
|
101
|
+
if size == 1
|
|
150
102
|
Complex(self.extract_cpu)
|
|
151
103
|
else
|
|
152
104
|
# convert to DComplex?
|
|
@@ -163,7 +115,7 @@ module Cumo
|
|
|
163
115
|
case a
|
|
164
116
|
when NArray
|
|
165
117
|
(a.ndim == 0) ? a[:new] : a
|
|
166
|
-
when Numeric,Range
|
|
118
|
+
when Numeric, Range
|
|
167
119
|
self[a]
|
|
168
120
|
else
|
|
169
121
|
cast(a)
|
|
@@ -201,7 +153,7 @@ module Cumo
|
|
|
201
153
|
end
|
|
202
154
|
a << b if !b.empty?
|
|
203
155
|
end
|
|
204
|
-
if a.size==1
|
|
156
|
+
if a.size == 1
|
|
205
157
|
self.cast(a[0])
|
|
206
158
|
else
|
|
207
159
|
self.cast(a)
|
|
@@ -237,18 +189,18 @@ module Cumo
|
|
|
237
189
|
|
|
238
190
|
def each_over_axis(axis=0)
|
|
239
191
|
unless block_given?
|
|
240
|
-
return to_enum(:each_over_axis,axis)
|
|
192
|
+
return to_enum(:each_over_axis, axis)
|
|
241
193
|
end
|
|
242
194
|
if ndim == 0
|
|
243
195
|
if axis != 0
|
|
244
|
-
raise ArgumentError,"axis=#{axis} is invalid"
|
|
196
|
+
raise ArgumentError, "axis=#{axis} is invalid"
|
|
245
197
|
end
|
|
246
198
|
niter = 1
|
|
247
199
|
else
|
|
248
200
|
axis = check_axis(axis)
|
|
249
201
|
niter = shape[axis]
|
|
250
202
|
end
|
|
251
|
-
idx = [true]*ndim
|
|
203
|
+
idx = [true] * ndim
|
|
252
204
|
niter.times do |i|
|
|
253
205
|
idx[axis] = i
|
|
254
206
|
yield(self[*idx])
|
|
@@ -275,15 +227,15 @@ module Cumo
|
|
|
275
227
|
# p a.append([7, 8, 9], axis:0)
|
|
276
228
|
# # in `append': dimension mismatch (Cumo::NArray::DimensionError)
|
|
277
229
|
|
|
278
|
-
def append(other,axis:nil)
|
|
230
|
+
def append(other, axis:nil)
|
|
279
231
|
other = self.class.cast(other)
|
|
280
232
|
if axis
|
|
281
233
|
if ndim != other.ndim
|
|
282
234
|
raise DimensionError, "dimension mismatch"
|
|
283
235
|
end
|
|
284
|
-
return concatenate(other,axis:axis)
|
|
236
|
+
return concatenate(other, axis:axis)
|
|
285
237
|
else
|
|
286
|
-
a = self.class.zeros(size+other.size)
|
|
238
|
+
a = self.class.zeros(size + other.size)
|
|
287
239
|
a[0...size] = self[true]
|
|
288
240
|
a[size..-1] = other[true]
|
|
289
241
|
return a
|
|
@@ -310,11 +262,11 @@ module Cumo
|
|
|
310
262
|
# # Cumo::DFloat(view)#shape=[9]
|
|
311
263
|
# # [1, 3, 5, 7, 8, 9, 10, 11, 12]
|
|
312
264
|
|
|
313
|
-
def delete(indice,axis=nil)
|
|
265
|
+
def delete(indice, axis=nil)
|
|
314
266
|
if axis
|
|
315
267
|
bit = Bit.ones(shape[axis])
|
|
316
268
|
bit[indice] = 0
|
|
317
|
-
idx = [true]*ndim
|
|
269
|
+
idx = [true] * ndim
|
|
318
270
|
idx[axis] = bit.where
|
|
319
271
|
return self[*idx].copy
|
|
320
272
|
else
|
|
@@ -395,14 +347,14 @@ module Cumo
|
|
|
395
347
|
# # [[0, 999, 1, 2, 999, 3],
|
|
396
348
|
# # [4, 999, 5, 6, 999, 7]]
|
|
397
349
|
|
|
398
|
-
def insert(indice,values,axis:nil)
|
|
350
|
+
def insert(indice, values, axis:nil)
|
|
399
351
|
if axis
|
|
400
352
|
values = self.class.asarray(values)
|
|
401
353
|
nd = values.ndim
|
|
402
|
-
midx = [:new]*(ndim-nd) + [true]*nd
|
|
354
|
+
midx = [:new] * (ndim - nd) + [true] * nd
|
|
403
355
|
case indice
|
|
404
356
|
when Numeric
|
|
405
|
-
midx[-nd-1] = true
|
|
357
|
+
midx[-nd - 1] = true
|
|
406
358
|
midx[axis] = :new
|
|
407
359
|
end
|
|
408
360
|
values = values[*midx]
|
|
@@ -412,27 +364,27 @@ module Cumo
|
|
|
412
364
|
idx = Int64.asarray(indice)
|
|
413
365
|
nidx = idx.size
|
|
414
366
|
if nidx == 1
|
|
415
|
-
nidx = values.shape[axis||0]
|
|
367
|
+
nidx = values.shape[axis || 0]
|
|
416
368
|
idx = idx + Int64.new(nidx).seq
|
|
417
369
|
else
|
|
418
370
|
sidx = idx.sort_index
|
|
419
371
|
idx[sidx] += Int64.new(nidx).seq
|
|
420
372
|
end
|
|
421
373
|
if axis
|
|
422
|
-
bit = Bit.ones(shape[axis]+nidx)
|
|
374
|
+
bit = Bit.ones(shape[axis] + nidx)
|
|
423
375
|
bit[idx] = 0
|
|
424
376
|
new_shape = shape
|
|
425
377
|
new_shape[axis] += nidx
|
|
426
378
|
a = self.class.zeros(new_shape)
|
|
427
|
-
mdidx = [true]*ndim
|
|
379
|
+
mdidx = [true] * ndim
|
|
428
380
|
mdidx[axis] = bit.where
|
|
429
381
|
a[*mdidx] = self
|
|
430
382
|
mdidx[axis] = idx
|
|
431
383
|
a[*mdidx] = values
|
|
432
384
|
else
|
|
433
|
-
bit = Bit.ones(size+nidx)
|
|
385
|
+
bit = Bit.ones(size + nidx)
|
|
434
386
|
bit[idx] = 0
|
|
435
|
-
a = self.class.zeros(size+nidx)
|
|
387
|
+
a = self.class.zeros(size + nidx)
|
|
436
388
|
a[bit.where] = self.flatten
|
|
437
389
|
a[idx] = values
|
|
438
390
|
end
|
|
@@ -461,8 +413,8 @@ module Cumo
|
|
|
461
413
|
# # [[1, 2, 5],
|
|
462
414
|
# # [3, 4, 6]]
|
|
463
415
|
|
|
464
|
-
def concatenate(arrays,axis:0)
|
|
465
|
-
klass = (self==NArray) ? NArray.array_type(arrays) : self
|
|
416
|
+
def concatenate(arrays, axis:0)
|
|
417
|
+
klass = (self == NArray) ? NArray.array_type(arrays) : self
|
|
466
418
|
nd = 0
|
|
467
419
|
arrays = arrays.map do |a|
|
|
468
420
|
case a
|
|
@@ -473,7 +425,7 @@ module Cumo
|
|
|
473
425
|
when Array
|
|
474
426
|
a = klass.cast(a)
|
|
475
427
|
else
|
|
476
|
-
raise TypeError,"not Cumo::NArray: #{a.inspect[0..48]}"
|
|
428
|
+
raise TypeError, "not Cumo::NArray: #{a.inspect[0..48]}"
|
|
477
429
|
end
|
|
478
430
|
if a.ndim > nd
|
|
479
431
|
nd = a.ndim
|
|
@@ -484,31 +436,31 @@ module Cumo
|
|
|
484
436
|
axis += nd
|
|
485
437
|
end
|
|
486
438
|
if axis < 0 || axis >= nd
|
|
487
|
-
raise ArgumentError,"axis is out of range"
|
|
439
|
+
raise ArgumentError, "axis is out of range"
|
|
488
440
|
end
|
|
489
441
|
new_shape = nil
|
|
490
442
|
sum_size = 0
|
|
491
443
|
arrays.each do |a|
|
|
492
444
|
a_shape = a.shape
|
|
493
445
|
if nd != a_shape.size
|
|
494
|
-
a_shape = [1]*(nd-a_shape.size) + a_shape
|
|
446
|
+
a_shape = [1] * (nd - a_shape.size) + a_shape
|
|
495
447
|
end
|
|
496
448
|
sum_size += a_shape.delete_at(axis)
|
|
497
449
|
if new_shape
|
|
498
450
|
if new_shape != a_shape
|
|
499
|
-
raise ShapeError,"shape mismatch"
|
|
451
|
+
raise ShapeError, "shape mismatch"
|
|
500
452
|
end
|
|
501
453
|
else
|
|
502
454
|
new_shape = a_shape
|
|
503
455
|
end
|
|
504
456
|
end
|
|
505
|
-
new_shape.insert(axis,sum_size)
|
|
457
|
+
new_shape.insert(axis, sum_size)
|
|
506
458
|
result = klass.zeros(*new_shape)
|
|
507
459
|
lst = 0
|
|
508
460
|
refs = [true] * nd
|
|
509
461
|
arrays.each do |a|
|
|
510
462
|
fst = lst
|
|
511
|
-
lst = fst + (a.shape[axis-nd]||1)
|
|
463
|
+
lst = fst + (a.shape[axis - nd] || 1)
|
|
512
464
|
refs[axis] = fst...lst
|
|
513
465
|
result[*refs] = a
|
|
514
466
|
end
|
|
@@ -539,7 +491,7 @@ module Cumo
|
|
|
539
491
|
arys = arrays.map do |a|
|
|
540
492
|
_atleast_2d(cast(a))
|
|
541
493
|
end
|
|
542
|
-
concatenate(arys,axis:0)
|
|
494
|
+
concatenate(arys, axis:0)
|
|
543
495
|
end
|
|
544
496
|
|
|
545
497
|
# Stack arrays horizontally (column wise).
|
|
@@ -559,7 +511,7 @@ module Cumo
|
|
|
559
511
|
# # [3, 4]]
|
|
560
512
|
|
|
561
513
|
def hstack(arrays)
|
|
562
|
-
klass = (self==NArray) ? NArray.array_type(arrays) : self
|
|
514
|
+
klass = (self == NArray) ? NArray.array_type(arrays) : self
|
|
563
515
|
nd = 0
|
|
564
516
|
arys = arrays.map do |a|
|
|
565
517
|
a = klass.cast(a)
|
|
@@ -567,7 +519,7 @@ module Cumo
|
|
|
567
519
|
a
|
|
568
520
|
end
|
|
569
521
|
dim = (nd >= 2) ? 1 : 0
|
|
570
|
-
concatenate(arys,axis:dim)
|
|
522
|
+
concatenate(arys, axis:dim)
|
|
571
523
|
end
|
|
572
524
|
|
|
573
525
|
# Stack arrays in depth wise (along third axis).
|
|
@@ -592,7 +544,7 @@ module Cumo
|
|
|
592
544
|
arys = arrays.map do |a|
|
|
593
545
|
_atleast_3d(cast(a))
|
|
594
546
|
end
|
|
595
|
-
concatenate(arys,axis:2)
|
|
547
|
+
concatenate(arys, axis:2)
|
|
596
548
|
end
|
|
597
549
|
|
|
598
550
|
# Stack 1-d arrays into columns of a 2-d array.
|
|
@@ -609,20 +561,20 @@ module Cumo
|
|
|
609
561
|
arys = arrays.map do |a|
|
|
610
562
|
a = cast(a)
|
|
611
563
|
case a.ndim
|
|
612
|
-
when 0; a[:new
|
|
613
|
-
when 1; a[true
|
|
564
|
+
when 0; a[:new, :new]
|
|
565
|
+
when 1; a[true, :new]
|
|
614
566
|
else; a
|
|
615
567
|
end
|
|
616
568
|
end
|
|
617
|
-
concatenate(arys,axis:1)
|
|
569
|
+
concatenate(arys, axis:1)
|
|
618
570
|
end
|
|
619
571
|
|
|
620
572
|
private
|
|
621
573
|
# Return an narray with at least two dimension.
|
|
622
574
|
def _atleast_2d(a)
|
|
623
575
|
case a.ndim
|
|
624
|
-
when 0; a[:new
|
|
625
|
-
when 1; a[:new,true]
|
|
576
|
+
when 0; a[:new, :new]
|
|
577
|
+
when 1; a[:new, true]
|
|
626
578
|
else; a
|
|
627
579
|
end
|
|
628
580
|
end
|
|
@@ -630,9 +582,9 @@ module Cumo
|
|
|
630
582
|
# Return an narray with at least three dimension.
|
|
631
583
|
def _atleast_3d(a)
|
|
632
584
|
case a.ndim
|
|
633
|
-
when 0; a[:new
|
|
634
|
-
when 1; a[:new,true
|
|
635
|
-
when 2; a[true,true
|
|
585
|
+
when 0; a[:new, :new, :new]
|
|
586
|
+
when 1; a[:new, true, :new]
|
|
587
|
+
when 2; a[true, true, :new]
|
|
636
588
|
else; a
|
|
637
589
|
end
|
|
638
590
|
end
|
|
@@ -660,7 +612,7 @@ module Cumo
|
|
|
660
612
|
# # [[1, 2, 5],
|
|
661
613
|
# # [3, 4, 6]]
|
|
662
614
|
|
|
663
|
-
def concatenate(*arrays,axis:0)
|
|
615
|
+
def concatenate(*arrays, axis:0)
|
|
664
616
|
axis = check_axis(axis)
|
|
665
617
|
self_shape = shape
|
|
666
618
|
self_shape.delete_at(axis)
|
|
@@ -674,19 +626,19 @@ module Cumo
|
|
|
674
626
|
when Array
|
|
675
627
|
a = self.class.cast(a)
|
|
676
628
|
else
|
|
677
|
-
raise TypeError,"not Cumo::NArray: #{a.inspect[0..48]}"
|
|
629
|
+
raise TypeError, "not Cumo::NArray: #{a.inspect[0..48]}"
|
|
678
630
|
end
|
|
679
631
|
if a.ndim > ndim
|
|
680
|
-
raise ShapeError,"dimension mismatch"
|
|
632
|
+
raise ShapeError, "dimension mismatch"
|
|
681
633
|
end
|
|
682
634
|
a_shape = a.shape
|
|
683
|
-
sum_size += a_shape.delete_at(axis-ndim) || 1
|
|
635
|
+
sum_size += a_shape.delete_at(axis - ndim) || 1
|
|
684
636
|
if self_shape != a_shape
|
|
685
|
-
raise ShapeError,"shape mismatch"
|
|
637
|
+
raise ShapeError, "shape mismatch"
|
|
686
638
|
end
|
|
687
639
|
a
|
|
688
640
|
end
|
|
689
|
-
self_shape.insert(axis,sum_size)
|
|
641
|
+
self_shape.insert(axis, sum_size)
|
|
690
642
|
result = self.class.zeros(*self_shape)
|
|
691
643
|
lst = shape[axis]
|
|
692
644
|
refs = [true] * ndim
|
|
@@ -694,7 +646,7 @@ module Cumo
|
|
|
694
646
|
result[*refs] = self
|
|
695
647
|
arrays.each do |a|
|
|
696
648
|
fst = lst
|
|
697
|
-
lst = fst + (a.shape[axis-ndim] || 1)
|
|
649
|
+
lst = fst + (a.shape[axis - ndim] || 1)
|
|
698
650
|
refs[axis] = fst...lst
|
|
699
651
|
result[*refs] = a
|
|
700
652
|
end
|
|
@@ -735,7 +687,7 @@ module Cumo
|
|
|
735
687
|
case indices_or_sections
|
|
736
688
|
when Integer
|
|
737
689
|
div_axis, mod_axis = size_axis.divmod(indices_or_sections)
|
|
738
|
-
refs = [true]*ndim
|
|
690
|
+
refs = [true] * ndim
|
|
739
691
|
beg_idx = 0
|
|
740
692
|
mod_axis.times.map do |i|
|
|
741
693
|
end_idx = beg_idx + div_axis + 1
|
|
@@ -743,16 +695,16 @@ module Cumo
|
|
|
743
695
|
beg_idx = end_idx
|
|
744
696
|
self[*refs]
|
|
745
697
|
end +
|
|
746
|
-
(indices_or_sections-mod_axis).times.map do |i|
|
|
698
|
+
(indices_or_sections - mod_axis).times.map do |i|
|
|
747
699
|
end_idx = beg_idx + div_axis
|
|
748
700
|
refs[axis] = beg_idx ... end_idx
|
|
749
701
|
beg_idx = end_idx
|
|
750
702
|
self[*refs]
|
|
751
703
|
end
|
|
752
704
|
when NArray
|
|
753
|
-
split(indices_or_sections.to_a,axis:axis)
|
|
705
|
+
split(indices_or_sections.to_a, axis:axis)
|
|
754
706
|
when Array
|
|
755
|
-
refs = [true]*ndim
|
|
707
|
+
refs = [true] * ndim
|
|
756
708
|
fst = 0
|
|
757
709
|
(indices_or_sections + [size_axis]).map do |lst|
|
|
758
710
|
lst = size_axis if lst > size_axis
|
|
@@ -761,7 +713,7 @@ module Cumo
|
|
|
761
713
|
self[*refs]
|
|
762
714
|
end
|
|
763
715
|
else
|
|
764
|
-
raise TypeError,"argument must be Integer or Array"
|
|
716
|
+
raise TypeError, "argument must be Integer or Array"
|
|
765
717
|
end
|
|
766
718
|
end
|
|
767
719
|
|
|
@@ -859,8 +811,8 @@ module Cumo
|
|
|
859
811
|
|
|
860
812
|
def tile(*arg)
|
|
861
813
|
arg.each do |i|
|
|
862
|
-
if !i.kind_of?(Integer) || i<1
|
|
863
|
-
raise ArgumentError,"argument should be positive integer"
|
|
814
|
+
if !i.kind_of?(Integer) || i < 1
|
|
815
|
+
raise ArgumentError, "argument should be positive integer"
|
|
864
816
|
end
|
|
865
817
|
end
|
|
866
818
|
ns = arg.size
|
|
@@ -869,26 +821,26 @@ module Cumo
|
|
|
869
821
|
new_shp = []
|
|
870
822
|
src_shp = []
|
|
871
823
|
res_shp = []
|
|
872
|
-
(nd-ns).times do
|
|
824
|
+
(nd - ns).times do
|
|
873
825
|
new_shp << 1
|
|
874
826
|
new_shp << (n = shp.shift)
|
|
875
827
|
src_shp << :new
|
|
876
828
|
src_shp << true
|
|
877
829
|
res_shp << n
|
|
878
830
|
end
|
|
879
|
-
(ns-nd).times do
|
|
831
|
+
(ns - nd).times do
|
|
880
832
|
new_shp << (m = arg.shift)
|
|
881
833
|
new_shp << 1
|
|
882
834
|
src_shp << :new
|
|
883
835
|
src_shp << :new
|
|
884
836
|
res_shp << m
|
|
885
837
|
end
|
|
886
|
-
[nd,ns].min.times do
|
|
838
|
+
[nd, ns].min.times do
|
|
887
839
|
new_shp << (m = arg.shift)
|
|
888
840
|
new_shp << (n = shp.shift)
|
|
889
841
|
src_shp << :new
|
|
890
842
|
src_shp << true
|
|
891
|
-
res_shp << n*m
|
|
843
|
+
res_shp << n * m
|
|
892
844
|
end
|
|
893
845
|
self.class.new(*new_shp).store(self[*src_shp]).reshape(*res_shp)
|
|
894
846
|
end
|
|
@@ -918,7 +870,7 @@ module Cumo
|
|
|
918
870
|
# # [3, 4],
|
|
919
871
|
# # [3, 4]]
|
|
920
872
|
|
|
921
|
-
def repeat(arg,axis:nil)
|
|
873
|
+
def repeat(arg, axis:nil)
|
|
922
874
|
case axis
|
|
923
875
|
when Integer
|
|
924
876
|
axis = check_axis(axis)
|
|
@@ -927,25 +879,25 @@ module Cumo
|
|
|
927
879
|
c = self.flatten
|
|
928
880
|
axis = 0
|
|
929
881
|
else
|
|
930
|
-
raise ArgumentError,"invalid axis"
|
|
882
|
+
raise ArgumentError, "invalid axis"
|
|
931
883
|
end
|
|
932
884
|
case arg
|
|
933
885
|
when Integer
|
|
934
|
-
if !arg.kind_of?(Integer) || arg<1
|
|
935
|
-
raise ArgumentError,"argument should be positive integer"
|
|
886
|
+
if !arg.kind_of?(Integer) || arg < 1
|
|
887
|
+
raise ArgumentError, "argument should be positive integer"
|
|
936
888
|
end
|
|
937
|
-
idx = c.shape[axis].times.map{|i| [i]*arg}.flatten
|
|
889
|
+
idx = c.shape[axis].times.map { |i| [i] * arg }.flatten
|
|
938
890
|
else
|
|
939
891
|
arg = arg.to_a
|
|
940
892
|
if arg.size != c.shape[axis]
|
|
941
|
-
raise ArgumentError,"repeat size shoud be equal to size along axis"
|
|
893
|
+
raise ArgumentError, "repeat size shoud be equal to size along axis"
|
|
942
894
|
end
|
|
943
895
|
arg.each do |i|
|
|
944
|
-
if !i.kind_of?(Integer) || i<0
|
|
945
|
-
raise ArgumentError,"argument should be non-negative integer"
|
|
896
|
+
if !i.kind_of?(Integer) || i < 0
|
|
897
|
+
raise ArgumentError, "argument should be non-negative integer"
|
|
946
898
|
end
|
|
947
899
|
end
|
|
948
|
-
idx = arg.each_with_index.map{|a,i| [i]*a}.flatten
|
|
900
|
+
idx = arg.each_with_index.map { |a, i| [i] * a }.flatten
|
|
949
901
|
end
|
|
950
902
|
ref = [true] * c.ndim
|
|
951
903
|
ref[axis] = idx
|
|
@@ -980,27 +932,27 @@ module Cumo
|
|
|
980
932
|
# # Cumo::DFloat#shape=[1,4]
|
|
981
933
|
# # [[-1, 2, 0, -2]]
|
|
982
934
|
|
|
983
|
-
def diff(n=1,axis:-1)
|
|
935
|
+
def diff(n=1, axis:-1)
|
|
984
936
|
axis = check_axis(axis)
|
|
985
937
|
if n < 0 || n >= shape[axis]
|
|
986
|
-
raise ShapeError,"n=#{n} is invalid for shape[#{axis}]=#{shape[axis]}"
|
|
938
|
+
raise ShapeError, "n=#{n} is invalid for shape[#{axis}]=#{shape[axis]}"
|
|
987
939
|
end
|
|
988
940
|
# calculate polynomial coefficient
|
|
989
|
-
c = self.class[-1,1]
|
|
941
|
+
c = self.class[-1, 1]
|
|
990
942
|
2.upto(n) do |i|
|
|
991
|
-
x = self.class.zeros(i+1)
|
|
943
|
+
x = self.class.zeros(i + 1)
|
|
992
944
|
x[0..-2] = c
|
|
993
|
-
y = self.class.zeros(i+1)
|
|
945
|
+
y = self.class.zeros(i + 1)
|
|
994
946
|
y[1..-1] = c
|
|
995
947
|
c = y - x
|
|
996
948
|
end
|
|
997
|
-
s = [true]*ndim
|
|
949
|
+
s = [true] * ndim
|
|
998
950
|
s[axis] = n..-1
|
|
999
951
|
result = self[*s].dup
|
|
1000
952
|
sum = result.inplace
|
|
1001
|
-
(n-1).downto(0) do |i|
|
|
1002
|
-
s = [true]*ndim
|
|
1003
|
-
s[axis] = i..-n-1+i
|
|
953
|
+
(n - 1).downto(0) do |i|
|
|
954
|
+
s = [true] * ndim
|
|
955
|
+
s[axis] = i..-n - 1 + i
|
|
1004
956
|
sum + self[*s] * c[i] # inplace addition
|
|
1005
957
|
end
|
|
1006
958
|
return result
|
|
@@ -1020,11 +972,11 @@ module Cumo
|
|
|
1020
972
|
raise NArray::ShapeError, "must be >= 2-dimensional array"
|
|
1021
973
|
end
|
|
1022
974
|
if contiguous?
|
|
1023
|
-
*shp,m,n = shape
|
|
1024
|
-
idx = tril_indices(k-1)
|
|
1025
|
-
reshape!(*shp,m*n)
|
|
1026
|
-
self[false,idx] = 0
|
|
1027
|
-
reshape!(*shp,m,n)
|
|
975
|
+
*shp, m, n = shape
|
|
976
|
+
idx = tril_indices(k - 1)
|
|
977
|
+
reshape!(*shp, m * n)
|
|
978
|
+
self[false, idx] = 0
|
|
979
|
+
reshape!(*shp, m, n)
|
|
1028
980
|
else
|
|
1029
981
|
store(triu(k))
|
|
1030
982
|
end
|
|
@@ -1035,15 +987,15 @@ module Cumo
|
|
|
1035
987
|
if ndim < 2
|
|
1036
988
|
raise NArray::ShapeError, "must be >= 2-dimensional array"
|
|
1037
989
|
end
|
|
1038
|
-
m,n = shape[-2..-1]
|
|
1039
|
-
NArray.triu_indices(m,n,k)
|
|
990
|
+
m, n = shape[-2..-1]
|
|
991
|
+
NArray.triu_indices(m, n, k)
|
|
1040
992
|
end
|
|
1041
993
|
|
|
1042
994
|
# Return the indices for the uppler-triangle on and above the k-th diagonal.
|
|
1043
|
-
def self.triu_indices(m,n,k=0)
|
|
1044
|
-
x = Cumo::Int64.new(m,1).seq + k
|
|
1045
|
-
y = Cumo::Int64.new(1,n).seq
|
|
1046
|
-
(x<=y).where
|
|
995
|
+
def self.triu_indices(m, n, k=0)
|
|
996
|
+
x = Cumo::Int64.new(m, 1).seq + k
|
|
997
|
+
y = Cumo::Int64.new(1, n).seq
|
|
998
|
+
(x <= y).where
|
|
1047
999
|
end
|
|
1048
1000
|
|
|
1049
1001
|
# Lower triangular matrix.
|
|
@@ -1059,11 +1011,11 @@ module Cumo
|
|
|
1059
1011
|
raise NArray::ShapeError, "must be >= 2-dimensional array"
|
|
1060
1012
|
end
|
|
1061
1013
|
if contiguous?
|
|
1062
|
-
idx = triu_indices(k+1)
|
|
1063
|
-
*shp,m,n = shape
|
|
1064
|
-
reshape!(*shp,m*n)
|
|
1065
|
-
self[false,idx] = 0
|
|
1066
|
-
reshape!(*shp,m,n)
|
|
1014
|
+
idx = triu_indices(k + 1)
|
|
1015
|
+
*shp, m, n = shape
|
|
1016
|
+
reshape!(*shp, m * n)
|
|
1017
|
+
self[false, idx] = 0
|
|
1018
|
+
reshape!(*shp, m, n)
|
|
1067
1019
|
else
|
|
1068
1020
|
store(tril(k))
|
|
1069
1021
|
end
|
|
@@ -1074,15 +1026,15 @@ module Cumo
|
|
|
1074
1026
|
if ndim < 2
|
|
1075
1027
|
raise NArray::ShapeError, "must be >= 2-dimensional array"
|
|
1076
1028
|
end
|
|
1077
|
-
m,n = shape[-2..-1]
|
|
1078
|
-
NArray.tril_indices(m,n,k)
|
|
1029
|
+
m, n = shape[-2..-1]
|
|
1030
|
+
NArray.tril_indices(m, n, k)
|
|
1079
1031
|
end
|
|
1080
1032
|
|
|
1081
1033
|
# Return the indices for the lower-triangle on and below the k-th diagonal.
|
|
1082
|
-
def self.tril_indices(m,n,k=0)
|
|
1083
|
-
x = Cumo::Int64.new(m,1).seq + k
|
|
1084
|
-
y = Cumo::Int64.new(1,n).seq
|
|
1085
|
-
(x>=y).where
|
|
1034
|
+
def self.tril_indices(m, n, k=0)
|
|
1035
|
+
x = Cumo::Int64.new(m, 1).seq + k
|
|
1036
|
+
y = Cumo::Int64.new(1, n).seq
|
|
1037
|
+
(x >= y).where
|
|
1086
1038
|
end
|
|
1087
1039
|
|
|
1088
1040
|
# Return the k-th diagonal indices.
|
|
@@ -1090,22 +1042,22 @@ module Cumo
|
|
|
1090
1042
|
if ndim < 2
|
|
1091
1043
|
raise NArray::ShapeError, "must be >= 2-dimensional array"
|
|
1092
1044
|
end
|
|
1093
|
-
m,n = shape[-2..-1]
|
|
1094
|
-
NArray.diag_indices(m,n,k)
|
|
1045
|
+
m, n = shape[-2..-1]
|
|
1046
|
+
NArray.diag_indices(m, n, k)
|
|
1095
1047
|
end
|
|
1096
1048
|
|
|
1097
1049
|
# Return the k-th diagonal indices.
|
|
1098
|
-
def self.diag_indices(m,n,k=0)
|
|
1099
|
-
x = Cumo::Int64.new(m,1).seq + k
|
|
1100
|
-
y = Cumo::Int64.new(1,n).seq
|
|
1050
|
+
def self.diag_indices(m, n, k=0)
|
|
1051
|
+
x = Cumo::Int64.new(m, 1).seq + k
|
|
1052
|
+
y = Cumo::Int64.new(1, n).seq
|
|
1101
1053
|
(x.eq y).where
|
|
1102
1054
|
end
|
|
1103
1055
|
|
|
1104
1056
|
# Return a matrix whose diagonal is constructed by self along the last axis.
|
|
1105
1057
|
def diag(k=0)
|
|
1106
|
-
*shp,n = shape
|
|
1058
|
+
*shp, n = shape
|
|
1107
1059
|
n += k.abs
|
|
1108
|
-
a = self.class.zeros(*shp,n,n)
|
|
1060
|
+
a = self.class.zeros(*shp, n, n)
|
|
1109
1061
|
a.diagonal(k).store(self)
|
|
1110
1062
|
a
|
|
1111
1063
|
end
|
|
@@ -1120,8 +1072,8 @@ module Cumo
|
|
|
1120
1072
|
# @param axis [Array] (optional, default=[-2,-1]) diagonal axis
|
|
1121
1073
|
# @param nan [Bool] (optional, default=false) nan-aware algorithm, i.e., if true then it ignores nan.
|
|
1122
1074
|
|
|
1123
|
-
def trace(offset=nil,axis=nil,nan:false)
|
|
1124
|
-
diagonal(offset,axis).sum(nan:nan,axis:-1)
|
|
1075
|
+
def trace(offset=nil, axis=nil, nan:false)
|
|
1076
|
+
diagonal(offset, axis).sum(nan:nan, axis:-1)
|
|
1125
1077
|
end
|
|
1126
1078
|
|
|
1127
1079
|
|
|
@@ -1164,20 +1116,20 @@ module Cumo
|
|
|
1164
1116
|
when 0
|
|
1165
1117
|
b.mulsum(self, axis:-2)
|
|
1166
1118
|
when 1
|
|
1167
|
-
self[true
|
|
1119
|
+
self[true, :new].mulsum(b, axis:-2)
|
|
1168
1120
|
else
|
|
1169
1121
|
unless @@warn_slow_dot
|
|
1170
1122
|
nx = 200
|
|
1171
1123
|
ns = 200000
|
|
1172
|
-
am,an = shape[-2..-1]
|
|
1173
|
-
bm,bn = b.shape[-2..-1]
|
|
1124
|
+
am, an = shape[-2..-1]
|
|
1125
|
+
bm, bn = b.shape[-2..-1]
|
|
1174
1126
|
if am > nx && an > nx && bm > nx && bn > nx &&
|
|
1175
1127
|
size > ns && b.size > ns
|
|
1176
1128
|
@@warn_slow_dot = true
|
|
1177
1129
|
warn "\nwarning: matrix dot for #{t} is slow. Consider SFloat, DFloat, SComplex, or DComplex to use cuBLAS.\n\n"
|
|
1178
1130
|
end
|
|
1179
1131
|
end
|
|
1180
|
-
self[false
|
|
1132
|
+
self[false, :new].mulsum(b[false, :new, true, true], axis:-2)
|
|
1181
1133
|
end
|
|
1182
1134
|
end
|
|
1183
1135
|
end
|
|
@@ -1217,17 +1169,17 @@ module Cumo
|
|
|
1217
1169
|
def outer(b, axis:nil)
|
|
1218
1170
|
b = NArray.cast(b)
|
|
1219
1171
|
if axis.nil?
|
|
1220
|
-
self[false
|
|
1172
|
+
self[false, :new] * ((b.ndim == 0) ? b : b[false, :new, true])
|
|
1221
1173
|
else
|
|
1222
|
-
md,nd = [ndim,b.ndim].minmax
|
|
1174
|
+
md, nd = [ndim, b.ndim].minmax
|
|
1223
1175
|
axis = check_axis(axis) - nd
|
|
1224
1176
|
if axis < -md
|
|
1225
|
-
raise ArgumentError,"axis=#{axis} is out of range"
|
|
1177
|
+
raise ArgumentError, "axis=#{axis} is out of range"
|
|
1226
1178
|
end
|
|
1227
|
-
adim = [true]*ndim
|
|
1228
|
-
adim[axis+ndim+1,0] = :new
|
|
1229
|
-
bdim = [true]*b.ndim
|
|
1230
|
-
bdim[axis+b.ndim,0] = :new
|
|
1179
|
+
adim = [true] * ndim
|
|
1180
|
+
adim[axis + ndim + 1, 0] = :new
|
|
1181
|
+
bdim = [true] * b.ndim
|
|
1182
|
+
bdim[axis + b.ndim, 0] = :new
|
|
1231
1183
|
self[*adim] * b[*bdim]
|
|
1232
1184
|
end
|
|
1233
1185
|
end
|
|
@@ -1259,9 +1211,9 @@ module Cumo
|
|
|
1259
1211
|
ndb = b.ndim
|
|
1260
1212
|
shpa = shape
|
|
1261
1213
|
shpb = b.shape
|
|
1262
|
-
adim = [:new]*(2*[ndb-nda,0].max) + [true
|
|
1263
|
-
bdim = [:new]*(2*[nda-ndb,0].max) + [:new,true]*ndb
|
|
1264
|
-
shpr = (-[nda,ndb].max..-1).map{|i| (shpa[i]||1) * (shpb[i]||1)}
|
|
1214
|
+
adim = [:new] * (2 * [ndb - nda, 0].max) + [true, :new] * nda
|
|
1215
|
+
bdim = [:new] * (2 * [nda - ndb, 0].max) + [:new, true] * ndb
|
|
1216
|
+
shpr = (-[nda, ndb].max..-1).map { |i| (shpa[i] || 1) * (shpb[i] || 1) }
|
|
1265
1217
|
(self[*adim] * b[*bdim]).reshape(*shpr)
|
|
1266
1218
|
end
|
|
1267
1219
|
|
|
@@ -1269,7 +1221,7 @@ module Cumo
|
|
|
1269
1221
|
# under construction
|
|
1270
1222
|
def cov(y=nil, ddof:1, fweights:nil, aweights:nil)
|
|
1271
1223
|
if y
|
|
1272
|
-
m = NArray.vstack([self,y])
|
|
1224
|
+
m = NArray.vstack([self, y])
|
|
1273
1225
|
else
|
|
1274
1226
|
m = self
|
|
1275
1227
|
end
|
|
@@ -1280,7 +1232,7 @@ module Cumo
|
|
|
1280
1232
|
end
|
|
1281
1233
|
if aweights
|
|
1282
1234
|
a = aweights
|
|
1283
|
-
w = w ? w*a : a
|
|
1235
|
+
w = w ? w * a : a
|
|
1284
1236
|
end
|
|
1285
1237
|
if w
|
|
1286
1238
|
w_sum = w.sum(axis:-1, keepdims:true)
|
|
@@ -1289,23 +1241,23 @@ module Cumo
|
|
|
1289
1241
|
elsif aweights.nil?
|
|
1290
1242
|
fact = w_sum - ddof
|
|
1291
1243
|
else
|
|
1292
|
-
wa_sum = (w*a).sum(axis:-1, keepdims:true)
|
|
1244
|
+
wa_sum = (w * a).sum(axis:-1, keepdims:true)
|
|
1293
1245
|
fact = w_sum - ddof * wa_sum / w_sum
|
|
1294
1246
|
end
|
|
1295
1247
|
if (fact <= 0).any?
|
|
1296
|
-
raise StandardError,"Degrees of freedom <= 0 for slice"
|
|
1248
|
+
raise StandardError, "Degrees of freedom <= 0 for slice"
|
|
1297
1249
|
end
|
|
1298
1250
|
else
|
|
1299
1251
|
fact = m.shape[-1] - ddof
|
|
1300
1252
|
end
|
|
1301
1253
|
if w
|
|
1302
|
-
m -= (m*w).sum(axis:-1, keepdims:true) / w_sum
|
|
1303
|
-
mw = m*w
|
|
1254
|
+
m -= (m * w).sum(axis:-1, keepdims:true) / w_sum
|
|
1255
|
+
mw = m * w
|
|
1304
1256
|
else
|
|
1305
1257
|
m -= m.mean(axis:-1, keepdims:true)
|
|
1306
1258
|
mw = m
|
|
1307
1259
|
end
|
|
1308
|
-
mt = (m.ndim < 2) ? m : m.swapaxes(-2
|
|
1260
|
+
mt = (m.ndim < 2) ? m : m.swapaxes(-2, -1)
|
|
1309
1261
|
mw.dot(mt.conj) / fact
|
|
1310
1262
|
end
|
|
1311
1263
|
|
|
@@ -1313,15 +1265,15 @@ module Cumo
|
|
|
1313
1265
|
|
|
1314
1266
|
# @!visibility private
|
|
1315
1267
|
def check_axis(axis)
|
|
1316
|
-
unless Integer===axis
|
|
1317
|
-
raise ArgumentError,"axis=#{axis} must be Integer"
|
|
1268
|
+
unless Integer === axis
|
|
1269
|
+
raise ArgumentError, "axis=#{axis} must be Integer"
|
|
1318
1270
|
end
|
|
1319
1271
|
a = axis
|
|
1320
1272
|
if a < 0
|
|
1321
1273
|
a += ndim
|
|
1322
1274
|
end
|
|
1323
1275
|
if a < 0 || a >= ndim
|
|
1324
|
-
raise ArgumentError,"axis=#{axis} is invalid"
|
|
1276
|
+
raise ArgumentError, "axis=#{axis} is invalid"
|
|
1325
1277
|
end
|
|
1326
1278
|
a
|
|
1327
1279
|
end
|