cumo 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (266) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +27 -0
  3. data/.travis.yml +5 -0
  4. data/3rd_party/mkmf-cu/.gitignore +36 -0
  5. data/3rd_party/mkmf-cu/Gemfile +3 -0
  6. data/3rd_party/mkmf-cu/LICENSE +21 -0
  7. data/3rd_party/mkmf-cu/README.md +36 -0
  8. data/3rd_party/mkmf-cu/Rakefile +11 -0
  9. data/3rd_party/mkmf-cu/bin/mkmf-cu-nvcc +4 -0
  10. data/3rd_party/mkmf-cu/lib/mkmf-cu.rb +32 -0
  11. data/3rd_party/mkmf-cu/lib/mkmf-cu/cli.rb +80 -0
  12. data/3rd_party/mkmf-cu/lib/mkmf-cu/nvcc.rb +157 -0
  13. data/3rd_party/mkmf-cu/mkmf-cu.gemspec +16 -0
  14. data/3rd_party/mkmf-cu/test/test_mkmf-cu.rb +67 -0
  15. data/CODE_OF_CONDUCT.md +46 -0
  16. data/Gemfile +8 -0
  17. data/LICENSE.txt +82 -0
  18. data/README.md +252 -0
  19. data/Rakefile +43 -0
  20. data/bench/broadcast_fp32.rb +138 -0
  21. data/bench/cumo_bench.rb +193 -0
  22. data/bench/numo_bench.rb +138 -0
  23. data/bench/reduction_fp32.rb +117 -0
  24. data/bin/console +14 -0
  25. data/bin/setup +8 -0
  26. data/cumo.gemspec +32 -0
  27. data/ext/cumo/cuda/cublas.c +278 -0
  28. data/ext/cumo/cuda/driver.c +421 -0
  29. data/ext/cumo/cuda/memory_pool.cpp +185 -0
  30. data/ext/cumo/cuda/memory_pool_impl.cpp +308 -0
  31. data/ext/cumo/cuda/memory_pool_impl.hpp +370 -0
  32. data/ext/cumo/cuda/memory_pool_impl_test.cpp +554 -0
  33. data/ext/cumo/cuda/nvrtc.c +207 -0
  34. data/ext/cumo/cuda/runtime.c +167 -0
  35. data/ext/cumo/cumo.c +148 -0
  36. data/ext/cumo/depend.erb +58 -0
  37. data/ext/cumo/extconf.rb +179 -0
  38. data/ext/cumo/include/cumo.h +25 -0
  39. data/ext/cumo/include/cumo/compat.h +23 -0
  40. data/ext/cumo/include/cumo/cuda/cublas.h +153 -0
  41. data/ext/cumo/include/cumo/cuda/cumo_thrust.hpp +187 -0
  42. data/ext/cumo/include/cumo/cuda/cumo_thrust_complex.hpp +79 -0
  43. data/ext/cumo/include/cumo/cuda/driver.h +22 -0
  44. data/ext/cumo/include/cumo/cuda/memory_pool.h +28 -0
  45. data/ext/cumo/include/cumo/cuda/nvrtc.h +22 -0
  46. data/ext/cumo/include/cumo/cuda/runtime.h +40 -0
  47. data/ext/cumo/include/cumo/indexer.h +238 -0
  48. data/ext/cumo/include/cumo/intern.h +142 -0
  49. data/ext/cumo/include/cumo/intern_fwd.h +38 -0
  50. data/ext/cumo/include/cumo/intern_kernel.h +6 -0
  51. data/ext/cumo/include/cumo/narray.h +429 -0
  52. data/ext/cumo/include/cumo/narray_kernel.h +149 -0
  53. data/ext/cumo/include/cumo/ndloop.h +95 -0
  54. data/ext/cumo/include/cumo/reduce_kernel.h +126 -0
  55. data/ext/cumo/include/cumo/template.h +158 -0
  56. data/ext/cumo/include/cumo/template_kernel.h +77 -0
  57. data/ext/cumo/include/cumo/types/bit.h +40 -0
  58. data/ext/cumo/include/cumo/types/bit_kernel.h +34 -0
  59. data/ext/cumo/include/cumo/types/complex.h +402 -0
  60. data/ext/cumo/include/cumo/types/complex_kernel.h +414 -0
  61. data/ext/cumo/include/cumo/types/complex_macro.h +382 -0
  62. data/ext/cumo/include/cumo/types/complex_macro_kernel.h +186 -0
  63. data/ext/cumo/include/cumo/types/dcomplex.h +46 -0
  64. data/ext/cumo/include/cumo/types/dcomplex_kernel.h +13 -0
  65. data/ext/cumo/include/cumo/types/dfloat.h +47 -0
  66. data/ext/cumo/include/cumo/types/dfloat_kernel.h +14 -0
  67. data/ext/cumo/include/cumo/types/float_def.h +34 -0
  68. data/ext/cumo/include/cumo/types/float_def_kernel.h +39 -0
  69. data/ext/cumo/include/cumo/types/float_macro.h +191 -0
  70. data/ext/cumo/include/cumo/types/float_macro_kernel.h +158 -0
  71. data/ext/cumo/include/cumo/types/int16.h +24 -0
  72. data/ext/cumo/include/cumo/types/int16_kernel.h +23 -0
  73. data/ext/cumo/include/cumo/types/int32.h +24 -0
  74. data/ext/cumo/include/cumo/types/int32_kernel.h +19 -0
  75. data/ext/cumo/include/cumo/types/int64.h +24 -0
  76. data/ext/cumo/include/cumo/types/int64_kernel.h +19 -0
  77. data/ext/cumo/include/cumo/types/int8.h +24 -0
  78. data/ext/cumo/include/cumo/types/int8_kernel.h +19 -0
  79. data/ext/cumo/include/cumo/types/int_macro.h +67 -0
  80. data/ext/cumo/include/cumo/types/int_macro_kernel.h +48 -0
  81. data/ext/cumo/include/cumo/types/real_accum.h +486 -0
  82. data/ext/cumo/include/cumo/types/real_accum_kernel.h +101 -0
  83. data/ext/cumo/include/cumo/types/robj_macro.h +80 -0
  84. data/ext/cumo/include/cumo/types/robj_macro_kernel.h +0 -0
  85. data/ext/cumo/include/cumo/types/robject.h +27 -0
  86. data/ext/cumo/include/cumo/types/robject_kernel.h +7 -0
  87. data/ext/cumo/include/cumo/types/scomplex.h +46 -0
  88. data/ext/cumo/include/cumo/types/scomplex_kernel.h +13 -0
  89. data/ext/cumo/include/cumo/types/sfloat.h +48 -0
  90. data/ext/cumo/include/cumo/types/sfloat_kernel.h +14 -0
  91. data/ext/cumo/include/cumo/types/uint16.h +25 -0
  92. data/ext/cumo/include/cumo/types/uint16_kernel.h +20 -0
  93. data/ext/cumo/include/cumo/types/uint32.h +25 -0
  94. data/ext/cumo/include/cumo/types/uint32_kernel.h +20 -0
  95. data/ext/cumo/include/cumo/types/uint64.h +25 -0
  96. data/ext/cumo/include/cumo/types/uint64_kernel.h +20 -0
  97. data/ext/cumo/include/cumo/types/uint8.h +25 -0
  98. data/ext/cumo/include/cumo/types/uint8_kernel.h +20 -0
  99. data/ext/cumo/include/cumo/types/uint_macro.h +58 -0
  100. data/ext/cumo/include/cumo/types/uint_macro_kernel.h +38 -0
  101. data/ext/cumo/include/cumo/types/xint_macro.h +169 -0
  102. data/ext/cumo/include/cumo/types/xint_macro_kernel.h +88 -0
  103. data/ext/cumo/narray/SFMT-params.h +97 -0
  104. data/ext/cumo/narray/SFMT-params19937.h +46 -0
  105. data/ext/cumo/narray/SFMT.c +620 -0
  106. data/ext/cumo/narray/SFMT.h +167 -0
  107. data/ext/cumo/narray/array.c +638 -0
  108. data/ext/cumo/narray/data.c +961 -0
  109. data/ext/cumo/narray/gen/cogen.rb +56 -0
  110. data/ext/cumo/narray/gen/cogen_kernel.rb +58 -0
  111. data/ext/cumo/narray/gen/def/bit.rb +37 -0
  112. data/ext/cumo/narray/gen/def/dcomplex.rb +39 -0
  113. data/ext/cumo/narray/gen/def/dfloat.rb +37 -0
  114. data/ext/cumo/narray/gen/def/int16.rb +36 -0
  115. data/ext/cumo/narray/gen/def/int32.rb +36 -0
  116. data/ext/cumo/narray/gen/def/int64.rb +36 -0
  117. data/ext/cumo/narray/gen/def/int8.rb +36 -0
  118. data/ext/cumo/narray/gen/def/robject.rb +37 -0
  119. data/ext/cumo/narray/gen/def/scomplex.rb +39 -0
  120. data/ext/cumo/narray/gen/def/sfloat.rb +37 -0
  121. data/ext/cumo/narray/gen/def/uint16.rb +36 -0
  122. data/ext/cumo/narray/gen/def/uint32.rb +36 -0
  123. data/ext/cumo/narray/gen/def/uint64.rb +36 -0
  124. data/ext/cumo/narray/gen/def/uint8.rb +36 -0
  125. data/ext/cumo/narray/gen/erbpp2.rb +346 -0
  126. data/ext/cumo/narray/gen/narray_def.rb +268 -0
  127. data/ext/cumo/narray/gen/spec.rb +425 -0
  128. data/ext/cumo/narray/gen/tmpl/accum.c +86 -0
  129. data/ext/cumo/narray/gen/tmpl/accum_binary.c +121 -0
  130. data/ext/cumo/narray/gen/tmpl/accum_binary_kernel.cu +61 -0
  131. data/ext/cumo/narray/gen/tmpl/accum_index.c +119 -0
  132. data/ext/cumo/narray/gen/tmpl/accum_index_kernel.cu +66 -0
  133. data/ext/cumo/narray/gen/tmpl/accum_kernel.cu +12 -0
  134. data/ext/cumo/narray/gen/tmpl/alloc_func.c +107 -0
  135. data/ext/cumo/narray/gen/tmpl/allocate.c +37 -0
  136. data/ext/cumo/narray/gen/tmpl/aref.c +66 -0
  137. data/ext/cumo/narray/gen/tmpl/aref_cpu.c +50 -0
  138. data/ext/cumo/narray/gen/tmpl/aset.c +56 -0
  139. data/ext/cumo/narray/gen/tmpl/binary.c +162 -0
  140. data/ext/cumo/narray/gen/tmpl/binary2.c +70 -0
  141. data/ext/cumo/narray/gen/tmpl/binary2_kernel.cu +15 -0
  142. data/ext/cumo/narray/gen/tmpl/binary_kernel.cu +31 -0
  143. data/ext/cumo/narray/gen/tmpl/binary_s.c +45 -0
  144. data/ext/cumo/narray/gen/tmpl/binary_s_kernel.cu +15 -0
  145. data/ext/cumo/narray/gen/tmpl/bincount.c +181 -0
  146. data/ext/cumo/narray/gen/tmpl/cast.c +44 -0
  147. data/ext/cumo/narray/gen/tmpl/cast_array.c +13 -0
  148. data/ext/cumo/narray/gen/tmpl/class.c +9 -0
  149. data/ext/cumo/narray/gen/tmpl/class_kernel.cu +6 -0
  150. data/ext/cumo/narray/gen/tmpl/clip.c +121 -0
  151. data/ext/cumo/narray/gen/tmpl/coerce_cast.c +10 -0
  152. data/ext/cumo/narray/gen/tmpl/complex_accum_kernel.cu +129 -0
  153. data/ext/cumo/narray/gen/tmpl/cond_binary.c +68 -0
  154. data/ext/cumo/narray/gen/tmpl/cond_binary_kernel.cu +18 -0
  155. data/ext/cumo/narray/gen/tmpl/cond_unary.c +46 -0
  156. data/ext/cumo/narray/gen/tmpl/cum.c +50 -0
  157. data/ext/cumo/narray/gen/tmpl/each.c +47 -0
  158. data/ext/cumo/narray/gen/tmpl/each_with_index.c +70 -0
  159. data/ext/cumo/narray/gen/tmpl/ewcomp.c +79 -0
  160. data/ext/cumo/narray/gen/tmpl/ewcomp_kernel.cu +19 -0
  161. data/ext/cumo/narray/gen/tmpl/extract.c +22 -0
  162. data/ext/cumo/narray/gen/tmpl/extract_cpu.c +26 -0
  163. data/ext/cumo/narray/gen/tmpl/extract_data.c +53 -0
  164. data/ext/cumo/narray/gen/tmpl/eye.c +105 -0
  165. data/ext/cumo/narray/gen/tmpl/eye_kernel.cu +19 -0
  166. data/ext/cumo/narray/gen/tmpl/fill.c +52 -0
  167. data/ext/cumo/narray/gen/tmpl/fill_kernel.cu +29 -0
  168. data/ext/cumo/narray/gen/tmpl/float_accum_kernel.cu +106 -0
  169. data/ext/cumo/narray/gen/tmpl/format.c +62 -0
  170. data/ext/cumo/narray/gen/tmpl/format_to_a.c +49 -0
  171. data/ext/cumo/narray/gen/tmpl/frexp.c +38 -0
  172. data/ext/cumo/narray/gen/tmpl/gemm.c +203 -0
  173. data/ext/cumo/narray/gen/tmpl/init_class.c +20 -0
  174. data/ext/cumo/narray/gen/tmpl/init_module.c +12 -0
  175. data/ext/cumo/narray/gen/tmpl/inspect.c +21 -0
  176. data/ext/cumo/narray/gen/tmpl/lib.c +50 -0
  177. data/ext/cumo/narray/gen/tmpl/lib_kernel.cu +24 -0
  178. data/ext/cumo/narray/gen/tmpl/logseq.c +102 -0
  179. data/ext/cumo/narray/gen/tmpl/logseq_kernel.cu +31 -0
  180. data/ext/cumo/narray/gen/tmpl/map_with_index.c +98 -0
  181. data/ext/cumo/narray/gen/tmpl/median.c +66 -0
  182. data/ext/cumo/narray/gen/tmpl/minmax.c +47 -0
  183. data/ext/cumo/narray/gen/tmpl/module.c +9 -0
  184. data/ext/cumo/narray/gen/tmpl/module_kernel.cu +1 -0
  185. data/ext/cumo/narray/gen/tmpl/new_dim0.c +15 -0
  186. data/ext/cumo/narray/gen/tmpl/new_dim0_kernel.cu +8 -0
  187. data/ext/cumo/narray/gen/tmpl/poly.c +50 -0
  188. data/ext/cumo/narray/gen/tmpl/pow.c +97 -0
  189. data/ext/cumo/narray/gen/tmpl/pow_kernel.cu +29 -0
  190. data/ext/cumo/narray/gen/tmpl/powint.c +17 -0
  191. data/ext/cumo/narray/gen/tmpl/qsort.c +212 -0
  192. data/ext/cumo/narray/gen/tmpl/rand.c +168 -0
  193. data/ext/cumo/narray/gen/tmpl/rand_norm.c +121 -0
  194. data/ext/cumo/narray/gen/tmpl/real_accum_kernel.cu +75 -0
  195. data/ext/cumo/narray/gen/tmpl/seq.c +112 -0
  196. data/ext/cumo/narray/gen/tmpl/seq_kernel.cu +43 -0
  197. data/ext/cumo/narray/gen/tmpl/set2.c +57 -0
  198. data/ext/cumo/narray/gen/tmpl/sort.c +48 -0
  199. data/ext/cumo/narray/gen/tmpl/sort_index.c +111 -0
  200. data/ext/cumo/narray/gen/tmpl/store.c +41 -0
  201. data/ext/cumo/narray/gen/tmpl/store_array.c +187 -0
  202. data/ext/cumo/narray/gen/tmpl/store_array_kernel.cu +58 -0
  203. data/ext/cumo/narray/gen/tmpl/store_bit.c +86 -0
  204. data/ext/cumo/narray/gen/tmpl/store_bit_kernel.cu +66 -0
  205. data/ext/cumo/narray/gen/tmpl/store_from.c +81 -0
  206. data/ext/cumo/narray/gen/tmpl/store_from_kernel.cu +58 -0
  207. data/ext/cumo/narray/gen/tmpl/store_kernel.cu +3 -0
  208. data/ext/cumo/narray/gen/tmpl/store_numeric.c +9 -0
  209. data/ext/cumo/narray/gen/tmpl/to_a.c +43 -0
  210. data/ext/cumo/narray/gen/tmpl/unary.c +132 -0
  211. data/ext/cumo/narray/gen/tmpl/unary2.c +60 -0
  212. data/ext/cumo/narray/gen/tmpl/unary_kernel.cu +72 -0
  213. data/ext/cumo/narray/gen/tmpl/unary_ret2.c +34 -0
  214. data/ext/cumo/narray/gen/tmpl/unary_s.c +86 -0
  215. data/ext/cumo/narray/gen/tmpl/unary_s_kernel.cu +58 -0
  216. data/ext/cumo/narray/gen/tmpl_bit/allocate.c +24 -0
  217. data/ext/cumo/narray/gen/tmpl_bit/aref.c +54 -0
  218. data/ext/cumo/narray/gen/tmpl_bit/aref_cpu.c +57 -0
  219. data/ext/cumo/narray/gen/tmpl_bit/aset.c +56 -0
  220. data/ext/cumo/narray/gen/tmpl_bit/binary.c +98 -0
  221. data/ext/cumo/narray/gen/tmpl_bit/bit_count.c +64 -0
  222. data/ext/cumo/narray/gen/tmpl_bit/bit_count_cpu.c +88 -0
  223. data/ext/cumo/narray/gen/tmpl_bit/bit_count_kernel.cu +76 -0
  224. data/ext/cumo/narray/gen/tmpl_bit/bit_reduce.c +133 -0
  225. data/ext/cumo/narray/gen/tmpl_bit/each.c +48 -0
  226. data/ext/cumo/narray/gen/tmpl_bit/each_with_index.c +70 -0
  227. data/ext/cumo/narray/gen/tmpl_bit/extract.c +30 -0
  228. data/ext/cumo/narray/gen/tmpl_bit/extract_cpu.c +29 -0
  229. data/ext/cumo/narray/gen/tmpl_bit/fill.c +69 -0
  230. data/ext/cumo/narray/gen/tmpl_bit/format.c +64 -0
  231. data/ext/cumo/narray/gen/tmpl_bit/format_to_a.c +51 -0
  232. data/ext/cumo/narray/gen/tmpl_bit/inspect.c +21 -0
  233. data/ext/cumo/narray/gen/tmpl_bit/mask.c +136 -0
  234. data/ext/cumo/narray/gen/tmpl_bit/none_p.c +14 -0
  235. data/ext/cumo/narray/gen/tmpl_bit/store_array.c +108 -0
  236. data/ext/cumo/narray/gen/tmpl_bit/store_bit.c +70 -0
  237. data/ext/cumo/narray/gen/tmpl_bit/store_from.c +60 -0
  238. data/ext/cumo/narray/gen/tmpl_bit/to_a.c +47 -0
  239. data/ext/cumo/narray/gen/tmpl_bit/unary.c +81 -0
  240. data/ext/cumo/narray/gen/tmpl_bit/where.c +90 -0
  241. data/ext/cumo/narray/gen/tmpl_bit/where2.c +95 -0
  242. data/ext/cumo/narray/index.c +880 -0
  243. data/ext/cumo/narray/kwargs.c +153 -0
  244. data/ext/cumo/narray/math.c +142 -0
  245. data/ext/cumo/narray/narray.c +1948 -0
  246. data/ext/cumo/narray/ndloop.c +2105 -0
  247. data/ext/cumo/narray/rand.c +45 -0
  248. data/ext/cumo/narray/step.c +474 -0
  249. data/ext/cumo/narray/struct.c +886 -0
  250. data/lib/cumo.rb +3 -0
  251. data/lib/cumo/cuda.rb +11 -0
  252. data/lib/cumo/cuda/compile_error.rb +36 -0
  253. data/lib/cumo/cuda/compiler.rb +161 -0
  254. data/lib/cumo/cuda/device.rb +47 -0
  255. data/lib/cumo/cuda/link_state.rb +31 -0
  256. data/lib/cumo/cuda/module.rb +40 -0
  257. data/lib/cumo/cuda/nvrtc_program.rb +27 -0
  258. data/lib/cumo/linalg.rb +12 -0
  259. data/lib/cumo/narray.rb +2 -0
  260. data/lib/cumo/narray/extra.rb +1278 -0
  261. data/lib/erbpp.rb +294 -0
  262. data/lib/erbpp/line_number.rb +137 -0
  263. data/lib/erbpp/narray_def.rb +381 -0
  264. data/numo-narray-version +1 -0
  265. data/run.gdb +7 -0
  266. metadata +353 -0
@@ -0,0 +1,3 @@
1
+ require_relative File.join(__dir__, '../ext/cumo/cumo')
2
+ require_relative 'cumo/cuda'
3
+ require_relative 'cumo/narray/extra'
@@ -0,0 +1,11 @@
1
+ module Cumo
2
+ module CUDA
3
+ end
4
+ end
5
+
6
+ require_relative 'cuda/compile_error'
7
+ require_relative 'cuda/compiler'
8
+ require_relative 'cuda/device'
9
+ require_relative 'cuda/module'
10
+ require_relative 'cuda/link_state'
11
+ require_relative 'cuda/nvrtc_program'
@@ -0,0 +1,36 @@
1
+ require_relative '../cuda'
2
+
3
+ module Cumo::CUDA
4
+ class CompileError < StandardError
5
+ def initialize(msg, source, name, options)
6
+ @msg = msg
7
+ @source = source
8
+ @name = name
9
+ @options = options
10
+ end
11
+
12
+ def message
13
+ @msg
14
+ end
15
+
16
+ def to_s
17
+ @msg
18
+ end
19
+
20
+ def dump(io)
21
+ lines = @source.split("\n")
22
+ digits = Math.log10(lines.size).floor + 1
23
+ linum_fmt = "%0#{digits}d "
24
+ io.puts("NVRTC compilation error: #{@msg}")
25
+ io.puts("-----")
26
+ io.puts("Name: #{@name}")
27
+ io.puts("Options: #{@options.join(' ')}")
28
+ io.puts("CUDA source:")
29
+ lines.each.with_index do |line, i|
30
+ io.puts(linum_fmt.sprintf(i + 1) << line.rstrip)
31
+ end
32
+ io.puts("-----")
33
+ io.flush
34
+ end
35
+ end
36
+ end
@@ -0,0 +1,161 @@
1
+ require 'tmpdir'
2
+ require 'tempfile'
3
+ require 'fileutils'
4
+ require 'digest/md5'
5
+ require_relative '../cuda'
6
+
7
+ module Cumo::CUDA
8
+ class Compiler
9
+ VALID_KERNEL_NAME = /\A[a-zA-Z_][a-zA-Z_0-9]*\z/
10
+ DEFAULT_CACHE_DIR = File.expand_path('~/.cumo/kernel_cache')
11
+
12
+ @@empty_file_preprocess_cache ||= {}
13
+
14
+ def self.valid_kernel_name?(name)
15
+ VALID_KERNEL_NAME.match?(name)
16
+ end
17
+
18
+ def compile_using_nvrtc(source, options: [], arch: nil)
19
+ arch ||= get_arch
20
+ options += ["-arch=#{arch}"]
21
+
22
+ Dir.mktmpdir do |root_dir|
23
+ path = File.join(root_dir, 'kern')
24
+ cu_path = "#{path}.cu"
25
+
26
+ File.open(cu_path, 'w') do |cu_file|
27
+ cu_file.write(source)
28
+ end
29
+
30
+ prog = NVRTCProgram.new(source, name: cu_path)
31
+ begin
32
+ ptx = prog.compile(options: options)
33
+ rescue CompileError => e
34
+ if get_bool_env_variable('CUMO_DUMP_CUDA_SOURCE_ON_ERROR', false)
35
+ e.dump($stderr)
36
+ end
37
+ raise e
38
+ ensure
39
+ prog.destroy
40
+ end
41
+ return ptx
42
+ end
43
+ end
44
+
45
+ def compile_with_cache(source, options: [], arch: nil, cache_dir: nil, extra_source: nil)
46
+ # NVRTC does not use extra_source. extra_source is used for cache key.
47
+ cache_dir ||= get_cache_dir
48
+ arch ||= get_arch
49
+
50
+ options += ['-ftz=true']
51
+
52
+ env = [arch, options, get_nvrtc_version]
53
+ base = @@empty_file_preprocess_cache[env]
54
+ if base.nil?
55
+ # This is checking of NVRTC compiler internal version
56
+ base = preprocess('', options, arch)
57
+ @@empty_file_preprocess_cache[env] = base
58
+ end
59
+ key_src = "#{env} #{base} #{source} #{extra_source}"
60
+
61
+ key_src.encode!('utf-8')
62
+ digest = Digest::MD5.hexdigest(key_src)
63
+ name = "#{digest}_2.cubin"
64
+
65
+ unless Dir.exist?(cache_dir)
66
+ FileUtils.mkdir_p(cache_dir)
67
+ end
68
+
69
+ # TODO(sonots): thread-safe?
70
+ path = File.join(cache_dir, name)
71
+ cubin = load_cache(path)
72
+ if cubin
73
+ mod = Module.new
74
+ mod.load(cubin)
75
+ return mod
76
+ end
77
+
78
+ ptx = compile_using_nvrtc(source, options: options, arch: arch)
79
+ cubin = nil
80
+ cubin_hash = nil
81
+ LinkState.new do |ls|
82
+ ls.add_ptr_data(ptx, 'cumo.ptx')
83
+ cubin = ls.complete()
84
+ cubin_hash = Digest::MD5.hexdigest(cubin)
85
+ end
86
+
87
+ save_cache(path, cubin_hash, cubin)
88
+
89
+ # Save .cu source file along with .cubin
90
+ if get_bool_env_variable('CUMO_CACHE_SAVE_CUDA_SOURCE', false)
91
+ File.open("#{path}.cu", 'w') do |f|
92
+ f.write(source)
93
+ end
94
+ end
95
+
96
+ mod = Module.new
97
+ mod.load(cubin)
98
+ return mod
99
+ end
100
+
101
+ private
102
+
103
+ def save_cache(path, cubin_hash, cubin)
104
+ tf = Tempfile.create
105
+ tf.write(cubin_hash)
106
+ tf.write(cubin)
107
+ temp_path = tf.path
108
+ File.rename(temp_path, path)
109
+ end
110
+
111
+ def load_cache(path)
112
+ return nil unless File.exist?(path)
113
+ File.open(path, 'rb') do |file|
114
+ data = file.read
115
+ return nil unless data.size >= 32
116
+ hash = data[0...32]
117
+ cubin = data[32..-1]
118
+ cubin_hash = Digest::MD5.hexdigest(cubin)
119
+ return nil unless hash == cubin_hash
120
+ return cubin
121
+ end
122
+ nil
123
+ end
124
+
125
+ def get_cache_dir
126
+ ENV.fetch('CUMO_CACHE_DIR', DEFAULT_CACHE_DIR)
127
+ end
128
+
129
+ def get_nvrtc_version
130
+ @@nvrtc_version ||= NVRTC.nvrtcVersion
131
+ end
132
+
133
+ def get_arch
134
+ cc = Device.new.compute_capability
135
+ "compute_#{cc}"
136
+ end
137
+
138
+ def get_bool_env_variable(name, default)
139
+ val = ENV[name]
140
+ return default if val.nil? or val.size == 0
141
+ Integer(val) == 1 rescue false
142
+ end
143
+
144
+ def preprocess(source, options, arch)
145
+ options += ["-arch=#{arch}"]
146
+
147
+ prog = NVRTCProgram.new(source, name: '')
148
+ begin
149
+ result = prog.compile(options: options)
150
+ return result
151
+ rescue CompileError => e
152
+ if get_bool_env_variable('CUMO_DUMP_CUDA_SOURCE_ON_ERROR', false)
153
+ e.dump($stderr)
154
+ end
155
+ raise e
156
+ ensure
157
+ prog.destroy
158
+ end
159
+ end
160
+ end
161
+ end
@@ -0,0 +1,47 @@
1
+ require_relative '../cuda'
2
+
3
+ module Cumo::CUDA
4
+ class Device
5
+ attr_reader :id
6
+
7
+ def self.get_currend_id
8
+ Runtime.cudaGetDevice
9
+ end
10
+
11
+ def initialize(device_id = nil)
12
+ if device_id
13
+ @id = device_id
14
+ else
15
+ @id = Runtime.cudaGetDevice
16
+ end
17
+ @_device_stack = []
18
+ end
19
+
20
+ def use
21
+ Runtime.cudaSetDevice(@id)
22
+ end
23
+
24
+ def with
25
+ raise unless block_given?
26
+ prev_id = Runtime.cudaGetDevice
27
+ @_device_stack << prev_id
28
+ begin
29
+ Runtime.cudaSetDevice(@id) unless prev_id != @id
30
+ yield
31
+ ensure
32
+ prev_id = @_device_stack.pop
33
+ Runtime.cudaSetDevice(prev_id)
34
+ end
35
+ end
36
+
37
+ def synchronize
38
+ Runtime.cudaDeviceSynchronize
39
+ end
40
+
41
+ def compute_capability
42
+ major = Runtime.cudaDeviceGetAttributes(75, @id)
43
+ minor = Runtime.cudaDeviceGetAttributes(76, @id)
44
+ "#{major}#{minor}"
45
+ end
46
+ end
47
+ end
@@ -0,0 +1,31 @@
1
+ require_relative '../cuda'
2
+
3
+ module Cumo::CUDA
4
+ # CUDA link state.
5
+ class LinkState
6
+ def initialize
7
+ @ptr = Driver.cuLinkCreate
8
+ if block_given?
9
+ begin
10
+ yield(self)
11
+ ensure
12
+ destroy
13
+ end
14
+ end
15
+ end
16
+
17
+ def destroy
18
+ return unless @ptr
19
+ Driver.cuLinkDestroy(@ptr)
20
+ @ptr = nil
21
+ end
22
+
23
+ def add_ptr_data(data, name)
24
+ Driver.cuLinkAddData(@ptr, Driver::CU_JIT_INPUT_PTX, data, name)
25
+ end
26
+
27
+ def complete
28
+ cubin = Driver.cuLinkComplete(@ptr)
29
+ end
30
+ end
31
+ end
@@ -0,0 +1,40 @@
1
+ require_relative '../cuda'
2
+
3
+ module Cumo::CUDA
4
+ # CUDA kernel module.
5
+ class Module
6
+ def initialize
7
+ @ptr = nil
8
+ if block_given?
9
+ begin
10
+ yield(self)
11
+ ensure
12
+ unload
13
+ end
14
+ end
15
+ end
16
+
17
+ def unload
18
+ return unless @ptr
19
+ Driver.cuModuleUnload(@ptr)
20
+ @ptr = nil
21
+ end
22
+
23
+ def load_file(fname)
24
+ @ptr = Driver.cuModuleLoad(fname)
25
+ end
26
+
27
+ def load(cubin)
28
+ @ptr = Driver.cuModuleLoadData(cubin)
29
+ end
30
+
31
+ def get_global_var(name)
32
+ Driver.cuModuleGetGlobal(@ptr, name)
33
+ end
34
+
35
+ def get_function(name)
36
+ # Function(name)
37
+ end
38
+ end
39
+ end
40
+
@@ -0,0 +1,27 @@
1
+ require_relative '../cuda'
2
+ require_relative 'compile_error'
3
+
4
+ module Cumo::CUDA
5
+ class NVRTCProgram
6
+ def initialize(src, name: "default_program", headers: [], include_names: [])
7
+ @ptr = nil
8
+ @src = src # should be UTF-8
9
+ @name = name # should be UTF-8
10
+ @ptr = NVRTC.nvrtcCreateProgram(src, name, headers, include_names)
11
+ end
12
+
13
+ def destroy
14
+ NVRTC.nvrtcDestroyProgram(@ptr) if @ptr
15
+ end
16
+
17
+ def compile(options: [])
18
+ begin
19
+ NVRTC.nvrtcCompileProgram(@ptr, options)
20
+ return NVRTC.nvrtcGetPTX(@ptr)
21
+ rescue NVRTCError => e
22
+ log = NVRTC.nvrtcGetProgramLog(@ptr)
23
+ raise CompileError.new(log, @src, @name, options)
24
+ end
25
+ end
26
+ end
27
+ end
@@ -0,0 +1,12 @@
1
+ require 'cumo'
2
+
3
+ # Provide compatibility layers with numo/linalg
4
+ module Cumo
5
+ module Blas
6
+ class << self
7
+ def gemm(a, *args, **kwargs)
8
+ a.gemm(*args, **kwargs)
9
+ end
10
+ end
11
+ end
12
+ end
@@ -0,0 +1,2 @@
1
+ # This file is for compatibility with require 'numo/narray'
2
+ require_relative '../cumo'
@@ -0,0 +1,1278 @@
1
+ module Cumo
2
+ class NArray
3
+
4
+ # Return an unallocated array with the same shape and type as self.
5
+ def new_narray
6
+ self.class.new(*shape)
7
+ end
8
+
9
+ # Return an array of zeros with the same shape and type as self.
10
+ def new_zeros
11
+ self.class.zeros(*shape)
12
+ end
13
+
14
+ # Return an array of ones with the same shape and type as self.
15
+ def new_ones
16
+ self.class.ones(*shape)
17
+ end
18
+
19
+ # Return an array filled with value with the same shape and type as self.
20
+ def new_fill(value)
21
+ self.class.new(*shape).fill(value)
22
+ end
23
+
24
+ # Convert angles from radians to degrees.
25
+ def rad2deg
26
+ self * (180/Math::PI)
27
+ end
28
+
29
+ # Convert angles from degrees to radians.
30
+ def deg2rad
31
+ self * (Math::PI/180)
32
+ end
33
+
34
+ # Flip each row in the left/right direction.
35
+ # Same as `a[true, (-1..0).step(-1), ...]`.
36
+ def fliplr
37
+ reverse(1)
38
+ end
39
+
40
+ # Flip each column in the up/down direction.
41
+ # Same as `a[(-1..0).step(-1), ...]`.
42
+ def flipud
43
+ reverse(0)
44
+ end
45
+
46
+ # Multi-dimensional array indexing.
47
+ # Same as [] for one-dimensional NArray.
48
+ # Similar to numpy's tuple indexing, i.e., `a[[1,2,..],[3,4,..]]`
49
+ # (This method will be rewritten in C)
50
+ # @return [Cumo::NArray] one-dimensional view of self.
51
+ # @example
52
+ # p x = Cumo::DFloat.new(3,3,3).seq
53
+ # # Cumo::DFloat#shape=[3,3,3]
54
+ # # [[[0, 1, 2],
55
+ # # [3, 4, 5],
56
+ # # [6, 7, 8]],
57
+ # # [[9, 10, 11],
58
+ # # [12, 13, 14],
59
+ # # [15, 16, 17]],
60
+ # # [[18, 19, 20],
61
+ # # [21, 22, 23],
62
+ # # [24, 25, 26]]]
63
+ #
64
+ # p x.at([0,1,2],[0,1,2],[-1,-2,-3])
65
+ # # Cumo::DFloat(view)#shape=[3]
66
+ # # [2, 13, 24]
67
+ def at(*indices)
68
+ if indices.size != ndim
69
+ raise DimensionError, "argument length does not match dimension size"
70
+ end
71
+ idx = nil
72
+ stride = 1
73
+ (indices.size-1).downto(0) do |i|
74
+ ix = Int64.cast(indices[i])
75
+ if ix.ndim != 1
76
+ raise DimensionError, "index array is not one-dimensional"
77
+ end
78
+ ix[ix < 0] += shape[i]
79
+ if ((ix < 0) & (ix >= shape[i])).any?
80
+ raise IndexError, "index array is out of range"
81
+ end
82
+ if idx
83
+ if idx.size != ix.size
84
+ raise ShapeError, "index array sizes mismatch"
85
+ end
86
+ idx += ix * stride
87
+ stride *= shape[i]
88
+ else
89
+ idx = ix
90
+ stride = shape[i]
91
+ end
92
+ end
93
+ self[idx]
94
+ end
95
+
96
+ # Rotate in the plane specified by axes.
97
+ # @example
98
+ # p a = Cumo::Int32.new(2,2).seq
99
+ # # Cumo::Int32#shape=[2,2]
100
+ # # [[0, 1],
101
+ # # [2, 3]]
102
+ #
103
+ # p a.rot90
104
+ # # Cumo::Int32(view)#shape=[2,2]
105
+ # # [[1, 3],
106
+ # # [0, 2]]
107
+ #
108
+ # p a.rot90(2)
109
+ # # Cumo::Int32(view)#shape=[2,2]
110
+ # # [[3, 2],
111
+ # # [1, 0]]
112
+ #
113
+ # p a.rot90(3)
114
+ # # Cumo::Int32(view)#shape=[2,2]
115
+ # # [[2, 0],
116
+ # # [3, 1]]
117
+ def rot90(k=1,axes=[0,1])
118
+ case k % 4
119
+ when 0
120
+ view
121
+ when 1
122
+ swapaxes(*axes).reverse(axes[0])
123
+ when 2
124
+ reverse(*axes)
125
+ when 3
126
+ swapaxes(*axes).reverse(axes[1])
127
+ end
128
+ end
129
+
130
+ def to_i
131
+ if size==1
132
+ self.extract_cpu.to_i
133
+ else
134
+ # convert to Int?
135
+ raise TypeError, "can't convert #{self.class} into Integer"
136
+ end
137
+ end
138
+
139
+ def to_f
140
+ if size==1
141
+ self.extract_cpu.to_f
142
+ else
143
+ # convert to DFloat?
144
+ raise TypeError, "can't convert #{self.class} into Float"
145
+ end
146
+ end
147
+
148
+ def to_c
149
+ if size==1
150
+ Complex(self.extract_cpu)
151
+ else
152
+ # convert to DComplex?
153
+ raise TypeError, "can't convert #{self.class} into Complex"
154
+ end
155
+ end
156
+
157
+ # Convert the argument to an narray if not an narray.
158
+ def self.cast(a)
159
+ a.kind_of?(NArray) ? a : NArray.array_type(a).cast(a)
160
+ end
161
+
162
+ def self.asarray(a)
163
+ case a
164
+ when NArray
165
+ (a.ndim == 0) ? a[:new] : a
166
+ when Numeric,Range
167
+ self[a]
168
+ else
169
+ cast(a)
170
+ end
171
+ end
172
+
173
+ # parse matrix like matlab, octave
174
+ # @example
175
+ # a = Cumo::DFloat.parse %[
176
+ # 2 -3 5
177
+ # 4 9 7
178
+ # 2 -1 6
179
+ # ]
180
+ # => Cumo::DFloat#shape=[3,3]
181
+ # [[2, -3, 5],
182
+ # [4, 9, 7],
183
+ # [2, -1, 6]]
184
+
185
+ def self.parse(str, split1d:/\s+/, split2d:/;?$|;/,
186
+ split3d:/\s*\n(\s*\n)+/m)
187
+ a = []
188
+ str.split(split3d).each do |block|
189
+ b = []
190
+ #print "b"; p block
191
+ block.split(split2d).each do |line|
192
+ #p line
193
+ line.strip!
194
+ if !line.empty?
195
+ c = []
196
+ line.split(split1d).each do |item|
197
+ c << eval(item.strip) if !item.empty?
198
+ end
199
+ b << c if !c.empty?
200
+ end
201
+ end
202
+ a << b if !b.empty?
203
+ end
204
+ if a.size==1
205
+ self.cast(a[0])
206
+ else
207
+ self.cast(a)
208
+ end
209
+ end
210
+
211
+ # Append values to the end of an narray.
212
+ # @example
213
+ # a = Cumo::DFloat[1, 2, 3]
214
+ # p a.append([[4, 5, 6], [7, 8, 9]])
215
+ # # Cumo::DFloat#shape=[9]
216
+ # # [1, 2, 3, 4, 5, 6, 7, 8, 9]
217
+ #
218
+ # a = Cumo::DFloat[[1, 2, 3]]
219
+ # p a.append([[4, 5, 6], [7, 8, 9]],axis:0)
220
+ # # Cumo::DFloat#shape=[3,3]
221
+ # # [[1, 2, 3],
222
+ # # [4, 5, 6],
223
+ # # [7, 8, 9]]
224
+ #
225
+ # a = Cumo::DFloat[[1, 2, 3], [4, 5, 6]]
226
+ # p a.append([7, 8, 9], axis:0)
227
+ # # in `append': dimension mismatch (Cumo::NArray::DimensionError)
228
+
229
+ def append(other,axis:nil)
230
+ other = self.class.cast(other)
231
+ if axis
232
+ if ndim != other.ndim
233
+ raise DimensionError, "dimension mismatch"
234
+ end
235
+ return concatenate(other,axis:axis)
236
+ else
237
+ a = self.class.zeros(size+other.size)
238
+ a[0...size] = self[true]
239
+ a[size..-1] = other[true]
240
+ return a
241
+ end
242
+ end
243
+
244
+ # Return a new array with sub-arrays along an axis deleted.
245
+ # If axis is not given, obj is applied to the flattened array.
246
+
247
+ # @example
248
+ # a = Cumo::DFloat[[1,2,3,4], [5,6,7,8], [9,10,11,12]]
249
+ # p a.delete(1,0)
250
+ # # Cumo::DFloat(view)#shape=[2,4]
251
+ # # [[1, 2, 3, 4],
252
+ # # [9, 10, 11, 12]]
253
+ #
254
+ # p a.delete((0..-1).step(2),1)
255
+ # # Cumo::DFloat(view)#shape=[3,2]
256
+ # # [[2, 4],
257
+ # # [6, 8],
258
+ # # [10, 12]]
259
+ #
260
+ # p a.delete([1,3,5])
261
+ # # Cumo::DFloat(view)#shape=[9]
262
+ # # [1, 3, 5, 7, 8, 9, 10, 11, 12]
263
+
264
+ def delete(indice,axis=nil)
265
+ if axis
266
+ bit = Bit.ones(shape[axis])
267
+ bit[indice] = 0
268
+ idx = [true]*ndim
269
+ idx[axis] = bit.where
270
+ return self[*idx].copy
271
+ else
272
+ bit = Bit.ones(size)
273
+ bit[indice] = 0
274
+ return self[bit.where].copy
275
+ end
276
+ end
277
+
278
+ # Insert values along the axis before the indices.
279
+ # @example
280
+ # p a = Cumo::DFloat[[1, 2], [3, 4]]
281
+ # a = Cumo::Int32[[1, 1], [2, 2], [3, 3]]
282
+ #
283
+ # p a.insert(1,5)
284
+ # # Cumo::Int32#shape=[7]
285
+ # # [1, 5, 1, 2, 2, 3, 3]
286
+ #
287
+ # p a.insert(1, 5, axis:1)
288
+ # # Cumo::Int32#shape=[3,3]
289
+ # # [[1, 5, 1],
290
+ # # [2, 5, 2],
291
+ # # [3, 5, 3]]
292
+ #
293
+ # p a.insert([1], [[11],[12],[13]], axis:1)
294
+ # # Cumo::Int32#shape=[3,3]
295
+ # # [[1, 11, 1],
296
+ # # [2, 12, 2],
297
+ # # [3, 13, 3]]
298
+ #
299
+ # p a.insert(1, [11, 12, 13], axis:1)
300
+ # # Cumo::Int32#shape=[3,3]
301
+ # # [[1, 11, 1],
302
+ # # [2, 12, 2],
303
+ # # [3, 13, 3]]
304
+ #
305
+ # p a.insert([1], [11, 12, 13], axis:1)
306
+ # # Cumo::Int32#shape=[3,5]
307
+ # # [[1, 11, 12, 13, 1],
308
+ # # [2, 11, 12, 13, 2],
309
+ # # [3, 11, 12, 13, 3]]
310
+ #
311
+ # p b = a.flatten
312
+ # # Cumo::Int32(view)#shape=[6]
313
+ # # [1, 1, 2, 2, 3, 3]
314
+ #
315
+ # p b.insert(2,[15,16])
316
+ # # Cumo::Int32#shape=[8]
317
+ # # [1, 1, 15, 16, 2, 2, 3, 3]
318
+ #
319
+ # p b.insert([2,2],[15,16])
320
+ # # Cumo::Int32#shape=[8]
321
+ # # [1, 1, 15, 16, 2, 2, 3, 3]
322
+ #
323
+ # p b.insert([2,1],[15,16])
324
+ # # Cumo::Int32#shape=[8]
325
+ # # [1, 16, 1, 15, 2, 2, 3, 3]
326
+ #
327
+ # p b.insert([2,0,1],[15,16,17])
328
+ # # Cumo::Int32#shape=[9]
329
+ # # [16, 1, 17, 1, 15, 2, 2, 3, 3]
330
+ #
331
+ # p b.insert(2..3, [15, 16])
332
+ # # Cumo::Int32#shape=[8]
333
+ # # [1, 1, 15, 2, 16, 2, 3, 3]
334
+ #
335
+ # p b.insert(2, [7.13, 0.5])
336
+ # # Cumo::Int32#shape=[8]
337
+ # # [1, 1, 7, 0, 2, 2, 3, 3]
338
+ #
339
+ # p x = Cumo::DFloat.new(2,4).seq
340
+ # # Cumo::DFloat#shape=[2,4]
341
+ # # [[0, 1, 2, 3],
342
+ # # [4, 5, 6, 7]]
343
+ #
344
+ # p x.insert([1,3],999,axis:1)
345
+ # # Cumo::DFloat#shape=[2,6]
346
+ # # [[0, 999, 1, 2, 999, 3],
347
+ # # [4, 999, 5, 6, 999, 7]]
348
+
349
+ def insert(indice,values,axis:nil)
350
+ if axis
351
+ values = self.class.asarray(values)
352
+ nd = values.ndim
353
+ midx = [:new]*(ndim-nd) + [true]*nd
354
+ case indice
355
+ when Numeric
356
+ midx[-nd-1] = true
357
+ midx[axis] = :new
358
+ end
359
+ values = values[*midx]
360
+ else
361
+ values = self.class.asarray(values).flatten
362
+ end
363
+ idx = Int64.asarray(indice)
364
+ nidx = idx.size
365
+ if nidx == 1
366
+ nidx = values.shape[axis||0]
367
+ idx = idx + Int64.new(nidx).seq
368
+ else
369
+ sidx = idx.sort_index
370
+ idx[sidx] += Int64.new(nidx).seq
371
+ end
372
+ if axis
373
+ bit = Bit.ones(shape[axis]+nidx)
374
+ bit[idx] = 0
375
+ new_shape = shape
376
+ new_shape[axis] += nidx
377
+ a = self.class.zeros(new_shape)
378
+ mdidx = [true]*ndim
379
+ mdidx[axis] = bit.where
380
+ a[*mdidx] = self
381
+ mdidx[axis] = idx
382
+ a[*mdidx] = values
383
+ else
384
+ bit = Bit.ones(size+nidx)
385
+ bit[idx] = 0
386
+ a = self.class.zeros(size+nidx)
387
+ a[bit.where] = self.flatten
388
+ a[idx] = values
389
+ end
390
+ return a
391
+ end
392
+
393
+ class << self
394
+ # @example
395
+ # p a = Cumo::DFloat[[1, 2], [3, 4]]
396
+ # # Cumo::DFloat#shape=[2,2]
397
+ # # [[1, 2],
398
+ # # [3, 4]]
399
+ #
400
+ # p b = Cumo::DFloat[[5, 6]]
401
+ # # Cumo::DFloat#shape=[1,2]
402
+ # # [[5, 6]]
403
+ #
404
+ # p Cumo::NArray.concatenate([a,b],axis:0)
405
+ # # Cumo::DFloat#shape=[3,2]
406
+ # # [[1, 2],
407
+ # # [3, 4],
408
+ # # [5, 6]]
409
+ #
410
+ # p Cumo::NArray.concatenate([a,b.transpose], axis:1)
411
+ # # Cumo::DFloat#shape=[2,3]
412
+ # # [[1, 2, 5],
413
+ # # [3, 4, 6]]
414
+
415
+ def concatenate(arrays,axis:0)
416
+ klass = (self==NArray) ? NArray.array_type(arrays) : self
417
+ nd = 0
418
+ arrays = arrays.map do |a|
419
+ case a
420
+ when NArray
421
+ # ok
422
+ when Numeric
423
+ a = klass[a]
424
+ when Array
425
+ a = klass.cast(a)
426
+ else
427
+ raise TypeError,"not Cumo::NArray: #{a.inspect[0..48]}"
428
+ end
429
+ if a.ndim > nd
430
+ nd = a.ndim
431
+ end
432
+ a
433
+ end
434
+ if axis < 0
435
+ axis += nd
436
+ end
437
+ if axis < 0 || axis >= nd
438
+ raise ArgumentError,"axis is out of range"
439
+ end
440
+ new_shape = nil
441
+ sum_size = 0
442
+ arrays.each do |a|
443
+ a_shape = a.shape
444
+ if nd != a_shape.size
445
+ a_shape = [1]*(nd-a_shape.size) + a_shape
446
+ end
447
+ sum_size += a_shape.delete_at(axis)
448
+ if new_shape
449
+ if new_shape != a_shape
450
+ raise ShapeError,"shape mismatch"
451
+ end
452
+ else
453
+ new_shape = a_shape
454
+ end
455
+ end
456
+ new_shape.insert(axis,sum_size)
457
+ result = klass.zeros(*new_shape)
458
+ lst = 0
459
+ refs = [true] * nd
460
+ arrays.each do |a|
461
+ fst = lst
462
+ lst = fst + (a.shape[axis-nd]||1)
463
+ refs[axis] = fst...lst
464
+ result[*refs] = a
465
+ end
466
+ result
467
+ end
468
+
469
+ # Stack arrays vertically (row wise).
470
+ # @example
471
+ # a = Cumo::Int32[1,2,3]
472
+ # b = Cumo::Int32[2,3,4]
473
+ # p Cumo::NArray.vstack([a,b])
474
+ # # Cumo::Int32#shape=[2,3]
475
+ # # [[1, 2, 3],
476
+ # # [2, 3, 4]]
477
+ #
478
+ # a = Cumo::Int32[[1],[2],[3]]
479
+ # b = Cumo::Int32[[2],[3],[4]]
480
+ # p Cumo::NArray.vstack([a,b])
481
+ # # Cumo::Int32#shape=[6,1]
482
+ # # [[1],
483
+ # # [2],
484
+ # # [3],
485
+ # # [2],
486
+ # # [3],
487
+ # # [4]]
488
+
489
+ def vstack(arrays)
490
+ arys = arrays.map do |a|
491
+ _atleast_2d(cast(a))
492
+ end
493
+ concatenate(arys,axis:0)
494
+ end
495
+
496
+ # Stack arrays horizontally (column wise).
497
+ # @example
498
+ # a = Cumo::Int32[1,2,3]
499
+ # b = Cumo::Int32[2,3,4]
500
+ # p Cumo::NArray.hstack([a,b])
501
+ # # Cumo::Int32#shape=[6]
502
+ # # [1, 2, 3, 2, 3, 4]
503
+ #
504
+ # a = Cumo::Int32[[1],[2],[3]]
505
+ # b = Cumo::Int32[[2],[3],[4]]
506
+ # p Cumo::NArray.hstack([a,b])
507
+ # # Cumo::Int32#shape=[3,2]
508
+ # # [[1, 2],
509
+ # # [2, 3],
510
+ # # [3, 4]]
511
+
512
+ def hstack(arrays)
513
+ klass = (self==NArray) ? NArray.array_type(arrays) : self
514
+ nd = 0
515
+ arys = arrays.map do |a|
516
+ a = klass.cast(a)
517
+ nd = a.ndim if a.ndim > nd
518
+ a
519
+ end
520
+ dim = (nd >= 2) ? 1 : 0
521
+ concatenate(arys,axis:dim)
522
+ end
523
+
524
+ # Stack arrays in depth wise (along third axis).
525
+ # @example
526
+ # a = Cumo::Int32[1,2,3]
527
+ # b = Cumo::Int32[2,3,4]
528
+ # p Cumo::NArray.dstack([a,b])
529
+ # # Cumo::Int32#shape=[1,3,2]
530
+ # # [[[1, 2],
531
+ # # [2, 3],
532
+ # # [3, 4]]]
533
+ #
534
+ # a = Cumo::Int32[[1],[2],[3]]
535
+ # b = Cumo::Int32[[2],[3],[4]]
536
+ # p Cumo::NArray.dstack([a,b])
537
+ # # Cumo::Int32#shape=[3,1,2]
538
+ # # [[[1, 2]],
539
+ # # [[2, 3]],
540
+ # # [[3, 4]]]
541
+
542
+ def dstack(arrays)
543
+ arys = arrays.map do |a|
544
+ _atleast_3d(cast(a))
545
+ end
546
+ concatenate(arys,axis:2)
547
+ end
548
+
549
+ # Stack 1-d arrays into columns of a 2-d array.
550
+ # @example
551
+ # x = Cumo::Int32[1,2,3]
552
+ # y = Cumo::Int32[2,3,4]
553
+ # p Cumo::NArray.column_stack([x,y])
554
+ # # Cumo::Int32#shape=[3,2]
555
+ # # [[1, 2],
556
+ # # [2, 3],
557
+ # # [3, 4]]
558
+
559
+ def column_stack(arrays)
560
+ arys = arrays.map do |a|
561
+ a = cast(a)
562
+ case a.ndim
563
+ when 0; a[:new,:new]
564
+ when 1; a[true,:new]
565
+ else; a
566
+ end
567
+ end
568
+ concatenate(arys,axis:1)
569
+ end
570
+
571
+ private
572
+ # Return an narray with at least two dimension.
573
+ def _atleast_2d(a)
574
+ case a.ndim
575
+ when 0; a[:new,:new]
576
+ when 1; a[:new,true]
577
+ else; a
578
+ end
579
+ end
580
+
581
+ # Return an narray with at least three dimension.
582
+ def _atleast_3d(a)
583
+ case a.ndim
584
+ when 0; a[:new,:new,:new]
585
+ when 1; a[:new,true,:new]
586
+ when 2; a[true,true,:new]
587
+ else; a
588
+ end
589
+ end
590
+
591
+ end # class << self
592
+
593
+ # @example
594
+ # p a = Cumo::DFloat[[1, 2], [3, 4]]
595
+ # # Cumo::DFloat#shape=[2,2]
596
+ # # [[1, 2],
597
+ # # [3, 4]]
598
+ #
599
+ # p b = Cumo::DFloat[[5, 6]]
600
+ # # Cumo::DFloat#shape=[1,2]
601
+ # # [[5, 6]]
602
+ #
603
+ # p a.concatenate(b,axis:0)
604
+ # # Cumo::DFloat#shape=[3,2]
605
+ # # [[1, 2],
606
+ # # [3, 4],
607
+ # # [5, 6]]
608
+ #
609
+ # p a.concatenate(b.transpose, axis:1)
610
+ # # Cumo::DFloat#shape=[2,3]
611
+ # # [[1, 2, 5],
612
+ # # [3, 4, 6]]
613
+
614
+ def concatenate(*arrays,axis:0)
615
+ axis = check_axis(axis)
616
+ self_shape = shape
617
+ self_shape.delete_at(axis)
618
+ sum_size = shape[axis]
619
+ arrays.map! do |a|
620
+ case a
621
+ when NArray
622
+ # ok
623
+ when Numeric
624
+ a = self.class.new(1).store(a)
625
+ when Array
626
+ a = self.class.cast(a)
627
+ else
628
+ raise TypeError,"not Cumo::NArray: #{a.inspect[0..48]}"
629
+ end
630
+ if a.ndim > ndim
631
+ raise ShapeError,"dimension mismatch"
632
+ end
633
+ a_shape = a.shape
634
+ sum_size += a_shape.delete_at(axis-ndim) || 1
635
+ if self_shape != a_shape
636
+ raise ShapeError,"shape mismatch"
637
+ end
638
+ a
639
+ end
640
+ self_shape.insert(axis,sum_size)
641
+ result = self.class.zeros(*self_shape)
642
+ lst = shape[axis]
643
+ refs = [true] * ndim
644
+ refs[axis] = 0...lst
645
+ result[*refs] = self
646
+ arrays.each do |a|
647
+ fst = lst
648
+ lst = fst + (a.shape[axis-ndim] || 1)
649
+ refs[axis] = fst...lst
650
+ result[*refs] = a
651
+ end
652
+ result
653
+ end
654
+
655
+ # @example
656
+ # p x = Cumo::DFloat.new(9).seq
657
+ # # Cumo::DFloat#shape=[9]
658
+ # # [0, 1, 2, 3, 4, 5, 6, 7, 8]
659
+ #
660
+ # pp x.split(3)
661
+ # # [Cumo::DFloat(view)#shape=[3]
662
+ # # [0, 1, 2],
663
+ # # Cumo::DFloat(view)#shape=[3]
664
+ # # [3, 4, 5],
665
+ # # Cumo::DFloat(view)#shape=[3]
666
+ # # [6, 7, 8]]
667
+ #
668
+ # p x = Cumo::DFloat.new(8).seq
669
+ # # Cumo::DFloat#shape=[8]
670
+ # # [0, 1, 2, 3, 4, 5, 6, 7]
671
+ #
672
+ # pp x.split([3, 5, 6, 10])
673
+ # # [Cumo::DFloat(view)#shape=[3]
674
+ # # [0, 1, 2],
675
+ # # Cumo::DFloat(view)#shape=[2]
676
+ # # [3, 4],
677
+ # # Cumo::DFloat(view)#shape=[1]
678
+ # # [5],
679
+ # # Cumo::DFloat(view)#shape=[2]
680
+ # # [6, 7],
681
+ # # Cumo::DFloat(view)#shape=[0][]]
682
+
683
+ def split(indices_or_sections, axis:0)
684
+ axis = check_axis(axis)
685
+ size_axis = shape[axis]
686
+ case indices_or_sections
687
+ when Integer
688
+ div_axis, mod_axis = size_axis.divmod(indices_or_sections)
689
+ refs = [true]*ndim
690
+ beg_idx = 0
691
+ mod_axis.times.map do |i|
692
+ end_idx = beg_idx + div_axis + 1
693
+ refs[axis] = beg_idx ... end_idx
694
+ beg_idx = end_idx
695
+ self[*refs]
696
+ end +
697
+ (indices_or_sections-mod_axis).times.map do |i|
698
+ end_idx = beg_idx + div_axis
699
+ refs[axis] = beg_idx ... end_idx
700
+ beg_idx = end_idx
701
+ self[*refs]
702
+ end
703
+ when NArray
704
+ split(indices_or_sections.to_a,axis:axis)
705
+ when Array
706
+ refs = [true]*ndim
707
+ fst = 0
708
+ (indices_or_sections + [size_axis]).map do |lst|
709
+ lst = size_axis if lst > size_axis
710
+ refs[axis] = (fst < size_axis) ? fst...lst : -1...-1
711
+ fst = lst
712
+ self[*refs]
713
+ end
714
+ else
715
+ raise TypeError,"argument must be Integer or Array"
716
+ end
717
+ end
718
+
719
+ # @example
720
+ # p x = Cumo::DFloat.new(4,4).seq
721
+ # # Cumo::DFloat#shape=[4,4]
722
+ # # [[0, 1, 2, 3],
723
+ # # [4, 5, 6, 7],
724
+ # # [8, 9, 10, 11],
725
+ # # [12, 13, 14, 15]]
726
+ #
727
+ # pp x.hsplit(2)
728
+ # # [Cumo::DFloat(view)#shape=[4,2]
729
+ # # [[0, 1],
730
+ # # [4, 5],
731
+ # # [8, 9],
732
+ # # [12, 13]],
733
+ # # Cumo::DFloat(view)#shape=[4,2]
734
+ # # [[2, 3],
735
+ # # [6, 7],
736
+ # # [10, 11],
737
+ # # [14, 15]]]
738
+ #
739
+ # pp x.hsplit([3, 6])
740
+ # # [Cumo::DFloat(view)#shape=[4,3]
741
+ # # [[0, 1, 2],
742
+ # # [4, 5, 6],
743
+ # # [8, 9, 10],
744
+ # # [12, 13, 14]],
745
+ # # Cumo::DFloat(view)#shape=[4,1]
746
+ # # [[3],
747
+ # # [7],
748
+ # # [11],
749
+ # # [15]],
750
+ # # Cumo::DFloat(view)#shape=[4,0][]]
751
+
752
+ def vsplit(indices_or_sections)
753
+ split(indices_or_sections, axis:0)
754
+ end
755
+
756
+ def hsplit(indices_or_sections)
757
+ split(indices_or_sections, axis:1)
758
+ end
759
+
760
+ def dsplit(indices_or_sections)
761
+ split(indices_or_sections, axis:2)
762
+ end
763
+
764
+ # @example
765
+ # p a = Cumo::NArray[0,1,2]
766
+ # # Cumo::Int32#shape=[3]
767
+ # # [0, 1, 2]
768
+ #
769
+ # p a.tile(2)
770
+ # # Cumo::Int32#shape=[6]
771
+ # # [0, 1, 2, 0, 1, 2]
772
+ #
773
+ # p a.tile(2,2)
774
+ # # Cumo::Int32#shape=[2,6]
775
+ # # [[0, 1, 2, 0, 1, 2],
776
+ # # [0, 1, 2, 0, 1, 2]]
777
+ #
778
+ # p a.tile(2,1,2)
779
+ # # Cumo::Int32#shape=[2,1,6]
780
+ # # [[[0, 1, 2, 0, 1, 2]],
781
+ # # [[0, 1, 2, 0, 1, 2]]]
782
+ #
783
+ # p b = Cumo::NArray[[1, 2], [3, 4]]
784
+ # # Cumo::Int32#shape=[2,2]
785
+ # # [[1, 2],
786
+ # # [3, 4]]
787
+ #
788
+ # p b.tile(2)
789
+ # # Cumo::Int32#shape=[2,4]
790
+ # # [[1, 2, 1, 2],
791
+ # # [3, 4, 3, 4]]
792
+ #
793
+ # p b.tile(2,1)
794
+ # # Cumo::Int32#shape=[4,2]
795
+ # # [[1, 2],
796
+ # # [3, 4],
797
+ # # [1, 2],
798
+ # # [3, 4]]
799
+ #
800
+ # p c = Cumo::NArray[1,2,3,4]
801
+ # # Cumo::Int32#shape=[4]
802
+ # # [1, 2, 3, 4]
803
+ #
804
+ # p c.tile(4,1)
805
+ # # Cumo::Int32#shape=[4,4]
806
+ # # [[1, 2, 3, 4],
807
+ # # [1, 2, 3, 4],
808
+ # # [1, 2, 3, 4],
809
+ # # [1, 2, 3, 4]]
810
+
811
+ def tile(*arg)
812
+ arg.each do |i|
813
+ if !i.kind_of?(Integer) || i<1
814
+ raise ArgumentError,"argument should be positive integer"
815
+ end
816
+ end
817
+ ns = arg.size
818
+ nd = self.ndim
819
+ shp = self.shape
820
+ new_shp = []
821
+ src_shp = []
822
+ res_shp = []
823
+ (nd-ns).times do
824
+ new_shp << 1
825
+ new_shp << (n = shp.shift)
826
+ src_shp << :new
827
+ src_shp << true
828
+ res_shp << n
829
+ end
830
+ (ns-nd).times do
831
+ new_shp << (m = arg.shift)
832
+ new_shp << 1
833
+ src_shp << :new
834
+ src_shp << :new
835
+ res_shp << m
836
+ end
837
+ [nd,ns].min.times do
838
+ new_shp << (m = arg.shift)
839
+ new_shp << (n = shp.shift)
840
+ src_shp << :new
841
+ src_shp << true
842
+ res_shp << n*m
843
+ end
844
+ self.class.new(*new_shp).store(self[*src_shp]).reshape(*res_shp)
845
+ end
846
+
847
+ # @example
848
+ # p Cumo::NArray[3].repeat(4)
849
+ # # Cumo::Int32#shape=[4]
850
+ # # [3, 3, 3, 3]
851
+ #
852
+ # p x = Cumo::NArray[[1,2],[3,4]]
853
+ # # Cumo::Int32#shape=[2,2]
854
+ # # [[1, 2],
855
+ # # [3, 4]]
856
+ #
857
+ # p x.repeat(2)
858
+ # # Cumo::Int32#shape=[8]
859
+ # # [1, 1, 2, 2, 3, 3, 4, 4]
860
+ #
861
+ # p x.repeat(3,axis:1)
862
+ # # Cumo::Int32#shape=[2,6]
863
+ # # [[1, 1, 1, 2, 2, 2],
864
+ # # [3, 3, 3, 4, 4, 4]]
865
+ #
866
+ # p x.repeat([1,2],axis:0)
867
+ # # Cumo::Int32#shape=[3,2]
868
+ # # [[1, 2],
869
+ # # [3, 4],
870
+ # # [3, 4]]
871
+
872
+ def repeat(arg,axis:nil)
873
+ case axis
874
+ when Integer
875
+ axis = check_axis(axis)
876
+ c = self
877
+ when NilClass
878
+ c = self.flatten
879
+ axis = 0
880
+ else
881
+ raise ArgumentError,"invalid axis"
882
+ end
883
+ case arg
884
+ when Integer
885
+ if !arg.kind_of?(Integer) || arg<1
886
+ raise ArgumentError,"argument should be positive integer"
887
+ end
888
+ idx = c.shape[axis].times.map{|i| [i]*arg}.flatten
889
+ else
890
+ arg = arg.to_a
891
+ if arg.size != c.shape[axis]
892
+ raise ArgumentError,"repeat size shoud be equal to size along axis"
893
+ end
894
+ arg.each do |i|
895
+ if !i.kind_of?(Integer) || i<0
896
+ raise ArgumentError,"argument should be non-negative integer"
897
+ end
898
+ end
899
+ idx = arg.each_with_index.map{|a,i| [i]*a}.flatten
900
+ end
901
+ ref = [true] * c.ndim
902
+ ref[axis] = idx
903
+ c[*ref].copy
904
+ end
905
+
906
+ # Calculate the n-th discrete difference along given axis.
907
+ # @example
908
+ # p x = Cumo::DFloat[1, 2, 4, 7, 0]
909
+ # # Cumo::DFloat#shape=[5]
910
+ # # [1, 2, 4, 7, 0]
911
+ #
912
+ # p x.diff
913
+ # # Cumo::DFloat#shape=[4]
914
+ # # [1, 2, 3, -7]
915
+ #
916
+ # p x.diff(2)
917
+ # # Cumo::DFloat#shape=[3]
918
+ # # [1, 1, -10]
919
+ #
920
+ # p x = Cumo::DFloat[[1, 3, 6, 10], [0, 5, 6, 8]]
921
+ # # Cumo::DFloat#shape=[2,4]
922
+ # # [[1, 3, 6, 10],
923
+ # # [0, 5, 6, 8]]
924
+ #
925
+ # p x.diff
926
+ # # Cumo::DFloat#shape=[2,3]
927
+ # # [[2, 3, 4],
928
+ # # [5, 1, 2]]
929
+ #
930
+ # p x.diff(axis:0)
931
+ # # Cumo::DFloat#shape=[1,4]
932
+ # # [[-1, 2, 0, -2]]
933
+
934
+ def diff(n=1,axis:-1)
935
+ axis = check_axis(axis)
936
+ if n < 0 || n >= shape[axis]
937
+ raise ShapeError,"n=#{n} is invalid for shape[#{axis}]=#{shape[axis]}"
938
+ end
939
+ # calculate polynomial coefficient
940
+ c = self.class[-1,1]
941
+ 2.upto(n) do |i|
942
+ x = self.class.zeros(i+1)
943
+ x[0..-2] = c
944
+ y = self.class.zeros(i+1)
945
+ y[1..-1] = c
946
+ c = y - x
947
+ end
948
+ s = [true]*ndim
949
+ s[axis] = n..-1
950
+ result = self[*s].dup
951
+ sum = result.inplace
952
+ (n-1).downto(0) do |i|
953
+ s = [true]*ndim
954
+ s[axis] = i..-n-1+i
955
+ sum + self[*s] * c[i] # inplace addition
956
+ end
957
+ return result
958
+ end
959
+
960
+
961
+ # Upper triangular matrix.
962
+ # Return a copy with the elements below the k-th diagonal filled with zero.
963
+ def triu(k=0)
964
+ dup.triu!(k)
965
+ end
966
+
967
+ # Upper triangular matrix.
968
+ # Fill the self elements below the k-th diagonal with zero.
969
+ def triu!(k=0)
970
+ if ndim < 2
971
+ raise NArray::ShapeError, "must be >= 2-dimensional array"
972
+ end
973
+ if contiguous?
974
+ *shp,m,n = shape
975
+ idx = tril_indices(k-1)
976
+ reshape!(*shp,m*n)
977
+ self[false,idx] = 0
978
+ reshape!(*shp,m,n)
979
+ else
980
+ store(triu(k))
981
+ end
982
+ end
983
+
984
+ # Return the indices for the uppler-triangle on and above the k-th diagonal.
985
+ def triu_indices(k=0)
986
+ if ndim < 2
987
+ raise NArray::ShapeError, "must be >= 2-dimensional array"
988
+ end
989
+ m,n = shape[-2..-1]
990
+ NArray.triu_indices(m,n,k=0)
991
+ end
992
+
993
+ # Return the indices for the uppler-triangle on and above the k-th diagonal.
994
+ def self.triu_indices(m,n,k=0)
995
+ x = Cumo::Int64.new(m,1).seq + k
996
+ y = Cumo::Int64.new(1,n).seq
997
+ (x<=y).where
998
+ end
999
+
1000
+ # Lower triangular matrix.
1001
+ # Return a copy with the elements above the k-th diagonal filled with zero.
1002
+ def tril(k=0)
1003
+ dup.tril!(k)
1004
+ end
1005
+
1006
+ # Lower triangular matrix.
1007
+ # Fill the self elements above the k-th diagonal with zero.
1008
+ def tril!(k=0)
1009
+ if ndim < 2
1010
+ raise NArray::ShapeError, "must be >= 2-dimensional array"
1011
+ end
1012
+ if contiguous?
1013
+ idx = triu_indices(k+1)
1014
+ *shp,m,n = shape
1015
+ reshape!(*shp,m*n)
1016
+ self[false,idx] = 0
1017
+ reshape!(*shp,m,n)
1018
+ else
1019
+ store(tril(k))
1020
+ end
1021
+ end
1022
+
1023
+ # Return the indices for the lower-triangle on and below the k-th diagonal.
1024
+ def tril_indices(k=0)
1025
+ if ndim < 2
1026
+ raise NArray::ShapeError, "must be >= 2-dimensional array"
1027
+ end
1028
+ m,n = shape[-2..-1]
1029
+ NArray.tril_indices(m,n,k)
1030
+ end
1031
+
1032
+ # Return the indices for the lower-triangle on and below the k-th diagonal.
1033
+ def self.tril_indices(m,n,k=0)
1034
+ x = Cumo::Int64.new(m,1).seq + k
1035
+ y = Cumo::Int64.new(1,n).seq
1036
+ (x>=y).where
1037
+ end
1038
+
1039
+ # Return the k-th diagonal indices.
1040
+ def diag_indices(k=0)
1041
+ if ndim < 2
1042
+ raise NArray::ShapeError, "must be >= 2-dimensional array"
1043
+ end
1044
+ m,n = shape[-2..-1]
1045
+ NArray.diag_indices(m,n,k)
1046
+ end
1047
+
1048
+ # Return the k-th diagonal indices.
1049
+ def self.diag_indices(m,n,k=0)
1050
+ x = Cumo::Int64.new(m,1).seq + k
1051
+ y = Cumo::Int64.new(1,n).seq
1052
+ (x.eq y).where
1053
+ end
1054
+
1055
+ # Return a matrix whose diagonal is constructed by self along the last axis.
1056
+ def diag(k=0)
1057
+ *shp,n = shape
1058
+ n += k.abs
1059
+ a = self.class.zeros(*shp,n,n)
1060
+ a.diagonal(k).store(self)
1061
+ a
1062
+ end
1063
+
1064
+ # Return the sum along diagonals of the array.
1065
+ #
1066
+ # If 2-D array, computes the summation along its diagonal with the
1067
+ # given offset, i.e., sum of `a[i,i+offset]`.
1068
+ # If more than 2-D array, the diagonal is determined from the axes
1069
+ # specified by axis argument. The default is axis=[-2,-1].
1070
+ # @param offset [Integer] (optional, default=0) diagonal offset
1071
+ # @param axis [Array] (optional, default=[-2,-1]) diagonal axis
1072
+ # @param nan [Bool] (optional, default=false) nan-aware algorithm, i.e., if true then it ignores nan.
1073
+
1074
+ def trace(offset=nil,axis=nil,nan:false)
1075
+ diagonal(offset,axis).sum(nan:nan,axis:-1)
1076
+ end
1077
+
1078
+
1079
+ @@warn_slow_dot = false
1080
+
1081
+ # Dot product of two arrays.
1082
+ # @param b [Cumo::NArray]
1083
+ # @return [Cumo::NArray] return dot product
1084
+
1085
+ def dot(b)
1086
+ t = self.class::UPCAST[b.class]
1087
+ if [SFloat, DFloat, SComplex, DComplex].include?(t)
1088
+ b = self.class.asarray(b)
1089
+ case self.ndim
1090
+ when 1
1091
+ case b.ndim
1092
+ when 1
1093
+ self.mulsum(b, axis:-1)
1094
+ else
1095
+ self[:new, false].gemm(b).flatten
1096
+ end
1097
+ else
1098
+ case b.ndim
1099
+ when 1
1100
+ self.gemm(b[false, :new]).flatten
1101
+ else
1102
+ self.gemm(b)
1103
+ end
1104
+ end
1105
+ else
1106
+ b = self.class.asarray(b)
1107
+ case b.ndim
1108
+ when 1
1109
+ mulsum(b, axis:-1)
1110
+ else
1111
+ case ndim
1112
+ when 0
1113
+ b.mulsum(self, axis:-2)
1114
+ when 1
1115
+ self[true,:new].mulsum(b, axis:-2)
1116
+ else
1117
+ unless @@warn_slow_dot
1118
+ nx = 200
1119
+ ns = 200000
1120
+ am,an = shape[-2..-1]
1121
+ bm,bn = b.shape[-2..-1]
1122
+ if am > nx && an > nx && bm > nx && bn > nx &&
1123
+ size > ns && b.size > ns
1124
+ @@warn_slow_dot = true
1125
+ warn "\nwarning: matrix dot for #{t} is slow. Consider SFloat, DFloat, SComplex, or DComplex to use cuBLAS.\n\n"
1126
+ end
1127
+ end
1128
+ self[false,:new].mulsum(b[false,:new,true,true], axis:-2)
1129
+ end
1130
+ end
1131
+ end
1132
+ end
1133
+
1134
+ # Inner product of two arrays.
1135
+ # Same as `(a*b).sum(axis:-1)`.
1136
+ # @param b [Cumo::NArray]
1137
+ # @param axis [Integer] applied axis
1138
+ # @return [Cumo::NArray] return (a*b).sum(axis:axis)
1139
+
1140
+ def inner(b, axis:-1)
1141
+ mulsum(b, axis:axis)
1142
+ end
1143
+
1144
+ # Outer product of two arrays.
1145
+ # Same as `self[false,:new] * b[false,:new,true]`.
1146
+ #
1147
+ # @param b [Cumo::NArray]
1148
+ # @param axis [Integer] applied axis (default=-1)
1149
+ # @return [Cumo::NArray] return outer product
1150
+ # @example
1151
+ # a = Cumo::DFloat.ones(5)
1152
+ # => Cumo::DFloat#shape=[5]
1153
+ # [1, 1, 1, 1, 1]
1154
+ # b = Cumo::DFloat.linspace(-2,2,5)
1155
+ # => Cumo::DFloat#shape=[5]
1156
+ # [-2, -1, 0, 1, 2]
1157
+ # a.outer(b)
1158
+ # => Cumo::DFloat#shape=[5,5]
1159
+ # [[-2, -1, 0, 1, 2],
1160
+ # [-2, -1, 0, 1, 2],
1161
+ # [-2, -1, 0, 1, 2],
1162
+ # [-2, -1, 0, 1, 2],
1163
+ # [-2, -1, 0, 1, 2]]
1164
+
1165
+ def outer(b, axis:nil)
1166
+ b = NArray.cast(b)
1167
+ if axis.nil?
1168
+ self[false,:new] * ((b.ndim==0) ? b : b[false,:new,true])
1169
+ else
1170
+ md,nd = [ndim,b.ndim].minmax
1171
+ axis = check_axis(axis) - nd
1172
+ if axis < -md
1173
+ raise ArgumentError,"axis=#{axis} is out of range"
1174
+ end
1175
+ adim = [true]*ndim
1176
+ adim[axis+ndim+1,0] = :new
1177
+ bdim = [true]*b.ndim
1178
+ bdim[axis+b.ndim,0] = :new
1179
+ self[*adim] * b[*bdim]
1180
+ end
1181
+ end
1182
+
1183
+ # Kronecker product of two arrays.
1184
+ #
1185
+ # kron(a,b)[k_0, k_1, ...] = a[i_0, i_1, ...] * b[j_0, j_1, ...]
1186
+ # where: k_n = i_n * b.shape[n] + j_n
1187
+ #
1188
+ # @param b [Cumo::NArray]
1189
+ # @return [Cumo::NArray] return Kronecker product
1190
+ # @example
1191
+ # Cumo::DFloat[1,10,100].kron([5,6,7])
1192
+ # => Cumo::DFloat#shape=[9]
1193
+ # [5, 6, 7, 50, 60, 70, 500, 600, 700]
1194
+ # Cumo::DFloat[5,6,7].kron([1,10,100])
1195
+ # => Cumo::DFloat#shape=[9]
1196
+ # [5, 50, 500, 6, 60, 600, 7, 70, 700]
1197
+ # Cumo::DFloat.eye(2).kron(Cumo::DFloat.ones(2,2))
1198
+ # => Cumo::DFloat#shape=[4,4]
1199
+ # [[1, 1, 0, 0],
1200
+ # [1, 1, 0, 0],
1201
+ # [0, 0, 1, 1],
1202
+ # [0, 0, 1, 1]]
1203
+
1204
+ def kron(b)
1205
+ b = NArray.cast(b)
1206
+ nda = ndim
1207
+ ndb = b.ndim
1208
+ shpa = shape
1209
+ shpb = b.shape
1210
+ adim = [:new]*(2*[ndb-nda,0].max) + [true,:new]*nda
1211
+ bdim = [:new]*(2*[nda-ndb,0].max) + [:new,true]*ndb
1212
+ shpr = (-[nda,ndb].max..-1).map{|i| (shpa[i]||1) * (shpb[i]||1)}
1213
+ (self[*adim] * b[*bdim]).reshape(*shpr)
1214
+ end
1215
+
1216
+
1217
+ # under construction
1218
+ def cov(y=nil, ddof:1, fweights:nil, aweights:nil)
1219
+ if y
1220
+ m = NArray.vstack([self,y])
1221
+ else
1222
+ m = self
1223
+ end
1224
+ w = nil
1225
+ if fweights
1226
+ f = fweights
1227
+ w = f
1228
+ end
1229
+ if aweights
1230
+ a = aweights
1231
+ w = w ? w*a : a
1232
+ end
1233
+ if w
1234
+ w_sum = w.sum(axis:-1, keepdims:true)
1235
+ if ddof == 0
1236
+ fact = w_sum
1237
+ elsif aweights.nil?
1238
+ fact = w_sum - ddof
1239
+ else
1240
+ wa_sum = (w*a).sum(axis:-1, keepdims:true)
1241
+ fact = w_sum - ddof * wa_sum / w_sum
1242
+ end
1243
+ if (fact <= 0).any?
1244
+ raise StandardError,"Degrees of freedom <= 0 for slice"
1245
+ end
1246
+ else
1247
+ fact = m.shape[-1] - ddof
1248
+ end
1249
+ if w
1250
+ m -= (m*w).sum(axis:-1, keepdims:true) / w_sum
1251
+ mw = m*w
1252
+ else
1253
+ m -= m.mean(axis:-1, keepdims:true)
1254
+ mw = m
1255
+ end
1256
+ mt = (m.ndim < 2) ? m : m.swapaxes(-2,-1)
1257
+ mw.dot(mt.conj) / fact
1258
+ end
1259
+
1260
+ private
1261
+
1262
+ # @!visibility private
1263
+ def check_axis(axis)
1264
+ unless Integer===axis
1265
+ raise ArgumentError,"axis=#{axis} must be Integer"
1266
+ end
1267
+ a = axis
1268
+ if a < 0
1269
+ a += ndim
1270
+ end
1271
+ if a < 0 || a >= ndim
1272
+ raise ArgumentError,"axis=#{axis} is invalid"
1273
+ end
1274
+ a
1275
+ end
1276
+
1277
+ end
1278
+ end