cumo 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (266) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +27 -0
  3. data/.travis.yml +5 -0
  4. data/3rd_party/mkmf-cu/.gitignore +36 -0
  5. data/3rd_party/mkmf-cu/Gemfile +3 -0
  6. data/3rd_party/mkmf-cu/LICENSE +21 -0
  7. data/3rd_party/mkmf-cu/README.md +36 -0
  8. data/3rd_party/mkmf-cu/Rakefile +11 -0
  9. data/3rd_party/mkmf-cu/bin/mkmf-cu-nvcc +4 -0
  10. data/3rd_party/mkmf-cu/lib/mkmf-cu.rb +32 -0
  11. data/3rd_party/mkmf-cu/lib/mkmf-cu/cli.rb +80 -0
  12. data/3rd_party/mkmf-cu/lib/mkmf-cu/nvcc.rb +157 -0
  13. data/3rd_party/mkmf-cu/mkmf-cu.gemspec +16 -0
  14. data/3rd_party/mkmf-cu/test/test_mkmf-cu.rb +67 -0
  15. data/CODE_OF_CONDUCT.md +46 -0
  16. data/Gemfile +8 -0
  17. data/LICENSE.txt +82 -0
  18. data/README.md +252 -0
  19. data/Rakefile +43 -0
  20. data/bench/broadcast_fp32.rb +138 -0
  21. data/bench/cumo_bench.rb +193 -0
  22. data/bench/numo_bench.rb +138 -0
  23. data/bench/reduction_fp32.rb +117 -0
  24. data/bin/console +14 -0
  25. data/bin/setup +8 -0
  26. data/cumo.gemspec +32 -0
  27. data/ext/cumo/cuda/cublas.c +278 -0
  28. data/ext/cumo/cuda/driver.c +421 -0
  29. data/ext/cumo/cuda/memory_pool.cpp +185 -0
  30. data/ext/cumo/cuda/memory_pool_impl.cpp +308 -0
  31. data/ext/cumo/cuda/memory_pool_impl.hpp +370 -0
  32. data/ext/cumo/cuda/memory_pool_impl_test.cpp +554 -0
  33. data/ext/cumo/cuda/nvrtc.c +207 -0
  34. data/ext/cumo/cuda/runtime.c +167 -0
  35. data/ext/cumo/cumo.c +148 -0
  36. data/ext/cumo/depend.erb +58 -0
  37. data/ext/cumo/extconf.rb +179 -0
  38. data/ext/cumo/include/cumo.h +25 -0
  39. data/ext/cumo/include/cumo/compat.h +23 -0
  40. data/ext/cumo/include/cumo/cuda/cublas.h +153 -0
  41. data/ext/cumo/include/cumo/cuda/cumo_thrust.hpp +187 -0
  42. data/ext/cumo/include/cumo/cuda/cumo_thrust_complex.hpp +79 -0
  43. data/ext/cumo/include/cumo/cuda/driver.h +22 -0
  44. data/ext/cumo/include/cumo/cuda/memory_pool.h +28 -0
  45. data/ext/cumo/include/cumo/cuda/nvrtc.h +22 -0
  46. data/ext/cumo/include/cumo/cuda/runtime.h +40 -0
  47. data/ext/cumo/include/cumo/indexer.h +238 -0
  48. data/ext/cumo/include/cumo/intern.h +142 -0
  49. data/ext/cumo/include/cumo/intern_fwd.h +38 -0
  50. data/ext/cumo/include/cumo/intern_kernel.h +6 -0
  51. data/ext/cumo/include/cumo/narray.h +429 -0
  52. data/ext/cumo/include/cumo/narray_kernel.h +149 -0
  53. data/ext/cumo/include/cumo/ndloop.h +95 -0
  54. data/ext/cumo/include/cumo/reduce_kernel.h +126 -0
  55. data/ext/cumo/include/cumo/template.h +158 -0
  56. data/ext/cumo/include/cumo/template_kernel.h +77 -0
  57. data/ext/cumo/include/cumo/types/bit.h +40 -0
  58. data/ext/cumo/include/cumo/types/bit_kernel.h +34 -0
  59. data/ext/cumo/include/cumo/types/complex.h +402 -0
  60. data/ext/cumo/include/cumo/types/complex_kernel.h +414 -0
  61. data/ext/cumo/include/cumo/types/complex_macro.h +382 -0
  62. data/ext/cumo/include/cumo/types/complex_macro_kernel.h +186 -0
  63. data/ext/cumo/include/cumo/types/dcomplex.h +46 -0
  64. data/ext/cumo/include/cumo/types/dcomplex_kernel.h +13 -0
  65. data/ext/cumo/include/cumo/types/dfloat.h +47 -0
  66. data/ext/cumo/include/cumo/types/dfloat_kernel.h +14 -0
  67. data/ext/cumo/include/cumo/types/float_def.h +34 -0
  68. data/ext/cumo/include/cumo/types/float_def_kernel.h +39 -0
  69. data/ext/cumo/include/cumo/types/float_macro.h +191 -0
  70. data/ext/cumo/include/cumo/types/float_macro_kernel.h +158 -0
  71. data/ext/cumo/include/cumo/types/int16.h +24 -0
  72. data/ext/cumo/include/cumo/types/int16_kernel.h +23 -0
  73. data/ext/cumo/include/cumo/types/int32.h +24 -0
  74. data/ext/cumo/include/cumo/types/int32_kernel.h +19 -0
  75. data/ext/cumo/include/cumo/types/int64.h +24 -0
  76. data/ext/cumo/include/cumo/types/int64_kernel.h +19 -0
  77. data/ext/cumo/include/cumo/types/int8.h +24 -0
  78. data/ext/cumo/include/cumo/types/int8_kernel.h +19 -0
  79. data/ext/cumo/include/cumo/types/int_macro.h +67 -0
  80. data/ext/cumo/include/cumo/types/int_macro_kernel.h +48 -0
  81. data/ext/cumo/include/cumo/types/real_accum.h +486 -0
  82. data/ext/cumo/include/cumo/types/real_accum_kernel.h +101 -0
  83. data/ext/cumo/include/cumo/types/robj_macro.h +80 -0
  84. data/ext/cumo/include/cumo/types/robj_macro_kernel.h +0 -0
  85. data/ext/cumo/include/cumo/types/robject.h +27 -0
  86. data/ext/cumo/include/cumo/types/robject_kernel.h +7 -0
  87. data/ext/cumo/include/cumo/types/scomplex.h +46 -0
  88. data/ext/cumo/include/cumo/types/scomplex_kernel.h +13 -0
  89. data/ext/cumo/include/cumo/types/sfloat.h +48 -0
  90. data/ext/cumo/include/cumo/types/sfloat_kernel.h +14 -0
  91. data/ext/cumo/include/cumo/types/uint16.h +25 -0
  92. data/ext/cumo/include/cumo/types/uint16_kernel.h +20 -0
  93. data/ext/cumo/include/cumo/types/uint32.h +25 -0
  94. data/ext/cumo/include/cumo/types/uint32_kernel.h +20 -0
  95. data/ext/cumo/include/cumo/types/uint64.h +25 -0
  96. data/ext/cumo/include/cumo/types/uint64_kernel.h +20 -0
  97. data/ext/cumo/include/cumo/types/uint8.h +25 -0
  98. data/ext/cumo/include/cumo/types/uint8_kernel.h +20 -0
  99. data/ext/cumo/include/cumo/types/uint_macro.h +58 -0
  100. data/ext/cumo/include/cumo/types/uint_macro_kernel.h +38 -0
  101. data/ext/cumo/include/cumo/types/xint_macro.h +169 -0
  102. data/ext/cumo/include/cumo/types/xint_macro_kernel.h +88 -0
  103. data/ext/cumo/narray/SFMT-params.h +97 -0
  104. data/ext/cumo/narray/SFMT-params19937.h +46 -0
  105. data/ext/cumo/narray/SFMT.c +620 -0
  106. data/ext/cumo/narray/SFMT.h +167 -0
  107. data/ext/cumo/narray/array.c +638 -0
  108. data/ext/cumo/narray/data.c +961 -0
  109. data/ext/cumo/narray/gen/cogen.rb +56 -0
  110. data/ext/cumo/narray/gen/cogen_kernel.rb +58 -0
  111. data/ext/cumo/narray/gen/def/bit.rb +37 -0
  112. data/ext/cumo/narray/gen/def/dcomplex.rb +39 -0
  113. data/ext/cumo/narray/gen/def/dfloat.rb +37 -0
  114. data/ext/cumo/narray/gen/def/int16.rb +36 -0
  115. data/ext/cumo/narray/gen/def/int32.rb +36 -0
  116. data/ext/cumo/narray/gen/def/int64.rb +36 -0
  117. data/ext/cumo/narray/gen/def/int8.rb +36 -0
  118. data/ext/cumo/narray/gen/def/robject.rb +37 -0
  119. data/ext/cumo/narray/gen/def/scomplex.rb +39 -0
  120. data/ext/cumo/narray/gen/def/sfloat.rb +37 -0
  121. data/ext/cumo/narray/gen/def/uint16.rb +36 -0
  122. data/ext/cumo/narray/gen/def/uint32.rb +36 -0
  123. data/ext/cumo/narray/gen/def/uint64.rb +36 -0
  124. data/ext/cumo/narray/gen/def/uint8.rb +36 -0
  125. data/ext/cumo/narray/gen/erbpp2.rb +346 -0
  126. data/ext/cumo/narray/gen/narray_def.rb +268 -0
  127. data/ext/cumo/narray/gen/spec.rb +425 -0
  128. data/ext/cumo/narray/gen/tmpl/accum.c +86 -0
  129. data/ext/cumo/narray/gen/tmpl/accum_binary.c +121 -0
  130. data/ext/cumo/narray/gen/tmpl/accum_binary_kernel.cu +61 -0
  131. data/ext/cumo/narray/gen/tmpl/accum_index.c +119 -0
  132. data/ext/cumo/narray/gen/tmpl/accum_index_kernel.cu +66 -0
  133. data/ext/cumo/narray/gen/tmpl/accum_kernel.cu +12 -0
  134. data/ext/cumo/narray/gen/tmpl/alloc_func.c +107 -0
  135. data/ext/cumo/narray/gen/tmpl/allocate.c +37 -0
  136. data/ext/cumo/narray/gen/tmpl/aref.c +66 -0
  137. data/ext/cumo/narray/gen/tmpl/aref_cpu.c +50 -0
  138. data/ext/cumo/narray/gen/tmpl/aset.c +56 -0
  139. data/ext/cumo/narray/gen/tmpl/binary.c +162 -0
  140. data/ext/cumo/narray/gen/tmpl/binary2.c +70 -0
  141. data/ext/cumo/narray/gen/tmpl/binary2_kernel.cu +15 -0
  142. data/ext/cumo/narray/gen/tmpl/binary_kernel.cu +31 -0
  143. data/ext/cumo/narray/gen/tmpl/binary_s.c +45 -0
  144. data/ext/cumo/narray/gen/tmpl/binary_s_kernel.cu +15 -0
  145. data/ext/cumo/narray/gen/tmpl/bincount.c +181 -0
  146. data/ext/cumo/narray/gen/tmpl/cast.c +44 -0
  147. data/ext/cumo/narray/gen/tmpl/cast_array.c +13 -0
  148. data/ext/cumo/narray/gen/tmpl/class.c +9 -0
  149. data/ext/cumo/narray/gen/tmpl/class_kernel.cu +6 -0
  150. data/ext/cumo/narray/gen/tmpl/clip.c +121 -0
  151. data/ext/cumo/narray/gen/tmpl/coerce_cast.c +10 -0
  152. data/ext/cumo/narray/gen/tmpl/complex_accum_kernel.cu +129 -0
  153. data/ext/cumo/narray/gen/tmpl/cond_binary.c +68 -0
  154. data/ext/cumo/narray/gen/tmpl/cond_binary_kernel.cu +18 -0
  155. data/ext/cumo/narray/gen/tmpl/cond_unary.c +46 -0
  156. data/ext/cumo/narray/gen/tmpl/cum.c +50 -0
  157. data/ext/cumo/narray/gen/tmpl/each.c +47 -0
  158. data/ext/cumo/narray/gen/tmpl/each_with_index.c +70 -0
  159. data/ext/cumo/narray/gen/tmpl/ewcomp.c +79 -0
  160. data/ext/cumo/narray/gen/tmpl/ewcomp_kernel.cu +19 -0
  161. data/ext/cumo/narray/gen/tmpl/extract.c +22 -0
  162. data/ext/cumo/narray/gen/tmpl/extract_cpu.c +26 -0
  163. data/ext/cumo/narray/gen/tmpl/extract_data.c +53 -0
  164. data/ext/cumo/narray/gen/tmpl/eye.c +105 -0
  165. data/ext/cumo/narray/gen/tmpl/eye_kernel.cu +19 -0
  166. data/ext/cumo/narray/gen/tmpl/fill.c +52 -0
  167. data/ext/cumo/narray/gen/tmpl/fill_kernel.cu +29 -0
  168. data/ext/cumo/narray/gen/tmpl/float_accum_kernel.cu +106 -0
  169. data/ext/cumo/narray/gen/tmpl/format.c +62 -0
  170. data/ext/cumo/narray/gen/tmpl/format_to_a.c +49 -0
  171. data/ext/cumo/narray/gen/tmpl/frexp.c +38 -0
  172. data/ext/cumo/narray/gen/tmpl/gemm.c +203 -0
  173. data/ext/cumo/narray/gen/tmpl/init_class.c +20 -0
  174. data/ext/cumo/narray/gen/tmpl/init_module.c +12 -0
  175. data/ext/cumo/narray/gen/tmpl/inspect.c +21 -0
  176. data/ext/cumo/narray/gen/tmpl/lib.c +50 -0
  177. data/ext/cumo/narray/gen/tmpl/lib_kernel.cu +24 -0
  178. data/ext/cumo/narray/gen/tmpl/logseq.c +102 -0
  179. data/ext/cumo/narray/gen/tmpl/logseq_kernel.cu +31 -0
  180. data/ext/cumo/narray/gen/tmpl/map_with_index.c +98 -0
  181. data/ext/cumo/narray/gen/tmpl/median.c +66 -0
  182. data/ext/cumo/narray/gen/tmpl/minmax.c +47 -0
  183. data/ext/cumo/narray/gen/tmpl/module.c +9 -0
  184. data/ext/cumo/narray/gen/tmpl/module_kernel.cu +1 -0
  185. data/ext/cumo/narray/gen/tmpl/new_dim0.c +15 -0
  186. data/ext/cumo/narray/gen/tmpl/new_dim0_kernel.cu +8 -0
  187. data/ext/cumo/narray/gen/tmpl/poly.c +50 -0
  188. data/ext/cumo/narray/gen/tmpl/pow.c +97 -0
  189. data/ext/cumo/narray/gen/tmpl/pow_kernel.cu +29 -0
  190. data/ext/cumo/narray/gen/tmpl/powint.c +17 -0
  191. data/ext/cumo/narray/gen/tmpl/qsort.c +212 -0
  192. data/ext/cumo/narray/gen/tmpl/rand.c +168 -0
  193. data/ext/cumo/narray/gen/tmpl/rand_norm.c +121 -0
  194. data/ext/cumo/narray/gen/tmpl/real_accum_kernel.cu +75 -0
  195. data/ext/cumo/narray/gen/tmpl/seq.c +112 -0
  196. data/ext/cumo/narray/gen/tmpl/seq_kernel.cu +43 -0
  197. data/ext/cumo/narray/gen/tmpl/set2.c +57 -0
  198. data/ext/cumo/narray/gen/tmpl/sort.c +48 -0
  199. data/ext/cumo/narray/gen/tmpl/sort_index.c +111 -0
  200. data/ext/cumo/narray/gen/tmpl/store.c +41 -0
  201. data/ext/cumo/narray/gen/tmpl/store_array.c +187 -0
  202. data/ext/cumo/narray/gen/tmpl/store_array_kernel.cu +58 -0
  203. data/ext/cumo/narray/gen/tmpl/store_bit.c +86 -0
  204. data/ext/cumo/narray/gen/tmpl/store_bit_kernel.cu +66 -0
  205. data/ext/cumo/narray/gen/tmpl/store_from.c +81 -0
  206. data/ext/cumo/narray/gen/tmpl/store_from_kernel.cu +58 -0
  207. data/ext/cumo/narray/gen/tmpl/store_kernel.cu +3 -0
  208. data/ext/cumo/narray/gen/tmpl/store_numeric.c +9 -0
  209. data/ext/cumo/narray/gen/tmpl/to_a.c +43 -0
  210. data/ext/cumo/narray/gen/tmpl/unary.c +132 -0
  211. data/ext/cumo/narray/gen/tmpl/unary2.c +60 -0
  212. data/ext/cumo/narray/gen/tmpl/unary_kernel.cu +72 -0
  213. data/ext/cumo/narray/gen/tmpl/unary_ret2.c +34 -0
  214. data/ext/cumo/narray/gen/tmpl/unary_s.c +86 -0
  215. data/ext/cumo/narray/gen/tmpl/unary_s_kernel.cu +58 -0
  216. data/ext/cumo/narray/gen/tmpl_bit/allocate.c +24 -0
  217. data/ext/cumo/narray/gen/tmpl_bit/aref.c +54 -0
  218. data/ext/cumo/narray/gen/tmpl_bit/aref_cpu.c +57 -0
  219. data/ext/cumo/narray/gen/tmpl_bit/aset.c +56 -0
  220. data/ext/cumo/narray/gen/tmpl_bit/binary.c +98 -0
  221. data/ext/cumo/narray/gen/tmpl_bit/bit_count.c +64 -0
  222. data/ext/cumo/narray/gen/tmpl_bit/bit_count_cpu.c +88 -0
  223. data/ext/cumo/narray/gen/tmpl_bit/bit_count_kernel.cu +76 -0
  224. data/ext/cumo/narray/gen/tmpl_bit/bit_reduce.c +133 -0
  225. data/ext/cumo/narray/gen/tmpl_bit/each.c +48 -0
  226. data/ext/cumo/narray/gen/tmpl_bit/each_with_index.c +70 -0
  227. data/ext/cumo/narray/gen/tmpl_bit/extract.c +30 -0
  228. data/ext/cumo/narray/gen/tmpl_bit/extract_cpu.c +29 -0
  229. data/ext/cumo/narray/gen/tmpl_bit/fill.c +69 -0
  230. data/ext/cumo/narray/gen/tmpl_bit/format.c +64 -0
  231. data/ext/cumo/narray/gen/tmpl_bit/format_to_a.c +51 -0
  232. data/ext/cumo/narray/gen/tmpl_bit/inspect.c +21 -0
  233. data/ext/cumo/narray/gen/tmpl_bit/mask.c +136 -0
  234. data/ext/cumo/narray/gen/tmpl_bit/none_p.c +14 -0
  235. data/ext/cumo/narray/gen/tmpl_bit/store_array.c +108 -0
  236. data/ext/cumo/narray/gen/tmpl_bit/store_bit.c +70 -0
  237. data/ext/cumo/narray/gen/tmpl_bit/store_from.c +60 -0
  238. data/ext/cumo/narray/gen/tmpl_bit/to_a.c +47 -0
  239. data/ext/cumo/narray/gen/tmpl_bit/unary.c +81 -0
  240. data/ext/cumo/narray/gen/tmpl_bit/where.c +90 -0
  241. data/ext/cumo/narray/gen/tmpl_bit/where2.c +95 -0
  242. data/ext/cumo/narray/index.c +880 -0
  243. data/ext/cumo/narray/kwargs.c +153 -0
  244. data/ext/cumo/narray/math.c +142 -0
  245. data/ext/cumo/narray/narray.c +1948 -0
  246. data/ext/cumo/narray/ndloop.c +2105 -0
  247. data/ext/cumo/narray/rand.c +45 -0
  248. data/ext/cumo/narray/step.c +474 -0
  249. data/ext/cumo/narray/struct.c +886 -0
  250. data/lib/cumo.rb +3 -0
  251. data/lib/cumo/cuda.rb +11 -0
  252. data/lib/cumo/cuda/compile_error.rb +36 -0
  253. data/lib/cumo/cuda/compiler.rb +161 -0
  254. data/lib/cumo/cuda/device.rb +47 -0
  255. data/lib/cumo/cuda/link_state.rb +31 -0
  256. data/lib/cumo/cuda/module.rb +40 -0
  257. data/lib/cumo/cuda/nvrtc_program.rb +27 -0
  258. data/lib/cumo/linalg.rb +12 -0
  259. data/lib/cumo/narray.rb +2 -0
  260. data/lib/cumo/narray/extra.rb +1278 -0
  261. data/lib/erbpp.rb +294 -0
  262. data/lib/erbpp/line_number.rb +137 -0
  263. data/lib/erbpp/narray_def.rb +381 -0
  264. data/numo-narray-version +1 -0
  265. data/run.gdb +7 -0
  266. metadata +353 -0
@@ -0,0 +1,3 @@
1
+ require_relative File.join(__dir__, '../ext/cumo/cumo')
2
+ require_relative 'cumo/cuda'
3
+ require_relative 'cumo/narray/extra'
@@ -0,0 +1,11 @@
1
+ module Cumo
2
+ module CUDA
3
+ end
4
+ end
5
+
6
+ require_relative 'cuda/compile_error'
7
+ require_relative 'cuda/compiler'
8
+ require_relative 'cuda/device'
9
+ require_relative 'cuda/module'
10
+ require_relative 'cuda/link_state'
11
+ require_relative 'cuda/nvrtc_program'
@@ -0,0 +1,36 @@
1
+ require_relative '../cuda'
2
+
3
+ module Cumo::CUDA
4
+ class CompileError < StandardError
5
+ def initialize(msg, source, name, options)
6
+ @msg = msg
7
+ @source = source
8
+ @name = name
9
+ @options = options
10
+ end
11
+
12
+ def message
13
+ @msg
14
+ end
15
+
16
+ def to_s
17
+ @msg
18
+ end
19
+
20
+ def dump(io)
21
+ lines = @source.split("\n")
22
+ digits = Math.log10(lines.size).floor + 1
23
+ linum_fmt = "%0#{digits}d "
24
+ io.puts("NVRTC compilation error: #{@msg}")
25
+ io.puts("-----")
26
+ io.puts("Name: #{@name}")
27
+ io.puts("Options: #{@options.join(' ')}")
28
+ io.puts("CUDA source:")
29
+ lines.each.with_index do |line, i|
30
+ io.puts(linum_fmt.sprintf(i + 1) << line.rstrip)
31
+ end
32
+ io.puts("-----")
33
+ io.flush
34
+ end
35
+ end
36
+ end
@@ -0,0 +1,161 @@
1
+ require 'tmpdir'
2
+ require 'tempfile'
3
+ require 'fileutils'
4
+ require 'digest/md5'
5
+ require_relative '../cuda'
6
+
7
+ module Cumo::CUDA
8
+ class Compiler
9
+ VALID_KERNEL_NAME = /\A[a-zA-Z_][a-zA-Z_0-9]*\z/
10
+ DEFAULT_CACHE_DIR = File.expand_path('~/.cumo/kernel_cache')
11
+
12
+ @@empty_file_preprocess_cache ||= {}
13
+
14
+ def self.valid_kernel_name?(name)
15
+ VALID_KERNEL_NAME.match?(name)
16
+ end
17
+
18
+ def compile_using_nvrtc(source, options: [], arch: nil)
19
+ arch ||= get_arch
20
+ options += ["-arch=#{arch}"]
21
+
22
+ Dir.mktmpdir do |root_dir|
23
+ path = File.join(root_dir, 'kern')
24
+ cu_path = "#{path}.cu"
25
+
26
+ File.open(cu_path, 'w') do |cu_file|
27
+ cu_file.write(source)
28
+ end
29
+
30
+ prog = NVRTCProgram.new(source, name: cu_path)
31
+ begin
32
+ ptx = prog.compile(options: options)
33
+ rescue CompileError => e
34
+ if get_bool_env_variable('CUMO_DUMP_CUDA_SOURCE_ON_ERROR', false)
35
+ e.dump($stderr)
36
+ end
37
+ raise e
38
+ ensure
39
+ prog.destroy
40
+ end
41
+ return ptx
42
+ end
43
+ end
44
+
45
+ def compile_with_cache(source, options: [], arch: nil, cache_dir: nil, extra_source: nil)
46
+ # NVRTC does not use extra_source. extra_source is used for cache key.
47
+ cache_dir ||= get_cache_dir
48
+ arch ||= get_arch
49
+
50
+ options += ['-ftz=true']
51
+
52
+ env = [arch, options, get_nvrtc_version]
53
+ base = @@empty_file_preprocess_cache[env]
54
+ if base.nil?
55
+ # This is checking of NVRTC compiler internal version
56
+ base = preprocess('', options, arch)
57
+ @@empty_file_preprocess_cache[env] = base
58
+ end
59
+ key_src = "#{env} #{base} #{source} #{extra_source}"
60
+
61
+ key_src.encode!('utf-8')
62
+ digest = Digest::MD5.hexdigest(key_src)
63
+ name = "#{digest}_2.cubin"
64
+
65
+ unless Dir.exist?(cache_dir)
66
+ FileUtils.mkdir_p(cache_dir)
67
+ end
68
+
69
+ # TODO(sonots): thread-safe?
70
+ path = File.join(cache_dir, name)
71
+ cubin = load_cache(path)
72
+ if cubin
73
+ mod = Module.new
74
+ mod.load(cubin)
75
+ return mod
76
+ end
77
+
78
+ ptx = compile_using_nvrtc(source, options: options, arch: arch)
79
+ cubin = nil
80
+ cubin_hash = nil
81
+ LinkState.new do |ls|
82
+ ls.add_ptr_data(ptx, 'cumo.ptx')
83
+ cubin = ls.complete()
84
+ cubin_hash = Digest::MD5.hexdigest(cubin)
85
+ end
86
+
87
+ save_cache(path, cubin_hash, cubin)
88
+
89
+ # Save .cu source file along with .cubin
90
+ if get_bool_env_variable('CUMO_CACHE_SAVE_CUDA_SOURCE', false)
91
+ File.open("#{path}.cu", 'w') do |f|
92
+ f.write(source)
93
+ end
94
+ end
95
+
96
+ mod = Module.new
97
+ mod.load(cubin)
98
+ return mod
99
+ end
100
+
101
+ private
102
+
103
+ def save_cache(path, cubin_hash, cubin)
104
+ tf = Tempfile.create
105
+ tf.write(cubin_hash)
106
+ tf.write(cubin)
107
+ temp_path = tf.path
108
+ File.rename(temp_path, path)
109
+ end
110
+
111
+ def load_cache(path)
112
+ return nil unless File.exist?(path)
113
+ File.open(path, 'rb') do |file|
114
+ data = file.read
115
+ return nil unless data.size >= 32
116
+ hash = data[0...32]
117
+ cubin = data[32..-1]
118
+ cubin_hash = Digest::MD5.hexdigest(cubin)
119
+ return nil unless hash == cubin_hash
120
+ return cubin
121
+ end
122
+ nil
123
+ end
124
+
125
+ def get_cache_dir
126
+ ENV.fetch('CUMO_CACHE_DIR', DEFAULT_CACHE_DIR)
127
+ end
128
+
129
+ def get_nvrtc_version
130
+ @@nvrtc_version ||= NVRTC.nvrtcVersion
131
+ end
132
+
133
+ def get_arch
134
+ cc = Device.new.compute_capability
135
+ "compute_#{cc}"
136
+ end
137
+
138
+ def get_bool_env_variable(name, default)
139
+ val = ENV[name]
140
+ return default if val.nil? or val.size == 0
141
+ Integer(val) == 1 rescue false
142
+ end
143
+
144
+ def preprocess(source, options, arch)
145
+ options += ["-arch=#{arch}"]
146
+
147
+ prog = NVRTCProgram.new(source, name: '')
148
+ begin
149
+ result = prog.compile(options: options)
150
+ return result
151
+ rescue CompileError => e
152
+ if get_bool_env_variable('CUMO_DUMP_CUDA_SOURCE_ON_ERROR', false)
153
+ e.dump($stderr)
154
+ end
155
+ raise e
156
+ ensure
157
+ prog.destroy
158
+ end
159
+ end
160
+ end
161
+ end
@@ -0,0 +1,47 @@
1
+ require_relative '../cuda'
2
+
3
+ module Cumo::CUDA
4
+ class Device
5
+ attr_reader :id
6
+
7
+ def self.get_currend_id
8
+ Runtime.cudaGetDevice
9
+ end
10
+
11
+ def initialize(device_id = nil)
12
+ if device_id
13
+ @id = device_id
14
+ else
15
+ @id = Runtime.cudaGetDevice
16
+ end
17
+ @_device_stack = []
18
+ end
19
+
20
+ def use
21
+ Runtime.cudaSetDevice(@id)
22
+ end
23
+
24
+ def with
25
+ raise unless block_given?
26
+ prev_id = Runtime.cudaGetDevice
27
+ @_device_stack << prev_id
28
+ begin
29
+ Runtime.cudaSetDevice(@id) unless prev_id != @id
30
+ yield
31
+ ensure
32
+ prev_id = @_device_stack.pop
33
+ Runtime.cudaSetDevice(prev_id)
34
+ end
35
+ end
36
+
37
+ def synchronize
38
+ Runtime.cudaDeviceSynchronize
39
+ end
40
+
41
+ def compute_capability
42
+ major = Runtime.cudaDeviceGetAttributes(75, @id)
43
+ minor = Runtime.cudaDeviceGetAttributes(76, @id)
44
+ "#{major}#{minor}"
45
+ end
46
+ end
47
+ end
@@ -0,0 +1,31 @@
1
+ require_relative '../cuda'
2
+
3
+ module Cumo::CUDA
4
+ # CUDA link state.
5
+ class LinkState
6
+ def initialize
7
+ @ptr = Driver.cuLinkCreate
8
+ if block_given?
9
+ begin
10
+ yield(self)
11
+ ensure
12
+ destroy
13
+ end
14
+ end
15
+ end
16
+
17
+ def destroy
18
+ return unless @ptr
19
+ Driver.cuLinkDestroy(@ptr)
20
+ @ptr = nil
21
+ end
22
+
23
+ def add_ptr_data(data, name)
24
+ Driver.cuLinkAddData(@ptr, Driver::CU_JIT_INPUT_PTX, data, name)
25
+ end
26
+
27
+ def complete
28
+ cubin = Driver.cuLinkComplete(@ptr)
29
+ end
30
+ end
31
+ end
@@ -0,0 +1,40 @@
1
+ require_relative '../cuda'
2
+
3
+ module Cumo::CUDA
4
+ # CUDA kernel module.
5
+ class Module
6
+ def initialize
7
+ @ptr = nil
8
+ if block_given?
9
+ begin
10
+ yield(self)
11
+ ensure
12
+ unload
13
+ end
14
+ end
15
+ end
16
+
17
+ def unload
18
+ return unless @ptr
19
+ Driver.cuModuleUnload(@ptr)
20
+ @ptr = nil
21
+ end
22
+
23
+ def load_file(fname)
24
+ @ptr = Driver.cuModuleLoad(fname)
25
+ end
26
+
27
+ def load(cubin)
28
+ @ptr = Driver.cuModuleLoadData(cubin)
29
+ end
30
+
31
+ def get_global_var(name)
32
+ Driver.cuModuleGetGlobal(@ptr, name)
33
+ end
34
+
35
+ def get_function(name)
36
+ # Function(name)
37
+ end
38
+ end
39
+ end
40
+
@@ -0,0 +1,27 @@
1
+ require_relative '../cuda'
2
+ require_relative 'compile_error'
3
+
4
+ module Cumo::CUDA
5
+ class NVRTCProgram
6
+ def initialize(src, name: "default_program", headers: [], include_names: [])
7
+ @ptr = nil
8
+ @src = src # should be UTF-8
9
+ @name = name # should be UTF-8
10
+ @ptr = NVRTC.nvrtcCreateProgram(src, name, headers, include_names)
11
+ end
12
+
13
+ def destroy
14
+ NVRTC.nvrtcDestroyProgram(@ptr) if @ptr
15
+ end
16
+
17
+ def compile(options: [])
18
+ begin
19
+ NVRTC.nvrtcCompileProgram(@ptr, options)
20
+ return NVRTC.nvrtcGetPTX(@ptr)
21
+ rescue NVRTCError => e
22
+ log = NVRTC.nvrtcGetProgramLog(@ptr)
23
+ raise CompileError.new(log, @src, @name, options)
24
+ end
25
+ end
26
+ end
27
+ end
@@ -0,0 +1,12 @@
1
+ require 'cumo'
2
+
3
+ # Provide compatibility layers with numo/linalg
4
+ module Cumo
5
+ module Blas
6
+ class << self
7
+ def gemm(a, *args, **kwargs)
8
+ a.gemm(*args, **kwargs)
9
+ end
10
+ end
11
+ end
12
+ end
@@ -0,0 +1,2 @@
1
+ # This file is for compatibility with require 'numo/narray'
2
+ require_relative '../cumo'
@@ -0,0 +1,1278 @@
1
+ module Cumo
2
+ class NArray
3
+
4
+ # Return an unallocated array with the same shape and type as self.
5
+ def new_narray
6
+ self.class.new(*shape)
7
+ end
8
+
9
+ # Return an array of zeros with the same shape and type as self.
10
+ def new_zeros
11
+ self.class.zeros(*shape)
12
+ end
13
+
14
+ # Return an array of ones with the same shape and type as self.
15
+ def new_ones
16
+ self.class.ones(*shape)
17
+ end
18
+
19
+ # Return an array filled with value with the same shape and type as self.
20
+ def new_fill(value)
21
+ self.class.new(*shape).fill(value)
22
+ end
23
+
24
+ # Convert angles from radians to degrees.
25
+ def rad2deg
26
+ self * (180/Math::PI)
27
+ end
28
+
29
+ # Convert angles from degrees to radians.
30
+ def deg2rad
31
+ self * (Math::PI/180)
32
+ end
33
+
34
+ # Flip each row in the left/right direction.
35
+ # Same as `a[true, (-1..0).step(-1), ...]`.
36
+ def fliplr
37
+ reverse(1)
38
+ end
39
+
40
+ # Flip each column in the up/down direction.
41
+ # Same as `a[(-1..0).step(-1), ...]`.
42
+ def flipud
43
+ reverse(0)
44
+ end
45
+
46
+ # Multi-dimensional array indexing.
47
+ # Same as [] for one-dimensional NArray.
48
+ # Similar to numpy's tuple indexing, i.e., `a[[1,2,..],[3,4,..]]`
49
+ # (This method will be rewritten in C)
50
+ # @return [Cumo::NArray] one-dimensional view of self.
51
+ # @example
52
+ # p x = Cumo::DFloat.new(3,3,3).seq
53
+ # # Cumo::DFloat#shape=[3,3,3]
54
+ # # [[[0, 1, 2],
55
+ # # [3, 4, 5],
56
+ # # [6, 7, 8]],
57
+ # # [[9, 10, 11],
58
+ # # [12, 13, 14],
59
+ # # [15, 16, 17]],
60
+ # # [[18, 19, 20],
61
+ # # [21, 22, 23],
62
+ # # [24, 25, 26]]]
63
+ #
64
+ # p x.at([0,1,2],[0,1,2],[-1,-2,-3])
65
+ # # Cumo::DFloat(view)#shape=[3]
66
+ # # [2, 13, 24]
67
+ def at(*indices)
68
+ if indices.size != ndim
69
+ raise DimensionError, "argument length does not match dimension size"
70
+ end
71
+ idx = nil
72
+ stride = 1
73
+ (indices.size-1).downto(0) do |i|
74
+ ix = Int64.cast(indices[i])
75
+ if ix.ndim != 1
76
+ raise DimensionError, "index array is not one-dimensional"
77
+ end
78
+ ix[ix < 0] += shape[i]
79
+ if ((ix < 0) & (ix >= shape[i])).any?
80
+ raise IndexError, "index array is out of range"
81
+ end
82
+ if idx
83
+ if idx.size != ix.size
84
+ raise ShapeError, "index array sizes mismatch"
85
+ end
86
+ idx += ix * stride
87
+ stride *= shape[i]
88
+ else
89
+ idx = ix
90
+ stride = shape[i]
91
+ end
92
+ end
93
+ self[idx]
94
+ end
95
+
96
+ # Rotate in the plane specified by axes.
97
+ # @example
98
+ # p a = Cumo::Int32.new(2,2).seq
99
+ # # Cumo::Int32#shape=[2,2]
100
+ # # [[0, 1],
101
+ # # [2, 3]]
102
+ #
103
+ # p a.rot90
104
+ # # Cumo::Int32(view)#shape=[2,2]
105
+ # # [[1, 3],
106
+ # # [0, 2]]
107
+ #
108
+ # p a.rot90(2)
109
+ # # Cumo::Int32(view)#shape=[2,2]
110
+ # # [[3, 2],
111
+ # # [1, 0]]
112
+ #
113
+ # p a.rot90(3)
114
+ # # Cumo::Int32(view)#shape=[2,2]
115
+ # # [[2, 0],
116
+ # # [3, 1]]
117
+ def rot90(k=1,axes=[0,1])
118
+ case k % 4
119
+ when 0
120
+ view
121
+ when 1
122
+ swapaxes(*axes).reverse(axes[0])
123
+ when 2
124
+ reverse(*axes)
125
+ when 3
126
+ swapaxes(*axes).reverse(axes[1])
127
+ end
128
+ end
129
+
130
+ def to_i
131
+ if size==1
132
+ self.extract_cpu.to_i
133
+ else
134
+ # convert to Int?
135
+ raise TypeError, "can't convert #{self.class} into Integer"
136
+ end
137
+ end
138
+
139
+ def to_f
140
+ if size==1
141
+ self.extract_cpu.to_f
142
+ else
143
+ # convert to DFloat?
144
+ raise TypeError, "can't convert #{self.class} into Float"
145
+ end
146
+ end
147
+
148
+ def to_c
149
+ if size==1
150
+ Complex(self.extract_cpu)
151
+ else
152
+ # convert to DComplex?
153
+ raise TypeError, "can't convert #{self.class} into Complex"
154
+ end
155
+ end
156
+
157
+ # Convert the argument to an narray if not an narray.
158
+ def self.cast(a)
159
+ a.kind_of?(NArray) ? a : NArray.array_type(a).cast(a)
160
+ end
161
+
162
+ def self.asarray(a)
163
+ case a
164
+ when NArray
165
+ (a.ndim == 0) ? a[:new] : a
166
+ when Numeric,Range
167
+ self[a]
168
+ else
169
+ cast(a)
170
+ end
171
+ end
172
+
173
+ # parse matrix like matlab, octave
174
+ # @example
175
+ # a = Cumo::DFloat.parse %[
176
+ # 2 -3 5
177
+ # 4 9 7
178
+ # 2 -1 6
179
+ # ]
180
+ # => Cumo::DFloat#shape=[3,3]
181
+ # [[2, -3, 5],
182
+ # [4, 9, 7],
183
+ # [2, -1, 6]]
184
+
185
+ def self.parse(str, split1d:/\s+/, split2d:/;?$|;/,
186
+ split3d:/\s*\n(\s*\n)+/m)
187
+ a = []
188
+ str.split(split3d).each do |block|
189
+ b = []
190
+ #print "b"; p block
191
+ block.split(split2d).each do |line|
192
+ #p line
193
+ line.strip!
194
+ if !line.empty?
195
+ c = []
196
+ line.split(split1d).each do |item|
197
+ c << eval(item.strip) if !item.empty?
198
+ end
199
+ b << c if !c.empty?
200
+ end
201
+ end
202
+ a << b if !b.empty?
203
+ end
204
+ if a.size==1
205
+ self.cast(a[0])
206
+ else
207
+ self.cast(a)
208
+ end
209
+ end
210
+
211
+ # Append values to the end of an narray.
212
+ # @example
213
+ # a = Cumo::DFloat[1, 2, 3]
214
+ # p a.append([[4, 5, 6], [7, 8, 9]])
215
+ # # Cumo::DFloat#shape=[9]
216
+ # # [1, 2, 3, 4, 5, 6, 7, 8, 9]
217
+ #
218
+ # a = Cumo::DFloat[[1, 2, 3]]
219
+ # p a.append([[4, 5, 6], [7, 8, 9]],axis:0)
220
+ # # Cumo::DFloat#shape=[3,3]
221
+ # # [[1, 2, 3],
222
+ # # [4, 5, 6],
223
+ # # [7, 8, 9]]
224
+ #
225
+ # a = Cumo::DFloat[[1, 2, 3], [4, 5, 6]]
226
+ # p a.append([7, 8, 9], axis:0)
227
+ # # in `append': dimension mismatch (Cumo::NArray::DimensionError)
228
+
229
+ def append(other,axis:nil)
230
+ other = self.class.cast(other)
231
+ if axis
232
+ if ndim != other.ndim
233
+ raise DimensionError, "dimension mismatch"
234
+ end
235
+ return concatenate(other,axis:axis)
236
+ else
237
+ a = self.class.zeros(size+other.size)
238
+ a[0...size] = self[true]
239
+ a[size..-1] = other[true]
240
+ return a
241
+ end
242
+ end
243
+
244
+ # Return a new array with sub-arrays along an axis deleted.
245
+ # If axis is not given, obj is applied to the flattened array.
246
+
247
+ # @example
248
+ # a = Cumo::DFloat[[1,2,3,4], [5,6,7,8], [9,10,11,12]]
249
+ # p a.delete(1,0)
250
+ # # Cumo::DFloat(view)#shape=[2,4]
251
+ # # [[1, 2, 3, 4],
252
+ # # [9, 10, 11, 12]]
253
+ #
254
+ # p a.delete((0..-1).step(2),1)
255
+ # # Cumo::DFloat(view)#shape=[3,2]
256
+ # # [[2, 4],
257
+ # # [6, 8],
258
+ # # [10, 12]]
259
+ #
260
+ # p a.delete([1,3,5])
261
+ # # Cumo::DFloat(view)#shape=[9]
262
+ # # [1, 3, 5, 7, 8, 9, 10, 11, 12]
263
+
264
+ def delete(indice,axis=nil)
265
+ if axis
266
+ bit = Bit.ones(shape[axis])
267
+ bit[indice] = 0
268
+ idx = [true]*ndim
269
+ idx[axis] = bit.where
270
+ return self[*idx].copy
271
+ else
272
+ bit = Bit.ones(size)
273
+ bit[indice] = 0
274
+ return self[bit.where].copy
275
+ end
276
+ end
277
+
278
+ # Insert values along the axis before the indices.
279
+ # @example
280
+ # p a = Cumo::DFloat[[1, 2], [3, 4]]
281
+ # a = Cumo::Int32[[1, 1], [2, 2], [3, 3]]
282
+ #
283
+ # p a.insert(1,5)
284
+ # # Cumo::Int32#shape=[7]
285
+ # # [1, 5, 1, 2, 2, 3, 3]
286
+ #
287
+ # p a.insert(1, 5, axis:1)
288
+ # # Cumo::Int32#shape=[3,3]
289
+ # # [[1, 5, 1],
290
+ # # [2, 5, 2],
291
+ # # [3, 5, 3]]
292
+ #
293
+ # p a.insert([1], [[11],[12],[13]], axis:1)
294
+ # # Cumo::Int32#shape=[3,3]
295
+ # # [[1, 11, 1],
296
+ # # [2, 12, 2],
297
+ # # [3, 13, 3]]
298
+ #
299
+ # p a.insert(1, [11, 12, 13], axis:1)
300
+ # # Cumo::Int32#shape=[3,3]
301
+ # # [[1, 11, 1],
302
+ # # [2, 12, 2],
303
+ # # [3, 13, 3]]
304
+ #
305
+ # p a.insert([1], [11, 12, 13], axis:1)
306
+ # # Cumo::Int32#shape=[3,5]
307
+ # # [[1, 11, 12, 13, 1],
308
+ # # [2, 11, 12, 13, 2],
309
+ # # [3, 11, 12, 13, 3]]
310
+ #
311
+ # p b = a.flatten
312
+ # # Cumo::Int32(view)#shape=[6]
313
+ # # [1, 1, 2, 2, 3, 3]
314
+ #
315
+ # p b.insert(2,[15,16])
316
+ # # Cumo::Int32#shape=[8]
317
+ # # [1, 1, 15, 16, 2, 2, 3, 3]
318
+ #
319
+ # p b.insert([2,2],[15,16])
320
+ # # Cumo::Int32#shape=[8]
321
+ # # [1, 1, 15, 16, 2, 2, 3, 3]
322
+ #
323
+ # p b.insert([2,1],[15,16])
324
+ # # Cumo::Int32#shape=[8]
325
+ # # [1, 16, 1, 15, 2, 2, 3, 3]
326
+ #
327
+ # p b.insert([2,0,1],[15,16,17])
328
+ # # Cumo::Int32#shape=[9]
329
+ # # [16, 1, 17, 1, 15, 2, 2, 3, 3]
330
+ #
331
+ # p b.insert(2..3, [15, 16])
332
+ # # Cumo::Int32#shape=[8]
333
+ # # [1, 1, 15, 2, 16, 2, 3, 3]
334
+ #
335
+ # p b.insert(2, [7.13, 0.5])
336
+ # # Cumo::Int32#shape=[8]
337
+ # # [1, 1, 7, 0, 2, 2, 3, 3]
338
+ #
339
+ # p x = Cumo::DFloat.new(2,4).seq
340
+ # # Cumo::DFloat#shape=[2,4]
341
+ # # [[0, 1, 2, 3],
342
+ # # [4, 5, 6, 7]]
343
+ #
344
+ # p x.insert([1,3],999,axis:1)
345
+ # # Cumo::DFloat#shape=[2,6]
346
+ # # [[0, 999, 1, 2, 999, 3],
347
+ # # [4, 999, 5, 6, 999, 7]]
348
+
349
+ def insert(indice,values,axis:nil)
350
+ if axis
351
+ values = self.class.asarray(values)
352
+ nd = values.ndim
353
+ midx = [:new]*(ndim-nd) + [true]*nd
354
+ case indice
355
+ when Numeric
356
+ midx[-nd-1] = true
357
+ midx[axis] = :new
358
+ end
359
+ values = values[*midx]
360
+ else
361
+ values = self.class.asarray(values).flatten
362
+ end
363
+ idx = Int64.asarray(indice)
364
+ nidx = idx.size
365
+ if nidx == 1
366
+ nidx = values.shape[axis||0]
367
+ idx = idx + Int64.new(nidx).seq
368
+ else
369
+ sidx = idx.sort_index
370
+ idx[sidx] += Int64.new(nidx).seq
371
+ end
372
+ if axis
373
+ bit = Bit.ones(shape[axis]+nidx)
374
+ bit[idx] = 0
375
+ new_shape = shape
376
+ new_shape[axis] += nidx
377
+ a = self.class.zeros(new_shape)
378
+ mdidx = [true]*ndim
379
+ mdidx[axis] = bit.where
380
+ a[*mdidx] = self
381
+ mdidx[axis] = idx
382
+ a[*mdidx] = values
383
+ else
384
+ bit = Bit.ones(size+nidx)
385
+ bit[idx] = 0
386
+ a = self.class.zeros(size+nidx)
387
+ a[bit.where] = self.flatten
388
+ a[idx] = values
389
+ end
390
+ return a
391
+ end
392
+
393
+ class << self
394
+ # @example
395
+ # p a = Cumo::DFloat[[1, 2], [3, 4]]
396
+ # # Cumo::DFloat#shape=[2,2]
397
+ # # [[1, 2],
398
+ # # [3, 4]]
399
+ #
400
+ # p b = Cumo::DFloat[[5, 6]]
401
+ # # Cumo::DFloat#shape=[1,2]
402
+ # # [[5, 6]]
403
+ #
404
+ # p Cumo::NArray.concatenate([a,b],axis:0)
405
+ # # Cumo::DFloat#shape=[3,2]
406
+ # # [[1, 2],
407
+ # # [3, 4],
408
+ # # [5, 6]]
409
+ #
410
+ # p Cumo::NArray.concatenate([a,b.transpose], axis:1)
411
+ # # Cumo::DFloat#shape=[2,3]
412
+ # # [[1, 2, 5],
413
+ # # [3, 4, 6]]
414
+
415
+ def concatenate(arrays,axis:0)
416
+ klass = (self==NArray) ? NArray.array_type(arrays) : self
417
+ nd = 0
418
+ arrays = arrays.map do |a|
419
+ case a
420
+ when NArray
421
+ # ok
422
+ when Numeric
423
+ a = klass[a]
424
+ when Array
425
+ a = klass.cast(a)
426
+ else
427
+ raise TypeError,"not Cumo::NArray: #{a.inspect[0..48]}"
428
+ end
429
+ if a.ndim > nd
430
+ nd = a.ndim
431
+ end
432
+ a
433
+ end
434
+ if axis < 0
435
+ axis += nd
436
+ end
437
+ if axis < 0 || axis >= nd
438
+ raise ArgumentError,"axis is out of range"
439
+ end
440
+ new_shape = nil
441
+ sum_size = 0
442
+ arrays.each do |a|
443
+ a_shape = a.shape
444
+ if nd != a_shape.size
445
+ a_shape = [1]*(nd-a_shape.size) + a_shape
446
+ end
447
+ sum_size += a_shape.delete_at(axis)
448
+ if new_shape
449
+ if new_shape != a_shape
450
+ raise ShapeError,"shape mismatch"
451
+ end
452
+ else
453
+ new_shape = a_shape
454
+ end
455
+ end
456
+ new_shape.insert(axis,sum_size)
457
+ result = klass.zeros(*new_shape)
458
+ lst = 0
459
+ refs = [true] * nd
460
+ arrays.each do |a|
461
+ fst = lst
462
+ lst = fst + (a.shape[axis-nd]||1)
463
+ refs[axis] = fst...lst
464
+ result[*refs] = a
465
+ end
466
+ result
467
+ end
468
+
469
+ # Stack arrays vertically (row wise).
470
+ # @example
471
+ # a = Cumo::Int32[1,2,3]
472
+ # b = Cumo::Int32[2,3,4]
473
+ # p Cumo::NArray.vstack([a,b])
474
+ # # Cumo::Int32#shape=[2,3]
475
+ # # [[1, 2, 3],
476
+ # # [2, 3, 4]]
477
+ #
478
+ # a = Cumo::Int32[[1],[2],[3]]
479
+ # b = Cumo::Int32[[2],[3],[4]]
480
+ # p Cumo::NArray.vstack([a,b])
481
+ # # Cumo::Int32#shape=[6,1]
482
+ # # [[1],
483
+ # # [2],
484
+ # # [3],
485
+ # # [2],
486
+ # # [3],
487
+ # # [4]]
488
+
489
+ def vstack(arrays)
490
+ arys = arrays.map do |a|
491
+ _atleast_2d(cast(a))
492
+ end
493
+ concatenate(arys,axis:0)
494
+ end
495
+
496
+ # Stack arrays horizontally (column wise).
497
+ # @example
498
+ # a = Cumo::Int32[1,2,3]
499
+ # b = Cumo::Int32[2,3,4]
500
+ # p Cumo::NArray.hstack([a,b])
501
+ # # Cumo::Int32#shape=[6]
502
+ # # [1, 2, 3, 2, 3, 4]
503
+ #
504
+ # a = Cumo::Int32[[1],[2],[3]]
505
+ # b = Cumo::Int32[[2],[3],[4]]
506
+ # p Cumo::NArray.hstack([a,b])
507
+ # # Cumo::Int32#shape=[3,2]
508
+ # # [[1, 2],
509
+ # # [2, 3],
510
+ # # [3, 4]]
511
+
512
+ def hstack(arrays)
513
+ klass = (self==NArray) ? NArray.array_type(arrays) : self
514
+ nd = 0
515
+ arys = arrays.map do |a|
516
+ a = klass.cast(a)
517
+ nd = a.ndim if a.ndim > nd
518
+ a
519
+ end
520
+ dim = (nd >= 2) ? 1 : 0
521
+ concatenate(arys,axis:dim)
522
+ end
523
+
524
+ # Stack arrays in depth wise (along third axis).
525
+ # @example
526
+ # a = Cumo::Int32[1,2,3]
527
+ # b = Cumo::Int32[2,3,4]
528
+ # p Cumo::NArray.dstack([a,b])
529
+ # # Cumo::Int32#shape=[1,3,2]
530
+ # # [[[1, 2],
531
+ # # [2, 3],
532
+ # # [3, 4]]]
533
+ #
534
+ # a = Cumo::Int32[[1],[2],[3]]
535
+ # b = Cumo::Int32[[2],[3],[4]]
536
+ # p Cumo::NArray.dstack([a,b])
537
+ # # Cumo::Int32#shape=[3,1,2]
538
+ # # [[[1, 2]],
539
+ # # [[2, 3]],
540
+ # # [[3, 4]]]
541
+
542
+ def dstack(arrays)
543
+ arys = arrays.map do |a|
544
+ _atleast_3d(cast(a))
545
+ end
546
+ concatenate(arys,axis:2)
547
+ end
548
+
549
+ # Stack 1-d arrays into columns of a 2-d array.
550
+ # @example
551
+ # x = Cumo::Int32[1,2,3]
552
+ # y = Cumo::Int32[2,3,4]
553
+ # p Cumo::NArray.column_stack([x,y])
554
+ # # Cumo::Int32#shape=[3,2]
555
+ # # [[1, 2],
556
+ # # [2, 3],
557
+ # # [3, 4]]
558
+
559
+ def column_stack(arrays)
560
+ arys = arrays.map do |a|
561
+ a = cast(a)
562
+ case a.ndim
563
+ when 0; a[:new,:new]
564
+ when 1; a[true,:new]
565
+ else; a
566
+ end
567
+ end
568
+ concatenate(arys,axis:1)
569
+ end
570
+
571
+ private
572
+ # Return an narray with at least two dimension.
573
+ def _atleast_2d(a)
574
+ case a.ndim
575
+ when 0; a[:new,:new]
576
+ when 1; a[:new,true]
577
+ else; a
578
+ end
579
+ end
580
+
581
+ # Return an narray with at least three dimension.
582
+ def _atleast_3d(a)
583
+ case a.ndim
584
+ when 0; a[:new,:new,:new]
585
+ when 1; a[:new,true,:new]
586
+ when 2; a[true,true,:new]
587
+ else; a
588
+ end
589
+ end
590
+
591
+ end # class << self
592
+
593
+ # @example
594
+ # p a = Cumo::DFloat[[1, 2], [3, 4]]
595
+ # # Cumo::DFloat#shape=[2,2]
596
+ # # [[1, 2],
597
+ # # [3, 4]]
598
+ #
599
+ # p b = Cumo::DFloat[[5, 6]]
600
+ # # Cumo::DFloat#shape=[1,2]
601
+ # # [[5, 6]]
602
+ #
603
+ # p a.concatenate(b,axis:0)
604
+ # # Cumo::DFloat#shape=[3,2]
605
+ # # [[1, 2],
606
+ # # [3, 4],
607
+ # # [5, 6]]
608
+ #
609
+ # p a.concatenate(b.transpose, axis:1)
610
+ # # Cumo::DFloat#shape=[2,3]
611
+ # # [[1, 2, 5],
612
+ # # [3, 4, 6]]
613
+
614
+ def concatenate(*arrays,axis:0)
615
+ axis = check_axis(axis)
616
+ self_shape = shape
617
+ self_shape.delete_at(axis)
618
+ sum_size = shape[axis]
619
+ arrays.map! do |a|
620
+ case a
621
+ when NArray
622
+ # ok
623
+ when Numeric
624
+ a = self.class.new(1).store(a)
625
+ when Array
626
+ a = self.class.cast(a)
627
+ else
628
+ raise TypeError,"not Cumo::NArray: #{a.inspect[0..48]}"
629
+ end
630
+ if a.ndim > ndim
631
+ raise ShapeError,"dimension mismatch"
632
+ end
633
+ a_shape = a.shape
634
+ sum_size += a_shape.delete_at(axis-ndim) || 1
635
+ if self_shape != a_shape
636
+ raise ShapeError,"shape mismatch"
637
+ end
638
+ a
639
+ end
640
+ self_shape.insert(axis,sum_size)
641
+ result = self.class.zeros(*self_shape)
642
+ lst = shape[axis]
643
+ refs = [true] * ndim
644
+ refs[axis] = 0...lst
645
+ result[*refs] = self
646
+ arrays.each do |a|
647
+ fst = lst
648
+ lst = fst + (a.shape[axis-ndim] || 1)
649
+ refs[axis] = fst...lst
650
+ result[*refs] = a
651
+ end
652
+ result
653
+ end
654
+
655
+ # @example
656
+ # p x = Cumo::DFloat.new(9).seq
657
+ # # Cumo::DFloat#shape=[9]
658
+ # # [0, 1, 2, 3, 4, 5, 6, 7, 8]
659
+ #
660
+ # pp x.split(3)
661
+ # # [Cumo::DFloat(view)#shape=[3]
662
+ # # [0, 1, 2],
663
+ # # Cumo::DFloat(view)#shape=[3]
664
+ # # [3, 4, 5],
665
+ # # Cumo::DFloat(view)#shape=[3]
666
+ # # [6, 7, 8]]
667
+ #
668
+ # p x = Cumo::DFloat.new(8).seq
669
+ # # Cumo::DFloat#shape=[8]
670
+ # # [0, 1, 2, 3, 4, 5, 6, 7]
671
+ #
672
+ # pp x.split([3, 5, 6, 10])
673
+ # # [Cumo::DFloat(view)#shape=[3]
674
+ # # [0, 1, 2],
675
+ # # Cumo::DFloat(view)#shape=[2]
676
+ # # [3, 4],
677
+ # # Cumo::DFloat(view)#shape=[1]
678
+ # # [5],
679
+ # # Cumo::DFloat(view)#shape=[2]
680
+ # # [6, 7],
681
+ # # Cumo::DFloat(view)#shape=[0][]]
682
+
683
+ def split(indices_or_sections, axis:0)
684
+ axis = check_axis(axis)
685
+ size_axis = shape[axis]
686
+ case indices_or_sections
687
+ when Integer
688
+ div_axis, mod_axis = size_axis.divmod(indices_or_sections)
689
+ refs = [true]*ndim
690
+ beg_idx = 0
691
+ mod_axis.times.map do |i|
692
+ end_idx = beg_idx + div_axis + 1
693
+ refs[axis] = beg_idx ... end_idx
694
+ beg_idx = end_idx
695
+ self[*refs]
696
+ end +
697
+ (indices_or_sections-mod_axis).times.map do |i|
698
+ end_idx = beg_idx + div_axis
699
+ refs[axis] = beg_idx ... end_idx
700
+ beg_idx = end_idx
701
+ self[*refs]
702
+ end
703
+ when NArray
704
+ split(indices_or_sections.to_a,axis:axis)
705
+ when Array
706
+ refs = [true]*ndim
707
+ fst = 0
708
+ (indices_or_sections + [size_axis]).map do |lst|
709
+ lst = size_axis if lst > size_axis
710
+ refs[axis] = (fst < size_axis) ? fst...lst : -1...-1
711
+ fst = lst
712
+ self[*refs]
713
+ end
714
+ else
715
+ raise TypeError,"argument must be Integer or Array"
716
+ end
717
+ end
718
+
719
+ # @example
720
+ # p x = Cumo::DFloat.new(4,4).seq
721
+ # # Cumo::DFloat#shape=[4,4]
722
+ # # [[0, 1, 2, 3],
723
+ # # [4, 5, 6, 7],
724
+ # # [8, 9, 10, 11],
725
+ # # [12, 13, 14, 15]]
726
+ #
727
+ # pp x.hsplit(2)
728
+ # # [Cumo::DFloat(view)#shape=[4,2]
729
+ # # [[0, 1],
730
+ # # [4, 5],
731
+ # # [8, 9],
732
+ # # [12, 13]],
733
+ # # Cumo::DFloat(view)#shape=[4,2]
734
+ # # [[2, 3],
735
+ # # [6, 7],
736
+ # # [10, 11],
737
+ # # [14, 15]]]
738
+ #
739
+ # pp x.hsplit([3, 6])
740
+ # # [Cumo::DFloat(view)#shape=[4,3]
741
+ # # [[0, 1, 2],
742
+ # # [4, 5, 6],
743
+ # # [8, 9, 10],
744
+ # # [12, 13, 14]],
745
+ # # Cumo::DFloat(view)#shape=[4,1]
746
+ # # [[3],
747
+ # # [7],
748
+ # # [11],
749
+ # # [15]],
750
+ # # Cumo::DFloat(view)#shape=[4,0][]]
751
+
752
+ def vsplit(indices_or_sections)
753
+ split(indices_or_sections, axis:0)
754
+ end
755
+
756
+ def hsplit(indices_or_sections)
757
+ split(indices_or_sections, axis:1)
758
+ end
759
+
760
+ def dsplit(indices_or_sections)
761
+ split(indices_or_sections, axis:2)
762
+ end
763
+
764
+ # @example
765
+ # p a = Cumo::NArray[0,1,2]
766
+ # # Cumo::Int32#shape=[3]
767
+ # # [0, 1, 2]
768
+ #
769
+ # p a.tile(2)
770
+ # # Cumo::Int32#shape=[6]
771
+ # # [0, 1, 2, 0, 1, 2]
772
+ #
773
+ # p a.tile(2,2)
774
+ # # Cumo::Int32#shape=[2,6]
775
+ # # [[0, 1, 2, 0, 1, 2],
776
+ # # [0, 1, 2, 0, 1, 2]]
777
+ #
778
+ # p a.tile(2,1,2)
779
+ # # Cumo::Int32#shape=[2,1,6]
780
+ # # [[[0, 1, 2, 0, 1, 2]],
781
+ # # [[0, 1, 2, 0, 1, 2]]]
782
+ #
783
+ # p b = Cumo::NArray[[1, 2], [3, 4]]
784
+ # # Cumo::Int32#shape=[2,2]
785
+ # # [[1, 2],
786
+ # # [3, 4]]
787
+ #
788
+ # p b.tile(2)
789
+ # # Cumo::Int32#shape=[2,4]
790
+ # # [[1, 2, 1, 2],
791
+ # # [3, 4, 3, 4]]
792
+ #
793
+ # p b.tile(2,1)
794
+ # # Cumo::Int32#shape=[4,2]
795
+ # # [[1, 2],
796
+ # # [3, 4],
797
+ # # [1, 2],
798
+ # # [3, 4]]
799
+ #
800
+ # p c = Cumo::NArray[1,2,3,4]
801
+ # # Cumo::Int32#shape=[4]
802
+ # # [1, 2, 3, 4]
803
+ #
804
+ # p c.tile(4,1)
805
+ # # Cumo::Int32#shape=[4,4]
806
+ # # [[1, 2, 3, 4],
807
+ # # [1, 2, 3, 4],
808
+ # # [1, 2, 3, 4],
809
+ # # [1, 2, 3, 4]]
810
+
811
+ def tile(*arg)
812
+ arg.each do |i|
813
+ if !i.kind_of?(Integer) || i<1
814
+ raise ArgumentError,"argument should be positive integer"
815
+ end
816
+ end
817
+ ns = arg.size
818
+ nd = self.ndim
819
+ shp = self.shape
820
+ new_shp = []
821
+ src_shp = []
822
+ res_shp = []
823
+ (nd-ns).times do
824
+ new_shp << 1
825
+ new_shp << (n = shp.shift)
826
+ src_shp << :new
827
+ src_shp << true
828
+ res_shp << n
829
+ end
830
+ (ns-nd).times do
831
+ new_shp << (m = arg.shift)
832
+ new_shp << 1
833
+ src_shp << :new
834
+ src_shp << :new
835
+ res_shp << m
836
+ end
837
+ [nd,ns].min.times do
838
+ new_shp << (m = arg.shift)
839
+ new_shp << (n = shp.shift)
840
+ src_shp << :new
841
+ src_shp << true
842
+ res_shp << n*m
843
+ end
844
+ self.class.new(*new_shp).store(self[*src_shp]).reshape(*res_shp)
845
+ end
846
+
847
+ # @example
848
+ # p Cumo::NArray[3].repeat(4)
849
+ # # Cumo::Int32#shape=[4]
850
+ # # [3, 3, 3, 3]
851
+ #
852
+ # p x = Cumo::NArray[[1,2],[3,4]]
853
+ # # Cumo::Int32#shape=[2,2]
854
+ # # [[1, 2],
855
+ # # [3, 4]]
856
+ #
857
+ # p x.repeat(2)
858
+ # # Cumo::Int32#shape=[8]
859
+ # # [1, 1, 2, 2, 3, 3, 4, 4]
860
+ #
861
+ # p x.repeat(3,axis:1)
862
+ # # Cumo::Int32#shape=[2,6]
863
+ # # [[1, 1, 1, 2, 2, 2],
864
+ # # [3, 3, 3, 4, 4, 4]]
865
+ #
866
+ # p x.repeat([1,2],axis:0)
867
+ # # Cumo::Int32#shape=[3,2]
868
+ # # [[1, 2],
869
+ # # [3, 4],
870
+ # # [3, 4]]
871
+
872
+ def repeat(arg,axis:nil)
873
+ case axis
874
+ when Integer
875
+ axis = check_axis(axis)
876
+ c = self
877
+ when NilClass
878
+ c = self.flatten
879
+ axis = 0
880
+ else
881
+ raise ArgumentError,"invalid axis"
882
+ end
883
+ case arg
884
+ when Integer
885
+ if !arg.kind_of?(Integer) || arg<1
886
+ raise ArgumentError,"argument should be positive integer"
887
+ end
888
+ idx = c.shape[axis].times.map{|i| [i]*arg}.flatten
889
+ else
890
+ arg = arg.to_a
891
+ if arg.size != c.shape[axis]
892
+ raise ArgumentError,"repeat size shoud be equal to size along axis"
893
+ end
894
+ arg.each do |i|
895
+ if !i.kind_of?(Integer) || i<0
896
+ raise ArgumentError,"argument should be non-negative integer"
897
+ end
898
+ end
899
+ idx = arg.each_with_index.map{|a,i| [i]*a}.flatten
900
+ end
901
+ ref = [true] * c.ndim
902
+ ref[axis] = idx
903
+ c[*ref].copy
904
+ end
905
+
906
+ # Calculate the n-th discrete difference along given axis.
907
+ # @example
908
+ # p x = Cumo::DFloat[1, 2, 4, 7, 0]
909
+ # # Cumo::DFloat#shape=[5]
910
+ # # [1, 2, 4, 7, 0]
911
+ #
912
+ # p x.diff
913
+ # # Cumo::DFloat#shape=[4]
914
+ # # [1, 2, 3, -7]
915
+ #
916
+ # p x.diff(2)
917
+ # # Cumo::DFloat#shape=[3]
918
+ # # [1, 1, -10]
919
+ #
920
+ # p x = Cumo::DFloat[[1, 3, 6, 10], [0, 5, 6, 8]]
921
+ # # Cumo::DFloat#shape=[2,4]
922
+ # # [[1, 3, 6, 10],
923
+ # # [0, 5, 6, 8]]
924
+ #
925
+ # p x.diff
926
+ # # Cumo::DFloat#shape=[2,3]
927
+ # # [[2, 3, 4],
928
+ # # [5, 1, 2]]
929
+ #
930
+ # p x.diff(axis:0)
931
+ # # Cumo::DFloat#shape=[1,4]
932
+ # # [[-1, 2, 0, -2]]
933
+
934
+ def diff(n=1,axis:-1)
935
+ axis = check_axis(axis)
936
+ if n < 0 || n >= shape[axis]
937
+ raise ShapeError,"n=#{n} is invalid for shape[#{axis}]=#{shape[axis]}"
938
+ end
939
+ # calculate polynomial coefficient
940
+ c = self.class[-1,1]
941
+ 2.upto(n) do |i|
942
+ x = self.class.zeros(i+1)
943
+ x[0..-2] = c
944
+ y = self.class.zeros(i+1)
945
+ y[1..-1] = c
946
+ c = y - x
947
+ end
948
+ s = [true]*ndim
949
+ s[axis] = n..-1
950
+ result = self[*s].dup
951
+ sum = result.inplace
952
+ (n-1).downto(0) do |i|
953
+ s = [true]*ndim
954
+ s[axis] = i..-n-1+i
955
+ sum + self[*s] * c[i] # inplace addition
956
+ end
957
+ return result
958
+ end
959
+
960
+
961
+ # Upper triangular matrix.
962
+ # Return a copy with the elements below the k-th diagonal filled with zero.
963
+ def triu(k=0)
964
+ dup.triu!(k)
965
+ end
966
+
967
+ # Upper triangular matrix.
968
+ # Fill the self elements below the k-th diagonal with zero.
969
+ def triu!(k=0)
970
+ if ndim < 2
971
+ raise NArray::ShapeError, "must be >= 2-dimensional array"
972
+ end
973
+ if contiguous?
974
+ *shp,m,n = shape
975
+ idx = tril_indices(k-1)
976
+ reshape!(*shp,m*n)
977
+ self[false,idx] = 0
978
+ reshape!(*shp,m,n)
979
+ else
980
+ store(triu(k))
981
+ end
982
+ end
983
+
984
+ # Return the indices for the uppler-triangle on and above the k-th diagonal.
985
+ def triu_indices(k=0)
986
+ if ndim < 2
987
+ raise NArray::ShapeError, "must be >= 2-dimensional array"
988
+ end
989
+ m,n = shape[-2..-1]
990
+ NArray.triu_indices(m,n,k=0)
991
+ end
992
+
993
+ # Return the indices for the uppler-triangle on and above the k-th diagonal.
994
+ def self.triu_indices(m,n,k=0)
995
+ x = Cumo::Int64.new(m,1).seq + k
996
+ y = Cumo::Int64.new(1,n).seq
997
+ (x<=y).where
998
+ end
999
+
1000
+ # Lower triangular matrix.
1001
+ # Return a copy with the elements above the k-th diagonal filled with zero.
1002
+ def tril(k=0)
1003
+ dup.tril!(k)
1004
+ end
1005
+
1006
+ # Lower triangular matrix.
1007
+ # Fill the self elements above the k-th diagonal with zero.
1008
+ def tril!(k=0)
1009
+ if ndim < 2
1010
+ raise NArray::ShapeError, "must be >= 2-dimensional array"
1011
+ end
1012
+ if contiguous?
1013
+ idx = triu_indices(k+1)
1014
+ *shp,m,n = shape
1015
+ reshape!(*shp,m*n)
1016
+ self[false,idx] = 0
1017
+ reshape!(*shp,m,n)
1018
+ else
1019
+ store(tril(k))
1020
+ end
1021
+ end
1022
+
1023
+ # Return the indices for the lower-triangle on and below the k-th diagonal.
1024
+ def tril_indices(k=0)
1025
+ if ndim < 2
1026
+ raise NArray::ShapeError, "must be >= 2-dimensional array"
1027
+ end
1028
+ m,n = shape[-2..-1]
1029
+ NArray.tril_indices(m,n,k)
1030
+ end
1031
+
1032
+ # Return the indices for the lower-triangle on and below the k-th diagonal.
1033
+ def self.tril_indices(m,n,k=0)
1034
+ x = Cumo::Int64.new(m,1).seq + k
1035
+ y = Cumo::Int64.new(1,n).seq
1036
+ (x>=y).where
1037
+ end
1038
+
1039
+ # Return the k-th diagonal indices.
1040
+ def diag_indices(k=0)
1041
+ if ndim < 2
1042
+ raise NArray::ShapeError, "must be >= 2-dimensional array"
1043
+ end
1044
+ m,n = shape[-2..-1]
1045
+ NArray.diag_indices(m,n,k)
1046
+ end
1047
+
1048
+ # Return the k-th diagonal indices.
1049
+ def self.diag_indices(m,n,k=0)
1050
+ x = Cumo::Int64.new(m,1).seq + k
1051
+ y = Cumo::Int64.new(1,n).seq
1052
+ (x.eq y).where
1053
+ end
1054
+
1055
+ # Return a matrix whose diagonal is constructed by self along the last axis.
1056
+ def diag(k=0)
1057
+ *shp,n = shape
1058
+ n += k.abs
1059
+ a = self.class.zeros(*shp,n,n)
1060
+ a.diagonal(k).store(self)
1061
+ a
1062
+ end
1063
+
1064
+ # Return the sum along diagonals of the array.
1065
+ #
1066
+ # If 2-D array, computes the summation along its diagonal with the
1067
+ # given offset, i.e., sum of `a[i,i+offset]`.
1068
+ # If more than 2-D array, the diagonal is determined from the axes
1069
+ # specified by axis argument. The default is axis=[-2,-1].
1070
+ # @param offset [Integer] (optional, default=0) diagonal offset
1071
+ # @param axis [Array] (optional, default=[-2,-1]) diagonal axis
1072
+ # @param nan [Bool] (optional, default=false) nan-aware algorithm, i.e., if true then it ignores nan.
1073
+
1074
+ def trace(offset=nil,axis=nil,nan:false)
1075
+ diagonal(offset,axis).sum(nan:nan,axis:-1)
1076
+ end
1077
+
1078
+
1079
+ @@warn_slow_dot = false
1080
+
1081
+ # Dot product of two arrays.
1082
+ # @param b [Cumo::NArray]
1083
+ # @return [Cumo::NArray] return dot product
1084
+
1085
+ def dot(b)
1086
+ t = self.class::UPCAST[b.class]
1087
+ if [SFloat, DFloat, SComplex, DComplex].include?(t)
1088
+ b = self.class.asarray(b)
1089
+ case self.ndim
1090
+ when 1
1091
+ case b.ndim
1092
+ when 1
1093
+ self.mulsum(b, axis:-1)
1094
+ else
1095
+ self[:new, false].gemm(b).flatten
1096
+ end
1097
+ else
1098
+ case b.ndim
1099
+ when 1
1100
+ self.gemm(b[false, :new]).flatten
1101
+ else
1102
+ self.gemm(b)
1103
+ end
1104
+ end
1105
+ else
1106
+ b = self.class.asarray(b)
1107
+ case b.ndim
1108
+ when 1
1109
+ mulsum(b, axis:-1)
1110
+ else
1111
+ case ndim
1112
+ when 0
1113
+ b.mulsum(self, axis:-2)
1114
+ when 1
1115
+ self[true,:new].mulsum(b, axis:-2)
1116
+ else
1117
+ unless @@warn_slow_dot
1118
+ nx = 200
1119
+ ns = 200000
1120
+ am,an = shape[-2..-1]
1121
+ bm,bn = b.shape[-2..-1]
1122
+ if am > nx && an > nx && bm > nx && bn > nx &&
1123
+ size > ns && b.size > ns
1124
+ @@warn_slow_dot = true
1125
+ warn "\nwarning: matrix dot for #{t} is slow. Consider SFloat, DFloat, SComplex, or DComplex to use cuBLAS.\n\n"
1126
+ end
1127
+ end
1128
+ self[false,:new].mulsum(b[false,:new,true,true], axis:-2)
1129
+ end
1130
+ end
1131
+ end
1132
+ end
1133
+
1134
+ # Inner product of two arrays.
1135
+ # Same as `(a*b).sum(axis:-1)`.
1136
+ # @param b [Cumo::NArray]
1137
+ # @param axis [Integer] applied axis
1138
+ # @return [Cumo::NArray] return (a*b).sum(axis:axis)
1139
+
1140
+ def inner(b, axis:-1)
1141
+ mulsum(b, axis:axis)
1142
+ end
1143
+
1144
+ # Outer product of two arrays.
1145
+ # Same as `self[false,:new] * b[false,:new,true]`.
1146
+ #
1147
+ # @param b [Cumo::NArray]
1148
+ # @param axis [Integer] applied axis (default=-1)
1149
+ # @return [Cumo::NArray] return outer product
1150
+ # @example
1151
+ # a = Cumo::DFloat.ones(5)
1152
+ # => Cumo::DFloat#shape=[5]
1153
+ # [1, 1, 1, 1, 1]
1154
+ # b = Cumo::DFloat.linspace(-2,2,5)
1155
+ # => Cumo::DFloat#shape=[5]
1156
+ # [-2, -1, 0, 1, 2]
1157
+ # a.outer(b)
1158
+ # => Cumo::DFloat#shape=[5,5]
1159
+ # [[-2, -1, 0, 1, 2],
1160
+ # [-2, -1, 0, 1, 2],
1161
+ # [-2, -1, 0, 1, 2],
1162
+ # [-2, -1, 0, 1, 2],
1163
+ # [-2, -1, 0, 1, 2]]
1164
+
1165
+ def outer(b, axis:nil)
1166
+ b = NArray.cast(b)
1167
+ if axis.nil?
1168
+ self[false,:new] * ((b.ndim==0) ? b : b[false,:new,true])
1169
+ else
1170
+ md,nd = [ndim,b.ndim].minmax
1171
+ axis = check_axis(axis) - nd
1172
+ if axis < -md
1173
+ raise ArgumentError,"axis=#{axis} is out of range"
1174
+ end
1175
+ adim = [true]*ndim
1176
+ adim[axis+ndim+1,0] = :new
1177
+ bdim = [true]*b.ndim
1178
+ bdim[axis+b.ndim,0] = :new
1179
+ self[*adim] * b[*bdim]
1180
+ end
1181
+ end
1182
+
1183
+ # Kronecker product of two arrays.
1184
+ #
1185
+ # kron(a,b)[k_0, k_1, ...] = a[i_0, i_1, ...] * b[j_0, j_1, ...]
1186
+ # where: k_n = i_n * b.shape[n] + j_n
1187
+ #
1188
+ # @param b [Cumo::NArray]
1189
+ # @return [Cumo::NArray] return Kronecker product
1190
+ # @example
1191
+ # Cumo::DFloat[1,10,100].kron([5,6,7])
1192
+ # => Cumo::DFloat#shape=[9]
1193
+ # [5, 6, 7, 50, 60, 70, 500, 600, 700]
1194
+ # Cumo::DFloat[5,6,7].kron([1,10,100])
1195
+ # => Cumo::DFloat#shape=[9]
1196
+ # [5, 50, 500, 6, 60, 600, 7, 70, 700]
1197
+ # Cumo::DFloat.eye(2).kron(Cumo::DFloat.ones(2,2))
1198
+ # => Cumo::DFloat#shape=[4,4]
1199
+ # [[1, 1, 0, 0],
1200
+ # [1, 1, 0, 0],
1201
+ # [0, 0, 1, 1],
1202
+ # [0, 0, 1, 1]]
1203
+
1204
+ def kron(b)
1205
+ b = NArray.cast(b)
1206
+ nda = ndim
1207
+ ndb = b.ndim
1208
+ shpa = shape
1209
+ shpb = b.shape
1210
+ adim = [:new]*(2*[ndb-nda,0].max) + [true,:new]*nda
1211
+ bdim = [:new]*(2*[nda-ndb,0].max) + [:new,true]*ndb
1212
+ shpr = (-[nda,ndb].max..-1).map{|i| (shpa[i]||1) * (shpb[i]||1)}
1213
+ (self[*adim] * b[*bdim]).reshape(*shpr)
1214
+ end
1215
+
1216
+
1217
+ # under construction
1218
+ def cov(y=nil, ddof:1, fweights:nil, aweights:nil)
1219
+ if y
1220
+ m = NArray.vstack([self,y])
1221
+ else
1222
+ m = self
1223
+ end
1224
+ w = nil
1225
+ if fweights
1226
+ f = fweights
1227
+ w = f
1228
+ end
1229
+ if aweights
1230
+ a = aweights
1231
+ w = w ? w*a : a
1232
+ end
1233
+ if w
1234
+ w_sum = w.sum(axis:-1, keepdims:true)
1235
+ if ddof == 0
1236
+ fact = w_sum
1237
+ elsif aweights.nil?
1238
+ fact = w_sum - ddof
1239
+ else
1240
+ wa_sum = (w*a).sum(axis:-1, keepdims:true)
1241
+ fact = w_sum - ddof * wa_sum / w_sum
1242
+ end
1243
+ if (fact <= 0).any?
1244
+ raise StandardError,"Degrees of freedom <= 0 for slice"
1245
+ end
1246
+ else
1247
+ fact = m.shape[-1] - ddof
1248
+ end
1249
+ if w
1250
+ m -= (m*w).sum(axis:-1, keepdims:true) / w_sum
1251
+ mw = m*w
1252
+ else
1253
+ m -= m.mean(axis:-1, keepdims:true)
1254
+ mw = m
1255
+ end
1256
+ mt = (m.ndim < 2) ? m : m.swapaxes(-2,-1)
1257
+ mw.dot(mt.conj) / fact
1258
+ end
1259
+
1260
+ private
1261
+
1262
+ # @!visibility private
1263
+ def check_axis(axis)
1264
+ unless Integer===axis
1265
+ raise ArgumentError,"axis=#{axis} must be Integer"
1266
+ end
1267
+ a = axis
1268
+ if a < 0
1269
+ a += ndim
1270
+ end
1271
+ if a < 0 || a >= ndim
1272
+ raise ArgumentError,"axis=#{axis} is invalid"
1273
+ end
1274
+ a
1275
+ end
1276
+
1277
+ end
1278
+ end