warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (300) hide show
  1. warp/__init__.py +10 -4
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +5 -3
  6. warp/build_dll.py +29 -9
  7. warp/builtins.py +868 -507
  8. warp/codegen.py +1074 -638
  9. warp/config.py +3 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +715 -222
  12. warp/fabric.py +326 -0
  13. warp/fem/__init__.py +27 -0
  14. warp/fem/cache.py +389 -0
  15. warp/fem/dirichlet.py +181 -0
  16. warp/fem/domain.py +263 -0
  17. warp/fem/field/__init__.py +101 -0
  18. warp/fem/field/field.py +149 -0
  19. warp/fem/field/nodal_field.py +299 -0
  20. warp/fem/field/restriction.py +21 -0
  21. warp/fem/field/test.py +181 -0
  22. warp/fem/field/trial.py +183 -0
  23. warp/fem/geometry/__init__.py +19 -0
  24. warp/fem/geometry/closest_point.py +70 -0
  25. warp/fem/geometry/deformed_geometry.py +271 -0
  26. warp/fem/geometry/element.py +744 -0
  27. warp/fem/geometry/geometry.py +186 -0
  28. warp/fem/geometry/grid_2d.py +373 -0
  29. warp/fem/geometry/grid_3d.py +435 -0
  30. warp/fem/geometry/hexmesh.py +953 -0
  31. warp/fem/geometry/partition.py +376 -0
  32. warp/fem/geometry/quadmesh_2d.py +532 -0
  33. warp/fem/geometry/tetmesh.py +840 -0
  34. warp/fem/geometry/trimesh_2d.py +577 -0
  35. warp/fem/integrate.py +1616 -0
  36. warp/fem/operator.py +191 -0
  37. warp/fem/polynomial.py +213 -0
  38. warp/fem/quadrature/__init__.py +2 -0
  39. warp/fem/quadrature/pic_quadrature.py +245 -0
  40. warp/fem/quadrature/quadrature.py +294 -0
  41. warp/fem/space/__init__.py +292 -0
  42. warp/fem/space/basis_space.py +489 -0
  43. warp/fem/space/collocated_function_space.py +105 -0
  44. warp/fem/space/dof_mapper.py +236 -0
  45. warp/fem/space/function_space.py +145 -0
  46. warp/fem/space/grid_2d_function_space.py +267 -0
  47. warp/fem/space/grid_3d_function_space.py +306 -0
  48. warp/fem/space/hexmesh_function_space.py +352 -0
  49. warp/fem/space/partition.py +350 -0
  50. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  51. warp/fem/space/restriction.py +160 -0
  52. warp/fem/space/shape/__init__.py +15 -0
  53. warp/fem/space/shape/cube_shape_function.py +738 -0
  54. warp/fem/space/shape/shape_function.py +103 -0
  55. warp/fem/space/shape/square_shape_function.py +611 -0
  56. warp/fem/space/shape/tet_shape_function.py +567 -0
  57. warp/fem/space/shape/triangle_shape_function.py +429 -0
  58. warp/fem/space/tetmesh_function_space.py +292 -0
  59. warp/fem/space/topology.py +295 -0
  60. warp/fem/space/trimesh_2d_function_space.py +221 -0
  61. warp/fem/types.py +77 -0
  62. warp/fem/utils.py +495 -0
  63. warp/native/array.h +147 -44
  64. warp/native/builtin.h +122 -149
  65. warp/native/bvh.cpp +73 -325
  66. warp/native/bvh.cu +406 -23
  67. warp/native/bvh.h +34 -43
  68. warp/native/clang/clang.cpp +13 -8
  69. warp/native/crt.h +2 -0
  70. warp/native/cuda_crt.h +5 -0
  71. warp/native/cuda_util.cpp +15 -3
  72. warp/native/cuda_util.h +3 -1
  73. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  74. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  75. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  76. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  77. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  78. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  79. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  80. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  133. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  134. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  135. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  136. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  137. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  138. warp/native/cutlass_gemm.cu +5 -3
  139. warp/native/exports.h +1240 -952
  140. warp/native/fabric.h +228 -0
  141. warp/native/hashgrid.cpp +4 -4
  142. warp/native/hashgrid.h +22 -2
  143. warp/native/intersect.h +22 -7
  144. warp/native/intersect_adj.h +8 -8
  145. warp/native/intersect_tri.h +1 -1
  146. warp/native/marching.cu +157 -161
  147. warp/native/mat.h +80 -19
  148. warp/native/matnn.h +2 -2
  149. warp/native/mesh.cpp +33 -108
  150. warp/native/mesh.cu +114 -23
  151. warp/native/mesh.h +446 -46
  152. warp/native/noise.h +272 -329
  153. warp/native/quat.h +51 -8
  154. warp/native/rand.h +45 -35
  155. warp/native/range.h +6 -2
  156. warp/native/reduce.cpp +1 -1
  157. warp/native/reduce.cu +10 -12
  158. warp/native/runlength_encode.cu +6 -10
  159. warp/native/scan.cu +8 -11
  160. warp/native/sparse.cpp +4 -4
  161. warp/native/sparse.cu +164 -154
  162. warp/native/spatial.h +2 -2
  163. warp/native/temp_buffer.h +14 -30
  164. warp/native/vec.h +107 -23
  165. warp/native/volume.h +120 -0
  166. warp/native/warp.cpp +560 -30
  167. warp/native/warp.cu +431 -44
  168. warp/native/warp.h +13 -4
  169. warp/optim/__init__.py +1 -0
  170. warp/optim/linear.py +922 -0
  171. warp/optim/sgd.py +92 -0
  172. warp/render/render_opengl.py +335 -119
  173. warp/render/render_usd.py +11 -11
  174. warp/sim/__init__.py +2 -2
  175. warp/sim/articulation.py +385 -185
  176. warp/sim/collide.py +8 -0
  177. warp/sim/import_mjcf.py +297 -106
  178. warp/sim/import_urdf.py +389 -210
  179. warp/sim/import_usd.py +198 -97
  180. warp/sim/inertia.py +17 -18
  181. warp/sim/integrator_euler.py +14 -8
  182. warp/sim/integrator_xpbd.py +158 -16
  183. warp/sim/model.py +795 -291
  184. warp/sim/render.py +3 -3
  185. warp/sim/utils.py +3 -0
  186. warp/sparse.py +640 -150
  187. warp/stubs.py +606 -267
  188. warp/tape.py +61 -10
  189. warp/tests/__main__.py +3 -6
  190. warp/tests/assets/curlnoise_golden.npy +0 -0
  191. warp/tests/assets/pnoise_golden.npy +0 -0
  192. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  193. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  194. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  195. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  196. warp/tests/aux_test_unresolved_func.py +14 -0
  197. warp/tests/aux_test_unresolved_symbol.py +14 -0
  198. warp/tests/disabled_kinematics.py +239 -0
  199. warp/tests/run_coverage_serial.py +31 -0
  200. warp/tests/test_adam.py +103 -106
  201. warp/tests/test_arithmetic.py +128 -74
  202. warp/tests/test_array.py +212 -97
  203. warp/tests/test_array_reduce.py +57 -23
  204. warp/tests/test_atomic.py +64 -28
  205. warp/tests/test_bool.py +99 -0
  206. warp/tests/test_builtins_resolution.py +1292 -0
  207. warp/tests/test_bvh.py +42 -18
  208. warp/tests/test_closest_point_edge_edge.py +54 -57
  209. warp/tests/test_codegen.py +208 -130
  210. warp/tests/test_compile_consts.py +28 -20
  211. warp/tests/test_conditional.py +108 -24
  212. warp/tests/test_copy.py +10 -12
  213. warp/tests/test_ctypes.py +112 -88
  214. warp/tests/test_dense.py +21 -14
  215. warp/tests/test_devices.py +98 -0
  216. warp/tests/test_dlpack.py +75 -75
  217. warp/tests/test_examples.py +277 -0
  218. warp/tests/test_fabricarray.py +955 -0
  219. warp/tests/test_fast_math.py +15 -11
  220. warp/tests/test_fem.py +1271 -0
  221. warp/tests/test_fp16.py +53 -19
  222. warp/tests/test_func.py +187 -86
  223. warp/tests/test_generics.py +194 -49
  224. warp/tests/test_grad.py +178 -109
  225. warp/tests/test_grad_customs.py +176 -0
  226. warp/tests/test_hash_grid.py +52 -37
  227. warp/tests/test_import.py +10 -23
  228. warp/tests/test_indexedarray.py +32 -31
  229. warp/tests/test_intersect.py +18 -9
  230. warp/tests/test_large.py +141 -0
  231. warp/tests/test_launch.py +14 -41
  232. warp/tests/test_lerp.py +64 -65
  233. warp/tests/test_linear_solvers.py +154 -0
  234. warp/tests/test_lvalue.py +493 -0
  235. warp/tests/test_marching_cubes.py +12 -13
  236. warp/tests/test_mat.py +517 -2898
  237. warp/tests/test_mat_lite.py +115 -0
  238. warp/tests/test_mat_scalar_ops.py +2889 -0
  239. warp/tests/test_math.py +103 -9
  240. warp/tests/test_matmul.py +305 -69
  241. warp/tests/test_matmul_lite.py +410 -0
  242. warp/tests/test_mesh.py +71 -14
  243. warp/tests/test_mesh_query_aabb.py +41 -25
  244. warp/tests/test_mesh_query_point.py +140 -22
  245. warp/tests/test_mesh_query_ray.py +39 -22
  246. warp/tests/test_mlp.py +30 -22
  247. warp/tests/test_model.py +92 -89
  248. warp/tests/test_modules_lite.py +39 -0
  249. warp/tests/test_multigpu.py +88 -114
  250. warp/tests/test_noise.py +12 -11
  251. warp/tests/test_operators.py +16 -20
  252. warp/tests/test_options.py +11 -11
  253. warp/tests/test_pinned.py +17 -18
  254. warp/tests/test_print.py +32 -11
  255. warp/tests/test_quat.py +275 -129
  256. warp/tests/test_rand.py +18 -16
  257. warp/tests/test_reload.py +38 -34
  258. warp/tests/test_rounding.py +50 -43
  259. warp/tests/test_runlength_encode.py +168 -20
  260. warp/tests/test_smoothstep.py +9 -11
  261. warp/tests/test_snippet.py +143 -0
  262. warp/tests/test_sparse.py +261 -63
  263. warp/tests/test_spatial.py +276 -243
  264. warp/tests/test_streams.py +110 -85
  265. warp/tests/test_struct.py +268 -63
  266. warp/tests/test_tape.py +39 -21
  267. warp/tests/test_torch.py +118 -89
  268. warp/tests/test_transient_module.py +12 -13
  269. warp/tests/test_types.py +614 -0
  270. warp/tests/test_utils.py +494 -0
  271. warp/tests/test_vec.py +354 -2050
  272. warp/tests/test_vec_lite.py +73 -0
  273. warp/tests/test_vec_scalar_ops.py +2099 -0
  274. warp/tests/test_volume.py +457 -293
  275. warp/tests/test_volume_write.py +124 -134
  276. warp/tests/unittest_serial.py +35 -0
  277. warp/tests/unittest_suites.py +341 -0
  278. warp/tests/unittest_utils.py +568 -0
  279. warp/tests/unused_test_misc.py +71 -0
  280. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  281. warp/thirdparty/appdirs.py +36 -45
  282. warp/thirdparty/unittest_parallel.py +549 -0
  283. warp/torch.py +9 -6
  284. warp/types.py +1089 -366
  285. warp/utils.py +93 -387
  286. warp_lang-0.11.0.dist-info/METADATA +238 -0
  287. warp_lang-0.11.0.dist-info/RECORD +332 -0
  288. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  289. warp/tests/test_all.py +0 -219
  290. warp/tests/test_array_scan.py +0 -60
  291. warp/tests/test_base.py +0 -208
  292. warp/tests/test_unresolved_func.py +0 -7
  293. warp/tests/test_unresolved_symbol.py +0 -7
  294. warp_lang-0.10.1.dist-info/METADATA +0 -21
  295. warp_lang-0.10.1.dist-info/RECORD +0 -188
  296. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  297. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  298. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  299. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
  300. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/native/quat.h CHANGED
@@ -19,6 +19,15 @@ struct quat_t
19
19
  // zero constructor for adjoint variable initialization
20
20
  inline CUDA_CALLABLE quat_t(Type x=Type(0), Type y=Type(0), Type z=Type(0), Type w=Type(0)) : x(x), y(y), z(z), w(w) {}
21
21
  explicit inline CUDA_CALLABLE quat_t(const vec_t<3,Type>& v, Type w=Type(0)) : x(v[0]), y(v[1]), z(v[2]), w(w) {}
22
+
23
+ template<typename OtherType>
24
+ explicit inline CUDA_CALLABLE quat_t(const quat_t<OtherType>& other)
25
+ {
26
+ x = static_cast<Type>(other.x);
27
+ y = static_cast<Type>(other.y);
28
+ z = static_cast<Type>(other.z);
29
+ w = static_cast<Type>(other.w);
30
+ }
22
31
 
23
32
  // imaginary part
24
33
  Type x;
@@ -73,7 +82,17 @@ inline CUDA_CALLABLE void adj_quat_t(const vec_t<3,Type>& v, Type w, vec_t<3,Typ
73
82
  adj_v[0] += adj_ret.x;
74
83
  adj_v[1] += adj_ret.y;
75
84
  adj_v[2] += adj_ret.z;
76
- adj_w += adj_ret.w;
85
+ adj_w += adj_ret.w;
86
+ }
87
+
88
+ // casting constructor adjoint
89
+ template<typename Type, typename OtherType>
90
+ inline CUDA_CALLABLE void adj_quat_t(const quat_t<OtherType>& other, quat_t<OtherType>& adj_other, const quat_t<Type>& adj_ret)
91
+ {
92
+ adj_other.x += static_cast<OtherType>(adj_ret.x);
93
+ adj_other.y += static_cast<OtherType>(adj_ret.y);
94
+ adj_other.z += static_cast<OtherType>(adj_ret.z);
95
+ adj_other.w += static_cast<OtherType>(adj_ret.w);
77
96
  }
78
97
 
79
98
  // forward methods
@@ -206,12 +225,24 @@ inline CUDA_CALLABLE quat_t<Type> div(quat_t<Type> q, Type s)
206
225
  return quat_t<Type>(q.x/s, q.y/s, q.z/s, q.w/s);
207
226
  }
208
227
 
228
+ template<typename Type>
229
+ inline CUDA_CALLABLE quat_t<Type> div(Type s, quat_t<Type> q)
230
+ {
231
+ return quat_t<Type>(s/q.x, s/q.y, s/q.z, s/q.w);
232
+ }
233
+
209
234
  template<typename Type>
210
235
  inline CUDA_CALLABLE quat_t<Type> operator / (quat_t<Type> a, Type s)
211
236
  {
212
237
  return div(a,s);
213
238
  }
214
239
 
240
+ template<typename Type>
241
+ inline CUDA_CALLABLE quat_t<Type> operator / (Type s, quat_t<Type> a)
242
+ {
243
+ return div(s,a);
244
+ }
245
+
215
246
  template<typename Type>
216
247
  inline CUDA_CALLABLE quat_t<Type> operator*(Type s, const quat_t<Type>& a)
217
248
  {
@@ -321,7 +352,7 @@ inline CUDA_CALLABLE quat_t<Type> quat_from_matrix(const mat_t<3,3,Type>& m)
321
352
  }
322
353
 
323
354
  template<typename Type>
324
- inline CUDA_CALLABLE Type index(const quat_t<Type>& a, int idx)
355
+ inline CUDA_CALLABLE Type extract(const quat_t<Type>& a, int idx)
325
356
  {
326
357
  #if FP_CHECK
327
358
  if (idx < 0 || idx > 3)
@@ -357,7 +388,7 @@ CUDA_CALLABLE inline void adj_lerp(const quat_t<Type>& a, const quat_t<Type>& b,
357
388
  }
358
389
 
359
390
  template<typename Type>
360
- inline CUDA_CALLABLE void adj_index(const quat_t<Type>& a, int idx, quat_t<Type>& adj_a, int & adj_idx, Type & adj_ret)
391
+ inline CUDA_CALLABLE void adj_extract(const quat_t<Type>& a, int idx, quat_t<Type>& adj_a, int & adj_idx, Type & adj_ret)
361
392
  {
362
393
  #if FP_CHECK
363
394
  if (idx < 0 || idx > 3)
@@ -367,7 +398,7 @@ inline CUDA_CALLABLE void adj_index(const quat_t<Type>& a, int idx, quat_t<Type>
367
398
  }
368
399
  #endif
369
400
 
370
- // See wp::index(const quat_t<Type>& a, int idx) note
401
+ // See wp::extract(const quat_t<Type>& a, int idx) note
371
402
  if (idx == 0) {adj_a.x += adj_ret;}
372
403
  else if (idx == 1) {adj_a.y += adj_ret;}
373
404
  else if (idx == 2) {adj_a.z += adj_ret;}
@@ -504,9 +535,14 @@ inline CUDA_CALLABLE void tensordot(const quat_t<Type>& a, const quat_t<Type>& b
504
535
  }
505
536
 
506
537
  template<typename Type>
507
- inline CUDA_CALLABLE void adj_length(const quat_t<Type>& a, quat_t<Type>& adj_a, const Type adj_ret)
538
+ inline CUDA_CALLABLE void adj_length(const quat_t<Type>& a, Type ret, quat_t<Type>& adj_a, const Type adj_ret)
508
539
  {
509
- adj_a += normalize(a)*adj_ret;
540
+ if (ret > Type(kEps))
541
+ {
542
+ Type inv_l = Type(1)/ret;
543
+
544
+ adj_a += quat_t<Type>(a.x*inv_l, a.y*inv_l, a.z*inv_l, a.w*inv_l) * adj_ret;
545
+ }
510
546
  }
511
547
 
512
548
  template<typename Type>
@@ -589,6 +625,13 @@ inline CUDA_CALLABLE void adj_div(quat_t<Type> a, Type s, quat_t<Type>& adj_a, T
589
625
  adj_a += adj_ret / s;
590
626
  }
591
627
 
628
+ template<typename Type>
629
+ inline CUDA_CALLABLE void adj_div(Type s, quat_t<Type> a, Type& adj_s, quat_t<Type>& adj_a, const quat_t<Type>& adj_ret)
630
+ {
631
+ adj_s -= dot(a, adj_ret)/ (s * s); // - a / s^2
632
+ adj_a += s / adj_ret;
633
+ }
634
+
592
635
  template<typename Type>
593
636
  inline CUDA_CALLABLE void adj_quat_rotate(const quat_t<Type>& q, const vec_t<3,Type>& p, quat_t<Type>& adj_q, vec_t<3,Type>& adj_p, const vec_t<3,Type>& adj_ret)
594
637
  {
@@ -658,7 +701,7 @@ inline CUDA_CALLABLE void adj_quat_rotate_inv(const quat_t<Type>& q, const vec_t
658
701
  }
659
702
 
660
703
  template<typename Type>
661
- inline CUDA_CALLABLE void adj_quat_slerp(const quat_t<Type>& q0, const quat_t<Type>& q1, Type t, quat_t<Type>& adj_q0, quat_t<Type>& adj_q1, Type& adj_t, const quat_t<Type>& adj_ret)
704
+ inline CUDA_CALLABLE void adj_quat_slerp(const quat_t<Type>& q0, const quat_t<Type>& q1, Type t, quat_t<Type>& ret, quat_t<Type>& adj_q0, quat_t<Type>& adj_q1, Type& adj_t, const quat_t<Type>& adj_ret)
662
705
  {
663
706
  vec_t<3,Type> axis;
664
707
  Type angle;
@@ -669,7 +712,7 @@ inline CUDA_CALLABLE void adj_quat_slerp(const quat_t<Type>& q0, const quat_t<Ty
669
712
  angle = angle * 0.5;
670
713
 
671
714
  // adj_t
672
- adj_t += dot(mul(quat_slerp(q0, q1, t), quat_t<Type>(angle*axis[0], angle*axis[1], angle*axis[2], Type(0))), adj_ret);
715
+ adj_t += dot(mul(ret, quat_t<Type>(angle*axis[0], angle*axis[1], angle*axis[2], Type(0))), adj_ret);
673
716
 
674
717
  // adj_q0
675
718
  quat_t<Type> q_inc_x_q0;
warp/native/rand.h CHANGED
@@ -9,8 +9,8 @@
9
9
  # pragma once
10
10
  #include "array.h"
11
11
 
12
- #ifndef M_PI
13
- #define M_PI 3.14159265358979323846f
12
+ #ifndef M_PI_F
13
+ #define M_PI_F 3.14159265358979323846f
14
14
  #endif
15
15
 
16
16
  namespace wp
@@ -33,7 +33,7 @@ inline CUDA_CALLABLE float randf(uint32& state) { state = rand_pcg(state); retur
33
33
  inline CUDA_CALLABLE float randf(uint32& state, float min, float max) { return (max - min) * randf(state) + min; }
34
34
 
35
35
  // Box-Muller method
36
- inline CUDA_CALLABLE float randn(uint32& state) { return sqrt(-2.f * log(randf(state))) * cos(2.f * M_PI * randf(state)); }
36
+ inline CUDA_CALLABLE float randn(uint32& state) { return sqrt(-2.f * log(randf(state))) * cos(2.f * M_PI_F * randf(state)); }
37
37
 
38
38
  inline CUDA_CALLABLE void adj_rand_init(int seed, int& adj_seed, float adj_ret) {}
39
39
  inline CUDA_CALLABLE void adj_rand_init(int seed, int offset, int& adj_seed, int& adj_offset, float adj_ret) {}
@@ -55,14 +55,14 @@ inline CUDA_CALLABLE int sample_cdf(uint32& state, const array_t<float>& cdf)
55
55
  inline CUDA_CALLABLE vec2 sample_triangle(uint32& state)
56
56
  {
57
57
  float r = sqrt(randf(state));
58
- float u = 1.0 - r;
58
+ float u = 1.f - r;
59
59
  float v = randf(state) * r;
60
60
  return vec2(u, v);
61
61
  }
62
62
 
63
63
  inline CUDA_CALLABLE vec2 sample_unit_ring(uint32& state)
64
64
  {
65
- float theta = randf(state, 0.f, 2.f*M_PI);
65
+ float theta = randf(state, 0.f, 2.f*M_PI_F);
66
66
  float x = cos(theta);
67
67
  float y = sin(theta);
68
68
  return vec2(x, y);
@@ -71,7 +71,7 @@ inline CUDA_CALLABLE vec2 sample_unit_ring(uint32& state)
71
71
  inline CUDA_CALLABLE vec2 sample_unit_disk(uint32& state)
72
72
  {
73
73
  float r = sqrt(randf(state));
74
- float theta = randf(state, 0.f, 2.f*M_PI);
74
+ float theta = randf(state, 0.f, 2.f*M_PI_F);
75
75
  float x = r * cos(theta);
76
76
  float y = r * sin(theta);
77
77
  return vec2(x, y);
@@ -80,7 +80,7 @@ inline CUDA_CALLABLE vec2 sample_unit_disk(uint32& state)
80
80
  inline CUDA_CALLABLE vec3 sample_unit_sphere_surface(uint32& state)
81
81
  {
82
82
  float phi = acos(1.f - 2.f * randf(state));
83
- float theta = randf(state, 0.f, 2.f*M_PI);
83
+ float theta = randf(state, 0.f, 2.f*M_PI_F);
84
84
  float x = cos(theta) * sin(phi);
85
85
  float y = sin(theta) * sin(phi);
86
86
  float z = cos(phi);
@@ -90,7 +90,7 @@ inline CUDA_CALLABLE vec3 sample_unit_sphere_surface(uint32& state)
90
90
  inline CUDA_CALLABLE vec3 sample_unit_sphere(uint32& state)
91
91
  {
92
92
  float phi = acos(1.f - 2.f * randf(state));
93
- float theta = randf(state, 0.f, 2.f*M_PI);
93
+ float theta = randf(state, 0.f, 2.f*M_PI_F);
94
94
  float r = pow(randf(state), 1.f/3.f);
95
95
  float x = r * cos(theta) * sin(phi);
96
96
  float y = r * sin(theta) * sin(phi);
@@ -101,7 +101,7 @@ inline CUDA_CALLABLE vec3 sample_unit_sphere(uint32& state)
101
101
  inline CUDA_CALLABLE vec3 sample_unit_hemisphere_surface(uint32& state)
102
102
  {
103
103
  float phi = acos(1.f - randf(state));
104
- float theta = randf(state, 0.f, 2.f*M_PI);
104
+ float theta = randf(state, 0.f, 2.f*M_PI_F);
105
105
  float x = cos(theta) * sin(phi);
106
106
  float y = sin(theta) * sin(phi);
107
107
  float z = cos(phi);
@@ -111,7 +111,7 @@ inline CUDA_CALLABLE vec3 sample_unit_hemisphere_surface(uint32& state)
111
111
  inline CUDA_CALLABLE vec3 sample_unit_hemisphere(uint32& state)
112
112
  {
113
113
  float phi = acos(1.f - randf(state));
114
- float theta = randf(state, 0.f, 2.f*M_PI);
114
+ float theta = randf(state, 0.f, 2.f*M_PI_F);
115
115
  float r = pow(randf(state), 1.f/3.f);
116
116
  float x = r * cos(theta) * sin(phi);
117
117
  float y = r * sin(theta) * sin(phi);
@@ -134,6 +134,15 @@ inline CUDA_CALLABLE vec3 sample_unit_cube(uint32& state)
134
134
  return vec3(x, y, z);
135
135
  }
136
136
 
137
+ inline CUDA_CALLABLE vec4 sample_unit_hypercube(uint32& state)
138
+ {
139
+ float a = randf(state) - 0.5f;
140
+ float b = randf(state) - 0.5f;
141
+ float c = randf(state) - 0.5f;
142
+ float d = randf(state) - 0.5f;
143
+ return vec4(a, b, c, d);
144
+ }
145
+
137
146
  inline CUDA_CALLABLE void adj_sample_cdf(uint32& state, const array_t<float>& cdf, uint32& adj_state, array_t<float>& adj_cdf, const int& adj_ret) {}
138
147
  inline CUDA_CALLABLE void adj_sample_triangle(uint32& state, uint32& adj_state, const vec2& adj_ret) {}
139
148
  inline CUDA_CALLABLE void adj_sample_unit_ring(uint32& state, uint32& adj_state, const vec2& adj_ret) {}
@@ -144,6 +153,7 @@ inline CUDA_CALLABLE void adj_sample_unit_hemisphere_surface(uint32& state, uint
144
153
  inline CUDA_CALLABLE void adj_sample_unit_hemisphere(uint32& state, uint32& adj_state, const vec3& adj_ret) {}
145
154
  inline CUDA_CALLABLE void adj_sample_unit_square(uint32& state, uint32& adj_state, const vec2& adj_ret) {}
146
155
  inline CUDA_CALLABLE void adj_sample_unit_cube(uint32& state, uint32& adj_state, const vec3& adj_ret) {}
156
+ inline CUDA_CALLABLE void adj_sample_unit_hypercube(uint32& state, uint32& adj_state, const vec3& adj_ret) {}
147
157
 
148
158
  /*
149
159
  * log-gamma function to support some of these distributions. The
@@ -158,17 +168,17 @@ inline CUDA_CALLABLE float random_loggam(float x)
158
168
  float x0, x2, lg2pi, gl, gl0;
159
169
  uint32 n;
160
170
 
161
- const float a[10] = {8.333333333333333e-02, -2.777777777777778e-03,
162
- 7.936507936507937e-04, -5.952380952380952e-04,
163
- 8.417508417508418e-04, -1.917526917526918e-03,
164
- 6.410256410256410e-03, -2.955065359477124e-02,
165
- 1.796443723688307e-01, -1.39243221690590e+00};
171
+ const float a[10] = {8.333333333333333e-02f, -2.777777777777778e-03f,
172
+ 7.936507936507937e-04f, -5.952380952380952e-04f,
173
+ 8.417508417508418e-04f, -1.917526917526918e-03f,
174
+ 6.410256410256410e-03f, -2.955065359477124e-02f,
175
+ 1.796443723688307e-01f, -1.39243221690590e+00f};
166
176
 
167
- if ((x == 1.0) || (x == 2.0))
177
+ if ((x == 1.f) || (x == 2.f))
168
178
  {
169
- return 0.0;
179
+ return 0.f;
170
180
  }
171
- else if (x < 7.0)
181
+ else if (x < 7.f)
172
182
  {
173
183
  n = uint32((7 - x));
174
184
  }
@@ -178,8 +188,8 @@ inline CUDA_CALLABLE float random_loggam(float x)
178
188
  }
179
189
 
180
190
  x0 = x + float(n);
181
- x2 = (1.0 / x0) * (1.0 / x0);
182
- // log(2 * M_PI)
191
+ x2 = (1.f / x0) * (1.f / x0);
192
+ // log(2 * M_PI_F)
183
193
  lg2pi = 1.8378770664093453f;
184
194
  gl0 = a[9];
185
195
  for (int i = 8; i >= 0; i--)
@@ -187,13 +197,13 @@ inline CUDA_CALLABLE float random_loggam(float x)
187
197
  gl0 *= x2;
188
198
  gl0 += a[i];
189
199
  }
190
- gl = gl0 / x0 + 0.5 * lg2pi + (x0 - 0.5) * log(x0) - x0;
191
- if (x < 7.0)
200
+ gl = gl0 / x0 + 0.5f * lg2pi + (x0 - 0.5f) * log(x0) - x0;
201
+ if (x < 7.f)
192
202
  {
193
203
  for (uint32 k = 1; k <= n; k++)
194
204
  {
195
- gl -= log(x0 - 1.0);
196
- x0 -= 1.0;
205
+ gl -= log(x0 - 1.f);
206
+ x0 -= 1.f;
197
207
  }
198
208
  }
199
209
  return gl;
@@ -205,7 +215,7 @@ inline CUDA_CALLABLE uint32 random_poisson_mult(uint32& state, float lam) {
205
215
 
206
216
  enlam = exp(-lam);
207
217
  X = 0;
208
- prod = 1.0;
218
+ prod = 1.f;
209
219
 
210
220
  while (1)
211
221
  {
@@ -234,22 +244,22 @@ inline CUDA_CALLABLE uint32 random_poisson(uint32& state, float lam)
234
244
 
235
245
  slam = sqrt(lam);
236
246
  loglam = log(lam);
237
- b = 0.931 + 2.53 * slam;
238
- a = -0.059 + 0.02483 * b;
239
- invalpha = 1.1239 + 1.1328 / (b - 3.4);
240
- vr = 0.9277 - 3.6224 / (b - 2.0);
247
+ b = 0.931f + 2.53f * slam;
248
+ a = -0.059f + 0.02483f * b;
249
+ invalpha = 1.1239f + 1.1328f / (b - 3.4f);
250
+ vr = 0.9277f - 3.6224f / (b - 2.f);
241
251
 
242
252
  while (1)
243
253
  {
244
- U = randf(state) - 0.5;
254
+ U = randf(state) - 0.5f;
245
255
  V = randf(state);
246
- us = 0.5 - abs(U);
247
- k = uint32(floor((2 * a / us + b) * U + lam + 0.43));
248
- if ((us >= 0.07) && (V <= vr))
256
+ us = 0.5f - abs(U);
257
+ k = uint32(floor((2.f * a / us + b) * U + lam + 0.43f));
258
+ if ((us >= 0.07f) && (V <= vr))
249
259
  {
250
260
  return k;
251
261
  }
252
- if ((us < 0.013) && (V > us))
262
+ if ((us < 0.013f) && (V > us))
253
263
  {
254
264
  continue;
255
265
  }
@@ -261,7 +271,7 @@ inline CUDA_CALLABLE uint32 random_poisson(uint32& state, float lam)
261
271
  }
262
272
 
263
273
  /*
264
- * Adpated from NumPy's implementation
274
+ * Adapted from NumPy's implementation
265
275
  * Warp's state variable is half the precision of NumPy's so
266
276
  * poisson implementation uses half the precision used in NumPy's implementation
267
277
  * both precisions appear to converge in the statistical limit
warp/native/range.h CHANGED
@@ -15,8 +15,12 @@ namespace wp
15
15
  // represents a built-in Python range() loop
16
16
  struct range_t
17
17
  {
18
- CUDA_CALLABLE range_t() {}
19
- CUDA_CALLABLE range_t(int) {} // for backward pass
18
+ CUDA_CALLABLE range_t()
19
+ : start(0),
20
+ end(0),
21
+ step(0),
22
+ i(0)
23
+ {}
20
24
 
21
25
  int start;
22
26
  int end;
warp/native/reduce.cpp CHANGED
@@ -97,7 +97,7 @@ template <typename T> void array_sum_host(const T *ptr_a, T *ptr_out, int count,
97
97
  accumulate_func = dyn_len_sum<T>;
98
98
  }
99
99
 
100
- *ptr_out = 0.0f;
100
+ memset(ptr_out, 0, sizeof(T)*type_length);
101
101
  for (int i = 0; i < count; ++i)
102
102
  accumulate_func(ptr_a + i * stride, ptr_out, type_length);
103
103
  }
warp/native/reduce.cu CHANGED
@@ -103,23 +103,22 @@ template <typename T> void array_sum_device(const T *ptr_a, T *ptr_out, int coun
103
103
  assert((byte_stride % sizeof(T)) == 0);
104
104
  const int stride = byte_stride / sizeof(T);
105
105
 
106
- void *context = cuda_context_get_current();
107
- TemporaryBuffer &cub_temp = g_temp_buffer_map[context];
108
-
109
- ContextGuard guard(context);
106
+ ContextGuard guard(cuda_context_get_current());
110
107
  cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
111
108
 
112
109
  cub_strided_iterator<const T> ptr_strided{ptr_a, stride};
113
110
 
114
111
  size_t buff_size = 0;
115
112
  check_cuda(cub::DeviceReduce::Sum(nullptr, buff_size, ptr_strided, ptr_out, count, stream));
116
- cub_temp.ensure_fits(buff_size);
113
+ void* temp_buffer = alloc_temp_device(WP_CURRENT_CONTEXT, buff_size);
117
114
 
118
115
  for (int k = 0; k < type_length; ++k)
119
116
  {
120
117
  cub_strided_iterator<const T> ptr_strided{ptr_a + k, stride};
121
- check_cuda(cub::DeviceReduce::Sum(cub_temp.buffer, buff_size, ptr_strided, ptr_out + k, count, stream));
118
+ check_cuda(cub::DeviceReduce::Sum(temp_buffer, buff_size, ptr_strided, ptr_out + k, count, stream));
122
119
  }
120
+
121
+ free_temp_device(WP_CURRENT_CONTEXT, temp_buffer);
123
122
  }
124
123
 
125
124
  template <typename T>
@@ -265,19 +264,18 @@ void array_inner_device(const ElemT *ptr_a, const ElemT *ptr_b, ScalarT *ptr_out
265
264
  const int stride_a = byte_stride_a / sizeof(ElemT);
266
265
  const int stride_b = byte_stride_b / sizeof(ElemT);
267
266
 
268
- void *context = cuda_context_get_current();
269
- TemporaryBuffer &cub_temp = g_temp_buffer_map[context];
270
-
271
- ContextGuard guard(context);
267
+ ContextGuard guard(cuda_context_get_current());
272
268
  cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
273
269
 
274
270
  cub_inner_product_iterator<ElemT, ScalarT> inner_iterator{ptr_a, ptr_b, stride_a, stride_b, type_length};
275
271
 
276
272
  size_t buff_size = 0;
277
273
  check_cuda(cub::DeviceReduce::Sum(nullptr, buff_size, inner_iterator, ptr_out, count, stream));
278
- cub_temp.ensure_fits(buff_size);
274
+ void* temp_buffer = alloc_temp_device(WP_CURRENT_CONTEXT, buff_size);
275
+
276
+ check_cuda(cub::DeviceReduce::Sum(temp_buffer, buff_size, inner_iterator, ptr_out, count, stream));
279
277
 
280
- check_cuda(cub::DeviceReduce::Sum(cub_temp.buffer, buff_size, inner_iterator, ptr_out, count, stream));
278
+ free_temp_device(WP_CURRENT_CONTEXT, temp_buffer);
281
279
  }
282
280
 
283
281
  template <typename T>
@@ -3,8 +3,6 @@
3
3
  #include "warp.h"
4
4
  #include "cuda_util.h"
5
5
 
6
- #include "temp_buffer.h"
7
-
8
6
  #define THRUST_IGNORE_CUB_VERSION_CHECK
9
7
  #include <cub/device/device_run_length_encode.cuh>
10
8
 
@@ -15,11 +13,7 @@ void runlength_encode_device(int n,
15
13
  int *run_lengths,
16
14
  int *run_count)
17
15
  {
18
- void *context = cuda_context_get_current();
19
-
20
- TemporaryBuffer &cub_temp = g_temp_buffer_map[context];
21
-
22
- ContextGuard guard(context);
16
+ ContextGuard guard(cuda_context_get_current());
23
17
  cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
24
18
 
25
19
  size_t buff_size = 0;
@@ -27,11 +21,13 @@ void runlength_encode_device(int n,
27
21
  nullptr, buff_size, values, run_values, run_lengths, run_count,
28
22
  n, stream));
29
23
 
30
- cub_temp.ensure_fits(buff_size);
24
+ void* temp_buffer = alloc_temp_device(WP_CURRENT_CONTEXT, buff_size);
31
25
 
32
26
  check_cuda(cub::DeviceRunLengthEncode::Encode(
33
- cub_temp.buffer, buff_size, values, run_values, run_lengths, run_count,
27
+ temp_buffer, buff_size, values, run_values, run_lengths, run_count,
34
28
  n, stream));
29
+
30
+ free_temp_device(WP_CURRENT_CONTEXT, temp_buffer);
35
31
  }
36
32
 
37
33
  void runlength_encode_int_device(
@@ -47,4 +43,4 @@ void runlength_encode_int_device(
47
43
  reinterpret_cast<int *>(run_values),
48
44
  reinterpret_cast<int *>(run_lengths),
49
45
  reinterpret_cast<int *>(run_count));
50
- }
46
+ }
warp/native/scan.cu CHANGED
@@ -1,8 +1,6 @@
1
1
  #include "warp.h"
2
2
  #include "scan.h"
3
3
 
4
- #include "temp_buffer.h"
5
-
6
4
  #define THRUST_IGNORE_CUB_VERSION_CHECK
7
5
 
8
6
  #include <cub/device/device_scan.cuh>
@@ -10,29 +8,28 @@
10
8
  template<typename T>
11
9
  void scan_device(const T* values_in, T* values_out, int n, bool inclusive)
12
10
  {
13
- void *context = cuda_context_get_current();
14
- TemporaryBuffer &cub_temp = g_temp_buffer_map[context];
15
-
16
- ContextGuard guard(context);
11
+ ContextGuard guard(cuda_context_get_current());
17
12
 
18
13
  cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
19
14
 
20
15
  // compute temporary memory required
21
16
  size_t scan_temp_size;
22
17
  if (inclusive) {
23
- cub::DeviceScan::InclusiveSum(NULL, scan_temp_size, values_in, values_out, n);
18
+ check_cuda(cub::DeviceScan::InclusiveSum(NULL, scan_temp_size, values_in, values_out, n));
24
19
  } else {
25
- cub::DeviceScan::ExclusiveSum(NULL, scan_temp_size, values_in, values_out, n);
20
+ check_cuda(cub::DeviceScan::ExclusiveSum(NULL, scan_temp_size, values_in, values_out, n));
26
21
  }
27
22
 
28
- cub_temp.ensure_fits(scan_temp_size);
23
+ void* temp_buffer = alloc_temp_device(WP_CURRENT_CONTEXT, scan_temp_size);
29
24
 
30
25
  // scan
31
26
  if (inclusive) {
32
- cub::DeviceScan::InclusiveSum(cub_temp.buffer, scan_temp_size, values_in, values_out, n, (cudaStream_t)cuda_stream_get_current());
27
+ check_cuda(cub::DeviceScan::InclusiveSum(temp_buffer, scan_temp_size, values_in, values_out, n, stream));
33
28
  } else {
34
- cub::DeviceScan::ExclusiveSum(cub_temp.buffer, scan_temp_size, values_in, values_out, n, (cudaStream_t)cuda_stream_get_current());
29
+ check_cuda(cub::DeviceScan::ExclusiveSum(temp_buffer, scan_temp_size, values_in, values_out, n, stream));
35
30
  }
31
+
32
+ free_temp_device(WP_CURRENT_CONTEXT, temp_buffer);
36
33
  }
37
34
 
38
35
  template void scan_device(const int*, int*, int, bool);
warp/native/sparse.cpp CHANGED
@@ -179,10 +179,10 @@ void bsr_transpose_host(int rows_per_block, int cols_per_block, int row_count, i
179
179
  const int block_size = rows_per_block * cols_per_block;
180
180
 
181
181
  void (*block_transpose_func)(const T *, T *, int, int) = bsr_dyn_block_transpose<T>;
182
- switch (row_count)
182
+ switch (rows_per_block)
183
183
  {
184
184
  case 1:
185
- switch (col_count)
185
+ switch (cols_per_block)
186
186
  {
187
187
  case 1:
188
188
  block_transpose_func = bsr_fixed_block_transpose<1, 1, T>;
@@ -196,7 +196,7 @@ void bsr_transpose_host(int rows_per_block, int cols_per_block, int row_count, i
196
196
  }
197
197
  break;
198
198
  case 2:
199
- switch (col_count)
199
+ switch (cols_per_block)
200
200
  {
201
201
  case 1:
202
202
  block_transpose_func = bsr_fixed_block_transpose<2, 1, T>;
@@ -210,7 +210,7 @@ void bsr_transpose_host(int rows_per_block, int cols_per_block, int row_count, i
210
210
  }
211
211
  break;
212
212
  case 3:
213
- switch (col_count)
213
+ switch (cols_per_block)
214
214
  {
215
215
  case 1:
216
216
  block_transpose_func = bsr_fixed_block_transpose<3, 1, T>;