cumo 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
Files changed (266) hide show
  1. checksums.yaml +7 -0
  2. data/.gitignore +27 -0
  3. data/.travis.yml +5 -0
  4. data/3rd_party/mkmf-cu/.gitignore +36 -0
  5. data/3rd_party/mkmf-cu/Gemfile +3 -0
  6. data/3rd_party/mkmf-cu/LICENSE +21 -0
  7. data/3rd_party/mkmf-cu/README.md +36 -0
  8. data/3rd_party/mkmf-cu/Rakefile +11 -0
  9. data/3rd_party/mkmf-cu/bin/mkmf-cu-nvcc +4 -0
  10. data/3rd_party/mkmf-cu/lib/mkmf-cu.rb +32 -0
  11. data/3rd_party/mkmf-cu/lib/mkmf-cu/cli.rb +80 -0
  12. data/3rd_party/mkmf-cu/lib/mkmf-cu/nvcc.rb +157 -0
  13. data/3rd_party/mkmf-cu/mkmf-cu.gemspec +16 -0
  14. data/3rd_party/mkmf-cu/test/test_mkmf-cu.rb +67 -0
  15. data/CODE_OF_CONDUCT.md +46 -0
  16. data/Gemfile +8 -0
  17. data/LICENSE.txt +82 -0
  18. data/README.md +252 -0
  19. data/Rakefile +43 -0
  20. data/bench/broadcast_fp32.rb +138 -0
  21. data/bench/cumo_bench.rb +193 -0
  22. data/bench/numo_bench.rb +138 -0
  23. data/bench/reduction_fp32.rb +117 -0
  24. data/bin/console +14 -0
  25. data/bin/setup +8 -0
  26. data/cumo.gemspec +32 -0
  27. data/ext/cumo/cuda/cublas.c +278 -0
  28. data/ext/cumo/cuda/driver.c +421 -0
  29. data/ext/cumo/cuda/memory_pool.cpp +185 -0
  30. data/ext/cumo/cuda/memory_pool_impl.cpp +308 -0
  31. data/ext/cumo/cuda/memory_pool_impl.hpp +370 -0
  32. data/ext/cumo/cuda/memory_pool_impl_test.cpp +554 -0
  33. data/ext/cumo/cuda/nvrtc.c +207 -0
  34. data/ext/cumo/cuda/runtime.c +167 -0
  35. data/ext/cumo/cumo.c +148 -0
  36. data/ext/cumo/depend.erb +58 -0
  37. data/ext/cumo/extconf.rb +179 -0
  38. data/ext/cumo/include/cumo.h +25 -0
  39. data/ext/cumo/include/cumo/compat.h +23 -0
  40. data/ext/cumo/include/cumo/cuda/cublas.h +153 -0
  41. data/ext/cumo/include/cumo/cuda/cumo_thrust.hpp +187 -0
  42. data/ext/cumo/include/cumo/cuda/cumo_thrust_complex.hpp +79 -0
  43. data/ext/cumo/include/cumo/cuda/driver.h +22 -0
  44. data/ext/cumo/include/cumo/cuda/memory_pool.h +28 -0
  45. data/ext/cumo/include/cumo/cuda/nvrtc.h +22 -0
  46. data/ext/cumo/include/cumo/cuda/runtime.h +40 -0
  47. data/ext/cumo/include/cumo/indexer.h +238 -0
  48. data/ext/cumo/include/cumo/intern.h +142 -0
  49. data/ext/cumo/include/cumo/intern_fwd.h +38 -0
  50. data/ext/cumo/include/cumo/intern_kernel.h +6 -0
  51. data/ext/cumo/include/cumo/narray.h +429 -0
  52. data/ext/cumo/include/cumo/narray_kernel.h +149 -0
  53. data/ext/cumo/include/cumo/ndloop.h +95 -0
  54. data/ext/cumo/include/cumo/reduce_kernel.h +126 -0
  55. data/ext/cumo/include/cumo/template.h +158 -0
  56. data/ext/cumo/include/cumo/template_kernel.h +77 -0
  57. data/ext/cumo/include/cumo/types/bit.h +40 -0
  58. data/ext/cumo/include/cumo/types/bit_kernel.h +34 -0
  59. data/ext/cumo/include/cumo/types/complex.h +402 -0
  60. data/ext/cumo/include/cumo/types/complex_kernel.h +414 -0
  61. data/ext/cumo/include/cumo/types/complex_macro.h +382 -0
  62. data/ext/cumo/include/cumo/types/complex_macro_kernel.h +186 -0
  63. data/ext/cumo/include/cumo/types/dcomplex.h +46 -0
  64. data/ext/cumo/include/cumo/types/dcomplex_kernel.h +13 -0
  65. data/ext/cumo/include/cumo/types/dfloat.h +47 -0
  66. data/ext/cumo/include/cumo/types/dfloat_kernel.h +14 -0
  67. data/ext/cumo/include/cumo/types/float_def.h +34 -0
  68. data/ext/cumo/include/cumo/types/float_def_kernel.h +39 -0
  69. data/ext/cumo/include/cumo/types/float_macro.h +191 -0
  70. data/ext/cumo/include/cumo/types/float_macro_kernel.h +158 -0
  71. data/ext/cumo/include/cumo/types/int16.h +24 -0
  72. data/ext/cumo/include/cumo/types/int16_kernel.h +23 -0
  73. data/ext/cumo/include/cumo/types/int32.h +24 -0
  74. data/ext/cumo/include/cumo/types/int32_kernel.h +19 -0
  75. data/ext/cumo/include/cumo/types/int64.h +24 -0
  76. data/ext/cumo/include/cumo/types/int64_kernel.h +19 -0
  77. data/ext/cumo/include/cumo/types/int8.h +24 -0
  78. data/ext/cumo/include/cumo/types/int8_kernel.h +19 -0
  79. data/ext/cumo/include/cumo/types/int_macro.h +67 -0
  80. data/ext/cumo/include/cumo/types/int_macro_kernel.h +48 -0
  81. data/ext/cumo/include/cumo/types/real_accum.h +486 -0
  82. data/ext/cumo/include/cumo/types/real_accum_kernel.h +101 -0
  83. data/ext/cumo/include/cumo/types/robj_macro.h +80 -0
  84. data/ext/cumo/include/cumo/types/robj_macro_kernel.h +0 -0
  85. data/ext/cumo/include/cumo/types/robject.h +27 -0
  86. data/ext/cumo/include/cumo/types/robject_kernel.h +7 -0
  87. data/ext/cumo/include/cumo/types/scomplex.h +46 -0
  88. data/ext/cumo/include/cumo/types/scomplex_kernel.h +13 -0
  89. data/ext/cumo/include/cumo/types/sfloat.h +48 -0
  90. data/ext/cumo/include/cumo/types/sfloat_kernel.h +14 -0
  91. data/ext/cumo/include/cumo/types/uint16.h +25 -0
  92. data/ext/cumo/include/cumo/types/uint16_kernel.h +20 -0
  93. data/ext/cumo/include/cumo/types/uint32.h +25 -0
  94. data/ext/cumo/include/cumo/types/uint32_kernel.h +20 -0
  95. data/ext/cumo/include/cumo/types/uint64.h +25 -0
  96. data/ext/cumo/include/cumo/types/uint64_kernel.h +20 -0
  97. data/ext/cumo/include/cumo/types/uint8.h +25 -0
  98. data/ext/cumo/include/cumo/types/uint8_kernel.h +20 -0
  99. data/ext/cumo/include/cumo/types/uint_macro.h +58 -0
  100. data/ext/cumo/include/cumo/types/uint_macro_kernel.h +38 -0
  101. data/ext/cumo/include/cumo/types/xint_macro.h +169 -0
  102. data/ext/cumo/include/cumo/types/xint_macro_kernel.h +88 -0
  103. data/ext/cumo/narray/SFMT-params.h +97 -0
  104. data/ext/cumo/narray/SFMT-params19937.h +46 -0
  105. data/ext/cumo/narray/SFMT.c +620 -0
  106. data/ext/cumo/narray/SFMT.h +167 -0
  107. data/ext/cumo/narray/array.c +638 -0
  108. data/ext/cumo/narray/data.c +961 -0
  109. data/ext/cumo/narray/gen/cogen.rb +56 -0
  110. data/ext/cumo/narray/gen/cogen_kernel.rb +58 -0
  111. data/ext/cumo/narray/gen/def/bit.rb +37 -0
  112. data/ext/cumo/narray/gen/def/dcomplex.rb +39 -0
  113. data/ext/cumo/narray/gen/def/dfloat.rb +37 -0
  114. data/ext/cumo/narray/gen/def/int16.rb +36 -0
  115. data/ext/cumo/narray/gen/def/int32.rb +36 -0
  116. data/ext/cumo/narray/gen/def/int64.rb +36 -0
  117. data/ext/cumo/narray/gen/def/int8.rb +36 -0
  118. data/ext/cumo/narray/gen/def/robject.rb +37 -0
  119. data/ext/cumo/narray/gen/def/scomplex.rb +39 -0
  120. data/ext/cumo/narray/gen/def/sfloat.rb +37 -0
  121. data/ext/cumo/narray/gen/def/uint16.rb +36 -0
  122. data/ext/cumo/narray/gen/def/uint32.rb +36 -0
  123. data/ext/cumo/narray/gen/def/uint64.rb +36 -0
  124. data/ext/cumo/narray/gen/def/uint8.rb +36 -0
  125. data/ext/cumo/narray/gen/erbpp2.rb +346 -0
  126. data/ext/cumo/narray/gen/narray_def.rb +268 -0
  127. data/ext/cumo/narray/gen/spec.rb +425 -0
  128. data/ext/cumo/narray/gen/tmpl/accum.c +86 -0
  129. data/ext/cumo/narray/gen/tmpl/accum_binary.c +121 -0
  130. data/ext/cumo/narray/gen/tmpl/accum_binary_kernel.cu +61 -0
  131. data/ext/cumo/narray/gen/tmpl/accum_index.c +119 -0
  132. data/ext/cumo/narray/gen/tmpl/accum_index_kernel.cu +66 -0
  133. data/ext/cumo/narray/gen/tmpl/accum_kernel.cu +12 -0
  134. data/ext/cumo/narray/gen/tmpl/alloc_func.c +107 -0
  135. data/ext/cumo/narray/gen/tmpl/allocate.c +37 -0
  136. data/ext/cumo/narray/gen/tmpl/aref.c +66 -0
  137. data/ext/cumo/narray/gen/tmpl/aref_cpu.c +50 -0
  138. data/ext/cumo/narray/gen/tmpl/aset.c +56 -0
  139. data/ext/cumo/narray/gen/tmpl/binary.c +162 -0
  140. data/ext/cumo/narray/gen/tmpl/binary2.c +70 -0
  141. data/ext/cumo/narray/gen/tmpl/binary2_kernel.cu +15 -0
  142. data/ext/cumo/narray/gen/tmpl/binary_kernel.cu +31 -0
  143. data/ext/cumo/narray/gen/tmpl/binary_s.c +45 -0
  144. data/ext/cumo/narray/gen/tmpl/binary_s_kernel.cu +15 -0
  145. data/ext/cumo/narray/gen/tmpl/bincount.c +181 -0
  146. data/ext/cumo/narray/gen/tmpl/cast.c +44 -0
  147. data/ext/cumo/narray/gen/tmpl/cast_array.c +13 -0
  148. data/ext/cumo/narray/gen/tmpl/class.c +9 -0
  149. data/ext/cumo/narray/gen/tmpl/class_kernel.cu +6 -0
  150. data/ext/cumo/narray/gen/tmpl/clip.c +121 -0
  151. data/ext/cumo/narray/gen/tmpl/coerce_cast.c +10 -0
  152. data/ext/cumo/narray/gen/tmpl/complex_accum_kernel.cu +129 -0
  153. data/ext/cumo/narray/gen/tmpl/cond_binary.c +68 -0
  154. data/ext/cumo/narray/gen/tmpl/cond_binary_kernel.cu +18 -0
  155. data/ext/cumo/narray/gen/tmpl/cond_unary.c +46 -0
  156. data/ext/cumo/narray/gen/tmpl/cum.c +50 -0
  157. data/ext/cumo/narray/gen/tmpl/each.c +47 -0
  158. data/ext/cumo/narray/gen/tmpl/each_with_index.c +70 -0
  159. data/ext/cumo/narray/gen/tmpl/ewcomp.c +79 -0
  160. data/ext/cumo/narray/gen/tmpl/ewcomp_kernel.cu +19 -0
  161. data/ext/cumo/narray/gen/tmpl/extract.c +22 -0
  162. data/ext/cumo/narray/gen/tmpl/extract_cpu.c +26 -0
  163. data/ext/cumo/narray/gen/tmpl/extract_data.c +53 -0
  164. data/ext/cumo/narray/gen/tmpl/eye.c +105 -0
  165. data/ext/cumo/narray/gen/tmpl/eye_kernel.cu +19 -0
  166. data/ext/cumo/narray/gen/tmpl/fill.c +52 -0
  167. data/ext/cumo/narray/gen/tmpl/fill_kernel.cu +29 -0
  168. data/ext/cumo/narray/gen/tmpl/float_accum_kernel.cu +106 -0
  169. data/ext/cumo/narray/gen/tmpl/format.c +62 -0
  170. data/ext/cumo/narray/gen/tmpl/format_to_a.c +49 -0
  171. data/ext/cumo/narray/gen/tmpl/frexp.c +38 -0
  172. data/ext/cumo/narray/gen/tmpl/gemm.c +203 -0
  173. data/ext/cumo/narray/gen/tmpl/init_class.c +20 -0
  174. data/ext/cumo/narray/gen/tmpl/init_module.c +12 -0
  175. data/ext/cumo/narray/gen/tmpl/inspect.c +21 -0
  176. data/ext/cumo/narray/gen/tmpl/lib.c +50 -0
  177. data/ext/cumo/narray/gen/tmpl/lib_kernel.cu +24 -0
  178. data/ext/cumo/narray/gen/tmpl/logseq.c +102 -0
  179. data/ext/cumo/narray/gen/tmpl/logseq_kernel.cu +31 -0
  180. data/ext/cumo/narray/gen/tmpl/map_with_index.c +98 -0
  181. data/ext/cumo/narray/gen/tmpl/median.c +66 -0
  182. data/ext/cumo/narray/gen/tmpl/minmax.c +47 -0
  183. data/ext/cumo/narray/gen/tmpl/module.c +9 -0
  184. data/ext/cumo/narray/gen/tmpl/module_kernel.cu +1 -0
  185. data/ext/cumo/narray/gen/tmpl/new_dim0.c +15 -0
  186. data/ext/cumo/narray/gen/tmpl/new_dim0_kernel.cu +8 -0
  187. data/ext/cumo/narray/gen/tmpl/poly.c +50 -0
  188. data/ext/cumo/narray/gen/tmpl/pow.c +97 -0
  189. data/ext/cumo/narray/gen/tmpl/pow_kernel.cu +29 -0
  190. data/ext/cumo/narray/gen/tmpl/powint.c +17 -0
  191. data/ext/cumo/narray/gen/tmpl/qsort.c +212 -0
  192. data/ext/cumo/narray/gen/tmpl/rand.c +168 -0
  193. data/ext/cumo/narray/gen/tmpl/rand_norm.c +121 -0
  194. data/ext/cumo/narray/gen/tmpl/real_accum_kernel.cu +75 -0
  195. data/ext/cumo/narray/gen/tmpl/seq.c +112 -0
  196. data/ext/cumo/narray/gen/tmpl/seq_kernel.cu +43 -0
  197. data/ext/cumo/narray/gen/tmpl/set2.c +57 -0
  198. data/ext/cumo/narray/gen/tmpl/sort.c +48 -0
  199. data/ext/cumo/narray/gen/tmpl/sort_index.c +111 -0
  200. data/ext/cumo/narray/gen/tmpl/store.c +41 -0
  201. data/ext/cumo/narray/gen/tmpl/store_array.c +187 -0
  202. data/ext/cumo/narray/gen/tmpl/store_array_kernel.cu +58 -0
  203. data/ext/cumo/narray/gen/tmpl/store_bit.c +86 -0
  204. data/ext/cumo/narray/gen/tmpl/store_bit_kernel.cu +66 -0
  205. data/ext/cumo/narray/gen/tmpl/store_from.c +81 -0
  206. data/ext/cumo/narray/gen/tmpl/store_from_kernel.cu +58 -0
  207. data/ext/cumo/narray/gen/tmpl/store_kernel.cu +3 -0
  208. data/ext/cumo/narray/gen/tmpl/store_numeric.c +9 -0
  209. data/ext/cumo/narray/gen/tmpl/to_a.c +43 -0
  210. data/ext/cumo/narray/gen/tmpl/unary.c +132 -0
  211. data/ext/cumo/narray/gen/tmpl/unary2.c +60 -0
  212. data/ext/cumo/narray/gen/tmpl/unary_kernel.cu +72 -0
  213. data/ext/cumo/narray/gen/tmpl/unary_ret2.c +34 -0
  214. data/ext/cumo/narray/gen/tmpl/unary_s.c +86 -0
  215. data/ext/cumo/narray/gen/tmpl/unary_s_kernel.cu +58 -0
  216. data/ext/cumo/narray/gen/tmpl_bit/allocate.c +24 -0
  217. data/ext/cumo/narray/gen/tmpl_bit/aref.c +54 -0
  218. data/ext/cumo/narray/gen/tmpl_bit/aref_cpu.c +57 -0
  219. data/ext/cumo/narray/gen/tmpl_bit/aset.c +56 -0
  220. data/ext/cumo/narray/gen/tmpl_bit/binary.c +98 -0
  221. data/ext/cumo/narray/gen/tmpl_bit/bit_count.c +64 -0
  222. data/ext/cumo/narray/gen/tmpl_bit/bit_count_cpu.c +88 -0
  223. data/ext/cumo/narray/gen/tmpl_bit/bit_count_kernel.cu +76 -0
  224. data/ext/cumo/narray/gen/tmpl_bit/bit_reduce.c +133 -0
  225. data/ext/cumo/narray/gen/tmpl_bit/each.c +48 -0
  226. data/ext/cumo/narray/gen/tmpl_bit/each_with_index.c +70 -0
  227. data/ext/cumo/narray/gen/tmpl_bit/extract.c +30 -0
  228. data/ext/cumo/narray/gen/tmpl_bit/extract_cpu.c +29 -0
  229. data/ext/cumo/narray/gen/tmpl_bit/fill.c +69 -0
  230. data/ext/cumo/narray/gen/tmpl_bit/format.c +64 -0
  231. data/ext/cumo/narray/gen/tmpl_bit/format_to_a.c +51 -0
  232. data/ext/cumo/narray/gen/tmpl_bit/inspect.c +21 -0
  233. data/ext/cumo/narray/gen/tmpl_bit/mask.c +136 -0
  234. data/ext/cumo/narray/gen/tmpl_bit/none_p.c +14 -0
  235. data/ext/cumo/narray/gen/tmpl_bit/store_array.c +108 -0
  236. data/ext/cumo/narray/gen/tmpl_bit/store_bit.c +70 -0
  237. data/ext/cumo/narray/gen/tmpl_bit/store_from.c +60 -0
  238. data/ext/cumo/narray/gen/tmpl_bit/to_a.c +47 -0
  239. data/ext/cumo/narray/gen/tmpl_bit/unary.c +81 -0
  240. data/ext/cumo/narray/gen/tmpl_bit/where.c +90 -0
  241. data/ext/cumo/narray/gen/tmpl_bit/where2.c +95 -0
  242. data/ext/cumo/narray/index.c +880 -0
  243. data/ext/cumo/narray/kwargs.c +153 -0
  244. data/ext/cumo/narray/math.c +142 -0
  245. data/ext/cumo/narray/narray.c +1948 -0
  246. data/ext/cumo/narray/ndloop.c +2105 -0
  247. data/ext/cumo/narray/rand.c +45 -0
  248. data/ext/cumo/narray/step.c +474 -0
  249. data/ext/cumo/narray/struct.c +886 -0
  250. data/lib/cumo.rb +3 -0
  251. data/lib/cumo/cuda.rb +11 -0
  252. data/lib/cumo/cuda/compile_error.rb +36 -0
  253. data/lib/cumo/cuda/compiler.rb +161 -0
  254. data/lib/cumo/cuda/device.rb +47 -0
  255. data/lib/cumo/cuda/link_state.rb +31 -0
  256. data/lib/cumo/cuda/module.rb +40 -0
  257. data/lib/cumo/cuda/nvrtc_program.rb +27 -0
  258. data/lib/cumo/linalg.rb +12 -0
  259. data/lib/cumo/narray.rb +2 -0
  260. data/lib/cumo/narray/extra.rb +1278 -0
  261. data/lib/erbpp.rb +294 -0
  262. data/lib/erbpp/line_number.rb +137 -0
  263. data/lib/erbpp/narray_def.rb +381 -0
  264. data/numo-narray-version +1 -0
  265. data/run.gdb +7 -0
  266. metadata +353 -0
@@ -0,0 +1,554 @@
1
+ #include "memory_pool_impl.hpp"
2
+
3
+ #include <cassert>
4
+ #include <memory>
5
+ #include <iostream>
6
+
7
+ // TODO(sonots): Use googletest?
8
+ // TODO(sonots): Provide clean way to build this test outside extconf.rb
9
+
10
+ namespace cumo {
11
+ namespace internal {
12
+
13
+ class TestChunk {
14
+ private:
15
+ cudaStream_t stream_ptr_ = 0;
16
+
17
+ public:
18
+ TestChunk() {}
19
+
20
+ void Run() {
21
+ TestSplit();
22
+ TestMerge();
23
+ }
24
+
25
+ void TestSplit() {
26
+ auto mem = std::make_shared<Memory>(kRoundSize * 4);
27
+ auto chunk = std::make_shared<Chunk>(mem, 0, mem->size(), stream_ptr_);
28
+
29
+ auto tail = Split(chunk, kRoundSize * 2);
30
+ assert(chunk->ptr() == mem->ptr());
31
+ assert(chunk->offset() == 0);
32
+ assert(chunk->size() == kRoundSize * 2);
33
+ assert(chunk->prev() == nullptr);
34
+ assert(chunk->next()->ptr() == tail->ptr());
35
+ assert(chunk->stream_ptr() == stream_ptr_);
36
+ assert(tail->ptr() == mem->ptr() + kRoundSize * 2);
37
+ assert(tail->offset() == kRoundSize * 2);
38
+ assert(tail->size() == kRoundSize * 2);
39
+ assert(tail->prev()->ptr() == chunk->ptr());
40
+ assert(tail->next() == nullptr);
41
+ assert(tail->stream_ptr() == stream_ptr_);
42
+
43
+ auto tail_of_head = Split(chunk, kRoundSize);
44
+ assert(chunk->ptr() == mem->ptr());
45
+ assert(chunk->offset() == 0);
46
+ assert(chunk->size() == kRoundSize);
47
+ assert(chunk->prev() == nullptr);
48
+ assert(chunk->next()->ptr() == tail_of_head->ptr());
49
+ assert(chunk->stream_ptr() == stream_ptr_);
50
+ assert(tail_of_head->ptr() == mem->ptr() + kRoundSize);
51
+ assert(tail_of_head->offset() == kRoundSize);
52
+ assert(tail_of_head->size() == kRoundSize);
53
+ assert(tail_of_head->prev()->ptr() == chunk->ptr());
54
+ assert(tail_of_head->next()->ptr() == tail->ptr());
55
+ assert(tail_of_head->stream_ptr() == stream_ptr_);
56
+
57
+ auto tail_of_tail = Split(tail, kRoundSize);
58
+ assert(tail->ptr() == chunk->ptr() + kRoundSize * 2);
59
+ assert(tail->offset() == kRoundSize * 2);
60
+ assert(tail->size() == kRoundSize);
61
+ assert(tail->prev()->ptr() == tail_of_head->ptr());
62
+ assert(tail->next()->ptr() == tail_of_tail->ptr());
63
+ assert(tail->stream_ptr() == stream_ptr_);
64
+ assert(tail_of_tail->ptr() == mem->ptr() + kRoundSize * 3);
65
+ assert(tail_of_tail->offset() == kRoundSize * 3);
66
+ assert(tail_of_tail->size() == kRoundSize);
67
+ assert(tail_of_tail->prev()->ptr() == tail->ptr());
68
+ assert(tail_of_tail->next() == nullptr);
69
+ assert(tail_of_tail->stream_ptr() == stream_ptr_);
70
+ }
71
+
72
+ void TestMerge() {
73
+ auto mem = std::make_shared<Memory>(kRoundSize * 4);
74
+ auto chunk = std::make_shared<Chunk>(mem, 0, mem->size(), stream_ptr_);
75
+
76
+ auto chunk_ptr = chunk->ptr();
77
+ auto chunk_offset = chunk->offset();
78
+ auto chunk_size = chunk->size();
79
+
80
+ auto tail = Split(chunk, kRoundSize * 2);
81
+ auto head = chunk;
82
+ auto head_ptr = head->ptr();
83
+ auto head_offset = head->offset();
84
+ auto head_size = head->size();
85
+ auto tail_ptr = tail->ptr();
86
+ auto tail_offset = tail->offset();
87
+ auto tail_size = tail->size();
88
+
89
+ auto tail_of_head = Split(head, kRoundSize);
90
+ auto tail_of_tail = Split(tail, kRoundSize);
91
+
92
+ Merge(head, tail_of_head);
93
+ assert(head->ptr() == head_ptr);
94
+ assert(head->offset() == head_offset);
95
+ assert(head->size() == head_size);
96
+ assert(head->prev() == nullptr);
97
+ assert(head->next()->ptr() == tail_ptr);
98
+ assert(head->stream_ptr() == stream_ptr_);
99
+
100
+ Merge(tail, tail_of_tail);
101
+ assert(tail->ptr() == tail_ptr);
102
+ assert(tail->offset() == tail_offset);
103
+ assert(tail->size() == tail_size);
104
+ assert(tail->prev()->ptr() == head_ptr);
105
+ assert(tail->next() == nullptr);
106
+ assert(tail->stream_ptr() == stream_ptr_);
107
+
108
+ Merge(head, tail);
109
+ assert(head->ptr() == chunk_ptr);
110
+ assert(head->offset() == chunk_offset);
111
+ assert(head->size() == chunk_size);
112
+ assert(head->prev() == nullptr);
113
+ assert(head->next() == nullptr);
114
+ assert(head->stream_ptr() == stream_ptr_);
115
+ }
116
+ };
117
+
118
+ class TestSingleDeviceMemoryPool {
119
+ private:
120
+ std::shared_ptr<SingleDeviceMemoryPool> pool_;
121
+ cudaStream_t stream_ptr_ = 0;
122
+
123
+ public:
124
+ TestSingleDeviceMemoryPool() {}
125
+
126
+ void SetUp() {
127
+ pool_ = std::make_shared<SingleDeviceMemoryPool>();
128
+ }
129
+
130
+ void TearDown() {
131
+ pool_.reset();
132
+ }
133
+
134
+ void Run() {
135
+ TearDown(); SetUp(); TestGetRoundedSize();
136
+ TearDown(); SetUp(); TestGetBinIndex();
137
+ TearDown(); SetUp(); TestAppendToFreeList();
138
+ TearDown(); SetUp(); TestRemoveFromFreeList();
139
+ TearDown(); SetUp(); TestMalloc();
140
+ TearDown(); SetUp(); TestMallocWithZero();
141
+ TearDown(); SetUp(); TestFree();
142
+ TearDown(); SetUp(); TestFreeDoubly();
143
+ TearDown(); SetUp(); TestMallocSplit();
144
+ TearDown(); SetUp(); TestFreeMerge();
145
+ TearDown(); SetUp(); TestFreeDifferentSize();
146
+ TearDown(); SetUp(); TestFreeAllBlocks();
147
+ TearDown(); SetUp(); TestFreeAllBlocksWithoutMalloc();
148
+ TearDown(); SetUp(); TestFreeAllBlocksSplit();
149
+ TearDown(); SetUp(); TestGetUsedBytes();
150
+ TearDown(); SetUp(); TestGetFreeBytes();
151
+ TearDown(); SetUp(); TestGetTotalBytes();
152
+ TearDown();
153
+ }
154
+
155
+ void TestGetRoundedSize() {
156
+ assert(pool_->GetRoundedSize(kRoundSize - 1) == kRoundSize);
157
+ assert(pool_->GetRoundedSize(kRoundSize) == kRoundSize);
158
+ assert(pool_->GetRoundedSize(kRoundSize + 1) == kRoundSize * 2);
159
+ }
160
+
161
+ void TestGetBinIndex() {
162
+ assert(pool_->GetBinIndex(kRoundSize - 1) == 0);
163
+ assert(pool_->GetBinIndex(kRoundSize) == 0);
164
+ assert(pool_->GetBinIndex(kRoundSize + 1) == 1);
165
+ }
166
+
167
+ void TestAppendToFreeList() {
168
+ Arena& arena = pool_->GetArena(stream_ptr_);
169
+ ArenaIndexMap& arena_index_map = pool_->GetArenaIndexMap(stream_ptr_);
170
+
171
+ {
172
+ auto mem = std::make_shared<Memory>(kRoundSize * 4);
173
+ auto chunk = std::make_shared<Chunk>(mem, 0, mem->size(), stream_ptr_);
174
+ pool_->AppendToFreeList(chunk->size(), chunk, stream_ptr_);
175
+ }
176
+ assert(arena.size() == 1);
177
+ assert(arena[0].size() == 1);
178
+ assert(arena_index_map.size() == 1);
179
+ assert(arena_index_map[0] == 3);
180
+
181
+ // insert to same arena index
182
+ {
183
+ auto mem = std::make_shared<Memory>(kRoundSize * 4);
184
+ auto chunk = std::make_shared<Chunk>(mem, 0, mem->size(), stream_ptr_);
185
+ pool_->AppendToFreeList(chunk->size(), chunk, stream_ptr_);
186
+ }
187
+ assert(arena.size() == 1);
188
+ assert(arena[0].size() == 2);
189
+ assert(arena_index_map.size() == 1);
190
+ assert(arena_index_map[0] == 3);
191
+
192
+ // insert to larger arena index
193
+ {
194
+ auto mem = std::make_shared<Memory>(kRoundSize * 5);
195
+ auto chunk = std::make_shared<Chunk>(mem, 0, mem->size(), stream_ptr_);
196
+ pool_->AppendToFreeList(chunk->size(), chunk, stream_ptr_);
197
+ }
198
+ assert(arena.size() == 2);
199
+ assert(arena[0].size() == 2);
200
+ assert(arena[1].size() == 1);
201
+ assert(arena_index_map.size() == 2);
202
+ assert(arena_index_map[0] == 3);
203
+ assert(arena_index_map[1] == 4);
204
+
205
+ // insert to smaller arena index
206
+ {
207
+ auto mem = std::make_shared<Memory>(kRoundSize * 3);
208
+ auto chunk = std::make_shared<Chunk>(mem, 0, mem->size(), stream_ptr_);
209
+ pool_->AppendToFreeList(chunk->size(), chunk, stream_ptr_);
210
+ }
211
+ assert(arena.size() == 3);
212
+ assert(arena[0].size() == 1);
213
+ assert(arena[1].size() == 2);
214
+ assert(arena[2].size() == 1);
215
+ assert(arena_index_map.size() == 3);
216
+ assert(arena_index_map[0] == 2);
217
+ assert(arena_index_map[1] == 3);
218
+ assert(arena_index_map[2] == 4);
219
+ }
220
+
221
+ // TODO(sonots): Fix after implementing compaction
222
+ void TestRemoveFromFreeList() {
223
+ Arena& arena = pool_->GetArena(stream_ptr_);
224
+ ArenaIndexMap& arena_index_map = pool_->GetArenaIndexMap(stream_ptr_);
225
+
226
+ auto mem1 = std::make_shared<Memory>(kRoundSize * 4);
227
+ auto chunk1 = std::make_shared<Chunk>(mem1, 0, mem1->size(), stream_ptr_);
228
+ pool_->AppendToFreeList(chunk1->size(), chunk1, stream_ptr_);
229
+
230
+ auto mem2 = std::make_shared<Memory>(kRoundSize * 4);
231
+ auto chunk2 = std::make_shared<Chunk>(mem2, 0, mem2->size(), stream_ptr_);
232
+ pool_->AppendToFreeList(chunk2->size(), chunk2, stream_ptr_);
233
+
234
+ auto mem3 = std::make_shared<Memory>(kRoundSize * 5);
235
+ auto chunk3 = std::make_shared<Chunk>(mem3, 0, mem3->size(), stream_ptr_);
236
+ pool_->AppendToFreeList(chunk3->size(), chunk3, stream_ptr_);
237
+
238
+ auto mem4 = std::make_shared<Memory>(kRoundSize * 3);
239
+ auto chunk4 = std::make_shared<Chunk>(mem4, 0, mem4->size(), stream_ptr_);
240
+ pool_->AppendToFreeList(chunk4->size(), chunk4, stream_ptr_);
241
+
242
+ // remove one from two
243
+ pool_->RemoveFromFreeList(chunk1->size(), chunk1, stream_ptr_);
244
+ assert(arena.size() == 3);
245
+ assert(arena[0].size() == 1);
246
+ assert(arena[1].size() == 1);
247
+ assert(arena[2].size() == 1);
248
+ assert(arena_index_map.size() == 3);
249
+ assert(arena_index_map[0] == 2);
250
+ assert(arena_index_map[1] == 3);
251
+ assert(arena_index_map[2] == 4);
252
+
253
+ // remove two from two
254
+ pool_->RemoveFromFreeList(chunk2->size(), chunk2, stream_ptr_);
255
+ assert(arena.size() == 3);
256
+ assert(arena[0].size() == 1);
257
+ assert(arena[1].size() == 0);
258
+ assert(arena[2].size() == 1);
259
+ assert(arena_index_map.size() == 3);
260
+ assert(arena_index_map[0] == 2);
261
+ assert(arena_index_map[1] == 3);
262
+ assert(arena_index_map[2] == 4);
263
+
264
+ pool_->RemoveFromFreeList(chunk3->size(), chunk3, stream_ptr_);
265
+ assert(arena.size() == 3);
266
+ assert(arena[0].size() == 1);
267
+ assert(arena[1].size() == 0);
268
+ assert(arena[2].size() == 0);
269
+ assert(arena_index_map.size() == 3);
270
+ assert(arena_index_map[0] == 2);
271
+ assert(arena_index_map[1] == 3);
272
+ assert(arena_index_map[2] == 4);
273
+
274
+ pool_->RemoveFromFreeList(chunk4->size(), chunk4, stream_ptr_);
275
+ assert(arena.size() == 3);
276
+ assert(arena[0].size() == 0);
277
+ assert(arena[1].size() == 0);
278
+ assert(arena[2].size() == 0);
279
+ assert(arena_index_map.size() == 3);
280
+ assert(arena_index_map[0] == 2);
281
+ assert(arena_index_map[1] == 3);
282
+ assert(arena_index_map[2] == 4);
283
+ }
284
+
285
+ void TestMalloc() {
286
+ intptr_t p1 = pool_->Malloc(kRoundSize * 4);
287
+ intptr_t p2 = pool_->Malloc(kRoundSize * 4);
288
+ intptr_t p3 = pool_->Malloc(kRoundSize * 8);
289
+ assert(p1 != p2);
290
+ assert(p1 != p3);
291
+ assert(p2 != p3);
292
+ }
293
+
294
+ void TestMallocWithZero() {
295
+ pool_->Malloc(0); // actually, cuda returns 0
296
+ }
297
+
298
+ void TestFree() {
299
+ intptr_t p1 = pool_->Malloc(kRoundSize * 4);
300
+ pool_->Free(p1);
301
+ intptr_t p2 = pool_->Malloc(kRoundSize * 4);
302
+ assert(p1 == p2);
303
+ }
304
+
305
+ void TestFreeDoubly() {
306
+ intptr_t p1 = pool_->Malloc(kRoundSize * 4);
307
+ pool_->Free(p1);
308
+ // pool_->Free(p1); // will abort
309
+ }
310
+
311
+ void TestMallocSplit() {
312
+ intptr_t p = pool_->Malloc(kRoundSize * 4);
313
+ pool_->Free(p);
314
+ intptr_t head = pool_->Malloc(kRoundSize * 2);
315
+ intptr_t tail = pool_->Malloc(kRoundSize * 2);
316
+ assert(p == head);
317
+ assert(p + kRoundSize * 2 == tail);
318
+ }
319
+
320
+ void TestFreeMerge() {
321
+ intptr_t p1 = pool_->Malloc(kRoundSize * 4);
322
+ pool_->Free(p1);
323
+
324
+ // merge head into tail
325
+ {
326
+ intptr_t head = pool_->Malloc(kRoundSize * 2);
327
+ intptr_t tail = pool_->Malloc(kRoundSize * 2);
328
+ assert(p1 == head);
329
+ pool_->Free(tail);
330
+ pool_->Free(head);
331
+ intptr_t p2 = pool_->Malloc(kRoundSize * 4);
332
+ assert(p1 == p2);
333
+ pool_->Free(p2);
334
+ }
335
+
336
+ // merge tail into head
337
+ {
338
+ intptr_t head = pool_->Malloc(kRoundSize * 2);
339
+ intptr_t tail = pool_->Malloc(kRoundSize * 2);
340
+ assert(p1 == head);
341
+ pool_->Free(head);
342
+ pool_->Free(tail);
343
+ intptr_t p2 = pool_->Malloc(kRoundSize * 4);
344
+ assert(p1 == p2);
345
+ pool_->Free(p2);
346
+ }
347
+ }
348
+
349
+ void TestFreeDifferentSize() {
350
+ intptr_t p1 = pool_->Malloc(kRoundSize * 4);
351
+ pool_->Free(p1);
352
+ intptr_t p2 = pool_->Malloc(kRoundSize * 8);
353
+ assert(p1 != p2);
354
+ }
355
+
356
+ void TestFreeAllBlocks() {
357
+ intptr_t p1 = pool_->Malloc(kRoundSize * 4);
358
+ pool_->Free(p1);
359
+ pool_->FreeAllBlocks();
360
+ intptr_t p2 = pool_->Malloc(kRoundSize * 4);
361
+ // assert(p1 != p2); // cudaMalloc gets same address ...
362
+ pool_->Free(p2);
363
+ }
364
+
365
+ void TestFreeAllBlocksWithoutMalloc() {
366
+ pool_->FreeAllBlocks();
367
+ }
368
+
369
+ void TestFreeAllBlocksSplit() {
370
+ // do not free splitted blocks
371
+ intptr_t p = pool_->Malloc(kRoundSize * 4);
372
+ pool_->Free(p);
373
+ intptr_t head = pool_->Malloc(kRoundSize * 2);
374
+ intptr_t tail = pool_->Malloc(kRoundSize * 2);
375
+ pool_->Free(tail);
376
+ pool_->FreeAllBlocks();
377
+ intptr_t p2 = pool_->Malloc(kRoundSize * 2);
378
+ assert(tail == p2);
379
+ pool_->Free(head);
380
+ }
381
+
382
+ // void TestFreeAllBlocksStream() {
383
+ // intptr_t p1 = pool_->Malloc(kRoundSize * 4);
384
+ // pool_->Free(p1);
385
+ // with self.stream:
386
+ // p2 = pool_->Malloc(kRoundSize * 4)
387
+ // ptr2 = p2.ptr
388
+ // del p2
389
+ // pool_->free_all_blocks(stream=stream_module.Stream.null)
390
+ // p3 = pool_->Malloc(kRoundSize * 4)
391
+ // self.assertNotEqual(ptr1, p3.ptr)
392
+ // self.assertNotEqual(ptr2, p3.ptr)
393
+ // with self.stream:
394
+ // p4 = pool_->Malloc(kRoundSize * 4)
395
+ // self.assertNotEqual(ptr1, p4.ptr)
396
+ // assert(ptr2, p4.ptr)
397
+
398
+ // def test_free_all_blocks_all_streams(self):
399
+ // p1 = pool_.Malloc(kRoundSize * 4)
400
+ // ptr1 = p1.ptr
401
+ // del p1
402
+ // with self.stream:
403
+ // p2 = pool_.Malloc(kRoundSize * 4)
404
+ // ptr2 = p2.ptr
405
+ // del p2
406
+ // pool_.free_all_blocks()
407
+ // p3 = pool_.Malloc(kRoundSize * 4)
408
+ // self.assertNotEqual(ptr1, p3.ptr)
409
+ // self.assertNotEqual(ptr2, p3.ptr)
410
+ // with self.stream:
411
+ // p4 = pool_.Malloc(kRoundSize * 4)
412
+ // self.assertNotEqual(ptr1, p4.ptr)
413
+ // self.assertNotEqual(ptr2, p4.ptr)
414
+
415
+ void TestGetUsedBytes() {
416
+ intptr_t p1 = pool_->Malloc(kRoundSize * 2);
417
+ assert(kRoundSize * 2 == pool_->GetUsedBytes());
418
+ intptr_t p2 = pool_->Malloc(kRoundSize * 4);
419
+ assert(kRoundSize * 6 == pool_->GetUsedBytes());
420
+ pool_->Free(p2);
421
+ assert(kRoundSize * 2 == pool_->GetUsedBytes());
422
+ pool_->Free(p1);
423
+ assert(kRoundSize * 0 == pool_->GetUsedBytes());
424
+ intptr_t p3 = pool_->Malloc(kRoundSize * 1);
425
+ assert(kRoundSize * 1 == pool_->GetUsedBytes());
426
+ pool_->Free(p3);
427
+ }
428
+
429
+ // def test_used_bytes_stream(self):
430
+ // p1 = pool_.Malloc(kRoundSize * 4)
431
+ // del p1
432
+ // with self.stream:
433
+ // p2 = pool_.Malloc(kRoundSize * 2)
434
+ // assert(kRoundSize * 2, pool_.used_bytes())
435
+ // del p2
436
+
437
+ void TestGetFreeBytes() {
438
+ intptr_t p1 = pool_->Malloc(kRoundSize * 2);
439
+ assert(kRoundSize * 0 == pool_->GetFreeBytes());
440
+ intptr_t p2 = pool_->Malloc(kRoundSize * 4);
441
+ assert(kRoundSize * 0 == pool_->GetFreeBytes());
442
+ pool_->Free(p2);
443
+ assert(kRoundSize * 4 == pool_->GetFreeBytes());
444
+ pool_->Free(p1);
445
+ assert(kRoundSize * 6 == pool_->GetFreeBytes());
446
+ intptr_t p3 = pool_->Malloc(kRoundSize * 1);
447
+ assert(kRoundSize * 5 == pool_->GetFreeBytes());
448
+ pool_->Free(p3);
449
+ }
450
+
451
+ // def test_free_bytes_stream(self):
452
+ // p1 = pool_.Malloc(kRoundSize * 4)
453
+ // del p1
454
+ // with self.stream:
455
+ // p2 = pool_.Malloc(kRoundSize * 2)
456
+ // assert(kRoundSize * 4, pool_.free_bytes())
457
+ // del p2
458
+
459
+ void TestGetTotalBytes() {
460
+ intptr_t p1 = pool_->Malloc(kRoundSize * 2);
461
+ assert(kRoundSize * 2 == pool_->GetTotalBytes());
462
+ intptr_t p2 = pool_->Malloc(kRoundSize * 4);
463
+ assert(kRoundSize * 6 == pool_->GetTotalBytes());
464
+ pool_->Free(p1);
465
+ assert(kRoundSize * 6 == pool_->GetTotalBytes());
466
+ pool_->Free(p2);
467
+ assert(kRoundSize * 6 == pool_->GetTotalBytes());
468
+ intptr_t p3 = pool_->Malloc(kRoundSize * 1);
469
+ assert(kRoundSize * 6 == pool_->GetTotalBytes());
470
+ pool_->Free(p3);
471
+ }
472
+
473
+ // def test_total_bytes_stream(self):
474
+ // p1 = pool_.Malloc(kRoundSize * 4)
475
+ // del p1
476
+ // with self.stream:
477
+ // p2 = pool_.Malloc(kRoundSize * 2)
478
+ // assert(kRoundSize * 6, pool_.total_bytes())
479
+ // del p2
480
+
481
+ };
482
+
483
+ class TestMemoryPool {
484
+ private:
485
+ std::shared_ptr<MemoryPool> pool_;
486
+ cudaStream_t stream_ptr_ = 0;
487
+
488
+ public:
489
+ TestMemoryPool() {}
490
+
491
+ void SetUp() {
492
+ pool_ = std::make_shared<MemoryPool>();
493
+ }
494
+
495
+ void TearDown() {
496
+ pool_.reset();
497
+ }
498
+
499
+ void Run() {
500
+ TearDown(); SetUp(); TestMalloc();
501
+ TearDown(); SetUp(); TestFree();
502
+ TearDown(); SetUp(); TestFreeAllBlocks();
503
+ TearDown(); SetUp(); TestGetNumFreeBlocks();
504
+ TearDown(); SetUp(); TestGetUsedBytes();
505
+ TearDown(); SetUp(); TestGetFreeBytes();
506
+ TearDown(); SetUp(); TestGetTotalBytes();
507
+ TearDown();
508
+ }
509
+
510
+ void TestMalloc() {
511
+ auto p = pool_->Malloc(1);
512
+ assert(0 != p);
513
+ }
514
+
515
+ void TestFree() {
516
+ auto p = pool_->Malloc(1);
517
+ pool_->Free(p);
518
+ }
519
+
520
+ void TestFreeAllBlocks() {
521
+ auto p = pool_->Malloc(1);
522
+ assert(pool_->GetNumFreeBlocks() == 0);
523
+ pool_->Free(p);
524
+ assert(pool_->GetNumFreeBlocks() == 1);
525
+ pool_->FreeAllBlocks();
526
+ assert(pool_->GetNumFreeBlocks() == 0);
527
+ }
528
+
529
+ void TestGetNumFreeBlocks() {
530
+ assert(0 == pool_->GetNumFreeBlocks());
531
+ }
532
+
533
+ void TestGetUsedBytes() {
534
+ assert(0 == pool_->GetUsedBytes());
535
+ }
536
+
537
+ void TestGetFreeBytes() {
538
+ assert(0 == pool_->GetFreeBytes());
539
+ }
540
+
541
+ void TestGetTotalBytes() {
542
+ assert(0 == pool_->GetTotalBytes());
543
+ }
544
+ };
545
+
546
+ } // namespace internal
547
+ } // namespace cumo
548
+
549
+ int main() {
550
+ cumo::internal::TestChunk{}.Run();
551
+ cumo::internal::TestSingleDeviceMemoryPool{}.Run();
552
+ cumo::internal::TestMemoryPool{}.Run();
553
+ return 0;
554
+ }