cumo 0.4.3 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (106) hide show
  1. checksums.yaml +4 -4
  2. data/.gitignore +3 -0
  3. data/.rubocop.yml +15 -0
  4. data/.rubocop_todo.yml +1272 -0
  5. data/3rd_party/mkmf-cu/Gemfile +2 -0
  6. data/3rd_party/mkmf-cu/Rakefile +2 -1
  7. data/3rd_party/mkmf-cu/bin/mkmf-cu-nvcc +2 -0
  8. data/3rd_party/mkmf-cu/lib/mkmf-cu/cli.rb +36 -7
  9. data/3rd_party/mkmf-cu/lib/mkmf-cu/nvcc.rb +51 -45
  10. data/3rd_party/mkmf-cu/lib/mkmf-cu.rb +2 -0
  11. data/3rd_party/mkmf-cu/mkmf-cu.gemspec +3 -1
  12. data/3rd_party/mkmf-cu/test/test_mkmf-cu.rb +5 -3
  13. data/CHANGELOG.md +69 -0
  14. data/Gemfile +6 -1
  15. data/README.md +2 -10
  16. data/Rakefile +8 -11
  17. data/bench/broadcast_fp32.rb +28 -26
  18. data/bench/cumo_bench.rb +18 -16
  19. data/bench/numo_bench.rb +18 -16
  20. data/bench/reduction_fp32.rb +14 -12
  21. data/bin/console +1 -0
  22. data/cumo.gemspec +5 -8
  23. data/ext/cumo/cuda/cudnn.c +2 -2
  24. data/ext/cumo/cumo.c +7 -3
  25. data/ext/cumo/depend.erb +15 -13
  26. data/ext/cumo/extconf.rb +32 -46
  27. data/ext/cumo/include/cumo/cuda/cudnn.h +3 -1
  28. data/ext/cumo/include/cumo/intern.h +1 -0
  29. data/ext/cumo/include/cumo/narray.h +13 -1
  30. data/ext/cumo/include/cumo/template.h +2 -4
  31. data/ext/cumo/include/cumo/types/complex_macro.h +1 -1
  32. data/ext/cumo/include/cumo/types/float_macro.h +2 -2
  33. data/ext/cumo/include/cumo/types/xint_macro.h +3 -2
  34. data/ext/cumo/include/cumo.h +2 -2
  35. data/ext/cumo/narray/array.c +3 -3
  36. data/ext/cumo/narray/data.c +23 -2
  37. data/ext/cumo/narray/gen/cogen.rb +8 -7
  38. data/ext/cumo/narray/gen/cogen_kernel.rb +8 -7
  39. data/ext/cumo/narray/gen/def/bit.rb +3 -1
  40. data/ext/cumo/narray/gen/def/dcomplex.rb +2 -0
  41. data/ext/cumo/narray/gen/def/dfloat.rb +2 -0
  42. data/ext/cumo/narray/gen/def/int16.rb +2 -0
  43. data/ext/cumo/narray/gen/def/int32.rb +2 -0
  44. data/ext/cumo/narray/gen/def/int64.rb +2 -0
  45. data/ext/cumo/narray/gen/def/int8.rb +2 -0
  46. data/ext/cumo/narray/gen/def/robject.rb +2 -0
  47. data/ext/cumo/narray/gen/def/scomplex.rb +2 -0
  48. data/ext/cumo/narray/gen/def/sfloat.rb +2 -0
  49. data/ext/cumo/narray/gen/def/uint16.rb +2 -0
  50. data/ext/cumo/narray/gen/def/uint32.rb +2 -0
  51. data/ext/cumo/narray/gen/def/uint64.rb +2 -0
  52. data/ext/cumo/narray/gen/def/uint8.rb +2 -0
  53. data/ext/cumo/narray/gen/erbln.rb +9 -7
  54. data/ext/cumo/narray/gen/erbpp2.rb +26 -24
  55. data/ext/cumo/narray/gen/narray_def.rb +13 -11
  56. data/ext/cumo/narray/gen/spec.rb +58 -55
  57. data/ext/cumo/narray/gen/tmpl/alloc_func.c +1 -1
  58. data/ext/cumo/narray/gen/tmpl/at.c +34 -0
  59. data/ext/cumo/narray/gen/tmpl/batch_norm.c +1 -1
  60. data/ext/cumo/narray/gen/tmpl/batch_norm_backward.c +2 -2
  61. data/ext/cumo/narray/gen/tmpl/conv.c +1 -1
  62. data/ext/cumo/narray/gen/tmpl/conv_grad_w.c +3 -1
  63. data/ext/cumo/narray/gen/tmpl/conv_transpose.c +1 -1
  64. data/ext/cumo/narray/gen/tmpl/fixed_batch_norm.c +1 -1
  65. data/ext/cumo/narray/gen/tmpl/init_class.c +1 -0
  66. data/ext/cumo/narray/gen/tmpl/pooling_backward.c +1 -1
  67. data/ext/cumo/narray/gen/tmpl/pooling_forward.c +1 -1
  68. data/ext/cumo/narray/gen/tmpl/qsort.c +1 -5
  69. data/ext/cumo/narray/gen/tmpl/sort.c +1 -1
  70. data/ext/cumo/narray/gen/tmpl_bit/binary.c +42 -14
  71. data/ext/cumo/narray/gen/tmpl_bit/bit_count.c +5 -0
  72. data/ext/cumo/narray/gen/tmpl_bit/bit_reduce.c +5 -0
  73. data/ext/cumo/narray/gen/tmpl_bit/mask.c +27 -7
  74. data/ext/cumo/narray/gen/tmpl_bit/store_bit.c +21 -7
  75. data/ext/cumo/narray/gen/tmpl_bit/unary.c +21 -7
  76. data/ext/cumo/narray/index.c +243 -39
  77. data/ext/cumo/narray/index_kernel.cu +84 -0
  78. data/ext/cumo/narray/narray.c +38 -1
  79. data/ext/cumo/narray/ndloop.c +1 -1
  80. data/ext/cumo/narray/struct.c +1 -1
  81. data/lib/cumo/cuda/compile_error.rb +1 -1
  82. data/lib/cumo/cuda/compiler.rb +23 -22
  83. data/lib/cumo/cuda/cudnn.rb +1 -1
  84. data/lib/cumo/cuda/device.rb +1 -1
  85. data/lib/cumo/cuda/link_state.rb +2 -2
  86. data/lib/cumo/cuda/module.rb +1 -2
  87. data/lib/cumo/cuda/nvrtc_program.rb +3 -2
  88. data/lib/cumo/cuda.rb +2 -0
  89. data/lib/cumo/linalg.rb +2 -0
  90. data/lib/cumo/narray/extra.rb +137 -185
  91. data/lib/cumo/narray.rb +2 -0
  92. data/lib/cumo.rb +3 -1
  93. data/test/bit_test.rb +157 -0
  94. data/test/cuda/compiler_test.rb +69 -0
  95. data/test/cuda/device_test.rb +30 -0
  96. data/test/cuda/memory_pool_test.rb +45 -0
  97. data/test/cuda/nvrtc_test.rb +51 -0
  98. data/test/cuda/runtime_test.rb +28 -0
  99. data/test/cudnn_test.rb +498 -0
  100. data/test/cumo_test.rb +27 -0
  101. data/test/narray_test.rb +745 -0
  102. data/test/ractor_test.rb +52 -0
  103. data/test/test_helper.rb +31 -0
  104. metadata +31 -54
  105. data/.travis.yml +0 -5
  106. data/numo-narray-version +0 -1
@@ -1,3 +1,5 @@
1
+ # frozen_string_literal: true
2
+
1
3
  module Cumo
2
4
  class NArray
3
5
 
@@ -23,12 +25,12 @@ module Cumo
23
25
 
24
26
  # Convert angles from radians to degrees.
25
27
  def rad2deg
26
- self * (180/Math::PI)
28
+ self * (180 / Math::PI)
27
29
  end
28
30
 
29
31
  # Convert angles from degrees to radians.
30
32
  def deg2rad
31
- self * (Math::PI/180)
33
+ self * (Math::PI / 180)
32
34
  end
33
35
 
34
36
  # Flip each row in the left/right direction.
@@ -43,56 +45,6 @@ module Cumo
43
45
  reverse(0)
44
46
  end
45
47
 
46
- # Multi-dimensional array indexing.
47
- # Same as [] for one-dimensional NArray.
48
- # Similar to numpy's tuple indexing, i.e., `a[[1,2,..],[3,4,..]]`
49
- # (This method will be rewritten in C)
50
- # @return [Cumo::NArray] one-dimensional view of self.
51
- # @example
52
- # p x = Cumo::DFloat.new(3,3,3).seq
53
- # # Cumo::DFloat#shape=[3,3,3]
54
- # # [[[0, 1, 2],
55
- # # [3, 4, 5],
56
- # # [6, 7, 8]],
57
- # # [[9, 10, 11],
58
- # # [12, 13, 14],
59
- # # [15, 16, 17]],
60
- # # [[18, 19, 20],
61
- # # [21, 22, 23],
62
- # # [24, 25, 26]]]
63
- #
64
- # p x.at([0,1,2],[0,1,2],[-1,-2,-3])
65
- # # Cumo::DFloat(view)#shape=[3]
66
- # # [2, 13, 24]
67
- def at(*indices)
68
- if indices.size != ndim
69
- raise DimensionError, "argument length does not match dimension size"
70
- end
71
- idx = nil
72
- stride = 1
73
- (indices.size-1).downto(0) do |i|
74
- ix = Int64.cast(indices[i])
75
- if ix.ndim != 1
76
- raise DimensionError, "index array is not one-dimensional"
77
- end
78
- ix[ix < 0] += shape[i]
79
- if ((ix < 0) & (ix >= shape[i])).any?
80
- raise IndexError, "index array is out of range"
81
- end
82
- if idx
83
- if idx.size != ix.size
84
- raise ShapeError, "index array sizes mismatch"
85
- end
86
- idx += ix * stride
87
- stride *= shape[i]
88
- else
89
- idx = ix
90
- stride = shape[i]
91
- end
92
- end
93
- self[idx]
94
- end
95
-
96
48
  # Rotate in the plane specified by axes.
97
49
  # @example
98
50
  # p a = Cumo::Int32.new(2,2).seq
@@ -114,7 +66,7 @@ module Cumo
114
66
  # # Cumo::Int32(view)#shape=[2,2]
115
67
  # # [[2, 0],
116
68
  # # [3, 1]]
117
- def rot90(k=1,axes=[0,1])
69
+ def rot90(k=1, axes=[0, 1])
118
70
  case k % 4
119
71
  when 0
120
72
  view
@@ -128,7 +80,7 @@ module Cumo
128
80
  end
129
81
 
130
82
  def to_i
131
- if size==1
83
+ if size == 1
132
84
  self.extract_cpu.to_i
133
85
  else
134
86
  # convert to Int?
@@ -137,7 +89,7 @@ module Cumo
137
89
  end
138
90
 
139
91
  def to_f
140
- if size==1
92
+ if size == 1
141
93
  self.extract_cpu.to_f
142
94
  else
143
95
  # convert to DFloat?
@@ -146,7 +98,7 @@ module Cumo
146
98
  end
147
99
 
148
100
  def to_c
149
- if size==1
101
+ if size == 1
150
102
  Complex(self.extract_cpu)
151
103
  else
152
104
  # convert to DComplex?
@@ -163,7 +115,7 @@ module Cumo
163
115
  case a
164
116
  when NArray
165
117
  (a.ndim == 0) ? a[:new] : a
166
- when Numeric,Range
118
+ when Numeric, Range
167
119
  self[a]
168
120
  else
169
121
  cast(a)
@@ -201,7 +153,7 @@ module Cumo
201
153
  end
202
154
  a << b if !b.empty?
203
155
  end
204
- if a.size==1
156
+ if a.size == 1
205
157
  self.cast(a[0])
206
158
  else
207
159
  self.cast(a)
@@ -237,18 +189,18 @@ module Cumo
237
189
 
238
190
  def each_over_axis(axis=0)
239
191
  unless block_given?
240
- return to_enum(:each_over_axis,axis)
192
+ return to_enum(:each_over_axis, axis)
241
193
  end
242
194
  if ndim == 0
243
195
  if axis != 0
244
- raise ArgumentError,"axis=#{axis} is invalid"
196
+ raise ArgumentError, "axis=#{axis} is invalid"
245
197
  end
246
198
  niter = 1
247
199
  else
248
200
  axis = check_axis(axis)
249
201
  niter = shape[axis]
250
202
  end
251
- idx = [true]*ndim
203
+ idx = [true] * ndim
252
204
  niter.times do |i|
253
205
  idx[axis] = i
254
206
  yield(self[*idx])
@@ -275,15 +227,15 @@ module Cumo
275
227
  # p a.append([7, 8, 9], axis:0)
276
228
  # # in `append': dimension mismatch (Cumo::NArray::DimensionError)
277
229
 
278
- def append(other,axis:nil)
230
+ def append(other, axis:nil)
279
231
  other = self.class.cast(other)
280
232
  if axis
281
233
  if ndim != other.ndim
282
234
  raise DimensionError, "dimension mismatch"
283
235
  end
284
- return concatenate(other,axis:axis)
236
+ return concatenate(other, axis:axis)
285
237
  else
286
- a = self.class.zeros(size+other.size)
238
+ a = self.class.zeros(size + other.size)
287
239
  a[0...size] = self[true]
288
240
  a[size..-1] = other[true]
289
241
  return a
@@ -310,11 +262,11 @@ module Cumo
310
262
  # # Cumo::DFloat(view)#shape=[9]
311
263
  # # [1, 3, 5, 7, 8, 9, 10, 11, 12]
312
264
 
313
- def delete(indice,axis=nil)
265
+ def delete(indice, axis=nil)
314
266
  if axis
315
267
  bit = Bit.ones(shape[axis])
316
268
  bit[indice] = 0
317
- idx = [true]*ndim
269
+ idx = [true] * ndim
318
270
  idx[axis] = bit.where
319
271
  return self[*idx].copy
320
272
  else
@@ -395,14 +347,14 @@ module Cumo
395
347
  # # [[0, 999, 1, 2, 999, 3],
396
348
  # # [4, 999, 5, 6, 999, 7]]
397
349
 
398
- def insert(indice,values,axis:nil)
350
+ def insert(indice, values, axis:nil)
399
351
  if axis
400
352
  values = self.class.asarray(values)
401
353
  nd = values.ndim
402
- midx = [:new]*(ndim-nd) + [true]*nd
354
+ midx = [:new] * (ndim - nd) + [true] * nd
403
355
  case indice
404
356
  when Numeric
405
- midx[-nd-1] = true
357
+ midx[-nd - 1] = true
406
358
  midx[axis] = :new
407
359
  end
408
360
  values = values[*midx]
@@ -412,27 +364,27 @@ module Cumo
412
364
  idx = Int64.asarray(indice)
413
365
  nidx = idx.size
414
366
  if nidx == 1
415
- nidx = values.shape[axis||0]
367
+ nidx = values.shape[axis || 0]
416
368
  idx = idx + Int64.new(nidx).seq
417
369
  else
418
370
  sidx = idx.sort_index
419
371
  idx[sidx] += Int64.new(nidx).seq
420
372
  end
421
373
  if axis
422
- bit = Bit.ones(shape[axis]+nidx)
374
+ bit = Bit.ones(shape[axis] + nidx)
423
375
  bit[idx] = 0
424
376
  new_shape = shape
425
377
  new_shape[axis] += nidx
426
378
  a = self.class.zeros(new_shape)
427
- mdidx = [true]*ndim
379
+ mdidx = [true] * ndim
428
380
  mdidx[axis] = bit.where
429
381
  a[*mdidx] = self
430
382
  mdidx[axis] = idx
431
383
  a[*mdidx] = values
432
384
  else
433
- bit = Bit.ones(size+nidx)
385
+ bit = Bit.ones(size + nidx)
434
386
  bit[idx] = 0
435
- a = self.class.zeros(size+nidx)
387
+ a = self.class.zeros(size + nidx)
436
388
  a[bit.where] = self.flatten
437
389
  a[idx] = values
438
390
  end
@@ -461,8 +413,8 @@ module Cumo
461
413
  # # [[1, 2, 5],
462
414
  # # [3, 4, 6]]
463
415
 
464
- def concatenate(arrays,axis:0)
465
- klass = (self==NArray) ? NArray.array_type(arrays) : self
416
+ def concatenate(arrays, axis:0)
417
+ klass = (self == NArray) ? NArray.array_type(arrays) : self
466
418
  nd = 0
467
419
  arrays = arrays.map do |a|
468
420
  case a
@@ -473,7 +425,7 @@ module Cumo
473
425
  when Array
474
426
  a = klass.cast(a)
475
427
  else
476
- raise TypeError,"not Cumo::NArray: #{a.inspect[0..48]}"
428
+ raise TypeError, "not Cumo::NArray: #{a.inspect[0..48]}"
477
429
  end
478
430
  if a.ndim > nd
479
431
  nd = a.ndim
@@ -484,31 +436,31 @@ module Cumo
484
436
  axis += nd
485
437
  end
486
438
  if axis < 0 || axis >= nd
487
- raise ArgumentError,"axis is out of range"
439
+ raise ArgumentError, "axis is out of range"
488
440
  end
489
441
  new_shape = nil
490
442
  sum_size = 0
491
443
  arrays.each do |a|
492
444
  a_shape = a.shape
493
445
  if nd != a_shape.size
494
- a_shape = [1]*(nd-a_shape.size) + a_shape
446
+ a_shape = [1] * (nd - a_shape.size) + a_shape
495
447
  end
496
448
  sum_size += a_shape.delete_at(axis)
497
449
  if new_shape
498
450
  if new_shape != a_shape
499
- raise ShapeError,"shape mismatch"
451
+ raise ShapeError, "shape mismatch"
500
452
  end
501
453
  else
502
454
  new_shape = a_shape
503
455
  end
504
456
  end
505
- new_shape.insert(axis,sum_size)
457
+ new_shape.insert(axis, sum_size)
506
458
  result = klass.zeros(*new_shape)
507
459
  lst = 0
508
460
  refs = [true] * nd
509
461
  arrays.each do |a|
510
462
  fst = lst
511
- lst = fst + (a.shape[axis-nd]||1)
463
+ lst = fst + (a.shape[axis - nd] || 1)
512
464
  refs[axis] = fst...lst
513
465
  result[*refs] = a
514
466
  end
@@ -539,7 +491,7 @@ module Cumo
539
491
  arys = arrays.map do |a|
540
492
  _atleast_2d(cast(a))
541
493
  end
542
- concatenate(arys,axis:0)
494
+ concatenate(arys, axis:0)
543
495
  end
544
496
 
545
497
  # Stack arrays horizontally (column wise).
@@ -559,7 +511,7 @@ module Cumo
559
511
  # # [3, 4]]
560
512
 
561
513
  def hstack(arrays)
562
- klass = (self==NArray) ? NArray.array_type(arrays) : self
514
+ klass = (self == NArray) ? NArray.array_type(arrays) : self
563
515
  nd = 0
564
516
  arys = arrays.map do |a|
565
517
  a = klass.cast(a)
@@ -567,7 +519,7 @@ module Cumo
567
519
  a
568
520
  end
569
521
  dim = (nd >= 2) ? 1 : 0
570
- concatenate(arys,axis:dim)
522
+ concatenate(arys, axis:dim)
571
523
  end
572
524
 
573
525
  # Stack arrays in depth wise (along third axis).
@@ -592,7 +544,7 @@ module Cumo
592
544
  arys = arrays.map do |a|
593
545
  _atleast_3d(cast(a))
594
546
  end
595
- concatenate(arys,axis:2)
547
+ concatenate(arys, axis:2)
596
548
  end
597
549
 
598
550
  # Stack 1-d arrays into columns of a 2-d array.
@@ -609,20 +561,20 @@ module Cumo
609
561
  arys = arrays.map do |a|
610
562
  a = cast(a)
611
563
  case a.ndim
612
- when 0; a[:new,:new]
613
- when 1; a[true,:new]
564
+ when 0; a[:new, :new]
565
+ when 1; a[true, :new]
614
566
  else; a
615
567
  end
616
568
  end
617
- concatenate(arys,axis:1)
569
+ concatenate(arys, axis:1)
618
570
  end
619
571
 
620
572
  private
621
573
  # Return an narray with at least two dimension.
622
574
  def _atleast_2d(a)
623
575
  case a.ndim
624
- when 0; a[:new,:new]
625
- when 1; a[:new,true]
576
+ when 0; a[:new, :new]
577
+ when 1; a[:new, true]
626
578
  else; a
627
579
  end
628
580
  end
@@ -630,9 +582,9 @@ module Cumo
630
582
  # Return an narray with at least three dimension.
631
583
  def _atleast_3d(a)
632
584
  case a.ndim
633
- when 0; a[:new,:new,:new]
634
- when 1; a[:new,true,:new]
635
- when 2; a[true,true,:new]
585
+ when 0; a[:new, :new, :new]
586
+ when 1; a[:new, true, :new]
587
+ when 2; a[true, true, :new]
636
588
  else; a
637
589
  end
638
590
  end
@@ -660,7 +612,7 @@ module Cumo
660
612
  # # [[1, 2, 5],
661
613
  # # [3, 4, 6]]
662
614
 
663
- def concatenate(*arrays,axis:0)
615
+ def concatenate(*arrays, axis:0)
664
616
  axis = check_axis(axis)
665
617
  self_shape = shape
666
618
  self_shape.delete_at(axis)
@@ -674,19 +626,19 @@ module Cumo
674
626
  when Array
675
627
  a = self.class.cast(a)
676
628
  else
677
- raise TypeError,"not Cumo::NArray: #{a.inspect[0..48]}"
629
+ raise TypeError, "not Cumo::NArray: #{a.inspect[0..48]}"
678
630
  end
679
631
  if a.ndim > ndim
680
- raise ShapeError,"dimension mismatch"
632
+ raise ShapeError, "dimension mismatch"
681
633
  end
682
634
  a_shape = a.shape
683
- sum_size += a_shape.delete_at(axis-ndim) || 1
635
+ sum_size += a_shape.delete_at(axis - ndim) || 1
684
636
  if self_shape != a_shape
685
- raise ShapeError,"shape mismatch"
637
+ raise ShapeError, "shape mismatch"
686
638
  end
687
639
  a
688
640
  end
689
- self_shape.insert(axis,sum_size)
641
+ self_shape.insert(axis, sum_size)
690
642
  result = self.class.zeros(*self_shape)
691
643
  lst = shape[axis]
692
644
  refs = [true] * ndim
@@ -694,7 +646,7 @@ module Cumo
694
646
  result[*refs] = self
695
647
  arrays.each do |a|
696
648
  fst = lst
697
- lst = fst + (a.shape[axis-ndim] || 1)
649
+ lst = fst + (a.shape[axis - ndim] || 1)
698
650
  refs[axis] = fst...lst
699
651
  result[*refs] = a
700
652
  end
@@ -735,7 +687,7 @@ module Cumo
735
687
  case indices_or_sections
736
688
  when Integer
737
689
  div_axis, mod_axis = size_axis.divmod(indices_or_sections)
738
- refs = [true]*ndim
690
+ refs = [true] * ndim
739
691
  beg_idx = 0
740
692
  mod_axis.times.map do |i|
741
693
  end_idx = beg_idx + div_axis + 1
@@ -743,16 +695,16 @@ module Cumo
743
695
  beg_idx = end_idx
744
696
  self[*refs]
745
697
  end +
746
- (indices_or_sections-mod_axis).times.map do |i|
698
+ (indices_or_sections - mod_axis).times.map do |i|
747
699
  end_idx = beg_idx + div_axis
748
700
  refs[axis] = beg_idx ... end_idx
749
701
  beg_idx = end_idx
750
702
  self[*refs]
751
703
  end
752
704
  when NArray
753
- split(indices_or_sections.to_a,axis:axis)
705
+ split(indices_or_sections.to_a, axis:axis)
754
706
  when Array
755
- refs = [true]*ndim
707
+ refs = [true] * ndim
756
708
  fst = 0
757
709
  (indices_or_sections + [size_axis]).map do |lst|
758
710
  lst = size_axis if lst > size_axis
@@ -761,7 +713,7 @@ module Cumo
761
713
  self[*refs]
762
714
  end
763
715
  else
764
- raise TypeError,"argument must be Integer or Array"
716
+ raise TypeError, "argument must be Integer or Array"
765
717
  end
766
718
  end
767
719
 
@@ -859,8 +811,8 @@ module Cumo
859
811
 
860
812
  def tile(*arg)
861
813
  arg.each do |i|
862
- if !i.kind_of?(Integer) || i<1
863
- raise ArgumentError,"argument should be positive integer"
814
+ if !i.kind_of?(Integer) || i < 1
815
+ raise ArgumentError, "argument should be positive integer"
864
816
  end
865
817
  end
866
818
  ns = arg.size
@@ -869,26 +821,26 @@ module Cumo
869
821
  new_shp = []
870
822
  src_shp = []
871
823
  res_shp = []
872
- (nd-ns).times do
824
+ (nd - ns).times do
873
825
  new_shp << 1
874
826
  new_shp << (n = shp.shift)
875
827
  src_shp << :new
876
828
  src_shp << true
877
829
  res_shp << n
878
830
  end
879
- (ns-nd).times do
831
+ (ns - nd).times do
880
832
  new_shp << (m = arg.shift)
881
833
  new_shp << 1
882
834
  src_shp << :new
883
835
  src_shp << :new
884
836
  res_shp << m
885
837
  end
886
- [nd,ns].min.times do
838
+ [nd, ns].min.times do
887
839
  new_shp << (m = arg.shift)
888
840
  new_shp << (n = shp.shift)
889
841
  src_shp << :new
890
842
  src_shp << true
891
- res_shp << n*m
843
+ res_shp << n * m
892
844
  end
893
845
  self.class.new(*new_shp).store(self[*src_shp]).reshape(*res_shp)
894
846
  end
@@ -918,7 +870,7 @@ module Cumo
918
870
  # # [3, 4],
919
871
  # # [3, 4]]
920
872
 
921
- def repeat(arg,axis:nil)
873
+ def repeat(arg, axis:nil)
922
874
  case axis
923
875
  when Integer
924
876
  axis = check_axis(axis)
@@ -927,25 +879,25 @@ module Cumo
927
879
  c = self.flatten
928
880
  axis = 0
929
881
  else
930
- raise ArgumentError,"invalid axis"
882
+ raise ArgumentError, "invalid axis"
931
883
  end
932
884
  case arg
933
885
  when Integer
934
- if !arg.kind_of?(Integer) || arg<1
935
- raise ArgumentError,"argument should be positive integer"
886
+ if !arg.kind_of?(Integer) || arg < 1
887
+ raise ArgumentError, "argument should be positive integer"
936
888
  end
937
- idx = c.shape[axis].times.map{|i| [i]*arg}.flatten
889
+ idx = c.shape[axis].times.map { |i| [i] * arg }.flatten
938
890
  else
939
891
  arg = arg.to_a
940
892
  if arg.size != c.shape[axis]
941
- raise ArgumentError,"repeat size shoud be equal to size along axis"
893
+ raise ArgumentError, "repeat size shoud be equal to size along axis"
942
894
  end
943
895
  arg.each do |i|
944
- if !i.kind_of?(Integer) || i<0
945
- raise ArgumentError,"argument should be non-negative integer"
896
+ if !i.kind_of?(Integer) || i < 0
897
+ raise ArgumentError, "argument should be non-negative integer"
946
898
  end
947
899
  end
948
- idx = arg.each_with_index.map{|a,i| [i]*a}.flatten
900
+ idx = arg.each_with_index.map { |a, i| [i] * a }.flatten
949
901
  end
950
902
  ref = [true] * c.ndim
951
903
  ref[axis] = idx
@@ -980,27 +932,27 @@ module Cumo
980
932
  # # Cumo::DFloat#shape=[1,4]
981
933
  # # [[-1, 2, 0, -2]]
982
934
 
983
- def diff(n=1,axis:-1)
935
+ def diff(n=1, axis:-1)
984
936
  axis = check_axis(axis)
985
937
  if n < 0 || n >= shape[axis]
986
- raise ShapeError,"n=#{n} is invalid for shape[#{axis}]=#{shape[axis]}"
938
+ raise ShapeError, "n=#{n} is invalid for shape[#{axis}]=#{shape[axis]}"
987
939
  end
988
940
  # calculate polynomial coefficient
989
- c = self.class[-1,1]
941
+ c = self.class[-1, 1]
990
942
  2.upto(n) do |i|
991
- x = self.class.zeros(i+1)
943
+ x = self.class.zeros(i + 1)
992
944
  x[0..-2] = c
993
- y = self.class.zeros(i+1)
945
+ y = self.class.zeros(i + 1)
994
946
  y[1..-1] = c
995
947
  c = y - x
996
948
  end
997
- s = [true]*ndim
949
+ s = [true] * ndim
998
950
  s[axis] = n..-1
999
951
  result = self[*s].dup
1000
952
  sum = result.inplace
1001
- (n-1).downto(0) do |i|
1002
- s = [true]*ndim
1003
- s[axis] = i..-n-1+i
953
+ (n - 1).downto(0) do |i|
954
+ s = [true] * ndim
955
+ s[axis] = i..-n - 1 + i
1004
956
  sum + self[*s] * c[i] # inplace addition
1005
957
  end
1006
958
  return result
@@ -1020,11 +972,11 @@ module Cumo
1020
972
  raise NArray::ShapeError, "must be >= 2-dimensional array"
1021
973
  end
1022
974
  if contiguous?
1023
- *shp,m,n = shape
1024
- idx = tril_indices(k-1)
1025
- reshape!(*shp,m*n)
1026
- self[false,idx] = 0
1027
- reshape!(*shp,m,n)
975
+ *shp, m, n = shape
976
+ idx = tril_indices(k - 1)
977
+ reshape!(*shp, m * n)
978
+ self[false, idx] = 0
979
+ reshape!(*shp, m, n)
1028
980
  else
1029
981
  store(triu(k))
1030
982
  end
@@ -1035,15 +987,15 @@ module Cumo
1035
987
  if ndim < 2
1036
988
  raise NArray::ShapeError, "must be >= 2-dimensional array"
1037
989
  end
1038
- m,n = shape[-2..-1]
1039
- NArray.triu_indices(m,n,k)
990
+ m, n = shape[-2..-1]
991
+ NArray.triu_indices(m, n, k)
1040
992
  end
1041
993
 
1042
994
  # Return the indices for the uppler-triangle on and above the k-th diagonal.
1043
- def self.triu_indices(m,n,k=0)
1044
- x = Cumo::Int64.new(m,1).seq + k
1045
- y = Cumo::Int64.new(1,n).seq
1046
- (x<=y).where
995
+ def self.triu_indices(m, n, k=0)
996
+ x = Cumo::Int64.new(m, 1).seq + k
997
+ y = Cumo::Int64.new(1, n).seq
998
+ (x <= y).where
1047
999
  end
1048
1000
 
1049
1001
  # Lower triangular matrix.
@@ -1059,11 +1011,11 @@ module Cumo
1059
1011
  raise NArray::ShapeError, "must be >= 2-dimensional array"
1060
1012
  end
1061
1013
  if contiguous?
1062
- idx = triu_indices(k+1)
1063
- *shp,m,n = shape
1064
- reshape!(*shp,m*n)
1065
- self[false,idx] = 0
1066
- reshape!(*shp,m,n)
1014
+ idx = triu_indices(k + 1)
1015
+ *shp, m, n = shape
1016
+ reshape!(*shp, m * n)
1017
+ self[false, idx] = 0
1018
+ reshape!(*shp, m, n)
1067
1019
  else
1068
1020
  store(tril(k))
1069
1021
  end
@@ -1074,15 +1026,15 @@ module Cumo
1074
1026
  if ndim < 2
1075
1027
  raise NArray::ShapeError, "must be >= 2-dimensional array"
1076
1028
  end
1077
- m,n = shape[-2..-1]
1078
- NArray.tril_indices(m,n,k)
1029
+ m, n = shape[-2..-1]
1030
+ NArray.tril_indices(m, n, k)
1079
1031
  end
1080
1032
 
1081
1033
  # Return the indices for the lower-triangle on and below the k-th diagonal.
1082
- def self.tril_indices(m,n,k=0)
1083
- x = Cumo::Int64.new(m,1).seq + k
1084
- y = Cumo::Int64.new(1,n).seq
1085
- (x>=y).where
1034
+ def self.tril_indices(m, n, k=0)
1035
+ x = Cumo::Int64.new(m, 1).seq + k
1036
+ y = Cumo::Int64.new(1, n).seq
1037
+ (x >= y).where
1086
1038
  end
1087
1039
 
1088
1040
  # Return the k-th diagonal indices.
@@ -1090,22 +1042,22 @@ module Cumo
1090
1042
  if ndim < 2
1091
1043
  raise NArray::ShapeError, "must be >= 2-dimensional array"
1092
1044
  end
1093
- m,n = shape[-2..-1]
1094
- NArray.diag_indices(m,n,k)
1045
+ m, n = shape[-2..-1]
1046
+ NArray.diag_indices(m, n, k)
1095
1047
  end
1096
1048
 
1097
1049
  # Return the k-th diagonal indices.
1098
- def self.diag_indices(m,n,k=0)
1099
- x = Cumo::Int64.new(m,1).seq + k
1100
- y = Cumo::Int64.new(1,n).seq
1050
+ def self.diag_indices(m, n, k=0)
1051
+ x = Cumo::Int64.new(m, 1).seq + k
1052
+ y = Cumo::Int64.new(1, n).seq
1101
1053
  (x.eq y).where
1102
1054
  end
1103
1055
 
1104
1056
  # Return a matrix whose diagonal is constructed by self along the last axis.
1105
1057
  def diag(k=0)
1106
- *shp,n = shape
1058
+ *shp, n = shape
1107
1059
  n += k.abs
1108
- a = self.class.zeros(*shp,n,n)
1060
+ a = self.class.zeros(*shp, n, n)
1109
1061
  a.diagonal(k).store(self)
1110
1062
  a
1111
1063
  end
@@ -1120,8 +1072,8 @@ module Cumo
1120
1072
  # @param axis [Array] (optional, default=[-2,-1]) diagonal axis
1121
1073
  # @param nan [Bool] (optional, default=false) nan-aware algorithm, i.e., if true then it ignores nan.
1122
1074
 
1123
- def trace(offset=nil,axis=nil,nan:false)
1124
- diagonal(offset,axis).sum(nan:nan,axis:-1)
1075
+ def trace(offset=nil, axis=nil, nan:false)
1076
+ diagonal(offset, axis).sum(nan:nan, axis:-1)
1125
1077
  end
1126
1078
 
1127
1079
 
@@ -1164,20 +1116,20 @@ module Cumo
1164
1116
  when 0
1165
1117
  b.mulsum(self, axis:-2)
1166
1118
  when 1
1167
- self[true,:new].mulsum(b, axis:-2)
1119
+ self[true, :new].mulsum(b, axis:-2)
1168
1120
  else
1169
1121
  unless @@warn_slow_dot
1170
1122
  nx = 200
1171
1123
  ns = 200000
1172
- am,an = shape[-2..-1]
1173
- bm,bn = b.shape[-2..-1]
1124
+ am, an = shape[-2..-1]
1125
+ bm, bn = b.shape[-2..-1]
1174
1126
  if am > nx && an > nx && bm > nx && bn > nx &&
1175
1127
  size > ns && b.size > ns
1176
1128
  @@warn_slow_dot = true
1177
1129
  warn "\nwarning: matrix dot for #{t} is slow. Consider SFloat, DFloat, SComplex, or DComplex to use cuBLAS.\n\n"
1178
1130
  end
1179
1131
  end
1180
- self[false,:new].mulsum(b[false,:new,true,true], axis:-2)
1132
+ self[false, :new].mulsum(b[false, :new, true, true], axis:-2)
1181
1133
  end
1182
1134
  end
1183
1135
  end
@@ -1217,17 +1169,17 @@ module Cumo
1217
1169
  def outer(b, axis:nil)
1218
1170
  b = NArray.cast(b)
1219
1171
  if axis.nil?
1220
- self[false,:new] * ((b.ndim==0) ? b : b[false,:new,true])
1172
+ self[false, :new] * ((b.ndim == 0) ? b : b[false, :new, true])
1221
1173
  else
1222
- md,nd = [ndim,b.ndim].minmax
1174
+ md, nd = [ndim, b.ndim].minmax
1223
1175
  axis = check_axis(axis) - nd
1224
1176
  if axis < -md
1225
- raise ArgumentError,"axis=#{axis} is out of range"
1177
+ raise ArgumentError, "axis=#{axis} is out of range"
1226
1178
  end
1227
- adim = [true]*ndim
1228
- adim[axis+ndim+1,0] = :new
1229
- bdim = [true]*b.ndim
1230
- bdim[axis+b.ndim,0] = :new
1179
+ adim = [true] * ndim
1180
+ adim[axis + ndim + 1, 0] = :new
1181
+ bdim = [true] * b.ndim
1182
+ bdim[axis + b.ndim, 0] = :new
1231
1183
  self[*adim] * b[*bdim]
1232
1184
  end
1233
1185
  end
@@ -1259,9 +1211,9 @@ module Cumo
1259
1211
  ndb = b.ndim
1260
1212
  shpa = shape
1261
1213
  shpb = b.shape
1262
- adim = [:new]*(2*[ndb-nda,0].max) + [true,:new]*nda
1263
- bdim = [:new]*(2*[nda-ndb,0].max) + [:new,true]*ndb
1264
- shpr = (-[nda,ndb].max..-1).map{|i| (shpa[i]||1) * (shpb[i]||1)}
1214
+ adim = [:new] * (2 * [ndb - nda, 0].max) + [true, :new] * nda
1215
+ bdim = [:new] * (2 * [nda - ndb, 0].max) + [:new, true] * ndb
1216
+ shpr = (-[nda, ndb].max..-1).map { |i| (shpa[i] || 1) * (shpb[i] || 1) }
1265
1217
  (self[*adim] * b[*bdim]).reshape(*shpr)
1266
1218
  end
1267
1219
 
@@ -1269,7 +1221,7 @@ module Cumo
1269
1221
  # under construction
1270
1222
  def cov(y=nil, ddof:1, fweights:nil, aweights:nil)
1271
1223
  if y
1272
- m = NArray.vstack([self,y])
1224
+ m = NArray.vstack([self, y])
1273
1225
  else
1274
1226
  m = self
1275
1227
  end
@@ -1280,7 +1232,7 @@ module Cumo
1280
1232
  end
1281
1233
  if aweights
1282
1234
  a = aweights
1283
- w = w ? w*a : a
1235
+ w = w ? w * a : a
1284
1236
  end
1285
1237
  if w
1286
1238
  w_sum = w.sum(axis:-1, keepdims:true)
@@ -1289,23 +1241,23 @@ module Cumo
1289
1241
  elsif aweights.nil?
1290
1242
  fact = w_sum - ddof
1291
1243
  else
1292
- wa_sum = (w*a).sum(axis:-1, keepdims:true)
1244
+ wa_sum = (w * a).sum(axis:-1, keepdims:true)
1293
1245
  fact = w_sum - ddof * wa_sum / w_sum
1294
1246
  end
1295
1247
  if (fact <= 0).any?
1296
- raise StandardError,"Degrees of freedom <= 0 for slice"
1248
+ raise StandardError, "Degrees of freedom <= 0 for slice"
1297
1249
  end
1298
1250
  else
1299
1251
  fact = m.shape[-1] - ddof
1300
1252
  end
1301
1253
  if w
1302
- m -= (m*w).sum(axis:-1, keepdims:true) / w_sum
1303
- mw = m*w
1254
+ m -= (m * w).sum(axis:-1, keepdims:true) / w_sum
1255
+ mw = m * w
1304
1256
  else
1305
1257
  m -= m.mean(axis:-1, keepdims:true)
1306
1258
  mw = m
1307
1259
  end
1308
- mt = (m.ndim < 2) ? m : m.swapaxes(-2,-1)
1260
+ mt = (m.ndim < 2) ? m : m.swapaxes(-2, -1)
1309
1261
  mw.dot(mt.conj) / fact
1310
1262
  end
1311
1263
 
@@ -1313,15 +1265,15 @@ module Cumo
1313
1265
 
1314
1266
  # @!visibility private
1315
1267
  def check_axis(axis)
1316
- unless Integer===axis
1317
- raise ArgumentError,"axis=#{axis} must be Integer"
1268
+ unless Integer === axis
1269
+ raise ArgumentError, "axis=#{axis} must be Integer"
1318
1270
  end
1319
1271
  a = axis
1320
1272
  if a < 0
1321
1273
  a += ndim
1322
1274
  end
1323
1275
  if a < 0 || a >= ndim
1324
- raise ArgumentError,"axis=#{axis} is invalid"
1276
+ raise ArgumentError, "axis=#{axis} is invalid"
1325
1277
  end
1326
1278
  a
1327
1279
  end