cumo 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.gitignore +27 -0
- data/.travis.yml +5 -0
- data/3rd_party/mkmf-cu/.gitignore +36 -0
- data/3rd_party/mkmf-cu/Gemfile +3 -0
- data/3rd_party/mkmf-cu/LICENSE +21 -0
- data/3rd_party/mkmf-cu/README.md +36 -0
- data/3rd_party/mkmf-cu/Rakefile +11 -0
- data/3rd_party/mkmf-cu/bin/mkmf-cu-nvcc +4 -0
- data/3rd_party/mkmf-cu/lib/mkmf-cu.rb +32 -0
- data/3rd_party/mkmf-cu/lib/mkmf-cu/cli.rb +80 -0
- data/3rd_party/mkmf-cu/lib/mkmf-cu/nvcc.rb +157 -0
- data/3rd_party/mkmf-cu/mkmf-cu.gemspec +16 -0
- data/3rd_party/mkmf-cu/test/test_mkmf-cu.rb +67 -0
- data/CODE_OF_CONDUCT.md +46 -0
- data/Gemfile +8 -0
- data/LICENSE.txt +82 -0
- data/README.md +252 -0
- data/Rakefile +43 -0
- data/bench/broadcast_fp32.rb +138 -0
- data/bench/cumo_bench.rb +193 -0
- data/bench/numo_bench.rb +138 -0
- data/bench/reduction_fp32.rb +117 -0
- data/bin/console +14 -0
- data/bin/setup +8 -0
- data/cumo.gemspec +32 -0
- data/ext/cumo/cuda/cublas.c +278 -0
- data/ext/cumo/cuda/driver.c +421 -0
- data/ext/cumo/cuda/memory_pool.cpp +185 -0
- data/ext/cumo/cuda/memory_pool_impl.cpp +308 -0
- data/ext/cumo/cuda/memory_pool_impl.hpp +370 -0
- data/ext/cumo/cuda/memory_pool_impl_test.cpp +554 -0
- data/ext/cumo/cuda/nvrtc.c +207 -0
- data/ext/cumo/cuda/runtime.c +167 -0
- data/ext/cumo/cumo.c +148 -0
- data/ext/cumo/depend.erb +58 -0
- data/ext/cumo/extconf.rb +179 -0
- data/ext/cumo/include/cumo.h +25 -0
- data/ext/cumo/include/cumo/compat.h +23 -0
- data/ext/cumo/include/cumo/cuda/cublas.h +153 -0
- data/ext/cumo/include/cumo/cuda/cumo_thrust.hpp +187 -0
- data/ext/cumo/include/cumo/cuda/cumo_thrust_complex.hpp +79 -0
- data/ext/cumo/include/cumo/cuda/driver.h +22 -0
- data/ext/cumo/include/cumo/cuda/memory_pool.h +28 -0
- data/ext/cumo/include/cumo/cuda/nvrtc.h +22 -0
- data/ext/cumo/include/cumo/cuda/runtime.h +40 -0
- data/ext/cumo/include/cumo/indexer.h +238 -0
- data/ext/cumo/include/cumo/intern.h +142 -0
- data/ext/cumo/include/cumo/intern_fwd.h +38 -0
- data/ext/cumo/include/cumo/intern_kernel.h +6 -0
- data/ext/cumo/include/cumo/narray.h +429 -0
- data/ext/cumo/include/cumo/narray_kernel.h +149 -0
- data/ext/cumo/include/cumo/ndloop.h +95 -0
- data/ext/cumo/include/cumo/reduce_kernel.h +126 -0
- data/ext/cumo/include/cumo/template.h +158 -0
- data/ext/cumo/include/cumo/template_kernel.h +77 -0
- data/ext/cumo/include/cumo/types/bit.h +40 -0
- data/ext/cumo/include/cumo/types/bit_kernel.h +34 -0
- data/ext/cumo/include/cumo/types/complex.h +402 -0
- data/ext/cumo/include/cumo/types/complex_kernel.h +414 -0
- data/ext/cumo/include/cumo/types/complex_macro.h +382 -0
- data/ext/cumo/include/cumo/types/complex_macro_kernel.h +186 -0
- data/ext/cumo/include/cumo/types/dcomplex.h +46 -0
- data/ext/cumo/include/cumo/types/dcomplex_kernel.h +13 -0
- data/ext/cumo/include/cumo/types/dfloat.h +47 -0
- data/ext/cumo/include/cumo/types/dfloat_kernel.h +14 -0
- data/ext/cumo/include/cumo/types/float_def.h +34 -0
- data/ext/cumo/include/cumo/types/float_def_kernel.h +39 -0
- data/ext/cumo/include/cumo/types/float_macro.h +191 -0
- data/ext/cumo/include/cumo/types/float_macro_kernel.h +158 -0
- data/ext/cumo/include/cumo/types/int16.h +24 -0
- data/ext/cumo/include/cumo/types/int16_kernel.h +23 -0
- data/ext/cumo/include/cumo/types/int32.h +24 -0
- data/ext/cumo/include/cumo/types/int32_kernel.h +19 -0
- data/ext/cumo/include/cumo/types/int64.h +24 -0
- data/ext/cumo/include/cumo/types/int64_kernel.h +19 -0
- data/ext/cumo/include/cumo/types/int8.h +24 -0
- data/ext/cumo/include/cumo/types/int8_kernel.h +19 -0
- data/ext/cumo/include/cumo/types/int_macro.h +67 -0
- data/ext/cumo/include/cumo/types/int_macro_kernel.h +48 -0
- data/ext/cumo/include/cumo/types/real_accum.h +486 -0
- data/ext/cumo/include/cumo/types/real_accum_kernel.h +101 -0
- data/ext/cumo/include/cumo/types/robj_macro.h +80 -0
- data/ext/cumo/include/cumo/types/robj_macro_kernel.h +0 -0
- data/ext/cumo/include/cumo/types/robject.h +27 -0
- data/ext/cumo/include/cumo/types/robject_kernel.h +7 -0
- data/ext/cumo/include/cumo/types/scomplex.h +46 -0
- data/ext/cumo/include/cumo/types/scomplex_kernel.h +13 -0
- data/ext/cumo/include/cumo/types/sfloat.h +48 -0
- data/ext/cumo/include/cumo/types/sfloat_kernel.h +14 -0
- data/ext/cumo/include/cumo/types/uint16.h +25 -0
- data/ext/cumo/include/cumo/types/uint16_kernel.h +20 -0
- data/ext/cumo/include/cumo/types/uint32.h +25 -0
- data/ext/cumo/include/cumo/types/uint32_kernel.h +20 -0
- data/ext/cumo/include/cumo/types/uint64.h +25 -0
- data/ext/cumo/include/cumo/types/uint64_kernel.h +20 -0
- data/ext/cumo/include/cumo/types/uint8.h +25 -0
- data/ext/cumo/include/cumo/types/uint8_kernel.h +20 -0
- data/ext/cumo/include/cumo/types/uint_macro.h +58 -0
- data/ext/cumo/include/cumo/types/uint_macro_kernel.h +38 -0
- data/ext/cumo/include/cumo/types/xint_macro.h +169 -0
- data/ext/cumo/include/cumo/types/xint_macro_kernel.h +88 -0
- data/ext/cumo/narray/SFMT-params.h +97 -0
- data/ext/cumo/narray/SFMT-params19937.h +46 -0
- data/ext/cumo/narray/SFMT.c +620 -0
- data/ext/cumo/narray/SFMT.h +167 -0
- data/ext/cumo/narray/array.c +638 -0
- data/ext/cumo/narray/data.c +961 -0
- data/ext/cumo/narray/gen/cogen.rb +56 -0
- data/ext/cumo/narray/gen/cogen_kernel.rb +58 -0
- data/ext/cumo/narray/gen/def/bit.rb +37 -0
- data/ext/cumo/narray/gen/def/dcomplex.rb +39 -0
- data/ext/cumo/narray/gen/def/dfloat.rb +37 -0
- data/ext/cumo/narray/gen/def/int16.rb +36 -0
- data/ext/cumo/narray/gen/def/int32.rb +36 -0
- data/ext/cumo/narray/gen/def/int64.rb +36 -0
- data/ext/cumo/narray/gen/def/int8.rb +36 -0
- data/ext/cumo/narray/gen/def/robject.rb +37 -0
- data/ext/cumo/narray/gen/def/scomplex.rb +39 -0
- data/ext/cumo/narray/gen/def/sfloat.rb +37 -0
- data/ext/cumo/narray/gen/def/uint16.rb +36 -0
- data/ext/cumo/narray/gen/def/uint32.rb +36 -0
- data/ext/cumo/narray/gen/def/uint64.rb +36 -0
- data/ext/cumo/narray/gen/def/uint8.rb +36 -0
- data/ext/cumo/narray/gen/erbpp2.rb +346 -0
- data/ext/cumo/narray/gen/narray_def.rb +268 -0
- data/ext/cumo/narray/gen/spec.rb +425 -0
- data/ext/cumo/narray/gen/tmpl/accum.c +86 -0
- data/ext/cumo/narray/gen/tmpl/accum_binary.c +121 -0
- data/ext/cumo/narray/gen/tmpl/accum_binary_kernel.cu +61 -0
- data/ext/cumo/narray/gen/tmpl/accum_index.c +119 -0
- data/ext/cumo/narray/gen/tmpl/accum_index_kernel.cu +66 -0
- data/ext/cumo/narray/gen/tmpl/accum_kernel.cu +12 -0
- data/ext/cumo/narray/gen/tmpl/alloc_func.c +107 -0
- data/ext/cumo/narray/gen/tmpl/allocate.c +37 -0
- data/ext/cumo/narray/gen/tmpl/aref.c +66 -0
- data/ext/cumo/narray/gen/tmpl/aref_cpu.c +50 -0
- data/ext/cumo/narray/gen/tmpl/aset.c +56 -0
- data/ext/cumo/narray/gen/tmpl/binary.c +162 -0
- data/ext/cumo/narray/gen/tmpl/binary2.c +70 -0
- data/ext/cumo/narray/gen/tmpl/binary2_kernel.cu +15 -0
- data/ext/cumo/narray/gen/tmpl/binary_kernel.cu +31 -0
- data/ext/cumo/narray/gen/tmpl/binary_s.c +45 -0
- data/ext/cumo/narray/gen/tmpl/binary_s_kernel.cu +15 -0
- data/ext/cumo/narray/gen/tmpl/bincount.c +181 -0
- data/ext/cumo/narray/gen/tmpl/cast.c +44 -0
- data/ext/cumo/narray/gen/tmpl/cast_array.c +13 -0
- data/ext/cumo/narray/gen/tmpl/class.c +9 -0
- data/ext/cumo/narray/gen/tmpl/class_kernel.cu +6 -0
- data/ext/cumo/narray/gen/tmpl/clip.c +121 -0
- data/ext/cumo/narray/gen/tmpl/coerce_cast.c +10 -0
- data/ext/cumo/narray/gen/tmpl/complex_accum_kernel.cu +129 -0
- data/ext/cumo/narray/gen/tmpl/cond_binary.c +68 -0
- data/ext/cumo/narray/gen/tmpl/cond_binary_kernel.cu +18 -0
- data/ext/cumo/narray/gen/tmpl/cond_unary.c +46 -0
- data/ext/cumo/narray/gen/tmpl/cum.c +50 -0
- data/ext/cumo/narray/gen/tmpl/each.c +47 -0
- data/ext/cumo/narray/gen/tmpl/each_with_index.c +70 -0
- data/ext/cumo/narray/gen/tmpl/ewcomp.c +79 -0
- data/ext/cumo/narray/gen/tmpl/ewcomp_kernel.cu +19 -0
- data/ext/cumo/narray/gen/tmpl/extract.c +22 -0
- data/ext/cumo/narray/gen/tmpl/extract_cpu.c +26 -0
- data/ext/cumo/narray/gen/tmpl/extract_data.c +53 -0
- data/ext/cumo/narray/gen/tmpl/eye.c +105 -0
- data/ext/cumo/narray/gen/tmpl/eye_kernel.cu +19 -0
- data/ext/cumo/narray/gen/tmpl/fill.c +52 -0
- data/ext/cumo/narray/gen/tmpl/fill_kernel.cu +29 -0
- data/ext/cumo/narray/gen/tmpl/float_accum_kernel.cu +106 -0
- data/ext/cumo/narray/gen/tmpl/format.c +62 -0
- data/ext/cumo/narray/gen/tmpl/format_to_a.c +49 -0
- data/ext/cumo/narray/gen/tmpl/frexp.c +38 -0
- data/ext/cumo/narray/gen/tmpl/gemm.c +203 -0
- data/ext/cumo/narray/gen/tmpl/init_class.c +20 -0
- data/ext/cumo/narray/gen/tmpl/init_module.c +12 -0
- data/ext/cumo/narray/gen/tmpl/inspect.c +21 -0
- data/ext/cumo/narray/gen/tmpl/lib.c +50 -0
- data/ext/cumo/narray/gen/tmpl/lib_kernel.cu +24 -0
- data/ext/cumo/narray/gen/tmpl/logseq.c +102 -0
- data/ext/cumo/narray/gen/tmpl/logseq_kernel.cu +31 -0
- data/ext/cumo/narray/gen/tmpl/map_with_index.c +98 -0
- data/ext/cumo/narray/gen/tmpl/median.c +66 -0
- data/ext/cumo/narray/gen/tmpl/minmax.c +47 -0
- data/ext/cumo/narray/gen/tmpl/module.c +9 -0
- data/ext/cumo/narray/gen/tmpl/module_kernel.cu +1 -0
- data/ext/cumo/narray/gen/tmpl/new_dim0.c +15 -0
- data/ext/cumo/narray/gen/tmpl/new_dim0_kernel.cu +8 -0
- data/ext/cumo/narray/gen/tmpl/poly.c +50 -0
- data/ext/cumo/narray/gen/tmpl/pow.c +97 -0
- data/ext/cumo/narray/gen/tmpl/pow_kernel.cu +29 -0
- data/ext/cumo/narray/gen/tmpl/powint.c +17 -0
- data/ext/cumo/narray/gen/tmpl/qsort.c +212 -0
- data/ext/cumo/narray/gen/tmpl/rand.c +168 -0
- data/ext/cumo/narray/gen/tmpl/rand_norm.c +121 -0
- data/ext/cumo/narray/gen/tmpl/real_accum_kernel.cu +75 -0
- data/ext/cumo/narray/gen/tmpl/seq.c +112 -0
- data/ext/cumo/narray/gen/tmpl/seq_kernel.cu +43 -0
- data/ext/cumo/narray/gen/tmpl/set2.c +57 -0
- data/ext/cumo/narray/gen/tmpl/sort.c +48 -0
- data/ext/cumo/narray/gen/tmpl/sort_index.c +111 -0
- data/ext/cumo/narray/gen/tmpl/store.c +41 -0
- data/ext/cumo/narray/gen/tmpl/store_array.c +187 -0
- data/ext/cumo/narray/gen/tmpl/store_array_kernel.cu +58 -0
- data/ext/cumo/narray/gen/tmpl/store_bit.c +86 -0
- data/ext/cumo/narray/gen/tmpl/store_bit_kernel.cu +66 -0
- data/ext/cumo/narray/gen/tmpl/store_from.c +81 -0
- data/ext/cumo/narray/gen/tmpl/store_from_kernel.cu +58 -0
- data/ext/cumo/narray/gen/tmpl/store_kernel.cu +3 -0
- data/ext/cumo/narray/gen/tmpl/store_numeric.c +9 -0
- data/ext/cumo/narray/gen/tmpl/to_a.c +43 -0
- data/ext/cumo/narray/gen/tmpl/unary.c +132 -0
- data/ext/cumo/narray/gen/tmpl/unary2.c +60 -0
- data/ext/cumo/narray/gen/tmpl/unary_kernel.cu +72 -0
- data/ext/cumo/narray/gen/tmpl/unary_ret2.c +34 -0
- data/ext/cumo/narray/gen/tmpl/unary_s.c +86 -0
- data/ext/cumo/narray/gen/tmpl/unary_s_kernel.cu +58 -0
- data/ext/cumo/narray/gen/tmpl_bit/allocate.c +24 -0
- data/ext/cumo/narray/gen/tmpl_bit/aref.c +54 -0
- data/ext/cumo/narray/gen/tmpl_bit/aref_cpu.c +57 -0
- data/ext/cumo/narray/gen/tmpl_bit/aset.c +56 -0
- data/ext/cumo/narray/gen/tmpl_bit/binary.c +98 -0
- data/ext/cumo/narray/gen/tmpl_bit/bit_count.c +64 -0
- data/ext/cumo/narray/gen/tmpl_bit/bit_count_cpu.c +88 -0
- data/ext/cumo/narray/gen/tmpl_bit/bit_count_kernel.cu +76 -0
- data/ext/cumo/narray/gen/tmpl_bit/bit_reduce.c +133 -0
- data/ext/cumo/narray/gen/tmpl_bit/each.c +48 -0
- data/ext/cumo/narray/gen/tmpl_bit/each_with_index.c +70 -0
- data/ext/cumo/narray/gen/tmpl_bit/extract.c +30 -0
- data/ext/cumo/narray/gen/tmpl_bit/extract_cpu.c +29 -0
- data/ext/cumo/narray/gen/tmpl_bit/fill.c +69 -0
- data/ext/cumo/narray/gen/tmpl_bit/format.c +64 -0
- data/ext/cumo/narray/gen/tmpl_bit/format_to_a.c +51 -0
- data/ext/cumo/narray/gen/tmpl_bit/inspect.c +21 -0
- data/ext/cumo/narray/gen/tmpl_bit/mask.c +136 -0
- data/ext/cumo/narray/gen/tmpl_bit/none_p.c +14 -0
- data/ext/cumo/narray/gen/tmpl_bit/store_array.c +108 -0
- data/ext/cumo/narray/gen/tmpl_bit/store_bit.c +70 -0
- data/ext/cumo/narray/gen/tmpl_bit/store_from.c +60 -0
- data/ext/cumo/narray/gen/tmpl_bit/to_a.c +47 -0
- data/ext/cumo/narray/gen/tmpl_bit/unary.c +81 -0
- data/ext/cumo/narray/gen/tmpl_bit/where.c +90 -0
- data/ext/cumo/narray/gen/tmpl_bit/where2.c +95 -0
- data/ext/cumo/narray/index.c +880 -0
- data/ext/cumo/narray/kwargs.c +153 -0
- data/ext/cumo/narray/math.c +142 -0
- data/ext/cumo/narray/narray.c +1948 -0
- data/ext/cumo/narray/ndloop.c +2105 -0
- data/ext/cumo/narray/rand.c +45 -0
- data/ext/cumo/narray/step.c +474 -0
- data/ext/cumo/narray/struct.c +886 -0
- data/lib/cumo.rb +3 -0
- data/lib/cumo/cuda.rb +11 -0
- data/lib/cumo/cuda/compile_error.rb +36 -0
- data/lib/cumo/cuda/compiler.rb +161 -0
- data/lib/cumo/cuda/device.rb +47 -0
- data/lib/cumo/cuda/link_state.rb +31 -0
- data/lib/cumo/cuda/module.rb +40 -0
- data/lib/cumo/cuda/nvrtc_program.rb +27 -0
- data/lib/cumo/linalg.rb +12 -0
- data/lib/cumo/narray.rb +2 -0
- data/lib/cumo/narray/extra.rb +1278 -0
- data/lib/erbpp.rb +294 -0
- data/lib/erbpp/line_number.rb +137 -0
- data/lib/erbpp/narray_def.rb +381 -0
- data/numo-narray-version +1 -0
- data/run.gdb +7 -0
- metadata +353 -0
data/lib/cumo.rb
ADDED
data/lib/cumo/cuda.rb
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
module Cumo
|
|
2
|
+
module CUDA
|
|
3
|
+
end
|
|
4
|
+
end
|
|
5
|
+
|
|
6
|
+
require_relative 'cuda/compile_error'
|
|
7
|
+
require_relative 'cuda/compiler'
|
|
8
|
+
require_relative 'cuda/device'
|
|
9
|
+
require_relative 'cuda/module'
|
|
10
|
+
require_relative 'cuda/link_state'
|
|
11
|
+
require_relative 'cuda/nvrtc_program'
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
require_relative '../cuda'
|
|
2
|
+
|
|
3
|
+
module Cumo::CUDA
|
|
4
|
+
class CompileError < StandardError
|
|
5
|
+
def initialize(msg, source, name, options)
|
|
6
|
+
@msg = msg
|
|
7
|
+
@source = source
|
|
8
|
+
@name = name
|
|
9
|
+
@options = options
|
|
10
|
+
end
|
|
11
|
+
|
|
12
|
+
def message
|
|
13
|
+
@msg
|
|
14
|
+
end
|
|
15
|
+
|
|
16
|
+
def to_s
|
|
17
|
+
@msg
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
def dump(io)
|
|
21
|
+
lines = @source.split("\n")
|
|
22
|
+
digits = Math.log10(lines.size).floor + 1
|
|
23
|
+
linum_fmt = "%0#{digits}d "
|
|
24
|
+
io.puts("NVRTC compilation error: #{@msg}")
|
|
25
|
+
io.puts("-----")
|
|
26
|
+
io.puts("Name: #{@name}")
|
|
27
|
+
io.puts("Options: #{@options.join(' ')}")
|
|
28
|
+
io.puts("CUDA source:")
|
|
29
|
+
lines.each.with_index do |line, i|
|
|
30
|
+
io.puts(linum_fmt.sprintf(i + 1) << line.rstrip)
|
|
31
|
+
end
|
|
32
|
+
io.puts("-----")
|
|
33
|
+
io.flush
|
|
34
|
+
end
|
|
35
|
+
end
|
|
36
|
+
end
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
require 'tmpdir'
|
|
2
|
+
require 'tempfile'
|
|
3
|
+
require 'fileutils'
|
|
4
|
+
require 'digest/md5'
|
|
5
|
+
require_relative '../cuda'
|
|
6
|
+
|
|
7
|
+
module Cumo::CUDA
|
|
8
|
+
class Compiler
|
|
9
|
+
VALID_KERNEL_NAME = /\A[a-zA-Z_][a-zA-Z_0-9]*\z/
|
|
10
|
+
DEFAULT_CACHE_DIR = File.expand_path('~/.cumo/kernel_cache')
|
|
11
|
+
|
|
12
|
+
@@empty_file_preprocess_cache ||= {}
|
|
13
|
+
|
|
14
|
+
def self.valid_kernel_name?(name)
|
|
15
|
+
VALID_KERNEL_NAME.match?(name)
|
|
16
|
+
end
|
|
17
|
+
|
|
18
|
+
def compile_using_nvrtc(source, options: [], arch: nil)
|
|
19
|
+
arch ||= get_arch
|
|
20
|
+
options += ["-arch=#{arch}"]
|
|
21
|
+
|
|
22
|
+
Dir.mktmpdir do |root_dir|
|
|
23
|
+
path = File.join(root_dir, 'kern')
|
|
24
|
+
cu_path = "#{path}.cu"
|
|
25
|
+
|
|
26
|
+
File.open(cu_path, 'w') do |cu_file|
|
|
27
|
+
cu_file.write(source)
|
|
28
|
+
end
|
|
29
|
+
|
|
30
|
+
prog = NVRTCProgram.new(source, name: cu_path)
|
|
31
|
+
begin
|
|
32
|
+
ptx = prog.compile(options: options)
|
|
33
|
+
rescue CompileError => e
|
|
34
|
+
if get_bool_env_variable('CUMO_DUMP_CUDA_SOURCE_ON_ERROR', false)
|
|
35
|
+
e.dump($stderr)
|
|
36
|
+
end
|
|
37
|
+
raise e
|
|
38
|
+
ensure
|
|
39
|
+
prog.destroy
|
|
40
|
+
end
|
|
41
|
+
return ptx
|
|
42
|
+
end
|
|
43
|
+
end
|
|
44
|
+
|
|
45
|
+
def compile_with_cache(source, options: [], arch: nil, cache_dir: nil, extra_source: nil)
|
|
46
|
+
# NVRTC does not use extra_source. extra_source is used for cache key.
|
|
47
|
+
cache_dir ||= get_cache_dir
|
|
48
|
+
arch ||= get_arch
|
|
49
|
+
|
|
50
|
+
options += ['-ftz=true']
|
|
51
|
+
|
|
52
|
+
env = [arch, options, get_nvrtc_version]
|
|
53
|
+
base = @@empty_file_preprocess_cache[env]
|
|
54
|
+
if base.nil?
|
|
55
|
+
# This is checking of NVRTC compiler internal version
|
|
56
|
+
base = preprocess('', options, arch)
|
|
57
|
+
@@empty_file_preprocess_cache[env] = base
|
|
58
|
+
end
|
|
59
|
+
key_src = "#{env} #{base} #{source} #{extra_source}"
|
|
60
|
+
|
|
61
|
+
key_src.encode!('utf-8')
|
|
62
|
+
digest = Digest::MD5.hexdigest(key_src)
|
|
63
|
+
name = "#{digest}_2.cubin"
|
|
64
|
+
|
|
65
|
+
unless Dir.exist?(cache_dir)
|
|
66
|
+
FileUtils.mkdir_p(cache_dir)
|
|
67
|
+
end
|
|
68
|
+
|
|
69
|
+
# TODO(sonots): thread-safe?
|
|
70
|
+
path = File.join(cache_dir, name)
|
|
71
|
+
cubin = load_cache(path)
|
|
72
|
+
if cubin
|
|
73
|
+
mod = Module.new
|
|
74
|
+
mod.load(cubin)
|
|
75
|
+
return mod
|
|
76
|
+
end
|
|
77
|
+
|
|
78
|
+
ptx = compile_using_nvrtc(source, options: options, arch: arch)
|
|
79
|
+
cubin = nil
|
|
80
|
+
cubin_hash = nil
|
|
81
|
+
LinkState.new do |ls|
|
|
82
|
+
ls.add_ptr_data(ptx, 'cumo.ptx')
|
|
83
|
+
cubin = ls.complete()
|
|
84
|
+
cubin_hash = Digest::MD5.hexdigest(cubin)
|
|
85
|
+
end
|
|
86
|
+
|
|
87
|
+
save_cache(path, cubin_hash, cubin)
|
|
88
|
+
|
|
89
|
+
# Save .cu source file along with .cubin
|
|
90
|
+
if get_bool_env_variable('CUMO_CACHE_SAVE_CUDA_SOURCE', false)
|
|
91
|
+
File.open("#{path}.cu", 'w') do |f|
|
|
92
|
+
f.write(source)
|
|
93
|
+
end
|
|
94
|
+
end
|
|
95
|
+
|
|
96
|
+
mod = Module.new
|
|
97
|
+
mod.load(cubin)
|
|
98
|
+
return mod
|
|
99
|
+
end
|
|
100
|
+
|
|
101
|
+
private
|
|
102
|
+
|
|
103
|
+
def save_cache(path, cubin_hash, cubin)
|
|
104
|
+
tf = Tempfile.create
|
|
105
|
+
tf.write(cubin_hash)
|
|
106
|
+
tf.write(cubin)
|
|
107
|
+
temp_path = tf.path
|
|
108
|
+
File.rename(temp_path, path)
|
|
109
|
+
end
|
|
110
|
+
|
|
111
|
+
def load_cache(path)
|
|
112
|
+
return nil unless File.exist?(path)
|
|
113
|
+
File.open(path, 'rb') do |file|
|
|
114
|
+
data = file.read
|
|
115
|
+
return nil unless data.size >= 32
|
|
116
|
+
hash = data[0...32]
|
|
117
|
+
cubin = data[32..-1]
|
|
118
|
+
cubin_hash = Digest::MD5.hexdigest(cubin)
|
|
119
|
+
return nil unless hash == cubin_hash
|
|
120
|
+
return cubin
|
|
121
|
+
end
|
|
122
|
+
nil
|
|
123
|
+
end
|
|
124
|
+
|
|
125
|
+
def get_cache_dir
|
|
126
|
+
ENV.fetch('CUMO_CACHE_DIR', DEFAULT_CACHE_DIR)
|
|
127
|
+
end
|
|
128
|
+
|
|
129
|
+
def get_nvrtc_version
|
|
130
|
+
@@nvrtc_version ||= NVRTC.nvrtcVersion
|
|
131
|
+
end
|
|
132
|
+
|
|
133
|
+
def get_arch
|
|
134
|
+
cc = Device.new.compute_capability
|
|
135
|
+
"compute_#{cc}"
|
|
136
|
+
end
|
|
137
|
+
|
|
138
|
+
def get_bool_env_variable(name, default)
|
|
139
|
+
val = ENV[name]
|
|
140
|
+
return default if val.nil? or val.size == 0
|
|
141
|
+
Integer(val) == 1 rescue false
|
|
142
|
+
end
|
|
143
|
+
|
|
144
|
+
def preprocess(source, options, arch)
|
|
145
|
+
options += ["-arch=#{arch}"]
|
|
146
|
+
|
|
147
|
+
prog = NVRTCProgram.new(source, name: '')
|
|
148
|
+
begin
|
|
149
|
+
result = prog.compile(options: options)
|
|
150
|
+
return result
|
|
151
|
+
rescue CompileError => e
|
|
152
|
+
if get_bool_env_variable('CUMO_DUMP_CUDA_SOURCE_ON_ERROR', false)
|
|
153
|
+
e.dump($stderr)
|
|
154
|
+
end
|
|
155
|
+
raise e
|
|
156
|
+
ensure
|
|
157
|
+
prog.destroy
|
|
158
|
+
end
|
|
159
|
+
end
|
|
160
|
+
end
|
|
161
|
+
end
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
require_relative '../cuda'
|
|
2
|
+
|
|
3
|
+
module Cumo::CUDA
|
|
4
|
+
class Device
|
|
5
|
+
attr_reader :id
|
|
6
|
+
|
|
7
|
+
def self.get_currend_id
|
|
8
|
+
Runtime.cudaGetDevice
|
|
9
|
+
end
|
|
10
|
+
|
|
11
|
+
def initialize(device_id = nil)
|
|
12
|
+
if device_id
|
|
13
|
+
@id = device_id
|
|
14
|
+
else
|
|
15
|
+
@id = Runtime.cudaGetDevice
|
|
16
|
+
end
|
|
17
|
+
@_device_stack = []
|
|
18
|
+
end
|
|
19
|
+
|
|
20
|
+
def use
|
|
21
|
+
Runtime.cudaSetDevice(@id)
|
|
22
|
+
end
|
|
23
|
+
|
|
24
|
+
def with
|
|
25
|
+
raise unless block_given?
|
|
26
|
+
prev_id = Runtime.cudaGetDevice
|
|
27
|
+
@_device_stack << prev_id
|
|
28
|
+
begin
|
|
29
|
+
Runtime.cudaSetDevice(@id) unless prev_id != @id
|
|
30
|
+
yield
|
|
31
|
+
ensure
|
|
32
|
+
prev_id = @_device_stack.pop
|
|
33
|
+
Runtime.cudaSetDevice(prev_id)
|
|
34
|
+
end
|
|
35
|
+
end
|
|
36
|
+
|
|
37
|
+
def synchronize
|
|
38
|
+
Runtime.cudaDeviceSynchronize
|
|
39
|
+
end
|
|
40
|
+
|
|
41
|
+
def compute_capability
|
|
42
|
+
major = Runtime.cudaDeviceGetAttributes(75, @id)
|
|
43
|
+
minor = Runtime.cudaDeviceGetAttributes(76, @id)
|
|
44
|
+
"#{major}#{minor}"
|
|
45
|
+
end
|
|
46
|
+
end
|
|
47
|
+
end
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
require_relative '../cuda'
|
|
2
|
+
|
|
3
|
+
module Cumo::CUDA
|
|
4
|
+
# CUDA link state.
|
|
5
|
+
class LinkState
|
|
6
|
+
def initialize
|
|
7
|
+
@ptr = Driver.cuLinkCreate
|
|
8
|
+
if block_given?
|
|
9
|
+
begin
|
|
10
|
+
yield(self)
|
|
11
|
+
ensure
|
|
12
|
+
destroy
|
|
13
|
+
end
|
|
14
|
+
end
|
|
15
|
+
end
|
|
16
|
+
|
|
17
|
+
def destroy
|
|
18
|
+
return unless @ptr
|
|
19
|
+
Driver.cuLinkDestroy(@ptr)
|
|
20
|
+
@ptr = nil
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
def add_ptr_data(data, name)
|
|
24
|
+
Driver.cuLinkAddData(@ptr, Driver::CU_JIT_INPUT_PTX, data, name)
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
def complete
|
|
28
|
+
cubin = Driver.cuLinkComplete(@ptr)
|
|
29
|
+
end
|
|
30
|
+
end
|
|
31
|
+
end
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
require_relative '../cuda'
|
|
2
|
+
|
|
3
|
+
module Cumo::CUDA
|
|
4
|
+
# CUDA kernel module.
|
|
5
|
+
class Module
|
|
6
|
+
def initialize
|
|
7
|
+
@ptr = nil
|
|
8
|
+
if block_given?
|
|
9
|
+
begin
|
|
10
|
+
yield(self)
|
|
11
|
+
ensure
|
|
12
|
+
unload
|
|
13
|
+
end
|
|
14
|
+
end
|
|
15
|
+
end
|
|
16
|
+
|
|
17
|
+
def unload
|
|
18
|
+
return unless @ptr
|
|
19
|
+
Driver.cuModuleUnload(@ptr)
|
|
20
|
+
@ptr = nil
|
|
21
|
+
end
|
|
22
|
+
|
|
23
|
+
def load_file(fname)
|
|
24
|
+
@ptr = Driver.cuModuleLoad(fname)
|
|
25
|
+
end
|
|
26
|
+
|
|
27
|
+
def load(cubin)
|
|
28
|
+
@ptr = Driver.cuModuleLoadData(cubin)
|
|
29
|
+
end
|
|
30
|
+
|
|
31
|
+
def get_global_var(name)
|
|
32
|
+
Driver.cuModuleGetGlobal(@ptr, name)
|
|
33
|
+
end
|
|
34
|
+
|
|
35
|
+
def get_function(name)
|
|
36
|
+
# Function(name)
|
|
37
|
+
end
|
|
38
|
+
end
|
|
39
|
+
end
|
|
40
|
+
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
require_relative '../cuda'
|
|
2
|
+
require_relative 'compile_error'
|
|
3
|
+
|
|
4
|
+
module Cumo::CUDA
|
|
5
|
+
class NVRTCProgram
|
|
6
|
+
def initialize(src, name: "default_program", headers: [], include_names: [])
|
|
7
|
+
@ptr = nil
|
|
8
|
+
@src = src # should be UTF-8
|
|
9
|
+
@name = name # should be UTF-8
|
|
10
|
+
@ptr = NVRTC.nvrtcCreateProgram(src, name, headers, include_names)
|
|
11
|
+
end
|
|
12
|
+
|
|
13
|
+
def destroy
|
|
14
|
+
NVRTC.nvrtcDestroyProgram(@ptr) if @ptr
|
|
15
|
+
end
|
|
16
|
+
|
|
17
|
+
def compile(options: [])
|
|
18
|
+
begin
|
|
19
|
+
NVRTC.nvrtcCompileProgram(@ptr, options)
|
|
20
|
+
return NVRTC.nvrtcGetPTX(@ptr)
|
|
21
|
+
rescue NVRTCError => e
|
|
22
|
+
log = NVRTC.nvrtcGetProgramLog(@ptr)
|
|
23
|
+
raise CompileError.new(log, @src, @name, options)
|
|
24
|
+
end
|
|
25
|
+
end
|
|
26
|
+
end
|
|
27
|
+
end
|
data/lib/cumo/linalg.rb
ADDED
data/lib/cumo/narray.rb
ADDED
|
@@ -0,0 +1,1278 @@
|
|
|
1
|
+
module Cumo
|
|
2
|
+
class NArray
|
|
3
|
+
|
|
4
|
+
# Return an unallocated array with the same shape and type as self.
|
|
5
|
+
def new_narray
|
|
6
|
+
self.class.new(*shape)
|
|
7
|
+
end
|
|
8
|
+
|
|
9
|
+
# Return an array of zeros with the same shape and type as self.
|
|
10
|
+
def new_zeros
|
|
11
|
+
self.class.zeros(*shape)
|
|
12
|
+
end
|
|
13
|
+
|
|
14
|
+
# Return an array of ones with the same shape and type as self.
|
|
15
|
+
def new_ones
|
|
16
|
+
self.class.ones(*shape)
|
|
17
|
+
end
|
|
18
|
+
|
|
19
|
+
# Return an array filled with value with the same shape and type as self.
|
|
20
|
+
def new_fill(value)
|
|
21
|
+
self.class.new(*shape).fill(value)
|
|
22
|
+
end
|
|
23
|
+
|
|
24
|
+
# Convert angles from radians to degrees.
|
|
25
|
+
def rad2deg
|
|
26
|
+
self * (180/Math::PI)
|
|
27
|
+
end
|
|
28
|
+
|
|
29
|
+
# Convert angles from degrees to radians.
|
|
30
|
+
def deg2rad
|
|
31
|
+
self * (Math::PI/180)
|
|
32
|
+
end
|
|
33
|
+
|
|
34
|
+
# Flip each row in the left/right direction.
|
|
35
|
+
# Same as `a[true, (-1..0).step(-1), ...]`.
|
|
36
|
+
def fliplr
|
|
37
|
+
reverse(1)
|
|
38
|
+
end
|
|
39
|
+
|
|
40
|
+
# Flip each column in the up/down direction.
|
|
41
|
+
# Same as `a[(-1..0).step(-1), ...]`.
|
|
42
|
+
def flipud
|
|
43
|
+
reverse(0)
|
|
44
|
+
end
|
|
45
|
+
|
|
46
|
+
# Multi-dimensional array indexing.
|
|
47
|
+
# Same as [] for one-dimensional NArray.
|
|
48
|
+
# Similar to numpy's tuple indexing, i.e., `a[[1,2,..],[3,4,..]]`
|
|
49
|
+
# (This method will be rewritten in C)
|
|
50
|
+
# @return [Cumo::NArray] one-dimensional view of self.
|
|
51
|
+
# @example
|
|
52
|
+
# p x = Cumo::DFloat.new(3,3,3).seq
|
|
53
|
+
# # Cumo::DFloat#shape=[3,3,3]
|
|
54
|
+
# # [[[0, 1, 2],
|
|
55
|
+
# # [3, 4, 5],
|
|
56
|
+
# # [6, 7, 8]],
|
|
57
|
+
# # [[9, 10, 11],
|
|
58
|
+
# # [12, 13, 14],
|
|
59
|
+
# # [15, 16, 17]],
|
|
60
|
+
# # [[18, 19, 20],
|
|
61
|
+
# # [21, 22, 23],
|
|
62
|
+
# # [24, 25, 26]]]
|
|
63
|
+
#
|
|
64
|
+
# p x.at([0,1,2],[0,1,2],[-1,-2,-3])
|
|
65
|
+
# # Cumo::DFloat(view)#shape=[3]
|
|
66
|
+
# # [2, 13, 24]
|
|
67
|
+
def at(*indices)
|
|
68
|
+
if indices.size != ndim
|
|
69
|
+
raise DimensionError, "argument length does not match dimension size"
|
|
70
|
+
end
|
|
71
|
+
idx = nil
|
|
72
|
+
stride = 1
|
|
73
|
+
(indices.size-1).downto(0) do |i|
|
|
74
|
+
ix = Int64.cast(indices[i])
|
|
75
|
+
if ix.ndim != 1
|
|
76
|
+
raise DimensionError, "index array is not one-dimensional"
|
|
77
|
+
end
|
|
78
|
+
ix[ix < 0] += shape[i]
|
|
79
|
+
if ((ix < 0) & (ix >= shape[i])).any?
|
|
80
|
+
raise IndexError, "index array is out of range"
|
|
81
|
+
end
|
|
82
|
+
if idx
|
|
83
|
+
if idx.size != ix.size
|
|
84
|
+
raise ShapeError, "index array sizes mismatch"
|
|
85
|
+
end
|
|
86
|
+
idx += ix * stride
|
|
87
|
+
stride *= shape[i]
|
|
88
|
+
else
|
|
89
|
+
idx = ix
|
|
90
|
+
stride = shape[i]
|
|
91
|
+
end
|
|
92
|
+
end
|
|
93
|
+
self[idx]
|
|
94
|
+
end
|
|
95
|
+
|
|
96
|
+
# Rotate in the plane specified by axes.
|
|
97
|
+
# @example
|
|
98
|
+
# p a = Cumo::Int32.new(2,2).seq
|
|
99
|
+
# # Cumo::Int32#shape=[2,2]
|
|
100
|
+
# # [[0, 1],
|
|
101
|
+
# # [2, 3]]
|
|
102
|
+
#
|
|
103
|
+
# p a.rot90
|
|
104
|
+
# # Cumo::Int32(view)#shape=[2,2]
|
|
105
|
+
# # [[1, 3],
|
|
106
|
+
# # [0, 2]]
|
|
107
|
+
#
|
|
108
|
+
# p a.rot90(2)
|
|
109
|
+
# # Cumo::Int32(view)#shape=[2,2]
|
|
110
|
+
# # [[3, 2],
|
|
111
|
+
# # [1, 0]]
|
|
112
|
+
#
|
|
113
|
+
# p a.rot90(3)
|
|
114
|
+
# # Cumo::Int32(view)#shape=[2,2]
|
|
115
|
+
# # [[2, 0],
|
|
116
|
+
# # [3, 1]]
|
|
117
|
+
def rot90(k=1,axes=[0,1])
|
|
118
|
+
case k % 4
|
|
119
|
+
when 0
|
|
120
|
+
view
|
|
121
|
+
when 1
|
|
122
|
+
swapaxes(*axes).reverse(axes[0])
|
|
123
|
+
when 2
|
|
124
|
+
reverse(*axes)
|
|
125
|
+
when 3
|
|
126
|
+
swapaxes(*axes).reverse(axes[1])
|
|
127
|
+
end
|
|
128
|
+
end
|
|
129
|
+
|
|
130
|
+
def to_i
|
|
131
|
+
if size==1
|
|
132
|
+
self.extract_cpu.to_i
|
|
133
|
+
else
|
|
134
|
+
# convert to Int?
|
|
135
|
+
raise TypeError, "can't convert #{self.class} into Integer"
|
|
136
|
+
end
|
|
137
|
+
end
|
|
138
|
+
|
|
139
|
+
def to_f
|
|
140
|
+
if size==1
|
|
141
|
+
self.extract_cpu.to_f
|
|
142
|
+
else
|
|
143
|
+
# convert to DFloat?
|
|
144
|
+
raise TypeError, "can't convert #{self.class} into Float"
|
|
145
|
+
end
|
|
146
|
+
end
|
|
147
|
+
|
|
148
|
+
def to_c
|
|
149
|
+
if size==1
|
|
150
|
+
Complex(self.extract_cpu)
|
|
151
|
+
else
|
|
152
|
+
# convert to DComplex?
|
|
153
|
+
raise TypeError, "can't convert #{self.class} into Complex"
|
|
154
|
+
end
|
|
155
|
+
end
|
|
156
|
+
|
|
157
|
+
# Convert the argument to an narray if not an narray.
|
|
158
|
+
def self.cast(a)
|
|
159
|
+
a.kind_of?(NArray) ? a : NArray.array_type(a).cast(a)
|
|
160
|
+
end
|
|
161
|
+
|
|
162
|
+
def self.asarray(a)
|
|
163
|
+
case a
|
|
164
|
+
when NArray
|
|
165
|
+
(a.ndim == 0) ? a[:new] : a
|
|
166
|
+
when Numeric,Range
|
|
167
|
+
self[a]
|
|
168
|
+
else
|
|
169
|
+
cast(a)
|
|
170
|
+
end
|
|
171
|
+
end
|
|
172
|
+
|
|
173
|
+
# parse matrix like matlab, octave
|
|
174
|
+
# @example
|
|
175
|
+
# a = Cumo::DFloat.parse %[
|
|
176
|
+
# 2 -3 5
|
|
177
|
+
# 4 9 7
|
|
178
|
+
# 2 -1 6
|
|
179
|
+
# ]
|
|
180
|
+
# => Cumo::DFloat#shape=[3,3]
|
|
181
|
+
# [[2, -3, 5],
|
|
182
|
+
# [4, 9, 7],
|
|
183
|
+
# [2, -1, 6]]
|
|
184
|
+
|
|
185
|
+
def self.parse(str, split1d:/\s+/, split2d:/;?$|;/,
|
|
186
|
+
split3d:/\s*\n(\s*\n)+/m)
|
|
187
|
+
a = []
|
|
188
|
+
str.split(split3d).each do |block|
|
|
189
|
+
b = []
|
|
190
|
+
#print "b"; p block
|
|
191
|
+
block.split(split2d).each do |line|
|
|
192
|
+
#p line
|
|
193
|
+
line.strip!
|
|
194
|
+
if !line.empty?
|
|
195
|
+
c = []
|
|
196
|
+
line.split(split1d).each do |item|
|
|
197
|
+
c << eval(item.strip) if !item.empty?
|
|
198
|
+
end
|
|
199
|
+
b << c if !c.empty?
|
|
200
|
+
end
|
|
201
|
+
end
|
|
202
|
+
a << b if !b.empty?
|
|
203
|
+
end
|
|
204
|
+
if a.size==1
|
|
205
|
+
self.cast(a[0])
|
|
206
|
+
else
|
|
207
|
+
self.cast(a)
|
|
208
|
+
end
|
|
209
|
+
end
|
|
210
|
+
|
|
211
|
+
# Append values to the end of an narray.
|
|
212
|
+
# @example
|
|
213
|
+
# a = Cumo::DFloat[1, 2, 3]
|
|
214
|
+
# p a.append([[4, 5, 6], [7, 8, 9]])
|
|
215
|
+
# # Cumo::DFloat#shape=[9]
|
|
216
|
+
# # [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
|
217
|
+
#
|
|
218
|
+
# a = Cumo::DFloat[[1, 2, 3]]
|
|
219
|
+
# p a.append([[4, 5, 6], [7, 8, 9]],axis:0)
|
|
220
|
+
# # Cumo::DFloat#shape=[3,3]
|
|
221
|
+
# # [[1, 2, 3],
|
|
222
|
+
# # [4, 5, 6],
|
|
223
|
+
# # [7, 8, 9]]
|
|
224
|
+
#
|
|
225
|
+
# a = Cumo::DFloat[[1, 2, 3], [4, 5, 6]]
|
|
226
|
+
# p a.append([7, 8, 9], axis:0)
|
|
227
|
+
# # in `append': dimension mismatch (Cumo::NArray::DimensionError)
|
|
228
|
+
|
|
229
|
+
def append(other,axis:nil)
|
|
230
|
+
other = self.class.cast(other)
|
|
231
|
+
if axis
|
|
232
|
+
if ndim != other.ndim
|
|
233
|
+
raise DimensionError, "dimension mismatch"
|
|
234
|
+
end
|
|
235
|
+
return concatenate(other,axis:axis)
|
|
236
|
+
else
|
|
237
|
+
a = self.class.zeros(size+other.size)
|
|
238
|
+
a[0...size] = self[true]
|
|
239
|
+
a[size..-1] = other[true]
|
|
240
|
+
return a
|
|
241
|
+
end
|
|
242
|
+
end
|
|
243
|
+
|
|
244
|
+
# Return a new array with sub-arrays along an axis deleted.
|
|
245
|
+
# If axis is not given, obj is applied to the flattened array.
|
|
246
|
+
|
|
247
|
+
# @example
|
|
248
|
+
# a = Cumo::DFloat[[1,2,3,4], [5,6,7,8], [9,10,11,12]]
|
|
249
|
+
# p a.delete(1,0)
|
|
250
|
+
# # Cumo::DFloat(view)#shape=[2,4]
|
|
251
|
+
# # [[1, 2, 3, 4],
|
|
252
|
+
# # [9, 10, 11, 12]]
|
|
253
|
+
#
|
|
254
|
+
# p a.delete((0..-1).step(2),1)
|
|
255
|
+
# # Cumo::DFloat(view)#shape=[3,2]
|
|
256
|
+
# # [[2, 4],
|
|
257
|
+
# # [6, 8],
|
|
258
|
+
# # [10, 12]]
|
|
259
|
+
#
|
|
260
|
+
# p a.delete([1,3,5])
|
|
261
|
+
# # Cumo::DFloat(view)#shape=[9]
|
|
262
|
+
# # [1, 3, 5, 7, 8, 9, 10, 11, 12]
|
|
263
|
+
|
|
264
|
+
def delete(indice,axis=nil)
|
|
265
|
+
if axis
|
|
266
|
+
bit = Bit.ones(shape[axis])
|
|
267
|
+
bit[indice] = 0
|
|
268
|
+
idx = [true]*ndim
|
|
269
|
+
idx[axis] = bit.where
|
|
270
|
+
return self[*idx].copy
|
|
271
|
+
else
|
|
272
|
+
bit = Bit.ones(size)
|
|
273
|
+
bit[indice] = 0
|
|
274
|
+
return self[bit.where].copy
|
|
275
|
+
end
|
|
276
|
+
end
|
|
277
|
+
|
|
278
|
+
# Insert values along the axis before the indices.
|
|
279
|
+
# @example
|
|
280
|
+
# p a = Cumo::DFloat[[1, 2], [3, 4]]
|
|
281
|
+
# a = Cumo::Int32[[1, 1], [2, 2], [3, 3]]
|
|
282
|
+
#
|
|
283
|
+
# p a.insert(1,5)
|
|
284
|
+
# # Cumo::Int32#shape=[7]
|
|
285
|
+
# # [1, 5, 1, 2, 2, 3, 3]
|
|
286
|
+
#
|
|
287
|
+
# p a.insert(1, 5, axis:1)
|
|
288
|
+
# # Cumo::Int32#shape=[3,3]
|
|
289
|
+
# # [[1, 5, 1],
|
|
290
|
+
# # [2, 5, 2],
|
|
291
|
+
# # [3, 5, 3]]
|
|
292
|
+
#
|
|
293
|
+
# p a.insert([1], [[11],[12],[13]], axis:1)
|
|
294
|
+
# # Cumo::Int32#shape=[3,3]
|
|
295
|
+
# # [[1, 11, 1],
|
|
296
|
+
# # [2, 12, 2],
|
|
297
|
+
# # [3, 13, 3]]
|
|
298
|
+
#
|
|
299
|
+
# p a.insert(1, [11, 12, 13], axis:1)
|
|
300
|
+
# # Cumo::Int32#shape=[3,3]
|
|
301
|
+
# # [[1, 11, 1],
|
|
302
|
+
# # [2, 12, 2],
|
|
303
|
+
# # [3, 13, 3]]
|
|
304
|
+
#
|
|
305
|
+
# p a.insert([1], [11, 12, 13], axis:1)
|
|
306
|
+
# # Cumo::Int32#shape=[3,5]
|
|
307
|
+
# # [[1, 11, 12, 13, 1],
|
|
308
|
+
# # [2, 11, 12, 13, 2],
|
|
309
|
+
# # [3, 11, 12, 13, 3]]
|
|
310
|
+
#
|
|
311
|
+
# p b = a.flatten
|
|
312
|
+
# # Cumo::Int32(view)#shape=[6]
|
|
313
|
+
# # [1, 1, 2, 2, 3, 3]
|
|
314
|
+
#
|
|
315
|
+
# p b.insert(2,[15,16])
|
|
316
|
+
# # Cumo::Int32#shape=[8]
|
|
317
|
+
# # [1, 1, 15, 16, 2, 2, 3, 3]
|
|
318
|
+
#
|
|
319
|
+
# p b.insert([2,2],[15,16])
|
|
320
|
+
# # Cumo::Int32#shape=[8]
|
|
321
|
+
# # [1, 1, 15, 16, 2, 2, 3, 3]
|
|
322
|
+
#
|
|
323
|
+
# p b.insert([2,1],[15,16])
|
|
324
|
+
# # Cumo::Int32#shape=[8]
|
|
325
|
+
# # [1, 16, 1, 15, 2, 2, 3, 3]
|
|
326
|
+
#
|
|
327
|
+
# p b.insert([2,0,1],[15,16,17])
|
|
328
|
+
# # Cumo::Int32#shape=[9]
|
|
329
|
+
# # [16, 1, 17, 1, 15, 2, 2, 3, 3]
|
|
330
|
+
#
|
|
331
|
+
# p b.insert(2..3, [15, 16])
|
|
332
|
+
# # Cumo::Int32#shape=[8]
|
|
333
|
+
# # [1, 1, 15, 2, 16, 2, 3, 3]
|
|
334
|
+
#
|
|
335
|
+
# p b.insert(2, [7.13, 0.5])
|
|
336
|
+
# # Cumo::Int32#shape=[8]
|
|
337
|
+
# # [1, 1, 7, 0, 2, 2, 3, 3]
|
|
338
|
+
#
|
|
339
|
+
# p x = Cumo::DFloat.new(2,4).seq
|
|
340
|
+
# # Cumo::DFloat#shape=[2,4]
|
|
341
|
+
# # [[0, 1, 2, 3],
|
|
342
|
+
# # [4, 5, 6, 7]]
|
|
343
|
+
#
|
|
344
|
+
# p x.insert([1,3],999,axis:1)
|
|
345
|
+
# # Cumo::DFloat#shape=[2,6]
|
|
346
|
+
# # [[0, 999, 1, 2, 999, 3],
|
|
347
|
+
# # [4, 999, 5, 6, 999, 7]]
|
|
348
|
+
|
|
349
|
+
def insert(indice,values,axis:nil)
|
|
350
|
+
if axis
|
|
351
|
+
values = self.class.asarray(values)
|
|
352
|
+
nd = values.ndim
|
|
353
|
+
midx = [:new]*(ndim-nd) + [true]*nd
|
|
354
|
+
case indice
|
|
355
|
+
when Numeric
|
|
356
|
+
midx[-nd-1] = true
|
|
357
|
+
midx[axis] = :new
|
|
358
|
+
end
|
|
359
|
+
values = values[*midx]
|
|
360
|
+
else
|
|
361
|
+
values = self.class.asarray(values).flatten
|
|
362
|
+
end
|
|
363
|
+
idx = Int64.asarray(indice)
|
|
364
|
+
nidx = idx.size
|
|
365
|
+
if nidx == 1
|
|
366
|
+
nidx = values.shape[axis||0]
|
|
367
|
+
idx = idx + Int64.new(nidx).seq
|
|
368
|
+
else
|
|
369
|
+
sidx = idx.sort_index
|
|
370
|
+
idx[sidx] += Int64.new(nidx).seq
|
|
371
|
+
end
|
|
372
|
+
if axis
|
|
373
|
+
bit = Bit.ones(shape[axis]+nidx)
|
|
374
|
+
bit[idx] = 0
|
|
375
|
+
new_shape = shape
|
|
376
|
+
new_shape[axis] += nidx
|
|
377
|
+
a = self.class.zeros(new_shape)
|
|
378
|
+
mdidx = [true]*ndim
|
|
379
|
+
mdidx[axis] = bit.where
|
|
380
|
+
a[*mdidx] = self
|
|
381
|
+
mdidx[axis] = idx
|
|
382
|
+
a[*mdidx] = values
|
|
383
|
+
else
|
|
384
|
+
bit = Bit.ones(size+nidx)
|
|
385
|
+
bit[idx] = 0
|
|
386
|
+
a = self.class.zeros(size+nidx)
|
|
387
|
+
a[bit.where] = self.flatten
|
|
388
|
+
a[idx] = values
|
|
389
|
+
end
|
|
390
|
+
return a
|
|
391
|
+
end
|
|
392
|
+
|
|
393
|
+
class << self
|
|
394
|
+
# @example
|
|
395
|
+
# p a = Cumo::DFloat[[1, 2], [3, 4]]
|
|
396
|
+
# # Cumo::DFloat#shape=[2,2]
|
|
397
|
+
# # [[1, 2],
|
|
398
|
+
# # [3, 4]]
|
|
399
|
+
#
|
|
400
|
+
# p b = Cumo::DFloat[[5, 6]]
|
|
401
|
+
# # Cumo::DFloat#shape=[1,2]
|
|
402
|
+
# # [[5, 6]]
|
|
403
|
+
#
|
|
404
|
+
# p Cumo::NArray.concatenate([a,b],axis:0)
|
|
405
|
+
# # Cumo::DFloat#shape=[3,2]
|
|
406
|
+
# # [[1, 2],
|
|
407
|
+
# # [3, 4],
|
|
408
|
+
# # [5, 6]]
|
|
409
|
+
#
|
|
410
|
+
# p Cumo::NArray.concatenate([a,b.transpose], axis:1)
|
|
411
|
+
# # Cumo::DFloat#shape=[2,3]
|
|
412
|
+
# # [[1, 2, 5],
|
|
413
|
+
# # [3, 4, 6]]
|
|
414
|
+
|
|
415
|
+
def concatenate(arrays,axis:0)
|
|
416
|
+
klass = (self==NArray) ? NArray.array_type(arrays) : self
|
|
417
|
+
nd = 0
|
|
418
|
+
arrays = arrays.map do |a|
|
|
419
|
+
case a
|
|
420
|
+
when NArray
|
|
421
|
+
# ok
|
|
422
|
+
when Numeric
|
|
423
|
+
a = klass[a]
|
|
424
|
+
when Array
|
|
425
|
+
a = klass.cast(a)
|
|
426
|
+
else
|
|
427
|
+
raise TypeError,"not Cumo::NArray: #{a.inspect[0..48]}"
|
|
428
|
+
end
|
|
429
|
+
if a.ndim > nd
|
|
430
|
+
nd = a.ndim
|
|
431
|
+
end
|
|
432
|
+
a
|
|
433
|
+
end
|
|
434
|
+
if axis < 0
|
|
435
|
+
axis += nd
|
|
436
|
+
end
|
|
437
|
+
if axis < 0 || axis >= nd
|
|
438
|
+
raise ArgumentError,"axis is out of range"
|
|
439
|
+
end
|
|
440
|
+
new_shape = nil
|
|
441
|
+
sum_size = 0
|
|
442
|
+
arrays.each do |a|
|
|
443
|
+
a_shape = a.shape
|
|
444
|
+
if nd != a_shape.size
|
|
445
|
+
a_shape = [1]*(nd-a_shape.size) + a_shape
|
|
446
|
+
end
|
|
447
|
+
sum_size += a_shape.delete_at(axis)
|
|
448
|
+
if new_shape
|
|
449
|
+
if new_shape != a_shape
|
|
450
|
+
raise ShapeError,"shape mismatch"
|
|
451
|
+
end
|
|
452
|
+
else
|
|
453
|
+
new_shape = a_shape
|
|
454
|
+
end
|
|
455
|
+
end
|
|
456
|
+
new_shape.insert(axis,sum_size)
|
|
457
|
+
result = klass.zeros(*new_shape)
|
|
458
|
+
lst = 0
|
|
459
|
+
refs = [true] * nd
|
|
460
|
+
arrays.each do |a|
|
|
461
|
+
fst = lst
|
|
462
|
+
lst = fst + (a.shape[axis-nd]||1)
|
|
463
|
+
refs[axis] = fst...lst
|
|
464
|
+
result[*refs] = a
|
|
465
|
+
end
|
|
466
|
+
result
|
|
467
|
+
end
|
|
468
|
+
|
|
469
|
+
# Stack arrays vertically (row wise).
|
|
470
|
+
# @example
|
|
471
|
+
# a = Cumo::Int32[1,2,3]
|
|
472
|
+
# b = Cumo::Int32[2,3,4]
|
|
473
|
+
# p Cumo::NArray.vstack([a,b])
|
|
474
|
+
# # Cumo::Int32#shape=[2,3]
|
|
475
|
+
# # [[1, 2, 3],
|
|
476
|
+
# # [2, 3, 4]]
|
|
477
|
+
#
|
|
478
|
+
# a = Cumo::Int32[[1],[2],[3]]
|
|
479
|
+
# b = Cumo::Int32[[2],[3],[4]]
|
|
480
|
+
# p Cumo::NArray.vstack([a,b])
|
|
481
|
+
# # Cumo::Int32#shape=[6,1]
|
|
482
|
+
# # [[1],
|
|
483
|
+
# # [2],
|
|
484
|
+
# # [3],
|
|
485
|
+
# # [2],
|
|
486
|
+
# # [3],
|
|
487
|
+
# # [4]]
|
|
488
|
+
|
|
489
|
+
def vstack(arrays)
|
|
490
|
+
arys = arrays.map do |a|
|
|
491
|
+
_atleast_2d(cast(a))
|
|
492
|
+
end
|
|
493
|
+
concatenate(arys,axis:0)
|
|
494
|
+
end
|
|
495
|
+
|
|
496
|
+
# Stack arrays horizontally (column wise).
|
|
497
|
+
# @example
|
|
498
|
+
# a = Cumo::Int32[1,2,3]
|
|
499
|
+
# b = Cumo::Int32[2,3,4]
|
|
500
|
+
# p Cumo::NArray.hstack([a,b])
|
|
501
|
+
# # Cumo::Int32#shape=[6]
|
|
502
|
+
# # [1, 2, 3, 2, 3, 4]
|
|
503
|
+
#
|
|
504
|
+
# a = Cumo::Int32[[1],[2],[3]]
|
|
505
|
+
# b = Cumo::Int32[[2],[3],[4]]
|
|
506
|
+
# p Cumo::NArray.hstack([a,b])
|
|
507
|
+
# # Cumo::Int32#shape=[3,2]
|
|
508
|
+
# # [[1, 2],
|
|
509
|
+
# # [2, 3],
|
|
510
|
+
# # [3, 4]]
|
|
511
|
+
|
|
512
|
+
def hstack(arrays)
|
|
513
|
+
klass = (self==NArray) ? NArray.array_type(arrays) : self
|
|
514
|
+
nd = 0
|
|
515
|
+
arys = arrays.map do |a|
|
|
516
|
+
a = klass.cast(a)
|
|
517
|
+
nd = a.ndim if a.ndim > nd
|
|
518
|
+
a
|
|
519
|
+
end
|
|
520
|
+
dim = (nd >= 2) ? 1 : 0
|
|
521
|
+
concatenate(arys,axis:dim)
|
|
522
|
+
end
|
|
523
|
+
|
|
524
|
+
# Stack arrays in depth wise (along third axis).
|
|
525
|
+
# @example
|
|
526
|
+
# a = Cumo::Int32[1,2,3]
|
|
527
|
+
# b = Cumo::Int32[2,3,4]
|
|
528
|
+
# p Cumo::NArray.dstack([a,b])
|
|
529
|
+
# # Cumo::Int32#shape=[1,3,2]
|
|
530
|
+
# # [[[1, 2],
|
|
531
|
+
# # [2, 3],
|
|
532
|
+
# # [3, 4]]]
|
|
533
|
+
#
|
|
534
|
+
# a = Cumo::Int32[[1],[2],[3]]
|
|
535
|
+
# b = Cumo::Int32[[2],[3],[4]]
|
|
536
|
+
# p Cumo::NArray.dstack([a,b])
|
|
537
|
+
# # Cumo::Int32#shape=[3,1,2]
|
|
538
|
+
# # [[[1, 2]],
|
|
539
|
+
# # [[2, 3]],
|
|
540
|
+
# # [[3, 4]]]
|
|
541
|
+
|
|
542
|
+
def dstack(arrays)
|
|
543
|
+
arys = arrays.map do |a|
|
|
544
|
+
_atleast_3d(cast(a))
|
|
545
|
+
end
|
|
546
|
+
concatenate(arys,axis:2)
|
|
547
|
+
end
|
|
548
|
+
|
|
549
|
+
# Stack 1-d arrays into columns of a 2-d array.
|
|
550
|
+
# @example
|
|
551
|
+
# x = Cumo::Int32[1,2,3]
|
|
552
|
+
# y = Cumo::Int32[2,3,4]
|
|
553
|
+
# p Cumo::NArray.column_stack([x,y])
|
|
554
|
+
# # Cumo::Int32#shape=[3,2]
|
|
555
|
+
# # [[1, 2],
|
|
556
|
+
# # [2, 3],
|
|
557
|
+
# # [3, 4]]
|
|
558
|
+
|
|
559
|
+
def column_stack(arrays)
|
|
560
|
+
arys = arrays.map do |a|
|
|
561
|
+
a = cast(a)
|
|
562
|
+
case a.ndim
|
|
563
|
+
when 0; a[:new,:new]
|
|
564
|
+
when 1; a[true,:new]
|
|
565
|
+
else; a
|
|
566
|
+
end
|
|
567
|
+
end
|
|
568
|
+
concatenate(arys,axis:1)
|
|
569
|
+
end
|
|
570
|
+
|
|
571
|
+
private
|
|
572
|
+
# Return an narray with at least two dimension.
|
|
573
|
+
def _atleast_2d(a)
|
|
574
|
+
case a.ndim
|
|
575
|
+
when 0; a[:new,:new]
|
|
576
|
+
when 1; a[:new,true]
|
|
577
|
+
else; a
|
|
578
|
+
end
|
|
579
|
+
end
|
|
580
|
+
|
|
581
|
+
# Return an narray with at least three dimension.
|
|
582
|
+
def _atleast_3d(a)
|
|
583
|
+
case a.ndim
|
|
584
|
+
when 0; a[:new,:new,:new]
|
|
585
|
+
when 1; a[:new,true,:new]
|
|
586
|
+
when 2; a[true,true,:new]
|
|
587
|
+
else; a
|
|
588
|
+
end
|
|
589
|
+
end
|
|
590
|
+
|
|
591
|
+
end # class << self
|
|
592
|
+
|
|
593
|
+
# @example
|
|
594
|
+
# p a = Cumo::DFloat[[1, 2], [3, 4]]
|
|
595
|
+
# # Cumo::DFloat#shape=[2,2]
|
|
596
|
+
# # [[1, 2],
|
|
597
|
+
# # [3, 4]]
|
|
598
|
+
#
|
|
599
|
+
# p b = Cumo::DFloat[[5, 6]]
|
|
600
|
+
# # Cumo::DFloat#shape=[1,2]
|
|
601
|
+
# # [[5, 6]]
|
|
602
|
+
#
|
|
603
|
+
# p a.concatenate(b,axis:0)
|
|
604
|
+
# # Cumo::DFloat#shape=[3,2]
|
|
605
|
+
# # [[1, 2],
|
|
606
|
+
# # [3, 4],
|
|
607
|
+
# # [5, 6]]
|
|
608
|
+
#
|
|
609
|
+
# p a.concatenate(b.transpose, axis:1)
|
|
610
|
+
# # Cumo::DFloat#shape=[2,3]
|
|
611
|
+
# # [[1, 2, 5],
|
|
612
|
+
# # [3, 4, 6]]
|
|
613
|
+
|
|
614
|
+
def concatenate(*arrays,axis:0)
|
|
615
|
+
axis = check_axis(axis)
|
|
616
|
+
self_shape = shape
|
|
617
|
+
self_shape.delete_at(axis)
|
|
618
|
+
sum_size = shape[axis]
|
|
619
|
+
arrays.map! do |a|
|
|
620
|
+
case a
|
|
621
|
+
when NArray
|
|
622
|
+
# ok
|
|
623
|
+
when Numeric
|
|
624
|
+
a = self.class.new(1).store(a)
|
|
625
|
+
when Array
|
|
626
|
+
a = self.class.cast(a)
|
|
627
|
+
else
|
|
628
|
+
raise TypeError,"not Cumo::NArray: #{a.inspect[0..48]}"
|
|
629
|
+
end
|
|
630
|
+
if a.ndim > ndim
|
|
631
|
+
raise ShapeError,"dimension mismatch"
|
|
632
|
+
end
|
|
633
|
+
a_shape = a.shape
|
|
634
|
+
sum_size += a_shape.delete_at(axis-ndim) || 1
|
|
635
|
+
if self_shape != a_shape
|
|
636
|
+
raise ShapeError,"shape mismatch"
|
|
637
|
+
end
|
|
638
|
+
a
|
|
639
|
+
end
|
|
640
|
+
self_shape.insert(axis,sum_size)
|
|
641
|
+
result = self.class.zeros(*self_shape)
|
|
642
|
+
lst = shape[axis]
|
|
643
|
+
refs = [true] * ndim
|
|
644
|
+
refs[axis] = 0...lst
|
|
645
|
+
result[*refs] = self
|
|
646
|
+
arrays.each do |a|
|
|
647
|
+
fst = lst
|
|
648
|
+
lst = fst + (a.shape[axis-ndim] || 1)
|
|
649
|
+
refs[axis] = fst...lst
|
|
650
|
+
result[*refs] = a
|
|
651
|
+
end
|
|
652
|
+
result
|
|
653
|
+
end
|
|
654
|
+
|
|
655
|
+
# @example
|
|
656
|
+
# p x = Cumo::DFloat.new(9).seq
|
|
657
|
+
# # Cumo::DFloat#shape=[9]
|
|
658
|
+
# # [0, 1, 2, 3, 4, 5, 6, 7, 8]
|
|
659
|
+
#
|
|
660
|
+
# pp x.split(3)
|
|
661
|
+
# # [Cumo::DFloat(view)#shape=[3]
|
|
662
|
+
# # [0, 1, 2],
|
|
663
|
+
# # Cumo::DFloat(view)#shape=[3]
|
|
664
|
+
# # [3, 4, 5],
|
|
665
|
+
# # Cumo::DFloat(view)#shape=[3]
|
|
666
|
+
# # [6, 7, 8]]
|
|
667
|
+
#
|
|
668
|
+
# p x = Cumo::DFloat.new(8).seq
|
|
669
|
+
# # Cumo::DFloat#shape=[8]
|
|
670
|
+
# # [0, 1, 2, 3, 4, 5, 6, 7]
|
|
671
|
+
#
|
|
672
|
+
# pp x.split([3, 5, 6, 10])
|
|
673
|
+
# # [Cumo::DFloat(view)#shape=[3]
|
|
674
|
+
# # [0, 1, 2],
|
|
675
|
+
# # Cumo::DFloat(view)#shape=[2]
|
|
676
|
+
# # [3, 4],
|
|
677
|
+
# # Cumo::DFloat(view)#shape=[1]
|
|
678
|
+
# # [5],
|
|
679
|
+
# # Cumo::DFloat(view)#shape=[2]
|
|
680
|
+
# # [6, 7],
|
|
681
|
+
# # Cumo::DFloat(view)#shape=[0][]]
|
|
682
|
+
|
|
683
|
+
def split(indices_or_sections, axis:0)
|
|
684
|
+
axis = check_axis(axis)
|
|
685
|
+
size_axis = shape[axis]
|
|
686
|
+
case indices_or_sections
|
|
687
|
+
when Integer
|
|
688
|
+
div_axis, mod_axis = size_axis.divmod(indices_or_sections)
|
|
689
|
+
refs = [true]*ndim
|
|
690
|
+
beg_idx = 0
|
|
691
|
+
mod_axis.times.map do |i|
|
|
692
|
+
end_idx = beg_idx + div_axis + 1
|
|
693
|
+
refs[axis] = beg_idx ... end_idx
|
|
694
|
+
beg_idx = end_idx
|
|
695
|
+
self[*refs]
|
|
696
|
+
end +
|
|
697
|
+
(indices_or_sections-mod_axis).times.map do |i|
|
|
698
|
+
end_idx = beg_idx + div_axis
|
|
699
|
+
refs[axis] = beg_idx ... end_idx
|
|
700
|
+
beg_idx = end_idx
|
|
701
|
+
self[*refs]
|
|
702
|
+
end
|
|
703
|
+
when NArray
|
|
704
|
+
split(indices_or_sections.to_a,axis:axis)
|
|
705
|
+
when Array
|
|
706
|
+
refs = [true]*ndim
|
|
707
|
+
fst = 0
|
|
708
|
+
(indices_or_sections + [size_axis]).map do |lst|
|
|
709
|
+
lst = size_axis if lst > size_axis
|
|
710
|
+
refs[axis] = (fst < size_axis) ? fst...lst : -1...-1
|
|
711
|
+
fst = lst
|
|
712
|
+
self[*refs]
|
|
713
|
+
end
|
|
714
|
+
else
|
|
715
|
+
raise TypeError,"argument must be Integer or Array"
|
|
716
|
+
end
|
|
717
|
+
end
|
|
718
|
+
|
|
719
|
+
# @example
|
|
720
|
+
# p x = Cumo::DFloat.new(4,4).seq
|
|
721
|
+
# # Cumo::DFloat#shape=[4,4]
|
|
722
|
+
# # [[0, 1, 2, 3],
|
|
723
|
+
# # [4, 5, 6, 7],
|
|
724
|
+
# # [8, 9, 10, 11],
|
|
725
|
+
# # [12, 13, 14, 15]]
|
|
726
|
+
#
|
|
727
|
+
# pp x.hsplit(2)
|
|
728
|
+
# # [Cumo::DFloat(view)#shape=[4,2]
|
|
729
|
+
# # [[0, 1],
|
|
730
|
+
# # [4, 5],
|
|
731
|
+
# # [8, 9],
|
|
732
|
+
# # [12, 13]],
|
|
733
|
+
# # Cumo::DFloat(view)#shape=[4,2]
|
|
734
|
+
# # [[2, 3],
|
|
735
|
+
# # [6, 7],
|
|
736
|
+
# # [10, 11],
|
|
737
|
+
# # [14, 15]]]
|
|
738
|
+
#
|
|
739
|
+
# pp x.hsplit([3, 6])
|
|
740
|
+
# # [Cumo::DFloat(view)#shape=[4,3]
|
|
741
|
+
# # [[0, 1, 2],
|
|
742
|
+
# # [4, 5, 6],
|
|
743
|
+
# # [8, 9, 10],
|
|
744
|
+
# # [12, 13, 14]],
|
|
745
|
+
# # Cumo::DFloat(view)#shape=[4,1]
|
|
746
|
+
# # [[3],
|
|
747
|
+
# # [7],
|
|
748
|
+
# # [11],
|
|
749
|
+
# # [15]],
|
|
750
|
+
# # Cumo::DFloat(view)#shape=[4,0][]]
|
|
751
|
+
|
|
752
|
+
def vsplit(indices_or_sections)
|
|
753
|
+
split(indices_or_sections, axis:0)
|
|
754
|
+
end
|
|
755
|
+
|
|
756
|
+
def hsplit(indices_or_sections)
|
|
757
|
+
split(indices_or_sections, axis:1)
|
|
758
|
+
end
|
|
759
|
+
|
|
760
|
+
def dsplit(indices_or_sections)
|
|
761
|
+
split(indices_or_sections, axis:2)
|
|
762
|
+
end
|
|
763
|
+
|
|
764
|
+
# @example
|
|
765
|
+
# p a = Cumo::NArray[0,1,2]
|
|
766
|
+
# # Cumo::Int32#shape=[3]
|
|
767
|
+
# # [0, 1, 2]
|
|
768
|
+
#
|
|
769
|
+
# p a.tile(2)
|
|
770
|
+
# # Cumo::Int32#shape=[6]
|
|
771
|
+
# # [0, 1, 2, 0, 1, 2]
|
|
772
|
+
#
|
|
773
|
+
# p a.tile(2,2)
|
|
774
|
+
# # Cumo::Int32#shape=[2,6]
|
|
775
|
+
# # [[0, 1, 2, 0, 1, 2],
|
|
776
|
+
# # [0, 1, 2, 0, 1, 2]]
|
|
777
|
+
#
|
|
778
|
+
# p a.tile(2,1,2)
|
|
779
|
+
# # Cumo::Int32#shape=[2,1,6]
|
|
780
|
+
# # [[[0, 1, 2, 0, 1, 2]],
|
|
781
|
+
# # [[0, 1, 2, 0, 1, 2]]]
|
|
782
|
+
#
|
|
783
|
+
# p b = Cumo::NArray[[1, 2], [3, 4]]
|
|
784
|
+
# # Cumo::Int32#shape=[2,2]
|
|
785
|
+
# # [[1, 2],
|
|
786
|
+
# # [3, 4]]
|
|
787
|
+
#
|
|
788
|
+
# p b.tile(2)
|
|
789
|
+
# # Cumo::Int32#shape=[2,4]
|
|
790
|
+
# # [[1, 2, 1, 2],
|
|
791
|
+
# # [3, 4, 3, 4]]
|
|
792
|
+
#
|
|
793
|
+
# p b.tile(2,1)
|
|
794
|
+
# # Cumo::Int32#shape=[4,2]
|
|
795
|
+
# # [[1, 2],
|
|
796
|
+
# # [3, 4],
|
|
797
|
+
# # [1, 2],
|
|
798
|
+
# # [3, 4]]
|
|
799
|
+
#
|
|
800
|
+
# p c = Cumo::NArray[1,2,3,4]
|
|
801
|
+
# # Cumo::Int32#shape=[4]
|
|
802
|
+
# # [1, 2, 3, 4]
|
|
803
|
+
#
|
|
804
|
+
# p c.tile(4,1)
|
|
805
|
+
# # Cumo::Int32#shape=[4,4]
|
|
806
|
+
# # [[1, 2, 3, 4],
|
|
807
|
+
# # [1, 2, 3, 4],
|
|
808
|
+
# # [1, 2, 3, 4],
|
|
809
|
+
# # [1, 2, 3, 4]]
|
|
810
|
+
|
|
811
|
+
def tile(*arg)
|
|
812
|
+
arg.each do |i|
|
|
813
|
+
if !i.kind_of?(Integer) || i<1
|
|
814
|
+
raise ArgumentError,"argument should be positive integer"
|
|
815
|
+
end
|
|
816
|
+
end
|
|
817
|
+
ns = arg.size
|
|
818
|
+
nd = self.ndim
|
|
819
|
+
shp = self.shape
|
|
820
|
+
new_shp = []
|
|
821
|
+
src_shp = []
|
|
822
|
+
res_shp = []
|
|
823
|
+
(nd-ns).times do
|
|
824
|
+
new_shp << 1
|
|
825
|
+
new_shp << (n = shp.shift)
|
|
826
|
+
src_shp << :new
|
|
827
|
+
src_shp << true
|
|
828
|
+
res_shp << n
|
|
829
|
+
end
|
|
830
|
+
(ns-nd).times do
|
|
831
|
+
new_shp << (m = arg.shift)
|
|
832
|
+
new_shp << 1
|
|
833
|
+
src_shp << :new
|
|
834
|
+
src_shp << :new
|
|
835
|
+
res_shp << m
|
|
836
|
+
end
|
|
837
|
+
[nd,ns].min.times do
|
|
838
|
+
new_shp << (m = arg.shift)
|
|
839
|
+
new_shp << (n = shp.shift)
|
|
840
|
+
src_shp << :new
|
|
841
|
+
src_shp << true
|
|
842
|
+
res_shp << n*m
|
|
843
|
+
end
|
|
844
|
+
self.class.new(*new_shp).store(self[*src_shp]).reshape(*res_shp)
|
|
845
|
+
end
|
|
846
|
+
|
|
847
|
+
# @example
|
|
848
|
+
# p Cumo::NArray[3].repeat(4)
|
|
849
|
+
# # Cumo::Int32#shape=[4]
|
|
850
|
+
# # [3, 3, 3, 3]
|
|
851
|
+
#
|
|
852
|
+
# p x = Cumo::NArray[[1,2],[3,4]]
|
|
853
|
+
# # Cumo::Int32#shape=[2,2]
|
|
854
|
+
# # [[1, 2],
|
|
855
|
+
# # [3, 4]]
|
|
856
|
+
#
|
|
857
|
+
# p x.repeat(2)
|
|
858
|
+
# # Cumo::Int32#shape=[8]
|
|
859
|
+
# # [1, 1, 2, 2, 3, 3, 4, 4]
|
|
860
|
+
#
|
|
861
|
+
# p x.repeat(3,axis:1)
|
|
862
|
+
# # Cumo::Int32#shape=[2,6]
|
|
863
|
+
# # [[1, 1, 1, 2, 2, 2],
|
|
864
|
+
# # [3, 3, 3, 4, 4, 4]]
|
|
865
|
+
#
|
|
866
|
+
# p x.repeat([1,2],axis:0)
|
|
867
|
+
# # Cumo::Int32#shape=[3,2]
|
|
868
|
+
# # [[1, 2],
|
|
869
|
+
# # [3, 4],
|
|
870
|
+
# # [3, 4]]
|
|
871
|
+
|
|
872
|
+
def repeat(arg,axis:nil)
|
|
873
|
+
case axis
|
|
874
|
+
when Integer
|
|
875
|
+
axis = check_axis(axis)
|
|
876
|
+
c = self
|
|
877
|
+
when NilClass
|
|
878
|
+
c = self.flatten
|
|
879
|
+
axis = 0
|
|
880
|
+
else
|
|
881
|
+
raise ArgumentError,"invalid axis"
|
|
882
|
+
end
|
|
883
|
+
case arg
|
|
884
|
+
when Integer
|
|
885
|
+
if !arg.kind_of?(Integer) || arg<1
|
|
886
|
+
raise ArgumentError,"argument should be positive integer"
|
|
887
|
+
end
|
|
888
|
+
idx = c.shape[axis].times.map{|i| [i]*arg}.flatten
|
|
889
|
+
else
|
|
890
|
+
arg = arg.to_a
|
|
891
|
+
if arg.size != c.shape[axis]
|
|
892
|
+
raise ArgumentError,"repeat size shoud be equal to size along axis"
|
|
893
|
+
end
|
|
894
|
+
arg.each do |i|
|
|
895
|
+
if !i.kind_of?(Integer) || i<0
|
|
896
|
+
raise ArgumentError,"argument should be non-negative integer"
|
|
897
|
+
end
|
|
898
|
+
end
|
|
899
|
+
idx = arg.each_with_index.map{|a,i| [i]*a}.flatten
|
|
900
|
+
end
|
|
901
|
+
ref = [true] * c.ndim
|
|
902
|
+
ref[axis] = idx
|
|
903
|
+
c[*ref].copy
|
|
904
|
+
end
|
|
905
|
+
|
|
906
|
+
# Calculate the n-th discrete difference along given axis.
|
|
907
|
+
# @example
|
|
908
|
+
# p x = Cumo::DFloat[1, 2, 4, 7, 0]
|
|
909
|
+
# # Cumo::DFloat#shape=[5]
|
|
910
|
+
# # [1, 2, 4, 7, 0]
|
|
911
|
+
#
|
|
912
|
+
# p x.diff
|
|
913
|
+
# # Cumo::DFloat#shape=[4]
|
|
914
|
+
# # [1, 2, 3, -7]
|
|
915
|
+
#
|
|
916
|
+
# p x.diff(2)
|
|
917
|
+
# # Cumo::DFloat#shape=[3]
|
|
918
|
+
# # [1, 1, -10]
|
|
919
|
+
#
|
|
920
|
+
# p x = Cumo::DFloat[[1, 3, 6, 10], [0, 5, 6, 8]]
|
|
921
|
+
# # Cumo::DFloat#shape=[2,4]
|
|
922
|
+
# # [[1, 3, 6, 10],
|
|
923
|
+
# # [0, 5, 6, 8]]
|
|
924
|
+
#
|
|
925
|
+
# p x.diff
|
|
926
|
+
# # Cumo::DFloat#shape=[2,3]
|
|
927
|
+
# # [[2, 3, 4],
|
|
928
|
+
# # [5, 1, 2]]
|
|
929
|
+
#
|
|
930
|
+
# p x.diff(axis:0)
|
|
931
|
+
# # Cumo::DFloat#shape=[1,4]
|
|
932
|
+
# # [[-1, 2, 0, -2]]
|
|
933
|
+
|
|
934
|
+
def diff(n=1,axis:-1)
|
|
935
|
+
axis = check_axis(axis)
|
|
936
|
+
if n < 0 || n >= shape[axis]
|
|
937
|
+
raise ShapeError,"n=#{n} is invalid for shape[#{axis}]=#{shape[axis]}"
|
|
938
|
+
end
|
|
939
|
+
# calculate polynomial coefficient
|
|
940
|
+
c = self.class[-1,1]
|
|
941
|
+
2.upto(n) do |i|
|
|
942
|
+
x = self.class.zeros(i+1)
|
|
943
|
+
x[0..-2] = c
|
|
944
|
+
y = self.class.zeros(i+1)
|
|
945
|
+
y[1..-1] = c
|
|
946
|
+
c = y - x
|
|
947
|
+
end
|
|
948
|
+
s = [true]*ndim
|
|
949
|
+
s[axis] = n..-1
|
|
950
|
+
result = self[*s].dup
|
|
951
|
+
sum = result.inplace
|
|
952
|
+
(n-1).downto(0) do |i|
|
|
953
|
+
s = [true]*ndim
|
|
954
|
+
s[axis] = i..-n-1+i
|
|
955
|
+
sum + self[*s] * c[i] # inplace addition
|
|
956
|
+
end
|
|
957
|
+
return result
|
|
958
|
+
end
|
|
959
|
+
|
|
960
|
+
|
|
961
|
+
# Upper triangular matrix.
|
|
962
|
+
# Return a copy with the elements below the k-th diagonal filled with zero.
|
|
963
|
+
def triu(k=0)
|
|
964
|
+
dup.triu!(k)
|
|
965
|
+
end
|
|
966
|
+
|
|
967
|
+
# Upper triangular matrix.
|
|
968
|
+
# Fill the self elements below the k-th diagonal with zero.
|
|
969
|
+
def triu!(k=0)
|
|
970
|
+
if ndim < 2
|
|
971
|
+
raise NArray::ShapeError, "must be >= 2-dimensional array"
|
|
972
|
+
end
|
|
973
|
+
if contiguous?
|
|
974
|
+
*shp,m,n = shape
|
|
975
|
+
idx = tril_indices(k-1)
|
|
976
|
+
reshape!(*shp,m*n)
|
|
977
|
+
self[false,idx] = 0
|
|
978
|
+
reshape!(*shp,m,n)
|
|
979
|
+
else
|
|
980
|
+
store(triu(k))
|
|
981
|
+
end
|
|
982
|
+
end
|
|
983
|
+
|
|
984
|
+
# Return the indices for the uppler-triangle on and above the k-th diagonal.
|
|
985
|
+
def triu_indices(k=0)
|
|
986
|
+
if ndim < 2
|
|
987
|
+
raise NArray::ShapeError, "must be >= 2-dimensional array"
|
|
988
|
+
end
|
|
989
|
+
m,n = shape[-2..-1]
|
|
990
|
+
NArray.triu_indices(m,n,k=0)
|
|
991
|
+
end
|
|
992
|
+
|
|
993
|
+
# Return the indices for the uppler-triangle on and above the k-th diagonal.
|
|
994
|
+
def self.triu_indices(m,n,k=0)
|
|
995
|
+
x = Cumo::Int64.new(m,1).seq + k
|
|
996
|
+
y = Cumo::Int64.new(1,n).seq
|
|
997
|
+
(x<=y).where
|
|
998
|
+
end
|
|
999
|
+
|
|
1000
|
+
# Lower triangular matrix.
|
|
1001
|
+
# Return a copy with the elements above the k-th diagonal filled with zero.
|
|
1002
|
+
def tril(k=0)
|
|
1003
|
+
dup.tril!(k)
|
|
1004
|
+
end
|
|
1005
|
+
|
|
1006
|
+
# Lower triangular matrix.
|
|
1007
|
+
# Fill the self elements above the k-th diagonal with zero.
|
|
1008
|
+
def tril!(k=0)
|
|
1009
|
+
if ndim < 2
|
|
1010
|
+
raise NArray::ShapeError, "must be >= 2-dimensional array"
|
|
1011
|
+
end
|
|
1012
|
+
if contiguous?
|
|
1013
|
+
idx = triu_indices(k+1)
|
|
1014
|
+
*shp,m,n = shape
|
|
1015
|
+
reshape!(*shp,m*n)
|
|
1016
|
+
self[false,idx] = 0
|
|
1017
|
+
reshape!(*shp,m,n)
|
|
1018
|
+
else
|
|
1019
|
+
store(tril(k))
|
|
1020
|
+
end
|
|
1021
|
+
end
|
|
1022
|
+
|
|
1023
|
+
# Return the indices for the lower-triangle on and below the k-th diagonal.
|
|
1024
|
+
def tril_indices(k=0)
|
|
1025
|
+
if ndim < 2
|
|
1026
|
+
raise NArray::ShapeError, "must be >= 2-dimensional array"
|
|
1027
|
+
end
|
|
1028
|
+
m,n = shape[-2..-1]
|
|
1029
|
+
NArray.tril_indices(m,n,k)
|
|
1030
|
+
end
|
|
1031
|
+
|
|
1032
|
+
# Return the indices for the lower-triangle on and below the k-th diagonal.
|
|
1033
|
+
def self.tril_indices(m,n,k=0)
|
|
1034
|
+
x = Cumo::Int64.new(m,1).seq + k
|
|
1035
|
+
y = Cumo::Int64.new(1,n).seq
|
|
1036
|
+
(x>=y).where
|
|
1037
|
+
end
|
|
1038
|
+
|
|
1039
|
+
# Return the k-th diagonal indices.
|
|
1040
|
+
def diag_indices(k=0)
|
|
1041
|
+
if ndim < 2
|
|
1042
|
+
raise NArray::ShapeError, "must be >= 2-dimensional array"
|
|
1043
|
+
end
|
|
1044
|
+
m,n = shape[-2..-1]
|
|
1045
|
+
NArray.diag_indices(m,n,k)
|
|
1046
|
+
end
|
|
1047
|
+
|
|
1048
|
+
# Return the k-th diagonal indices.
|
|
1049
|
+
def self.diag_indices(m,n,k=0)
|
|
1050
|
+
x = Cumo::Int64.new(m,1).seq + k
|
|
1051
|
+
y = Cumo::Int64.new(1,n).seq
|
|
1052
|
+
(x.eq y).where
|
|
1053
|
+
end
|
|
1054
|
+
|
|
1055
|
+
# Return a matrix whose diagonal is constructed by self along the last axis.
|
|
1056
|
+
def diag(k=0)
|
|
1057
|
+
*shp,n = shape
|
|
1058
|
+
n += k.abs
|
|
1059
|
+
a = self.class.zeros(*shp,n,n)
|
|
1060
|
+
a.diagonal(k).store(self)
|
|
1061
|
+
a
|
|
1062
|
+
end
|
|
1063
|
+
|
|
1064
|
+
# Return the sum along diagonals of the array.
|
|
1065
|
+
#
|
|
1066
|
+
# If 2-D array, computes the summation along its diagonal with the
|
|
1067
|
+
# given offset, i.e., sum of `a[i,i+offset]`.
|
|
1068
|
+
# If more than 2-D array, the diagonal is determined from the axes
|
|
1069
|
+
# specified by axis argument. The default is axis=[-2,-1].
|
|
1070
|
+
# @param offset [Integer] (optional, default=0) diagonal offset
|
|
1071
|
+
# @param axis [Array] (optional, default=[-2,-1]) diagonal axis
|
|
1072
|
+
# @param nan [Bool] (optional, default=false) nan-aware algorithm, i.e., if true then it ignores nan.
|
|
1073
|
+
|
|
1074
|
+
def trace(offset=nil,axis=nil,nan:false)
|
|
1075
|
+
diagonal(offset,axis).sum(nan:nan,axis:-1)
|
|
1076
|
+
end
|
|
1077
|
+
|
|
1078
|
+
|
|
1079
|
+
@@warn_slow_dot = false
|
|
1080
|
+
|
|
1081
|
+
# Dot product of two arrays.
|
|
1082
|
+
# @param b [Cumo::NArray]
|
|
1083
|
+
# @return [Cumo::NArray] return dot product
|
|
1084
|
+
|
|
1085
|
+
def dot(b)
|
|
1086
|
+
t = self.class::UPCAST[b.class]
|
|
1087
|
+
if [SFloat, DFloat, SComplex, DComplex].include?(t)
|
|
1088
|
+
b = self.class.asarray(b)
|
|
1089
|
+
case self.ndim
|
|
1090
|
+
when 1
|
|
1091
|
+
case b.ndim
|
|
1092
|
+
when 1
|
|
1093
|
+
self.mulsum(b, axis:-1)
|
|
1094
|
+
else
|
|
1095
|
+
self[:new, false].gemm(b).flatten
|
|
1096
|
+
end
|
|
1097
|
+
else
|
|
1098
|
+
case b.ndim
|
|
1099
|
+
when 1
|
|
1100
|
+
self.gemm(b[false, :new]).flatten
|
|
1101
|
+
else
|
|
1102
|
+
self.gemm(b)
|
|
1103
|
+
end
|
|
1104
|
+
end
|
|
1105
|
+
else
|
|
1106
|
+
b = self.class.asarray(b)
|
|
1107
|
+
case b.ndim
|
|
1108
|
+
when 1
|
|
1109
|
+
mulsum(b, axis:-1)
|
|
1110
|
+
else
|
|
1111
|
+
case ndim
|
|
1112
|
+
when 0
|
|
1113
|
+
b.mulsum(self, axis:-2)
|
|
1114
|
+
when 1
|
|
1115
|
+
self[true,:new].mulsum(b, axis:-2)
|
|
1116
|
+
else
|
|
1117
|
+
unless @@warn_slow_dot
|
|
1118
|
+
nx = 200
|
|
1119
|
+
ns = 200000
|
|
1120
|
+
am,an = shape[-2..-1]
|
|
1121
|
+
bm,bn = b.shape[-2..-1]
|
|
1122
|
+
if am > nx && an > nx && bm > nx && bn > nx &&
|
|
1123
|
+
size > ns && b.size > ns
|
|
1124
|
+
@@warn_slow_dot = true
|
|
1125
|
+
warn "\nwarning: matrix dot for #{t} is slow. Consider SFloat, DFloat, SComplex, or DComplex to use cuBLAS.\n\n"
|
|
1126
|
+
end
|
|
1127
|
+
end
|
|
1128
|
+
self[false,:new].mulsum(b[false,:new,true,true], axis:-2)
|
|
1129
|
+
end
|
|
1130
|
+
end
|
|
1131
|
+
end
|
|
1132
|
+
end
|
|
1133
|
+
|
|
1134
|
+
# Inner product of two arrays.
|
|
1135
|
+
# Same as `(a*b).sum(axis:-1)`.
|
|
1136
|
+
# @param b [Cumo::NArray]
|
|
1137
|
+
# @param axis [Integer] applied axis
|
|
1138
|
+
# @return [Cumo::NArray] return (a*b).sum(axis:axis)
|
|
1139
|
+
|
|
1140
|
+
def inner(b, axis:-1)
|
|
1141
|
+
mulsum(b, axis:axis)
|
|
1142
|
+
end
|
|
1143
|
+
|
|
1144
|
+
# Outer product of two arrays.
|
|
1145
|
+
# Same as `self[false,:new] * b[false,:new,true]`.
|
|
1146
|
+
#
|
|
1147
|
+
# @param b [Cumo::NArray]
|
|
1148
|
+
# @param axis [Integer] applied axis (default=-1)
|
|
1149
|
+
# @return [Cumo::NArray] return outer product
|
|
1150
|
+
# @example
|
|
1151
|
+
# a = Cumo::DFloat.ones(5)
|
|
1152
|
+
# => Cumo::DFloat#shape=[5]
|
|
1153
|
+
# [1, 1, 1, 1, 1]
|
|
1154
|
+
# b = Cumo::DFloat.linspace(-2,2,5)
|
|
1155
|
+
# => Cumo::DFloat#shape=[5]
|
|
1156
|
+
# [-2, -1, 0, 1, 2]
|
|
1157
|
+
# a.outer(b)
|
|
1158
|
+
# => Cumo::DFloat#shape=[5,5]
|
|
1159
|
+
# [[-2, -1, 0, 1, 2],
|
|
1160
|
+
# [-2, -1, 0, 1, 2],
|
|
1161
|
+
# [-2, -1, 0, 1, 2],
|
|
1162
|
+
# [-2, -1, 0, 1, 2],
|
|
1163
|
+
# [-2, -1, 0, 1, 2]]
|
|
1164
|
+
|
|
1165
|
+
def outer(b, axis:nil)
|
|
1166
|
+
b = NArray.cast(b)
|
|
1167
|
+
if axis.nil?
|
|
1168
|
+
self[false,:new] * ((b.ndim==0) ? b : b[false,:new,true])
|
|
1169
|
+
else
|
|
1170
|
+
md,nd = [ndim,b.ndim].minmax
|
|
1171
|
+
axis = check_axis(axis) - nd
|
|
1172
|
+
if axis < -md
|
|
1173
|
+
raise ArgumentError,"axis=#{axis} is out of range"
|
|
1174
|
+
end
|
|
1175
|
+
adim = [true]*ndim
|
|
1176
|
+
adim[axis+ndim+1,0] = :new
|
|
1177
|
+
bdim = [true]*b.ndim
|
|
1178
|
+
bdim[axis+b.ndim,0] = :new
|
|
1179
|
+
self[*adim] * b[*bdim]
|
|
1180
|
+
end
|
|
1181
|
+
end
|
|
1182
|
+
|
|
1183
|
+
# Kronecker product of two arrays.
|
|
1184
|
+
#
|
|
1185
|
+
# kron(a,b)[k_0, k_1, ...] = a[i_0, i_1, ...] * b[j_0, j_1, ...]
|
|
1186
|
+
# where: k_n = i_n * b.shape[n] + j_n
|
|
1187
|
+
#
|
|
1188
|
+
# @param b [Cumo::NArray]
|
|
1189
|
+
# @return [Cumo::NArray] return Kronecker product
|
|
1190
|
+
# @example
|
|
1191
|
+
# Cumo::DFloat[1,10,100].kron([5,6,7])
|
|
1192
|
+
# => Cumo::DFloat#shape=[9]
|
|
1193
|
+
# [5, 6, 7, 50, 60, 70, 500, 600, 700]
|
|
1194
|
+
# Cumo::DFloat[5,6,7].kron([1,10,100])
|
|
1195
|
+
# => Cumo::DFloat#shape=[9]
|
|
1196
|
+
# [5, 50, 500, 6, 60, 600, 7, 70, 700]
|
|
1197
|
+
# Cumo::DFloat.eye(2).kron(Cumo::DFloat.ones(2,2))
|
|
1198
|
+
# => Cumo::DFloat#shape=[4,4]
|
|
1199
|
+
# [[1, 1, 0, 0],
|
|
1200
|
+
# [1, 1, 0, 0],
|
|
1201
|
+
# [0, 0, 1, 1],
|
|
1202
|
+
# [0, 0, 1, 1]]
|
|
1203
|
+
|
|
1204
|
+
def kron(b)
|
|
1205
|
+
b = NArray.cast(b)
|
|
1206
|
+
nda = ndim
|
|
1207
|
+
ndb = b.ndim
|
|
1208
|
+
shpa = shape
|
|
1209
|
+
shpb = b.shape
|
|
1210
|
+
adim = [:new]*(2*[ndb-nda,0].max) + [true,:new]*nda
|
|
1211
|
+
bdim = [:new]*(2*[nda-ndb,0].max) + [:new,true]*ndb
|
|
1212
|
+
shpr = (-[nda,ndb].max..-1).map{|i| (shpa[i]||1) * (shpb[i]||1)}
|
|
1213
|
+
(self[*adim] * b[*bdim]).reshape(*shpr)
|
|
1214
|
+
end
|
|
1215
|
+
|
|
1216
|
+
|
|
1217
|
+
# under construction
|
|
1218
|
+
def cov(y=nil, ddof:1, fweights:nil, aweights:nil)
|
|
1219
|
+
if y
|
|
1220
|
+
m = NArray.vstack([self,y])
|
|
1221
|
+
else
|
|
1222
|
+
m = self
|
|
1223
|
+
end
|
|
1224
|
+
w = nil
|
|
1225
|
+
if fweights
|
|
1226
|
+
f = fweights
|
|
1227
|
+
w = f
|
|
1228
|
+
end
|
|
1229
|
+
if aweights
|
|
1230
|
+
a = aweights
|
|
1231
|
+
w = w ? w*a : a
|
|
1232
|
+
end
|
|
1233
|
+
if w
|
|
1234
|
+
w_sum = w.sum(axis:-1, keepdims:true)
|
|
1235
|
+
if ddof == 0
|
|
1236
|
+
fact = w_sum
|
|
1237
|
+
elsif aweights.nil?
|
|
1238
|
+
fact = w_sum - ddof
|
|
1239
|
+
else
|
|
1240
|
+
wa_sum = (w*a).sum(axis:-1, keepdims:true)
|
|
1241
|
+
fact = w_sum - ddof * wa_sum / w_sum
|
|
1242
|
+
end
|
|
1243
|
+
if (fact <= 0).any?
|
|
1244
|
+
raise StandardError,"Degrees of freedom <= 0 for slice"
|
|
1245
|
+
end
|
|
1246
|
+
else
|
|
1247
|
+
fact = m.shape[-1] - ddof
|
|
1248
|
+
end
|
|
1249
|
+
if w
|
|
1250
|
+
m -= (m*w).sum(axis:-1, keepdims:true) / w_sum
|
|
1251
|
+
mw = m*w
|
|
1252
|
+
else
|
|
1253
|
+
m -= m.mean(axis:-1, keepdims:true)
|
|
1254
|
+
mw = m
|
|
1255
|
+
end
|
|
1256
|
+
mt = (m.ndim < 2) ? m : m.swapaxes(-2,-1)
|
|
1257
|
+
mw.dot(mt.conj) / fact
|
|
1258
|
+
end
|
|
1259
|
+
|
|
1260
|
+
private
|
|
1261
|
+
|
|
1262
|
+
# @!visibility private
|
|
1263
|
+
def check_axis(axis)
|
|
1264
|
+
unless Integer===axis
|
|
1265
|
+
raise ArgumentError,"axis=#{axis} must be Integer"
|
|
1266
|
+
end
|
|
1267
|
+
a = axis
|
|
1268
|
+
if a < 0
|
|
1269
|
+
a += ndim
|
|
1270
|
+
end
|
|
1271
|
+
if a < 0 || a >= ndim
|
|
1272
|
+
raise ArgumentError,"axis=#{axis} is invalid"
|
|
1273
|
+
end
|
|
1274
|
+
a
|
|
1275
|
+
end
|
|
1276
|
+
|
|
1277
|
+
end
|
|
1278
|
+
end
|