cumo 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/.gitignore +27 -0
- data/.travis.yml +5 -0
- data/3rd_party/mkmf-cu/.gitignore +36 -0
- data/3rd_party/mkmf-cu/Gemfile +3 -0
- data/3rd_party/mkmf-cu/LICENSE +21 -0
- data/3rd_party/mkmf-cu/README.md +36 -0
- data/3rd_party/mkmf-cu/Rakefile +11 -0
- data/3rd_party/mkmf-cu/bin/mkmf-cu-nvcc +4 -0
- data/3rd_party/mkmf-cu/lib/mkmf-cu.rb +32 -0
- data/3rd_party/mkmf-cu/lib/mkmf-cu/cli.rb +80 -0
- data/3rd_party/mkmf-cu/lib/mkmf-cu/nvcc.rb +157 -0
- data/3rd_party/mkmf-cu/mkmf-cu.gemspec +16 -0
- data/3rd_party/mkmf-cu/test/test_mkmf-cu.rb +67 -0
- data/CODE_OF_CONDUCT.md +46 -0
- data/Gemfile +8 -0
- data/LICENSE.txt +82 -0
- data/README.md +252 -0
- data/Rakefile +43 -0
- data/bench/broadcast_fp32.rb +138 -0
- data/bench/cumo_bench.rb +193 -0
- data/bench/numo_bench.rb +138 -0
- data/bench/reduction_fp32.rb +117 -0
- data/bin/console +14 -0
- data/bin/setup +8 -0
- data/cumo.gemspec +32 -0
- data/ext/cumo/cuda/cublas.c +278 -0
- data/ext/cumo/cuda/driver.c +421 -0
- data/ext/cumo/cuda/memory_pool.cpp +185 -0
- data/ext/cumo/cuda/memory_pool_impl.cpp +308 -0
- data/ext/cumo/cuda/memory_pool_impl.hpp +370 -0
- data/ext/cumo/cuda/memory_pool_impl_test.cpp +554 -0
- data/ext/cumo/cuda/nvrtc.c +207 -0
- data/ext/cumo/cuda/runtime.c +167 -0
- data/ext/cumo/cumo.c +148 -0
- data/ext/cumo/depend.erb +58 -0
- data/ext/cumo/extconf.rb +179 -0
- data/ext/cumo/include/cumo.h +25 -0
- data/ext/cumo/include/cumo/compat.h +23 -0
- data/ext/cumo/include/cumo/cuda/cublas.h +153 -0
- data/ext/cumo/include/cumo/cuda/cumo_thrust.hpp +187 -0
- data/ext/cumo/include/cumo/cuda/cumo_thrust_complex.hpp +79 -0
- data/ext/cumo/include/cumo/cuda/driver.h +22 -0
- data/ext/cumo/include/cumo/cuda/memory_pool.h +28 -0
- data/ext/cumo/include/cumo/cuda/nvrtc.h +22 -0
- data/ext/cumo/include/cumo/cuda/runtime.h +40 -0
- data/ext/cumo/include/cumo/indexer.h +238 -0
- data/ext/cumo/include/cumo/intern.h +142 -0
- data/ext/cumo/include/cumo/intern_fwd.h +38 -0
- data/ext/cumo/include/cumo/intern_kernel.h +6 -0
- data/ext/cumo/include/cumo/narray.h +429 -0
- data/ext/cumo/include/cumo/narray_kernel.h +149 -0
- data/ext/cumo/include/cumo/ndloop.h +95 -0
- data/ext/cumo/include/cumo/reduce_kernel.h +126 -0
- data/ext/cumo/include/cumo/template.h +158 -0
- data/ext/cumo/include/cumo/template_kernel.h +77 -0
- data/ext/cumo/include/cumo/types/bit.h +40 -0
- data/ext/cumo/include/cumo/types/bit_kernel.h +34 -0
- data/ext/cumo/include/cumo/types/complex.h +402 -0
- data/ext/cumo/include/cumo/types/complex_kernel.h +414 -0
- data/ext/cumo/include/cumo/types/complex_macro.h +382 -0
- data/ext/cumo/include/cumo/types/complex_macro_kernel.h +186 -0
- data/ext/cumo/include/cumo/types/dcomplex.h +46 -0
- data/ext/cumo/include/cumo/types/dcomplex_kernel.h +13 -0
- data/ext/cumo/include/cumo/types/dfloat.h +47 -0
- data/ext/cumo/include/cumo/types/dfloat_kernel.h +14 -0
- data/ext/cumo/include/cumo/types/float_def.h +34 -0
- data/ext/cumo/include/cumo/types/float_def_kernel.h +39 -0
- data/ext/cumo/include/cumo/types/float_macro.h +191 -0
- data/ext/cumo/include/cumo/types/float_macro_kernel.h +158 -0
- data/ext/cumo/include/cumo/types/int16.h +24 -0
- data/ext/cumo/include/cumo/types/int16_kernel.h +23 -0
- data/ext/cumo/include/cumo/types/int32.h +24 -0
- data/ext/cumo/include/cumo/types/int32_kernel.h +19 -0
- data/ext/cumo/include/cumo/types/int64.h +24 -0
- data/ext/cumo/include/cumo/types/int64_kernel.h +19 -0
- data/ext/cumo/include/cumo/types/int8.h +24 -0
- data/ext/cumo/include/cumo/types/int8_kernel.h +19 -0
- data/ext/cumo/include/cumo/types/int_macro.h +67 -0
- data/ext/cumo/include/cumo/types/int_macro_kernel.h +48 -0
- data/ext/cumo/include/cumo/types/real_accum.h +486 -0
- data/ext/cumo/include/cumo/types/real_accum_kernel.h +101 -0
- data/ext/cumo/include/cumo/types/robj_macro.h +80 -0
- data/ext/cumo/include/cumo/types/robj_macro_kernel.h +0 -0
- data/ext/cumo/include/cumo/types/robject.h +27 -0
- data/ext/cumo/include/cumo/types/robject_kernel.h +7 -0
- data/ext/cumo/include/cumo/types/scomplex.h +46 -0
- data/ext/cumo/include/cumo/types/scomplex_kernel.h +13 -0
- data/ext/cumo/include/cumo/types/sfloat.h +48 -0
- data/ext/cumo/include/cumo/types/sfloat_kernel.h +14 -0
- data/ext/cumo/include/cumo/types/uint16.h +25 -0
- data/ext/cumo/include/cumo/types/uint16_kernel.h +20 -0
- data/ext/cumo/include/cumo/types/uint32.h +25 -0
- data/ext/cumo/include/cumo/types/uint32_kernel.h +20 -0
- data/ext/cumo/include/cumo/types/uint64.h +25 -0
- data/ext/cumo/include/cumo/types/uint64_kernel.h +20 -0
- data/ext/cumo/include/cumo/types/uint8.h +25 -0
- data/ext/cumo/include/cumo/types/uint8_kernel.h +20 -0
- data/ext/cumo/include/cumo/types/uint_macro.h +58 -0
- data/ext/cumo/include/cumo/types/uint_macro_kernel.h +38 -0
- data/ext/cumo/include/cumo/types/xint_macro.h +169 -0
- data/ext/cumo/include/cumo/types/xint_macro_kernel.h +88 -0
- data/ext/cumo/narray/SFMT-params.h +97 -0
- data/ext/cumo/narray/SFMT-params19937.h +46 -0
- data/ext/cumo/narray/SFMT.c +620 -0
- data/ext/cumo/narray/SFMT.h +167 -0
- data/ext/cumo/narray/array.c +638 -0
- data/ext/cumo/narray/data.c +961 -0
- data/ext/cumo/narray/gen/cogen.rb +56 -0
- data/ext/cumo/narray/gen/cogen_kernel.rb +58 -0
- data/ext/cumo/narray/gen/def/bit.rb +37 -0
- data/ext/cumo/narray/gen/def/dcomplex.rb +39 -0
- data/ext/cumo/narray/gen/def/dfloat.rb +37 -0
- data/ext/cumo/narray/gen/def/int16.rb +36 -0
- data/ext/cumo/narray/gen/def/int32.rb +36 -0
- data/ext/cumo/narray/gen/def/int64.rb +36 -0
- data/ext/cumo/narray/gen/def/int8.rb +36 -0
- data/ext/cumo/narray/gen/def/robject.rb +37 -0
- data/ext/cumo/narray/gen/def/scomplex.rb +39 -0
- data/ext/cumo/narray/gen/def/sfloat.rb +37 -0
- data/ext/cumo/narray/gen/def/uint16.rb +36 -0
- data/ext/cumo/narray/gen/def/uint32.rb +36 -0
- data/ext/cumo/narray/gen/def/uint64.rb +36 -0
- data/ext/cumo/narray/gen/def/uint8.rb +36 -0
- data/ext/cumo/narray/gen/erbpp2.rb +346 -0
- data/ext/cumo/narray/gen/narray_def.rb +268 -0
- data/ext/cumo/narray/gen/spec.rb +425 -0
- data/ext/cumo/narray/gen/tmpl/accum.c +86 -0
- data/ext/cumo/narray/gen/tmpl/accum_binary.c +121 -0
- data/ext/cumo/narray/gen/tmpl/accum_binary_kernel.cu +61 -0
- data/ext/cumo/narray/gen/tmpl/accum_index.c +119 -0
- data/ext/cumo/narray/gen/tmpl/accum_index_kernel.cu +66 -0
- data/ext/cumo/narray/gen/tmpl/accum_kernel.cu +12 -0
- data/ext/cumo/narray/gen/tmpl/alloc_func.c +107 -0
- data/ext/cumo/narray/gen/tmpl/allocate.c +37 -0
- data/ext/cumo/narray/gen/tmpl/aref.c +66 -0
- data/ext/cumo/narray/gen/tmpl/aref_cpu.c +50 -0
- data/ext/cumo/narray/gen/tmpl/aset.c +56 -0
- data/ext/cumo/narray/gen/tmpl/binary.c +162 -0
- data/ext/cumo/narray/gen/tmpl/binary2.c +70 -0
- data/ext/cumo/narray/gen/tmpl/binary2_kernel.cu +15 -0
- data/ext/cumo/narray/gen/tmpl/binary_kernel.cu +31 -0
- data/ext/cumo/narray/gen/tmpl/binary_s.c +45 -0
- data/ext/cumo/narray/gen/tmpl/binary_s_kernel.cu +15 -0
- data/ext/cumo/narray/gen/tmpl/bincount.c +181 -0
- data/ext/cumo/narray/gen/tmpl/cast.c +44 -0
- data/ext/cumo/narray/gen/tmpl/cast_array.c +13 -0
- data/ext/cumo/narray/gen/tmpl/class.c +9 -0
- data/ext/cumo/narray/gen/tmpl/class_kernel.cu +6 -0
- data/ext/cumo/narray/gen/tmpl/clip.c +121 -0
- data/ext/cumo/narray/gen/tmpl/coerce_cast.c +10 -0
- data/ext/cumo/narray/gen/tmpl/complex_accum_kernel.cu +129 -0
- data/ext/cumo/narray/gen/tmpl/cond_binary.c +68 -0
- data/ext/cumo/narray/gen/tmpl/cond_binary_kernel.cu +18 -0
- data/ext/cumo/narray/gen/tmpl/cond_unary.c +46 -0
- data/ext/cumo/narray/gen/tmpl/cum.c +50 -0
- data/ext/cumo/narray/gen/tmpl/each.c +47 -0
- data/ext/cumo/narray/gen/tmpl/each_with_index.c +70 -0
- data/ext/cumo/narray/gen/tmpl/ewcomp.c +79 -0
- data/ext/cumo/narray/gen/tmpl/ewcomp_kernel.cu +19 -0
- data/ext/cumo/narray/gen/tmpl/extract.c +22 -0
- data/ext/cumo/narray/gen/tmpl/extract_cpu.c +26 -0
- data/ext/cumo/narray/gen/tmpl/extract_data.c +53 -0
- data/ext/cumo/narray/gen/tmpl/eye.c +105 -0
- data/ext/cumo/narray/gen/tmpl/eye_kernel.cu +19 -0
- data/ext/cumo/narray/gen/tmpl/fill.c +52 -0
- data/ext/cumo/narray/gen/tmpl/fill_kernel.cu +29 -0
- data/ext/cumo/narray/gen/tmpl/float_accum_kernel.cu +106 -0
- data/ext/cumo/narray/gen/tmpl/format.c +62 -0
- data/ext/cumo/narray/gen/tmpl/format_to_a.c +49 -0
- data/ext/cumo/narray/gen/tmpl/frexp.c +38 -0
- data/ext/cumo/narray/gen/tmpl/gemm.c +203 -0
- data/ext/cumo/narray/gen/tmpl/init_class.c +20 -0
- data/ext/cumo/narray/gen/tmpl/init_module.c +12 -0
- data/ext/cumo/narray/gen/tmpl/inspect.c +21 -0
- data/ext/cumo/narray/gen/tmpl/lib.c +50 -0
- data/ext/cumo/narray/gen/tmpl/lib_kernel.cu +24 -0
- data/ext/cumo/narray/gen/tmpl/logseq.c +102 -0
- data/ext/cumo/narray/gen/tmpl/logseq_kernel.cu +31 -0
- data/ext/cumo/narray/gen/tmpl/map_with_index.c +98 -0
- data/ext/cumo/narray/gen/tmpl/median.c +66 -0
- data/ext/cumo/narray/gen/tmpl/minmax.c +47 -0
- data/ext/cumo/narray/gen/tmpl/module.c +9 -0
- data/ext/cumo/narray/gen/tmpl/module_kernel.cu +1 -0
- data/ext/cumo/narray/gen/tmpl/new_dim0.c +15 -0
- data/ext/cumo/narray/gen/tmpl/new_dim0_kernel.cu +8 -0
- data/ext/cumo/narray/gen/tmpl/poly.c +50 -0
- data/ext/cumo/narray/gen/tmpl/pow.c +97 -0
- data/ext/cumo/narray/gen/tmpl/pow_kernel.cu +29 -0
- data/ext/cumo/narray/gen/tmpl/powint.c +17 -0
- data/ext/cumo/narray/gen/tmpl/qsort.c +212 -0
- data/ext/cumo/narray/gen/tmpl/rand.c +168 -0
- data/ext/cumo/narray/gen/tmpl/rand_norm.c +121 -0
- data/ext/cumo/narray/gen/tmpl/real_accum_kernel.cu +75 -0
- data/ext/cumo/narray/gen/tmpl/seq.c +112 -0
- data/ext/cumo/narray/gen/tmpl/seq_kernel.cu +43 -0
- data/ext/cumo/narray/gen/tmpl/set2.c +57 -0
- data/ext/cumo/narray/gen/tmpl/sort.c +48 -0
- data/ext/cumo/narray/gen/tmpl/sort_index.c +111 -0
- data/ext/cumo/narray/gen/tmpl/store.c +41 -0
- data/ext/cumo/narray/gen/tmpl/store_array.c +187 -0
- data/ext/cumo/narray/gen/tmpl/store_array_kernel.cu +58 -0
- data/ext/cumo/narray/gen/tmpl/store_bit.c +86 -0
- data/ext/cumo/narray/gen/tmpl/store_bit_kernel.cu +66 -0
- data/ext/cumo/narray/gen/tmpl/store_from.c +81 -0
- data/ext/cumo/narray/gen/tmpl/store_from_kernel.cu +58 -0
- data/ext/cumo/narray/gen/tmpl/store_kernel.cu +3 -0
- data/ext/cumo/narray/gen/tmpl/store_numeric.c +9 -0
- data/ext/cumo/narray/gen/tmpl/to_a.c +43 -0
- data/ext/cumo/narray/gen/tmpl/unary.c +132 -0
- data/ext/cumo/narray/gen/tmpl/unary2.c +60 -0
- data/ext/cumo/narray/gen/tmpl/unary_kernel.cu +72 -0
- data/ext/cumo/narray/gen/tmpl/unary_ret2.c +34 -0
- data/ext/cumo/narray/gen/tmpl/unary_s.c +86 -0
- data/ext/cumo/narray/gen/tmpl/unary_s_kernel.cu +58 -0
- data/ext/cumo/narray/gen/tmpl_bit/allocate.c +24 -0
- data/ext/cumo/narray/gen/tmpl_bit/aref.c +54 -0
- data/ext/cumo/narray/gen/tmpl_bit/aref_cpu.c +57 -0
- data/ext/cumo/narray/gen/tmpl_bit/aset.c +56 -0
- data/ext/cumo/narray/gen/tmpl_bit/binary.c +98 -0
- data/ext/cumo/narray/gen/tmpl_bit/bit_count.c +64 -0
- data/ext/cumo/narray/gen/tmpl_bit/bit_count_cpu.c +88 -0
- data/ext/cumo/narray/gen/tmpl_bit/bit_count_kernel.cu +76 -0
- data/ext/cumo/narray/gen/tmpl_bit/bit_reduce.c +133 -0
- data/ext/cumo/narray/gen/tmpl_bit/each.c +48 -0
- data/ext/cumo/narray/gen/tmpl_bit/each_with_index.c +70 -0
- data/ext/cumo/narray/gen/tmpl_bit/extract.c +30 -0
- data/ext/cumo/narray/gen/tmpl_bit/extract_cpu.c +29 -0
- data/ext/cumo/narray/gen/tmpl_bit/fill.c +69 -0
- data/ext/cumo/narray/gen/tmpl_bit/format.c +64 -0
- data/ext/cumo/narray/gen/tmpl_bit/format_to_a.c +51 -0
- data/ext/cumo/narray/gen/tmpl_bit/inspect.c +21 -0
- data/ext/cumo/narray/gen/tmpl_bit/mask.c +136 -0
- data/ext/cumo/narray/gen/tmpl_bit/none_p.c +14 -0
- data/ext/cumo/narray/gen/tmpl_bit/store_array.c +108 -0
- data/ext/cumo/narray/gen/tmpl_bit/store_bit.c +70 -0
- data/ext/cumo/narray/gen/tmpl_bit/store_from.c +60 -0
- data/ext/cumo/narray/gen/tmpl_bit/to_a.c +47 -0
- data/ext/cumo/narray/gen/tmpl_bit/unary.c +81 -0
- data/ext/cumo/narray/gen/tmpl_bit/where.c +90 -0
- data/ext/cumo/narray/gen/tmpl_bit/where2.c +95 -0
- data/ext/cumo/narray/index.c +880 -0
- data/ext/cumo/narray/kwargs.c +153 -0
- data/ext/cumo/narray/math.c +142 -0
- data/ext/cumo/narray/narray.c +1948 -0
- data/ext/cumo/narray/ndloop.c +2105 -0
- data/ext/cumo/narray/rand.c +45 -0
- data/ext/cumo/narray/step.c +474 -0
- data/ext/cumo/narray/struct.c +886 -0
- data/lib/cumo.rb +3 -0
- data/lib/cumo/cuda.rb +11 -0
- data/lib/cumo/cuda/compile_error.rb +36 -0
- data/lib/cumo/cuda/compiler.rb +161 -0
- data/lib/cumo/cuda/device.rb +47 -0
- data/lib/cumo/cuda/link_state.rb +31 -0
- data/lib/cumo/cuda/module.rb +40 -0
- data/lib/cumo/cuda/nvrtc_program.rb +27 -0
- data/lib/cumo/linalg.rb +12 -0
- data/lib/cumo/narray.rb +2 -0
- data/lib/cumo/narray/extra.rb +1278 -0
- data/lib/erbpp.rb +294 -0
- data/lib/erbpp/line_number.rb +137 -0
- data/lib/erbpp/narray_def.rb +381 -0
- data/numo-narray-version +1 -0
- data/run.gdb +7 -0
- metadata +353 -0
data/ext/cumo/depend.erb
ADDED
@@ -0,0 +1,58 @@
|
|
1
|
+
TAGSRC = \
|
2
|
+
../../ruby/include/ruby/*.h \
|
3
|
+
../../ruby/*.c \
|
4
|
+
narray/*.h \
|
5
|
+
narray/types/*.h \
|
6
|
+
narray/*.c \
|
7
|
+
narray/types/*.c \
|
8
|
+
narray/types/*.cu
|
9
|
+
|
10
|
+
tags : TAGS
|
11
|
+
TAGS : $(TAGSRC)
|
12
|
+
etags $(TAGSRC)
|
13
|
+
|
14
|
+
C_TMPL = <%=Dir.glob("narray/gen/tmpl*/*.c").join(" ")%>
|
15
|
+
CU_TMPL = <%=Dir.glob("narray/gen/tmpl*/*.cu").join(" ")%>
|
16
|
+
|
17
|
+
C_COGEN = narray/gen/cogen.rb
|
18
|
+
CU_COGEN = narray/gen/cogen_kernel.rb
|
19
|
+
C_DEPENDS = $(C_TMPL) narray/gen/*.rb
|
20
|
+
CU_DEPENDS = $(CU_TMPL) narray/gen/*.rb
|
21
|
+
|
22
|
+
<%
|
23
|
+
list_type_c = []
|
24
|
+
list_type_rb = Dir.glob("narray/gen/def/*.rb")
|
25
|
+
list_type_rb.each do |type_rb|
|
26
|
+
type_name = File.basename(type_rb, ".rb")
|
27
|
+
next if ENV['DTYPE'] and !type_name.downcase.include?(ENV['DTYPE'].downcase)
|
28
|
+
list_type_c << type_c = "narray/types/" + type_name + ".c"
|
29
|
+
%>
|
30
|
+
<%=type_c%>: <%=type_rb%> $(C_DEPENDS)
|
31
|
+
$(MAKEDIRS) $(@D) types
|
32
|
+
ruby $(C_COGEN) -l -o $@ <%=type_rb%>
|
33
|
+
<% end %>
|
34
|
+
|
35
|
+
<%
|
36
|
+
list_type_cu = []
|
37
|
+
list_type_rb = Dir.glob("narray/gen/def/*.rb")
|
38
|
+
list_type_rb.each do |type_rb|
|
39
|
+
type_name = File.basename(type_rb, ".rb")
|
40
|
+
next if ENV['DTYPE'] and !type_name.downcase.include?(ENV['DTYPE'].downcase)
|
41
|
+
list_type_cu << type_cu = "narray/types/" + type_name + "_kernel.cu"
|
42
|
+
%>
|
43
|
+
<%=type_cu%>: <%=type_rb%> $(CU_DEPENDS)
|
44
|
+
$(MAKEDIRS) $(@D) types
|
45
|
+
ruby $(CU_COGEN) -l -o $@ <%=type_rb%>
|
46
|
+
<% end %>
|
47
|
+
|
48
|
+
src : <%= list_type_cu.join(" ") %> <%= list_type_c.join(" ") %>
|
49
|
+
|
50
|
+
build-ctest : cuda/memory_pool_impl_test.exe
|
51
|
+
|
52
|
+
run-ctest : cuda/memory_pool_impl_test.exe
|
53
|
+
./$<
|
54
|
+
|
55
|
+
cuda/memory_pool_impl_test.exe: cuda/memory_pool_impl_test.cpp cuda/memory_pool_impl.cpp cuda/memory_pool_impl.hpp
|
56
|
+
nvcc -DNO_RUBY -std=c++14 <%= ENV['DEBUG'] ? '-g -O0 --compiler-options -Wall' : '' %> -L. -L$(libdir) -I. $(INCFLAGS) -o $@ $< cuda/memory_pool_impl.cpp
|
57
|
+
|
58
|
+
CLEANOBJS = *.o */*.o */*/*.o *.bak narray/types/*.c narray/types/*_kernel.cu *.exe */*.exe
|
data/ext/cumo/extconf.rb
ADDED
@@ -0,0 +1,179 @@
|
|
1
|
+
require 'rbconfig.rb'
|
2
|
+
require "erb"
|
3
|
+
require_relative '../../3rd_party/mkmf-cu/lib/mkmf-cu'
|
4
|
+
|
5
|
+
if RUBY_VERSION < "2.0.0"
|
6
|
+
puts "Cumo::NArray requires Ruby version 2.0 or later."
|
7
|
+
exit(1)
|
8
|
+
end
|
9
|
+
|
10
|
+
def have_numo_narray!
|
11
|
+
version_path = File.join(__dir__, "..", "..", "numo-narray-version")
|
12
|
+
version = File.read(version_path).strip
|
13
|
+
gem_spec = Gem::Specification.find_by_name("numo-narray", version)
|
14
|
+
|
15
|
+
$INCFLAGS += " -I#{gem_spec.gem_dir}/ext/numo/narray"
|
16
|
+
if !have_header("numo/narray.h")
|
17
|
+
puts "
|
18
|
+
Header numo/narray.h was not found. Give pathname as follows:
|
19
|
+
% ruby extconf.rb --with-narray-include=narray_h_dir"
|
20
|
+
exit(1)
|
21
|
+
end
|
22
|
+
|
23
|
+
if RUBY_PLATFORM =~ /cygwin|mingw/
|
24
|
+
$LDFLAGS += " -L#{gem_spec.gem_dir}/ext/numo"
|
25
|
+
unless have_library("narray","nary_new")
|
26
|
+
puts "libnarray.a not found"
|
27
|
+
exit(1)
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
|
32
|
+
def create_depend
|
33
|
+
message "creating depend\n"
|
34
|
+
depend_path = File.join(__dir__, "depend")
|
35
|
+
File.open(depend_path, "w") do |depend|
|
36
|
+
depend_erb_path = File.join(__dir__, "depend.erb")
|
37
|
+
File.open(depend_erb_path, "r") do |depend_erb|
|
38
|
+
erb = ERB.new(depend_erb.read)
|
39
|
+
erb.filename = depend_erb_path
|
40
|
+
depend.print(erb.result)
|
41
|
+
end
|
42
|
+
end
|
43
|
+
end
|
44
|
+
|
45
|
+
rm_f 'include/cumo/extconf.h'
|
46
|
+
|
47
|
+
MakeMakefileCuda.install!(cxx: true)
|
48
|
+
|
49
|
+
if ENV['DEBUG']
|
50
|
+
$CFLAGS="-g -O0 -Wall"
|
51
|
+
end
|
52
|
+
$CXXFLAGS += " -std=c++14 "
|
53
|
+
#$CFLAGS=" $(cflags) -O3 -m64 -msse2 -funroll-loops"
|
54
|
+
#$CFLAGS=" $(cflags) -O3"
|
55
|
+
$INCFLAGS = "-Iinclude -Inarray -Icuda #{$INCFLAGS}"
|
56
|
+
|
57
|
+
$INSTALLFILES = Dir.glob(%w[include/cumo/*.h include/cumo/types/*.h include/cumo/cuda/*.h]).map{|x| [x,'$(archdir)'] }
|
58
|
+
$INSTALLFILES << ['include/cumo/extconf.h','$(archdir)']
|
59
|
+
if /cygwin|mingw/ =~ RUBY_PLATFORM
|
60
|
+
$INSTALLFILES << ['libcumo.a', '$(archdir)']
|
61
|
+
end
|
62
|
+
|
63
|
+
srcs = %w(
|
64
|
+
cumo
|
65
|
+
narray/narray
|
66
|
+
narray/array
|
67
|
+
narray/step
|
68
|
+
narray/index
|
69
|
+
narray/ndloop
|
70
|
+
narray/data
|
71
|
+
narray/types/bit
|
72
|
+
narray/types/int8
|
73
|
+
narray/types/int16
|
74
|
+
narray/types/int32
|
75
|
+
narray/types/int64
|
76
|
+
narray/types/uint8
|
77
|
+
narray/types/uint16
|
78
|
+
narray/types/uint32
|
79
|
+
narray/types/uint64
|
80
|
+
narray/types/sfloat
|
81
|
+
narray/types/dfloat
|
82
|
+
narray/types/scomplex
|
83
|
+
narray/types/dcomplex
|
84
|
+
narray/types/robject
|
85
|
+
narray/types/bit_kernel
|
86
|
+
narray/types/int8_kernel
|
87
|
+
narray/types/int16_kernel
|
88
|
+
narray/types/int32_kernel
|
89
|
+
narray/types/int64_kernel
|
90
|
+
narray/types/uint8_kernel
|
91
|
+
narray/types/uint16_kernel
|
92
|
+
narray/types/uint32_kernel
|
93
|
+
narray/types/uint64_kernel
|
94
|
+
narray/types/sfloat_kernel
|
95
|
+
narray/types/dfloat_kernel
|
96
|
+
narray/types/scomplex_kernel
|
97
|
+
narray/types/dcomplex_kernel
|
98
|
+
narray/types/robject_kernel
|
99
|
+
narray/math
|
100
|
+
narray/SFMT
|
101
|
+
narray/struct
|
102
|
+
narray/rand
|
103
|
+
cuda/cublas
|
104
|
+
cuda/driver
|
105
|
+
cuda/memory_pool
|
106
|
+
cuda/memory_pool_impl
|
107
|
+
cuda/runtime
|
108
|
+
cuda/nvrtc
|
109
|
+
)
|
110
|
+
|
111
|
+
if RUBY_VERSION[0..3] == "2.1."
|
112
|
+
puts "add kwargs"
|
113
|
+
srcs << "kwargs"
|
114
|
+
end
|
115
|
+
|
116
|
+
$objs = srcs.map {|src| "#{src}.o" }
|
117
|
+
|
118
|
+
dir_config("narray")
|
119
|
+
|
120
|
+
have_numo_narray!
|
121
|
+
|
122
|
+
if have_header("dlfcn.h")
|
123
|
+
exit(1) unless have_library("dl")
|
124
|
+
exit(1) unless have_func("dlopen")
|
125
|
+
elsif have_header("windows.h")
|
126
|
+
exit(1) unless have_func("LoadLibrary")
|
127
|
+
end
|
128
|
+
|
129
|
+
if have_header("stdbool.h")
|
130
|
+
stdbool = "stdbool.h"
|
131
|
+
else
|
132
|
+
stdbool = nil
|
133
|
+
end
|
134
|
+
|
135
|
+
if have_header("stdint.h")
|
136
|
+
stdint = "stdint.h"
|
137
|
+
elsif have_header("sys/types.h")
|
138
|
+
stdint = "sys/types.h"
|
139
|
+
else
|
140
|
+
stdint = nil
|
141
|
+
end
|
142
|
+
|
143
|
+
have_type("bool", stdbool)
|
144
|
+
unless have_type("u_int8_t", stdint)
|
145
|
+
have_type("uint8_t",stdint)
|
146
|
+
end
|
147
|
+
unless have_type("u_int16_t", stdint)
|
148
|
+
have_type("uint16_t",stdint)
|
149
|
+
end
|
150
|
+
have_type("int32_t", stdint)
|
151
|
+
unless have_type("u_int32_t", stdint)
|
152
|
+
have_type("uint32_t",stdint)
|
153
|
+
end
|
154
|
+
have_type("int64_t", stdint)
|
155
|
+
unless have_type("u_int64_t", stdint)
|
156
|
+
have_type("uint64_t", stdint)
|
157
|
+
end
|
158
|
+
have_func("exp10")
|
159
|
+
|
160
|
+
have_var("rb_cComplex")
|
161
|
+
have_func("rb_thread_call_without_gvl")
|
162
|
+
|
163
|
+
create_header('include/cumo/extconf.h')
|
164
|
+
$extconf_h = nil # nvcc does not support #include RUBY_EXTCONF_H
|
165
|
+
|
166
|
+
create_depend
|
167
|
+
|
168
|
+
HEADER_DIRS = (ENV['CPATH'] || '').split(':')
|
169
|
+
LIB_DIRS = (ENV['LIBRARY_PATH'] || '').split(':')
|
170
|
+
dir_config('cumo', HEADER_DIRS, LIB_DIRS)
|
171
|
+
|
172
|
+
have_library('cuda')
|
173
|
+
have_library('cudart')
|
174
|
+
have_library('nvrtc')
|
175
|
+
have_library('cublas')
|
176
|
+
# have_library('cusolver')
|
177
|
+
# have_library('curand')
|
178
|
+
|
179
|
+
create_makefile('cumo')
|
@@ -0,0 +1,25 @@
|
|
1
|
+
#ifndef CUMO_H
|
2
|
+
#define CUMO_H
|
3
|
+
|
4
|
+
#include "cumo/narray.h"
|
5
|
+
|
6
|
+
#if defined(__cplusplus)
|
7
|
+
extern "C" {
|
8
|
+
#if 0
|
9
|
+
} /* satisfy cc-mode */
|
10
|
+
#endif
|
11
|
+
#endif
|
12
|
+
|
13
|
+
#define CUMO_VERSION "0.1.0"
|
14
|
+
#define CUMO_VERSION_CODE 10
|
15
|
+
|
16
|
+
bool cumo_compatible_mode_enabled_p();
|
17
|
+
|
18
|
+
#if defined(__cplusplus)
|
19
|
+
#if 0
|
20
|
+
{ /* satisfy cc-mode */
|
21
|
+
#endif
|
22
|
+
} /* extern "C" { */
|
23
|
+
#endif
|
24
|
+
|
25
|
+
#endif /* ifndef CUMO_H */
|
@@ -0,0 +1,23 @@
|
|
1
|
+
#ifndef CUMO_COMPAT_H
|
2
|
+
#define CUMO_COMPAT_H
|
3
|
+
|
4
|
+
#if !defined RSTRING_LEN
|
5
|
+
#define RSTRING_LEN(a) RSTRING(a)->len
|
6
|
+
#endif
|
7
|
+
#if !defined RSTRING_PTR
|
8
|
+
#define RSTRING_PTR(a) RSTRING(a)->ptr
|
9
|
+
#endif
|
10
|
+
#if !defined RARRAY_LEN
|
11
|
+
#define RARRAY_LEN(a) RARRAY(a)->len
|
12
|
+
#endif
|
13
|
+
#if !defined RARRAY_PTR
|
14
|
+
#define RARRAY_PTR(a) RARRAY(a)->ptr
|
15
|
+
#endif
|
16
|
+
#if !defined RARRAY_AREF
|
17
|
+
#define RARRAY_AREF(a,i) RARRAY_PTR(a)[i]
|
18
|
+
#endif
|
19
|
+
#if !defined RARRAY_ASET
|
20
|
+
#define RARRAY_ASET(a,i,v) (RARRAY_PTR(a)[i] = v)
|
21
|
+
#endif
|
22
|
+
|
23
|
+
#endif /* ifndef CUMO_COMPAT_H */
|
@@ -0,0 +1,153 @@
|
|
1
|
+
#ifndef CUMO_CUDA_CUBLAS_H
|
2
|
+
#define CUMO_CUDA_CUBLAS_H
|
3
|
+
|
4
|
+
#include <ruby.h>
|
5
|
+
#include "cublas_v2.h"
|
6
|
+
|
7
|
+
#if defined(__cplusplus)
|
8
|
+
extern "C" {
|
9
|
+
#if 0
|
10
|
+
} /* satisfy cc-mode */
|
11
|
+
#endif
|
12
|
+
#endif
|
13
|
+
|
14
|
+
#define option_value cumo_cublas_option_value
|
15
|
+
extern VALUE cumo_cublas_option_value(VALUE value, VALUE default_value);
|
16
|
+
|
17
|
+
//#define option_order cumo_cublas_option_order
|
18
|
+
//extern enum CBLAS_ORDER cumo_cublas_option_order(VALUE order);
|
19
|
+
|
20
|
+
#define option_trans cumo_cublas_option_trans
|
21
|
+
extern cublasOperation_t cumo_cublas_option_trans(VALUE trans);
|
22
|
+
|
23
|
+
#define option_uplo cumo_cublas_option_uplo
|
24
|
+
extern cublasFillMode_t cumo_cublas_option_uplo(VALUE uplo);
|
25
|
+
|
26
|
+
#define option_diag cumo_cublas_option_diag
|
27
|
+
extern cublasDiagType_t cumo_cublas_option_diag(VALUE diag);
|
28
|
+
|
29
|
+
#define option_side cumo_cublas_option_side
|
30
|
+
extern cublasSideMode_t cumo_cublas_option_side(VALUE side);
|
31
|
+
|
32
|
+
//#define check_func cumo_cublas_check_func
|
33
|
+
//extern void cumo_cublas_check_func(void **func, const char *name);
|
34
|
+
|
35
|
+
// TODO: Check if a and b are row_major?
|
36
|
+
#define SWAP_IFROW(a,b,tmp) \
|
37
|
+
{(tmp)=(a);(a)=(b);(b)=(tmp);}
|
38
|
+
|
39
|
+
#define SWAP_IFTR(trans,a,b,tmp) \
|
40
|
+
{ if ((trans)!=CUBLAS_OP_N) \
|
41
|
+
{(tmp)=(a);(a)=(b);(b)=(tmp);} \
|
42
|
+
}
|
43
|
+
|
44
|
+
/*
|
45
|
+
//#define SWAP_IFCOLTR(order,trans,a,b,tmp) \
|
46
|
+
// { if (((order)==CblasRowMajor && (trans)!=CblasNoTrans) || \
|
47
|
+
// ((order)!=CblasRowMajor && (trans)==CblasNoTrans)) \
|
48
|
+
// {(tmp)=(a);(a)=(b);(b)=(tmp);} \
|
49
|
+
// }
|
50
|
+
|
51
|
+
//#define SWAP_IFCOL(order,a,b,tmp) \
|
52
|
+
// { if ((order)==CblasColMajor) {(tmp)=(a);(a)=(b);(b)=(tmp);} }
|
53
|
+
//
|
54
|
+
//#define SWAP_IFROW(order,a,b,tmp) \
|
55
|
+
// { if ((order)==CblasRowMajor) {(tmp)=(a);(a)=(b);(b)=(tmp);} }
|
56
|
+
//
|
57
|
+
//#define SWAP_IFCOLTR(order,trans,a,b,tmp) \
|
58
|
+
// { if (((order)==CblasRowMajor && (trans)!=CblasNoTrans) || \
|
59
|
+
// ((order)!=CblasRowMajor && (trans)==CblasNoTrans)) \
|
60
|
+
// {(tmp)=(a);(a)=(b);(b)=(tmp);} \
|
61
|
+
// }
|
62
|
+
//
|
63
|
+
//#define CHECK_FUNC(fptr, fname) \
|
64
|
+
// { if ((fptr)==0) { check_func((void*)(&(fptr)),fname); } }
|
65
|
+
*/
|
66
|
+
|
67
|
+
#define ROW_SIZE(na) ((na)->shape[(na)->ndim-2])
|
68
|
+
#define COL_SIZE(na) ((na)->shape[(na)->ndim-1])
|
69
|
+
|
70
|
+
#define CHECK_NARRAY_TYPE(x,t) \
|
71
|
+
if (CLASS_OF(x)!=(t)) { \
|
72
|
+
rb_raise(rb_eTypeError,"invalid NArray type (class)"); \
|
73
|
+
}
|
74
|
+
|
75
|
+
// Error Class ??
|
76
|
+
#define CHECK_DIM_GE(na,nd) \
|
77
|
+
if ((na)->ndim<(nd)) { \
|
78
|
+
rb_raise(nary_eShapeError, \
|
79
|
+
"n-dimension=%d, but >=%d is expected", \
|
80
|
+
(na)->ndim, (nd)); \
|
81
|
+
}
|
82
|
+
|
83
|
+
#define CHECK_DIM_EQ(na1,nd) \
|
84
|
+
if ((na1)->ndim != (nd)) { \
|
85
|
+
rb_raise(nary_eShapeError, \
|
86
|
+
"dimention mismatch: %d != %d", \
|
87
|
+
(na1)->ndim, (nd)); \
|
88
|
+
}
|
89
|
+
|
90
|
+
#define CHECK_SQUARE(name,na) \
|
91
|
+
if ((na)->shape[(na)->ndim-1] != (na)->shape[(na)->ndim-2]) { \
|
92
|
+
rb_raise(nary_eShapeError,"%s is not square matrix",name); \
|
93
|
+
}
|
94
|
+
|
95
|
+
#define CHECK_SIZE_GE(na,sz) \
|
96
|
+
if ((na)->size < (size_t)(sz)) { \
|
97
|
+
rb_raise(nary_eShapeError, \
|
98
|
+
"NArray size must be >= %"SZF"u",(size_t)(sz));\
|
99
|
+
}
|
100
|
+
#define CHECK_NON_EMPTY(na) \
|
101
|
+
if ((na)->size==0) { \
|
102
|
+
rb_raise(nary_eShapeError,"empty NArray"); \
|
103
|
+
}
|
104
|
+
|
105
|
+
#define CHECK_SIZE_EQ(n,m) \
|
106
|
+
if ((n)!=(m)) { \
|
107
|
+
rb_raise(nary_eShapeError, \
|
108
|
+
"size mismatch: %"SZF"d != %"SZF"d", \
|
109
|
+
(size_t)(n),(size_t)(m)); \
|
110
|
+
}
|
111
|
+
|
112
|
+
#define CHECK_SAME_SHAPE(na1,na2) \
|
113
|
+
{ int i; \
|
114
|
+
CHECK_DIM_EQ(na1,na2->ndim); \
|
115
|
+
for (i=0; i<na1->ndim; i++) { \
|
116
|
+
CHECK_SIZE_EQ(na1->shape[i],na2->shape[i]); \
|
117
|
+
} \
|
118
|
+
}
|
119
|
+
|
120
|
+
#define CHECK_INT_EQ(sm,m,sn,n) \
|
121
|
+
if ((m) != (n)) { \
|
122
|
+
rb_raise(nary_eShapeError, \
|
123
|
+
"%s must be == %s: %s=%d %s=%d", \
|
124
|
+
sm,sn,sm,m,sn,n); \
|
125
|
+
}
|
126
|
+
|
127
|
+
// Error Class ??
|
128
|
+
#define CHECK_LEADING_GE(sld,ld,sn,n) \
|
129
|
+
if ((ld) < (n)) { \
|
130
|
+
rb_raise(nary_eShapeError, \
|
131
|
+
"%s must be >= max(%s,1): %s=%d %s=%d", \
|
132
|
+
sld,sn,sld,ld,sn,n); \
|
133
|
+
}
|
134
|
+
|
135
|
+
#define COPY_OR_CAST_TO(a,T) \
|
136
|
+
{ \
|
137
|
+
if (CLASS_OF(a) == (T)) { \
|
138
|
+
if (!TEST_INPLACE(a)) { \
|
139
|
+
a = na_copy(a); \
|
140
|
+
} \
|
141
|
+
} else { \
|
142
|
+
a = rb_funcall(T,rb_intern("cast"),1,a); \
|
143
|
+
} \
|
144
|
+
}
|
145
|
+
|
146
|
+
#if defined(__cplusplus)
|
147
|
+
#if 0
|
148
|
+
{ /* satisfy cc-mode */
|
149
|
+
#endif
|
150
|
+
} /* extern "C" { */
|
151
|
+
#endif
|
152
|
+
|
153
|
+
#endif /* ifndef CUMO_CUDA_CUBLAS_H */
|
@@ -0,0 +1,187 @@
|
|
1
|
+
#ifndef CUMO_CUDA_THRUST_H
|
2
|
+
#define CUMO_CUDA_THRUST_H
|
3
|
+
|
4
|
+
#include <thrust/device_ptr.h>
|
5
|
+
#include <thrust/device_vector.h>
|
6
|
+
#include <thrust/extrema.h>
|
7
|
+
#include <thrust/functional.h>
|
8
|
+
#include <thrust/inner_product.h>
|
9
|
+
#include <thrust/iterator/counting_iterator.h>
|
10
|
+
#include <thrust/iterator/transform_iterator.h>
|
11
|
+
#include <thrust/iterator/permutation_iterator.h>
|
12
|
+
#include <thrust/reduce.h>
|
13
|
+
#include <thrust/system/cuda/execution_policy.h>
|
14
|
+
#include <thrust/transform_reduce.h>
|
15
|
+
|
16
|
+
// this example illustrates how to make strided access to a range of values
|
17
|
+
// examples:
|
18
|
+
// strided_range([0, 1, 2, 3, 4, 5, 6], 1) -> [0, 1, 2, 3, 4, 5, 6]
|
19
|
+
// strided_range([0, 1, 2, 3, 4, 5, 6], 2) -> [0, 2, 4, 6]
|
20
|
+
// strided_range([0, 1, 2, 3, 4, 5, 6], 3) -> [0, 3, 6]
|
21
|
+
// ...
|
22
|
+
// ref. https://github.com/thrust/thrust/blob/master/examples/strided_range.cu (Apache License)
|
23
|
+
|
24
|
+
template <typename Iterator>
|
25
|
+
class cumo_thrust_strided_range
|
26
|
+
{
|
27
|
+
public:
|
28
|
+
|
29
|
+
typedef typename thrust::iterator_difference<Iterator>::type difference_type;
|
30
|
+
|
31
|
+
struct stride_functor : public thrust::unary_function<difference_type,difference_type>
|
32
|
+
{
|
33
|
+
difference_type stride;
|
34
|
+
|
35
|
+
stride_functor(difference_type stride)
|
36
|
+
: stride(stride) {}
|
37
|
+
|
38
|
+
__host__ __device__
|
39
|
+
difference_type operator()(const difference_type& i) const
|
40
|
+
{
|
41
|
+
return stride * i;
|
42
|
+
}
|
43
|
+
};
|
44
|
+
|
45
|
+
typedef typename thrust::counting_iterator<difference_type> CountingIterator;
|
46
|
+
typedef typename thrust::transform_iterator<stride_functor, CountingIterator> TransformIterator;
|
47
|
+
typedef typename thrust::permutation_iterator<Iterator,TransformIterator> PermutationIterator;
|
48
|
+
|
49
|
+
// type of the strided_range iterator
|
50
|
+
typedef PermutationIterator iterator;
|
51
|
+
|
52
|
+
// construct strided_range for the range [first,last)
|
53
|
+
cumo_thrust_strided_range(Iterator first, Iterator last, difference_type stride)
|
54
|
+
: first(first), last(last), stride(stride) {}
|
55
|
+
|
56
|
+
iterator begin(void) const
|
57
|
+
{
|
58
|
+
return PermutationIterator(first, TransformIterator(CountingIterator(0), stride_functor(stride)));
|
59
|
+
}
|
60
|
+
|
61
|
+
iterator end(void) const
|
62
|
+
{
|
63
|
+
return begin() + ((last - first) + (stride - 1)) / stride;
|
64
|
+
}
|
65
|
+
|
66
|
+
protected:
|
67
|
+
Iterator first;
|
68
|
+
Iterator last;
|
69
|
+
difference_type stride;
|
70
|
+
};
|
71
|
+
|
72
|
+
|
73
|
+
// compute minimum and maximum values in a single reduction
|
74
|
+
// ref. https://github.com/thrust/thrust/blob/master/examples/minmax.cu (Apache License)
|
75
|
+
|
76
|
+
// cumo_thrust_minmax_pair stores the minimum and maximum
|
77
|
+
// values that have been encountered so far
|
78
|
+
template <typename T>
|
79
|
+
struct cumo_thrust_minmax_pair
|
80
|
+
{
|
81
|
+
T min_val;
|
82
|
+
T max_val;
|
83
|
+
};
|
84
|
+
|
85
|
+
// cumo_thrust_minmax_unary_op is a functor that takes in a value x and
|
86
|
+
// returns a cumo_thrust_minmax_pair whose minimum and maximum values
|
87
|
+
// are initialized to x.
|
88
|
+
template <typename T>
|
89
|
+
struct cumo_thrust_minmax_unary_op : public thrust::unary_function< T, cumo_thrust_minmax_pair<T> >
|
90
|
+
{
|
91
|
+
__host__ __device__ cumo_thrust_minmax_pair<T> operator()(const T& x) const
|
92
|
+
{
|
93
|
+
cumo_thrust_minmax_pair<T> result;
|
94
|
+
result.min_val = x;
|
95
|
+
result.max_val = x;
|
96
|
+
return result;
|
97
|
+
}
|
98
|
+
};
|
99
|
+
|
100
|
+
// cumo_thrust_minmax_binary_op is a functor that accepts two cumo_thrust_minmax_pair
|
101
|
+
// structs and returns a new cumo_thrust_minmax_pair whose minimum and
|
102
|
+
// maximum values are the min() and max() respectively of
|
103
|
+
// the minimums and maximums of the input pairs
|
104
|
+
template <typename T>
|
105
|
+
struct cumo_thrust_minmax_binary_op : public thrust::binary_function< cumo_thrust_minmax_pair<T>, cumo_thrust_minmax_pair<T>, cumo_thrust_minmax_pair<T> >
|
106
|
+
{
|
107
|
+
__host__ __device__ cumo_thrust_minmax_pair<T> operator()(const cumo_thrust_minmax_pair<T>& x, const cumo_thrust_minmax_pair<T>& y) const
|
108
|
+
{
|
109
|
+
cumo_thrust_minmax_pair<T> result;
|
110
|
+
result.min_val = thrust::min(x.min_val, y.min_val);
|
111
|
+
result.max_val = thrust::max(x.max_val, y.max_val);
|
112
|
+
return result;
|
113
|
+
}
|
114
|
+
};
|
115
|
+
|
116
|
+
// ref. https://github.com/thrust/thrust/blob/master/examples/summary_statistics.cu
|
117
|
+
|
118
|
+
// structure used to accumulate the moments and other
|
119
|
+
// statistical properties encountered so far.
|
120
|
+
template <typename T>
|
121
|
+
struct cumo_thrust_variance_data
|
122
|
+
{
|
123
|
+
T n;
|
124
|
+
T mean;
|
125
|
+
T M2;
|
126
|
+
|
127
|
+
// initialize to the identity element
|
128
|
+
void initialize()
|
129
|
+
{
|
130
|
+
n = mean = M2 = 0;
|
131
|
+
}
|
132
|
+
|
133
|
+
__host__ __device__ T variance() { return M2 / (n - 1); }
|
134
|
+
__host__ __device__ T variance_n() { return M2 / n; }
|
135
|
+
};
|
136
|
+
|
137
|
+
// stats_unary_op is a functor that takes in a value x and
|
138
|
+
// returns a variace_data whose mean value is initialized to x.
|
139
|
+
template <typename T>
|
140
|
+
struct cumo_thrust_variance_unary_op
|
141
|
+
{
|
142
|
+
__host__ __device__
|
143
|
+
cumo_thrust_variance_data<T> operator()(const T& x) const
|
144
|
+
{
|
145
|
+
cumo_thrust_variance_data<T> result;
|
146
|
+
result.n = 1;
|
147
|
+
result.mean = x;
|
148
|
+
result.M2 = 0;
|
149
|
+
|
150
|
+
return result;
|
151
|
+
}
|
152
|
+
};
|
153
|
+
|
154
|
+
// cumo_thrust_variance_binary_op is a functor that accepts two cumo_thrust_variance_data
|
155
|
+
// structs and returns a new cumo_thrust_variance_data which are an
|
156
|
+
// approximation to the cumo_thrust_variance for
|
157
|
+
// all values that have been agregated so far
|
158
|
+
template <typename T>
|
159
|
+
struct cumo_thrust_variance_binary_op
|
160
|
+
: public thrust::binary_function<const cumo_thrust_variance_data<T>&,
|
161
|
+
const cumo_thrust_variance_data<T>&,
|
162
|
+
cumo_thrust_variance_data<T> >
|
163
|
+
{
|
164
|
+
__host__ __device__
|
165
|
+
cumo_thrust_variance_data<T> operator()(const cumo_thrust_variance_data<T>& x, const cumo_thrust_variance_data <T>& y) const
|
166
|
+
{
|
167
|
+
cumo_thrust_variance_data<T> result;
|
168
|
+
|
169
|
+
// precompute some common subexpressions
|
170
|
+
T n = x.n + y.n;
|
171
|
+
|
172
|
+
T delta = y.mean - x.mean;
|
173
|
+
T delta2 = delta * delta;
|
174
|
+
|
175
|
+
//Basic number of samples (n)
|
176
|
+
result.n = n;
|
177
|
+
|
178
|
+
result.mean = x.mean + delta * y.n / n;
|
179
|
+
|
180
|
+
result.M2 = x.M2 + y.M2;
|
181
|
+
result.M2 += delta2 * x.n * y.n / n;
|
182
|
+
|
183
|
+
return result;
|
184
|
+
}
|
185
|
+
};
|
186
|
+
|
187
|
+
#endif /* ifndef CUMO_CUDA_THRUST_H */
|