warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/native/bvh.cpp CHANGED
@@ -27,35 +27,34 @@ class MedianBVHBuilder
27
27
  {
28
28
  public:
29
29
 
30
- void build(BVH& bvh, const bounds3* items, int n);
30
+ void build(BVH& bvh, const vec3* lowers, const vec3* uppers, int n);
31
31
 
32
32
  private:
33
33
 
34
- bounds3 calc_bounds(const bounds3* bounds, const int* indices, int start, int end);
34
+ bounds3 calc_bounds(const vec3* lowers, const vec3* uppers, const int* indices, int start, int end);
35
35
 
36
- int partition_median(const bounds3* bounds, int* indices, int start, int end, bounds3 range_bounds);
37
- int partition_midpoint(const bounds3* bounds, int* indices, int start, int end, bounds3 range_bounds);
38
- int partition_sah(const bounds3* bounds, int* indices, int start, int end, bounds3 range_bounds);
36
+ int partition_median(const vec3* lowers, const vec3* uppers, int* indices, int start, int end, bounds3 range_bounds);
37
+ int partition_midpoint(const vec3* lowers, const vec3* uppers, int* indices, int start, int end, bounds3 range_bounds);
38
+ int partition_sah(const vec3* lowers, const vec3* uppers, int* indices, int start, int end, bounds3 range_bounds);
39
39
 
40
- int build_recursive(BVH& bvh, const bounds3* bounds, int* indices, int start, int end, int depth, int parent);
40
+ int build_recursive(BVH& bvh, const vec3* lowers, const vec3* uppers, int* indices, int start, int end, int depth, int parent);
41
41
  };
42
42
 
43
43
  //////////////////////////////////////////////////////////////////////
44
44
 
45
- void MedianBVHBuilder::build(BVH& bvh, const bounds3* items, int n)
45
+ void MedianBVHBuilder::build(BVH& bvh, const vec3* lowers, const vec3* uppers, int n)
46
46
  {
47
- memset(&bvh, 0, sizeof(BVH));
48
-
47
+ bvh.max_depth = 0;
49
48
  bvh.max_nodes = 2*n-1;
50
49
 
51
50
  bvh.node_lowers = new BVHPackedNodeHalf[bvh.max_nodes];
52
51
  bvh.node_uppers = new BVHPackedNodeHalf[bvh.max_nodes];
53
52
  bvh.node_parents = new int[bvh.max_nodes];
54
-
55
- bvh.num_nodes = 0;
56
-
53
+ bvh.node_counts = NULL;
54
+
57
55
  // root is always in first slot for top down builders
58
- bvh.root = 0;
56
+ bvh.root = new int[1];
57
+ bvh.root[0] = 0;
59
58
 
60
59
  if (n == 0)
61
60
  return;
@@ -64,35 +63,42 @@ void MedianBVHBuilder::build(BVH& bvh, const bounds3* items, int n)
64
63
  for (int i=0; i < n; ++i)
65
64
  indices[i] = i;
66
65
 
67
- build_recursive(bvh, items, &indices[0], 0, n, 0, -1);
66
+ build_recursive(bvh, lowers, uppers, &indices[0], 0, n, 0, -1);
68
67
  }
69
68
 
70
69
 
71
- bounds3 MedianBVHBuilder::calc_bounds(const bounds3* bounds, const int* indices, int start, int end)
70
+ bounds3 MedianBVHBuilder::calc_bounds(const vec3* lowers, const vec3* uppers, const int* indices, int start, int end)
72
71
  {
73
72
  bounds3 u;
74
73
 
75
74
  for (int i=start; i < end; ++i)
76
- u = bounds_union(u, bounds[indices[i]]);
75
+ {
76
+ u.add_point(lowers[indices[i]]);
77
+ u.add_point(uppers[indices[i]]);
78
+ }
77
79
 
78
80
  return u;
79
81
  }
80
82
 
81
83
  struct PartitionPredicateMedian
82
84
  {
83
- PartitionPredicateMedian(const bounds3* bounds, int a) : bounds(bounds), axis(a) {}
85
+ PartitionPredicateMedian(const vec3* lowers, const vec3* uppers, int a) : lowers(lowers), uppers(uppers), axis(a) {}
84
86
 
85
87
  bool operator()(int a, int b) const
86
88
  {
87
- return bounds[a].center()[axis] < bounds[b].center()[axis];
89
+ vec3 a_center = 0.5f*(lowers[a] + uppers[a]);
90
+ vec3 b_center = 0.5f*(lowers[b] + uppers[b]);
91
+
92
+ return a_center[axis] < b_center[axis];
88
93
  }
89
94
 
90
- const bounds3* bounds;
95
+ const vec3* lowers;
96
+ const vec3* uppers;
91
97
  int axis;
92
98
  };
93
99
 
94
100
 
95
- int MedianBVHBuilder::partition_median(const bounds3* bounds, int* indices, int start, int end, bounds3 range_bounds)
101
+ int MedianBVHBuilder::partition_median(const vec3* lowers, const vec3* uppers, int* indices, int start, int end, bounds3 range_bounds)
96
102
  {
97
103
  assert(end-start >= 2);
98
104
 
@@ -102,27 +108,31 @@ int MedianBVHBuilder::partition_median(const bounds3* bounds, int* indices, int
102
108
 
103
109
  const int k = (start+end)/2;
104
110
 
105
- std::nth_element(&indices[start], &indices[k], &indices[end], PartitionPredicateMedian(&bounds[0], axis));
111
+ std::nth_element(&indices[start], &indices[k], &indices[end], PartitionPredicateMedian(lowers, uppers, axis));
106
112
 
107
113
  return k;
108
114
  }
109
115
 
110
116
  struct PartitionPredictateMidPoint
111
117
  {
112
- PartitionPredictateMidPoint(const bounds3* bounds, int a, float m) : bounds(bounds), axis(a), mid(m) {}
118
+ PartitionPredictateMidPoint(const vec3* lowers, const vec3* uppers, int a, float m) : lowers(lowers), uppers(uppers), axis(a), mid(m) {}
113
119
 
114
120
  bool operator()(int index) const
115
121
  {
116
- return bounds[index].center()[axis] <= mid;
122
+ vec3 center = 0.5f*(lowers[index] + uppers[index]);
123
+
124
+ return center[axis] <= mid;
117
125
  }
118
126
 
119
- const bounds3* bounds;
127
+ const vec3* lowers;
128
+ const vec3* uppers;
129
+
120
130
  int axis;
121
131
  float mid;
122
132
  };
123
133
 
124
134
 
125
- int MedianBVHBuilder::partition_midpoint(const bounds3* bounds, int* indices, int start, int end, bounds3 range_bounds)
135
+ int MedianBVHBuilder::partition_midpoint(const vec3* lowers, const vec3* uppers, int* indices, int start, int end, bounds3 range_bounds)
126
136
  {
127
137
  assert(end-start >= 2);
128
138
 
@@ -132,7 +142,7 @@ int MedianBVHBuilder::partition_midpoint(const bounds3* bounds, int* indices, in
132
142
  int axis = longest_axis(edges);
133
143
  float mid = center[axis];
134
144
 
135
- int* upper = std::partition(indices+start, indices+end, PartitionPredictateMidPoint(&bounds[0], axis, mid));
145
+ int* upper = std::partition(indices+start, indices+end, PartitionPredictateMidPoint(lowers, uppers, axis, mid));
136
146
 
137
147
  int k = upper-indices;
138
148
 
@@ -140,7 +150,6 @@ int MedianBVHBuilder::partition_midpoint(const bounds3* bounds, int* indices, in
140
150
  if (k == start || k == end)
141
151
  k = (start+end)/2;
142
152
 
143
-
144
153
  return k;
145
154
  }
146
155
 
@@ -200,7 +209,7 @@ int MedianBVHBuilder::partition_sah(const bounds3* bounds, int* indices, int sta
200
209
  }
201
210
  #endif
202
211
 
203
- int MedianBVHBuilder::build_recursive(BVH& bvh, const bounds3* bounds, int* indices, int start, int end, int depth, int parent)
212
+ int MedianBVHBuilder::build_recursive(BVH& bvh, const vec3* lowers, const vec3* uppers, int* indices, int start, int end, int depth, int parent)
204
213
  {
205
214
  assert(start < end);
206
215
 
@@ -212,7 +221,7 @@ int MedianBVHBuilder::build_recursive(BVH& bvh, const bounds3* bounds, int* indi
212
221
  if (depth > bvh.max_depth)
213
222
  bvh.max_depth = depth;
214
223
 
215
- bounds3 b = calc_bounds(bounds, indices, start, end);
224
+ bounds3 b = calc_bounds(lowers, uppers, indices, start, end);
216
225
 
217
226
  const int kMaxItemsPerLeaf = 1;
218
227
 
@@ -225,7 +234,7 @@ int MedianBVHBuilder::build_recursive(BVH& bvh, const bounds3* bounds, int* indi
225
234
  else
226
235
  {
227
236
  //int split = partition_midpoint(bounds, indices, start, end, b);
228
- int split = partition_median(bounds, indices, start, end, b);
237
+ int split = partition_median(lowers, uppers, indices, start, end, b);
229
238
  //int split = partition_sah(bounds, indices, start, end, b);
230
239
 
231
240
  if (split == start || split == end)
@@ -234,8 +243,8 @@ int MedianBVHBuilder::build_recursive(BVH& bvh, const bounds3* bounds, int* indi
234
243
  split = (start+end)/2;
235
244
  }
236
245
 
237
- int left_child = build_recursive(bvh, bounds, indices, start, split, depth+1, node_index);
238
- int right_child = build_recursive(bvh, bounds, indices, split, end, depth+1, node_index);
246
+ int left_child = build_recursive(bvh, lowers, uppers, indices, start, split, depth+1, node_index);
247
+ int right_child = build_recursive(bvh, lowers, uppers, indices, split, end, depth+1, node_index);
239
248
 
240
249
  bvh.node_lowers[node_index] = make_node(b.lower, left_child, false);
241
250
  bvh.node_uppers[node_index] = make_node(b.upper, right_child, false);
@@ -245,218 +254,8 @@ int MedianBVHBuilder::build_recursive(BVH& bvh, const bounds3* bounds, int* indi
245
254
  return node_index;
246
255
  }
247
256
 
248
- class LinearBVHBuilderCPU
249
- {
250
- public:
251
-
252
- void build(BVH& bvh, const bounds3* items, int n);
253
-
254
- private:
255
-
256
- // calculate Morton codes
257
- struct KeyIndexPair
258
- {
259
- uint32_t key;
260
- int index;
261
-
262
- inline bool operator < (const KeyIndexPair& rhs) const { return key < rhs.key; }
263
- };
264
-
265
- bounds3 calc_bounds(const bounds3* bounds, const KeyIndexPair* keys, int start, int end);
266
- int find_split(const KeyIndexPair* pairs, int start, int end);
267
- int build_recursive(BVH& bvh, const KeyIndexPair* keys, const bounds3* bounds, int start, int end, int depth);
268
-
269
- };
270
-
271
-
272
- // disable std::sort workaround for macOS error
273
- #if 0
274
- void LinearBVHBuilderCPU::build(BVH& bvh, const bounds3* items, int n)
275
- {
276
- memset(&bvh, 0, sizeof(BVH));
277
-
278
- bvh.max_nodes = 2*n-1;
279
-
280
- bvh.node_lowers = new BVHPackedNodeHalf[bvh.max_nodes];
281
- bvh.node_uppers = new BVHPackedNodeHalf[bvh.max_nodes];
282
- bvh.num_nodes = 0;
283
-
284
- // root is always in first slot for top down builders
285
- bvh.root = 0;
286
-
287
- std::vector<KeyIndexPair> keys;
288
- keys.reserve(n);
289
-
290
- bounds3 totalbounds3;
291
- for (int i=0; i < n; ++i)
292
- totalbounds3 = bounds_union(totalbounds3, items[i]);
293
-
294
- // ensure non-zero edge length in all dimensions
295
- totalbounds3.expand(0.001f);
296
-
297
- vec3 edges = totalbounds3.edges();
298
- vec3 invEdges = cw_div(vec3(1.0f), edges);
299
-
300
- for (int i=0; i < n; ++i)
301
- {
302
- vec3 center = items[i].center();
303
- vec3 local = cw_mul(center-totalbounds3.lower, invEdges);
304
257
 
305
- KeyIndexPair l;
306
- l.key = morton3<1024>(local.x, local.y, local.z);
307
- l.index = i;
308
-
309
- keys.push_back(l);
310
- }
311
-
312
- // sort by key
313
- std::sort(keys.begin(), keys.end());
314
-
315
- build_recursive(bvh, &keys[0], items, 0, n, 0);
316
-
317
- printf("Created BVH for %d items with %d nodes, max depth of %d\n", n, bvh.num_nodes, bvh.max_depth);
318
- }
319
- #endif
320
-
321
- inline bounds3 LinearBVHBuilderCPU::calc_bounds(const bounds3* bounds, const KeyIndexPair* keys, int start, int end)
322
- {
323
- bounds3 u;
324
-
325
- for (int i=start; i < end; ++i)
326
- u = bounds_union(u, bounds[keys[i].index]);
327
-
328
- return u;
329
- }
330
-
331
- inline int LinearBVHBuilderCPU::find_split(const KeyIndexPair* pairs, int start, int end)
332
- {
333
- if (pairs[start].key == pairs[end-1].key)
334
- return (start+end)/2;
335
-
336
- // find split point between keys, xor here means all bits
337
- // of the result are zero up until the first differing bit
338
- int common_prefix = clz(pairs[start].key ^ pairs[end-1].key);
339
-
340
- // use binary search to find the point at which this bit changes
341
- // from zero to a 1
342
- const int mask = 1 << (31-common_prefix);
343
-
344
- while (end-start > 0)
345
- {
346
- int index = (start+end)/2;
347
-
348
- if (pairs[index].key&mask)
349
- {
350
- end = index;
351
- }
352
- else
353
- start = index+1;
354
- }
355
-
356
- assert(start == end);
357
-
358
- return start;
359
- }
360
-
361
- int LinearBVHBuilderCPU::build_recursive(BVH& bvh, const KeyIndexPair* keys, const bounds3* bounds, int start, int end, int depth)
362
- {
363
- assert(start < end);
364
-
365
- const int n = end-start;
366
- const int nodeIndex = bvh.num_nodes++;
367
-
368
- assert(nodeIndex < bvh.max_nodes);
369
-
370
- if (depth > bvh.max_depth)
371
- bvh.max_depth = depth;
372
-
373
- bounds3 b = calc_bounds(bounds, keys, start, end);
374
-
375
- const int kMaxItemsPerLeaf = 1;
376
-
377
- if (n <= kMaxItemsPerLeaf)
378
- {
379
- bvh.node_lowers[nodeIndex] = make_node(b.lower, keys[start].index, true);
380
- bvh.node_uppers[nodeIndex] = make_node(b.upper, keys[start].index, false);
381
- }
382
- else
383
- {
384
- int split = find_split(keys, start, end);
385
-
386
- int leftChild = build_recursive(bvh, keys, bounds, start, split, depth+1);
387
- int rightChild = build_recursive(bvh, keys, bounds, split, end, depth+1);
388
-
389
- bvh.node_lowers[nodeIndex] = make_node(b.lower, leftChild, false);
390
- bvh.node_uppers[nodeIndex] = make_node(b.upper, rightChild, false);
391
- }
392
-
393
- return nodeIndex;
394
- }
395
-
396
-
397
-
398
- // create only happens on host currently, use bvh_clone() to transfer BVH To device
399
- BVH bvh_create(const bounds3* bounds, int num_bounds)
400
- {
401
- BVH bvh;
402
- memset(&bvh, 0, sizeof(bvh));
403
-
404
- MedianBVHBuilder builder;
405
- //LinearBVHBuilderCPU builder;
406
- builder.build(bvh, bounds, num_bounds);
407
-
408
- return bvh;
409
- }
410
-
411
- void bvh_destroy_host(BVH& bvh)
412
- {
413
- delete[] bvh.node_lowers;
414
- delete[] bvh.node_uppers;
415
- delete[] bvh.node_parents;
416
- delete[] bvh.bounds;
417
-
418
- bvh.node_lowers = NULL;
419
- bvh.node_uppers = NULL;
420
- bvh.max_nodes = 0;
421
- bvh.num_nodes = 0;
422
- bvh.num_bounds = 0;
423
- }
424
-
425
- void bvh_destroy_device(BVH& bvh)
426
- {
427
- ContextGuard guard(bvh.context);
428
-
429
- free_device(WP_CURRENT_CONTEXT, bvh.node_lowers); bvh.node_lowers = NULL;
430
- free_device(WP_CURRENT_CONTEXT, bvh.node_uppers); bvh.node_uppers = NULL;
431
- free_device(WP_CURRENT_CONTEXT, bvh.node_parents); bvh.node_parents = NULL;
432
- free_device(WP_CURRENT_CONTEXT, bvh.node_counts); bvh.node_counts = NULL;
433
- free_device(WP_CURRENT_CONTEXT, bvh.bounds); bvh.bounds = NULL;
434
- }
435
-
436
- BVH bvh_clone(void* context, const BVH& bvh_host)
437
- {
438
- ContextGuard guard(context);
439
-
440
- BVH bvh_device = bvh_host;
441
-
442
- bvh_device.context = context ? context : cuda_context_get_current();
443
-
444
- bvh_device.node_lowers = (BVHPackedNodeHalf*)alloc_device(WP_CURRENT_CONTEXT, sizeof(BVHPackedNodeHalf)*bvh_host.max_nodes);
445
- bvh_device.node_uppers = (BVHPackedNodeHalf*)alloc_device(WP_CURRENT_CONTEXT, sizeof(BVHPackedNodeHalf)*bvh_host.max_nodes);
446
- bvh_device.node_parents = (int*)alloc_device(WP_CURRENT_CONTEXT, sizeof(int)*bvh_host.max_nodes);
447
- bvh_device.node_counts = (int*)alloc_device(WP_CURRENT_CONTEXT, sizeof(int)*bvh_host.max_nodes);
448
- bvh_device.bounds = (bounds3*)alloc_device(WP_CURRENT_CONTEXT, sizeof(bounds3)*bvh_host.num_bounds);
449
-
450
- // copy host data to device
451
- memcpy_h2d(WP_CURRENT_CONTEXT, bvh_device.node_lowers, bvh_host.node_lowers, sizeof(BVHPackedNodeHalf)*bvh_host.max_nodes);
452
- memcpy_h2d(WP_CURRENT_CONTEXT, bvh_device.node_uppers, bvh_host.node_uppers, sizeof(BVHPackedNodeHalf)*bvh_host.max_nodes);
453
- memcpy_h2d(WP_CURRENT_CONTEXT, bvh_device.node_parents, bvh_host.node_parents, sizeof(int)*bvh_host.max_nodes);
454
- memcpy_h2d(WP_CURRENT_CONTEXT, bvh_device.bounds, bvh_host.bounds, sizeof(bounds3)*bvh_host.num_bounds);
455
-
456
- return bvh_device;
457
- }
458
-
459
- void bvh_refit_recursive(BVH& bvh, int index, const bounds3* bounds)
258
+ void bvh_refit_recursive(BVH& bvh, int index)
460
259
  {
461
260
  BVHPackedNodeHalf& lower = bvh.node_lowers[index];
462
261
  BVHPackedNodeHalf& upper = bvh.node_uppers[index];
@@ -465,16 +264,17 @@ void bvh_refit_recursive(BVH& bvh, int index, const bounds3* bounds)
465
264
  {
466
265
  const int leaf_index = lower.i;
467
266
 
468
- (vec3&)lower = bounds[leaf_index].lower;
469
- (vec3&)upper = bounds[leaf_index].upper;
267
+ // update leaf from items
268
+ (vec3&)lower = bvh.item_lowers[leaf_index];
269
+ (vec3&)upper = bvh.item_uppers[leaf_index];
470
270
  }
471
271
  else
472
272
  {
473
273
  int left_index = lower.i;
474
274
  int right_index = upper.i;
475
275
 
476
- bvh_refit_recursive(bvh, left_index, bounds);
477
- bvh_refit_recursive(bvh, right_index, bounds);
276
+ bvh_refit_recursive(bvh, left_index);
277
+ bvh_refit_recursive(bvh, right_index);
478
278
 
479
279
  // compute union of children
480
280
  const vec3& left_lower = (vec3&)bvh.node_lowers[left_index];
@@ -493,9 +293,9 @@ void bvh_refit_recursive(BVH& bvh, int index, const bounds3* bounds)
493
293
  }
494
294
  }
495
295
 
496
- void bvh_refit_host(BVH& bvh, const bounds3* b)
296
+ void bvh_refit_host(BVH& bvh)
497
297
  {
498
- bvh_refit_recursive(bvh, 0, b);
298
+ bvh_refit_recursive(bvh, 0);
499
299
  }
500
300
 
501
301
 
@@ -538,87 +338,46 @@ void bvh_rem_descriptor(uint64_t id)
538
338
 
539
339
  }
540
340
 
341
+
342
+ void bvh_destroy_host(BVH& bvh)
343
+ {
344
+ delete[] bvh.node_lowers;
345
+ delete[] bvh.node_uppers;
346
+ delete[] bvh.node_parents;
347
+ delete[] bvh.root;
348
+
349
+ bvh.node_lowers = NULL;
350
+ bvh.node_uppers = NULL;
351
+ bvh.node_parents = NULL;
352
+ bvh.root = NULL;
353
+
354
+ bvh.max_nodes = 0;
355
+ bvh.num_items = 0;
356
+ }
357
+
541
358
  } // namespace wp
542
359
 
543
- uint64_t bvh_create_host(vec3* lowers, vec3* uppers, int num_bounds)
360
+ uint64_t bvh_create_host(vec3* lowers, vec3* uppers, int num_items)
544
361
  {
545
362
  BVH* bvh = new BVH();
546
363
  memset(bvh, 0, sizeof(BVH));
547
364
 
548
365
  bvh->context = NULL;
549
366
 
550
- bvh->lowers = lowers;
551
- bvh->uppers = uppers;
552
- bvh->num_bounds = num_bounds;
553
-
554
- bvh->bounds = new bounds3[num_bounds];
555
-
556
- for (int i=0; i < num_bounds; ++i)
557
- {
558
- bvh->bounds[i].lower = lowers[i];
559
- bvh->bounds[i].upper = uppers[i];
560
- }
367
+ bvh->item_lowers = lowers;
368
+ bvh->item_uppers = uppers;
369
+ bvh->num_items = num_items;
561
370
 
562
371
  MedianBVHBuilder builder;
563
- builder.build(*bvh, bvh->bounds, num_bounds);
372
+ builder.build(*bvh, lowers, uppers, num_items);
564
373
 
565
374
  return (uint64_t)bvh;
566
375
  }
567
376
 
568
- uint64_t bvh_create_device(void* context, vec3* lowers, vec3* uppers, int num_bounds)
569
- {
570
- ContextGuard guard(context);
571
-
572
- // todo: BVH creation only on CPU at the moment so temporarily bring all the data back to host
573
- vec3* lowers_host = (vec3*)alloc_host(sizeof(vec3)*num_bounds);
574
- vec3* uppers_host = (vec3*)alloc_host(sizeof(vec3)*num_bounds);
575
- bounds3* bounds_host = (bounds3*)alloc_host(sizeof(bounds3)*num_bounds);
576
-
577
- memcpy_d2h(WP_CURRENT_CONTEXT, lowers_host, lowers, sizeof(vec3)*num_bounds);
578
- memcpy_d2h(WP_CURRENT_CONTEXT, uppers_host, uppers, sizeof(vec3)*num_bounds);
579
- cuda_context_synchronize(WP_CURRENT_CONTEXT);
580
-
581
- for (int i=0; i < num_bounds; ++i)
582
- {
583
- bounds_host[i] = bounds3();
584
- bounds_host[i].lower = lowers_host[i];
585
- bounds_host[i].upper = uppers_host[i];
586
- }
587
-
588
- BVH bvh_host = bvh_create(bounds_host, num_bounds);
589
- bvh_host.context = context ? context : cuda_context_get_current();
590
- bvh_host.bounds = bounds_host;
591
- bvh_host.num_bounds = num_bounds;
592
- BVH bvh_device_clone = bvh_clone(WP_CURRENT_CONTEXT, bvh_host);
593
-
594
- bvh_device_clone.lowers = lowers; // managed by the user
595
- bvh_device_clone.uppers = uppers; // managed by the user
596
-
597
- BVH* bvh_device = (BVH*)alloc_device(WP_CURRENT_CONTEXT, sizeof(BVH));
598
- memcpy_h2d(WP_CURRENT_CONTEXT, bvh_device, &bvh_device_clone, sizeof(BVH));
599
-
600
- bvh_destroy_host(bvh_host);
601
- free_host(lowers_host);
602
- free_host(uppers_host);
603
-
604
- uint64_t bvh_id = (uint64_t)bvh_device;
605
- bvh_add_descriptor(bvh_id, bvh_device_clone);
606
-
607
- return bvh_id;
608
- }
609
-
610
377
  void bvh_refit_host(uint64_t id)
611
378
  {
612
379
  BVH* bvh = (BVH*)(id);
613
-
614
- for (int i=0; i < bvh->num_bounds; ++i)
615
- {
616
- bvh->bounds[i] = bounds3();
617
- bvh->bounds[i].lower = bvh->lowers[i];
618
- bvh->bounds[i].upper = bvh->uppers[i];
619
- }
620
-
621
- bvh_refit_host(*bvh, bvh->bounds);
380
+ bvh_refit_host(*bvh);
622
381
  }
623
382
 
624
383
  void bvh_destroy_host(uint64_t id)
@@ -629,23 +388,11 @@ void bvh_destroy_host(uint64_t id)
629
388
  }
630
389
 
631
390
 
632
- void bvh_destroy_device(uint64_t id)
633
- {
634
- BVH bvh;
635
- if (bvh_get_descriptor(id, bvh))
636
- {
637
- bvh_destroy_device(bvh);
638
- mesh_rem_descriptor(id);
639
- }
640
- }
641
-
642
391
  // stubs for non-CUDA platforms
643
392
  #if !WP_ENABLE_CUDA
644
393
 
645
- void bvh_refit_device(uint64_t id)
646
- {
647
- }
648
-
649
-
394
+ uint64_t bvh_create_device(void* context, wp::vec3* lowers, wp::vec3* uppers, int num_items) { return 0; }
395
+ void bvh_refit_device(uint64_t id) {}
396
+ void bvh_destroy_device(uint64_t id) {}
650
397
 
651
398
  #endif // !WP_ENABLE_CUDA