warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/types.py CHANGED
@@ -5,19 +5,17 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
+ from __future__ import annotations
9
+
10
+ import builtins
8
11
  import ctypes
9
12
  import hashlib
13
+ import inspect
10
14
  import struct
11
15
  import zlib
12
- import numpy as np
16
+ from typing import Any, Callable, Generic, List, Tuple, TypeVar, Union
13
17
 
14
- from typing import Any
15
- from typing import Tuple
16
- from typing import TypeVar
17
- from typing import Generic
18
- from typing import List
19
- from typing import Callable
20
- from typing import Union
18
+ import numpy as np
21
19
 
22
20
  import warp
23
21
 
@@ -54,12 +52,14 @@ def constant(x):
54
52
  global _constant_hash
55
53
 
56
54
  # hash the constant value
57
- if isinstance(x, int):
55
+ if isinstance(x, builtins.bool):
56
+ # This needs to come before the check for `int` since all boolean
57
+ # values are also instances of `int`.
58
+ _constant_hash.update(struct.pack("?", x))
59
+ elif isinstance(x, int):
58
60
  _constant_hash.update(struct.pack("<q", x))
59
61
  elif isinstance(x, float):
60
62
  _constant_hash.update(struct.pack("<d", x))
61
- elif isinstance(x, bool):
62
- _constant_hash.update(struct.pack("?", x))
63
63
  elif isinstance(x, float16):
64
64
  # float16 is a special case
65
65
  p = ctypes.pointer(ctypes.c_float(x.value))
@@ -75,6 +75,14 @@ def constant(x):
75
75
  return x
76
76
 
77
77
 
78
+ def float_to_half_bits(value):
79
+ return warp.context.runtime.core.float_to_half_bits(value)
80
+
81
+
82
+ def half_bits_to_float(value):
83
+ return warp.context.runtime.core.half_bits_to_float(value)
84
+
85
+
78
86
  # ----------------------
79
87
  # built-in types
80
88
 
@@ -98,19 +106,15 @@ def vector(length, dtype):
98
106
  _wp_generic_type_str_ = "vec_t"
99
107
  _wp_constructor_ = "vector"
100
108
 
101
- def __init__(self, *args):
102
- if self._wp_scalar_type_ == float16:
103
- # special case for float16 type: in this case, data is stored
104
- # as uint16 but it's actually half precision floating point
105
- # data. This means we need to convert each of the arguments
106
- # to uint16s containing half float bits before storing them in
107
- # the array:
108
- from warp.context import runtime
109
-
110
- scalar_value = runtime.core.float_to_half_bits
111
- else:
112
- scalar_value = lambda x: x
109
+ # special handling for float16 type: in this case, data is stored
110
+ # as uint16 but it's actually half precision floating point
111
+ # data. This means we need to convert each of the arguments
112
+ # to uint16s containing half float bits before storing them in
113
+ # the array:
114
+ scalar_import = float_to_half_bits if _wp_scalar_type_ == float16 else lambda x: x
115
+ scalar_export = half_bits_to_float if _wp_scalar_type_ == float16 else lambda x: x
113
116
 
117
+ def __init__(self, *args):
114
118
  num_args = len(args)
115
119
  if num_args == 0:
116
120
  super().__init__()
@@ -120,29 +124,99 @@ def vector(length, dtype):
120
124
  self.__init__(*args[0])
121
125
  else:
122
126
  # set all elements to the same value
123
- value = scalar_value(args[0])
127
+ value = vec_t.scalar_import(args[0])
124
128
  for i in range(self._length_):
125
129
  super().__setitem__(i, value)
126
130
  elif num_args == self._length_:
127
131
  # set all scalar elements
128
132
  for i in range(self._length_):
129
- super().__setitem__(i, scalar_value(args[i]))
133
+ super().__setitem__(i, vec_t.scalar_import(args[i]))
130
134
  else:
131
135
  raise ValueError(
132
136
  f"Invalid number of arguments in vector constructor, expected {self._length_} elements, got {num_args}"
133
137
  )
134
138
 
139
+ def __getitem__(self, key):
140
+ if isinstance(key, int):
141
+ return vec_t.scalar_export(super().__getitem__(key))
142
+ elif isinstance(key, slice):
143
+ if self._wp_scalar_type_ == float16:
144
+ return [vec_t.scalar_export(x) for x in super().__getitem__(key)]
145
+ else:
146
+ return super().__getitem__(key)
147
+ else:
148
+ raise KeyError(f"Invalid key {key}, expected int or slice")
149
+
150
+ def __setitem__(self, key, value):
151
+ if isinstance(key, int):
152
+ try:
153
+ return super().__setitem__(key, vec_t.scalar_import(value))
154
+ except (TypeError, ctypes.ArgumentError):
155
+ raise TypeError(
156
+ f"Expected to assign a `{self._wp_scalar_type_.__name__}` value "
157
+ f"but got `{type(value).__name__}` instead"
158
+ ) from None
159
+ elif isinstance(key, slice):
160
+ try:
161
+ iter(value)
162
+ except TypeError:
163
+ raise TypeError(
164
+ f"Expected to assign a slice from a sequence of values "
165
+ f"but got `{type(value).__name__}` instead"
166
+ ) from None
167
+
168
+ if self._wp_scalar_type_ == float16:
169
+ converted = []
170
+ try:
171
+ for x in value:
172
+ converted.append(vec_t.scalar_import(x))
173
+ except ctypes.ArgumentError:
174
+ raise TypeError(
175
+ f"Expected to assign a slice from a sequence of `float16` values "
176
+ f"but got `{type(x).__name__}` instead"
177
+ ) from None
178
+
179
+ value = converted
180
+
181
+ try:
182
+ return super().__setitem__(key, value)
183
+ except TypeError:
184
+ for x in value:
185
+ try:
186
+ self._type_(x)
187
+ except TypeError:
188
+ raise TypeError(
189
+ f"Expected to assign a slice from a sequence of `{self._wp_scalar_type_.__name__}` values "
190
+ f"but got `{type(x).__name__}` instead"
191
+ ) from None
192
+ else:
193
+ raise KeyError(f"Invalid key {key}, expected int or slice")
194
+
195
+ def __getattr__(self, name):
196
+ idx = "xyzw".find(name)
197
+ if idx != -1:
198
+ return self.__getitem__(idx)
199
+
200
+ return self.__getattribute__(name)
201
+
202
+ def __setattr__(self, name, value):
203
+ idx = "xyzw".find(name)
204
+ if idx != -1:
205
+ return self.__setitem__(idx, value)
206
+
207
+ return super().__setattr__(name, value)
208
+
135
209
  def __add__(self, y):
136
210
  return warp.add(self, y)
137
211
 
138
212
  def __radd__(self, y):
139
- return warp.add(self, y)
213
+ return warp.add(y, self)
140
214
 
141
215
  def __sub__(self, y):
142
216
  return warp.sub(self, y)
143
217
 
144
- def __rsub__(self, x):
145
- return warp.sub(x, self)
218
+ def __rsub__(self, y):
219
+ return warp.sub(y, self)
146
220
 
147
221
  def __mul__(self, y):
148
222
  return warp.mul(self, y)
@@ -150,17 +224,17 @@ def vector(length, dtype):
150
224
  def __rmul__(self, x):
151
225
  return warp.mul(x, self)
152
226
 
153
- def __div__(self, y):
227
+ def __truediv__(self, y):
154
228
  return warp.div(self, y)
155
229
 
156
- def __rdiv__(self, x):
230
+ def __rtruediv__(self, x):
157
231
  return warp.div(x, self)
158
232
 
159
- def __pos__(self, y):
160
- return warp.pos(self, y)
233
+ def __pos__(self):
234
+ return warp.pos(self)
161
235
 
162
- def __neg__(self, y):
163
- return warp.neg(self, y)
236
+ def __neg__(self):
237
+ return warp.neg(self)
164
238
 
165
239
  def __str__(self):
166
240
  return f"[{', '.join(map(str, self))}]"
@@ -171,6 +245,17 @@ def vector(length, dtype):
171
245
  return False
172
246
  return True
173
247
 
248
+ @classmethod
249
+ def from_ptr(cls, ptr):
250
+ if ptr:
251
+ # create a new vector instance and initialize the contents from the binary data
252
+ # this skips float16 conversions, assuming that float16 data is already encoded as uint16
253
+ value = cls()
254
+ ctypes.memmove(ctypes.byref(value), ptr, ctypes.sizeof(cls._type_) * cls._length_)
255
+ return value
256
+ else:
257
+ raise RuntimeError("NULL pointer exception")
258
+
174
259
  return vec_t
175
260
 
176
261
 
@@ -197,19 +282,15 @@ def matrix(shape, dtype):
197
282
 
198
283
  _wp_row_type_ = vector(0 if shape[1] == Any else shape[1], dtype)
199
284
 
200
- def __init__(self, *args):
201
- if self._wp_scalar_type_ == float16:
202
- # special case for float16 type: in this case, data is stored
203
- # as uint16 but it's actually half precision floating point
204
- # data. This means we need to convert each of the arguments
205
- # to uint16s containing half float bits before storing them in
206
- # the array:
207
- from warp.context import runtime
208
-
209
- scalar_value = runtime.core.float_to_half_bits
210
- else:
211
- scalar_value = lambda x: x
285
+ # special handling for float16 type: in this case, data is stored
286
+ # as uint16 but it's actually half precision floating point
287
+ # data. This means we need to convert each of the arguments
288
+ # to uint16s containing half float bits before storing them in
289
+ # the array:
290
+ scalar_import = float_to_half_bits if _wp_scalar_type_ == float16 else lambda x: x
291
+ scalar_export = half_bits_to_float if _wp_scalar_type_ == float16 else lambda x: x
212
292
 
293
+ def __init__(self, *args):
213
294
  num_args = len(args)
214
295
  if num_args == 0:
215
296
  super().__init__()
@@ -219,13 +300,13 @@ def matrix(shape, dtype):
219
300
  self.__init__(*args[0])
220
301
  else:
221
302
  # set all elements to the same value
222
- value = scalar_value(args[0])
303
+ value = mat_t.scalar_import(args[0])
223
304
  for i in range(self._length_):
224
305
  super().__setitem__(i, value)
225
306
  elif num_args == self._length_:
226
307
  # set all scalar elements
227
308
  for i in range(self._length_):
228
- super().__setitem__(i, scalar_value(args[i]))
309
+ super().__setitem__(i, mat_t.scalar_import(args[i]))
229
310
  elif num_args == self._shape_[0]:
230
311
  # row vectors
231
312
  for i, row in enumerate(args):
@@ -235,7 +316,7 @@ def matrix(shape, dtype):
235
316
  )
236
317
  offset = i * self._shape_[1]
237
318
  for i in range(self._shape_[1]):
238
- super().__setitem__(offset + i, scalar_value(row[i]))
319
+ super().__setitem__(offset + i, mat_t.scalar_import(row[i]))
239
320
  else:
240
321
  raise ValueError(
241
322
  f"Invalid number of arguments in matrix constructor, expected {self._length_} elements, got {num_args}"
@@ -245,13 +326,13 @@ def matrix(shape, dtype):
245
326
  return warp.add(self, y)
246
327
 
247
328
  def __radd__(self, y):
248
- return warp.add(self, y)
329
+ return warp.add(y, self)
249
330
 
250
331
  def __sub__(self, y):
251
332
  return warp.sub(self, y)
252
333
 
253
- def __rsub__(self, x):
254
- return warp.sub(x, self)
334
+ def __rsub__(self, y):
335
+ return warp.sub(y, self)
255
336
 
256
337
  def __mul__(self, y):
257
338
  return warp.mul(self, y)
@@ -265,17 +346,17 @@ def matrix(shape, dtype):
265
346
  def __rmatmul__(self, x):
266
347
  return warp.mul(x, self)
267
348
 
268
- def __div__(self, y):
349
+ def __truediv__(self, y):
269
350
  return warp.div(self, y)
270
351
 
271
- def __rdiv__(self, x):
352
+ def __rtruediv__(self, x):
272
353
  return warp.div(x, self)
273
354
 
274
- def __pos__(self, y):
275
- return warp.pos(self, y)
355
+ def __pos__(self):
356
+ return warp.pos(self)
276
357
 
277
- def __neg__(self, y):
278
- return warp.neg(self, y)
358
+ def __neg__(self):
359
+ return warp.neg(self)
279
360
 
280
361
  def __str__(self):
281
362
  row_str = []
@@ -286,48 +367,96 @@ def matrix(shape, dtype):
286
367
  return "[" + ",\n ".join(row_str) + "]"
287
368
 
288
369
  def __eq__(self, other):
289
- for i in range(self._length_):
290
- if self[i] != other[i]:
291
- return False
370
+ for i in range(self._shape_[0]):
371
+ for j in range(self._shape_[1]):
372
+ if self[i][j] != other[i][j]:
373
+ return False
292
374
  return True
293
375
 
294
-
295
376
  def get_row(self, r):
296
377
  if r < 0 or r >= self._shape_[0]:
297
378
  raise IndexError("Invalid row index")
298
379
  row_start = r * self._shape_[1]
299
380
  row_end = row_start + self._shape_[1]
300
- return self._wp_row_type_(*super().__getitem__(slice(row_start, row_end)))
381
+ row_data = super().__getitem__(slice(row_start, row_end))
382
+ if self._wp_scalar_type_ == float16:
383
+ return self._wp_row_type_(*[mat_t.scalar_export(x) for x in row_data])
384
+ else:
385
+ return self._wp_row_type_(row_data)
301
386
 
302
387
  def set_row(self, r, v):
303
388
  if r < 0 or r >= self._shape_[0]:
304
389
  raise IndexError("Invalid row index")
390
+ try:
391
+ iter(v)
392
+ except TypeError:
393
+ raise TypeError(
394
+ f"Expected to assign a slice from a sequence of values "
395
+ f"but got `{type(v).__name__}` instead"
396
+ ) from None
397
+
305
398
  row_start = r * self._shape_[1]
306
399
  row_end = row_start + self._shape_[1]
400
+ if self._wp_scalar_type_ == float16:
401
+ converted = []
402
+ try:
403
+ for x in v:
404
+ converted.append(mat_t.scalar_import(x))
405
+ except ctypes.ArgumentError:
406
+ raise TypeError(
407
+ f"Expected to assign a slice from a sequence of `float16` values "
408
+ f"but got `{type(x).__name__}` instead"
409
+ ) from None
410
+
411
+ v = converted
307
412
  super().__setitem__(slice(row_start, row_end), v)
308
413
 
309
414
  def __getitem__(self, key):
310
415
  if isinstance(key, Tuple):
311
416
  # element indexing m[i,j]
312
- return super().__getitem__(key[1] * self._shape_[0] + key[1])
417
+ if len(key) != 2:
418
+ raise KeyError(f"Invalid key, expected one or two indices, got {len(key)}")
419
+ if any(isinstance(x, slice) for x in key):
420
+ raise KeyError(f"Slices are not supported when indexing matrices using the `m[i, j]` notation")
421
+ return mat_t.scalar_export(super().__getitem__(key[0] * self._shape_[1] + key[1]))
313
422
  elif isinstance(key, int):
314
423
  # row vector indexing m[r]
315
424
  return self.get_row(key)
316
425
  else:
317
- # slice etc.
318
- return super().__getitem__(key)
426
+ raise KeyError(f"Invalid key {key}, expected int or pair of ints")
319
427
 
320
428
  def __setitem__(self, key, value):
321
429
  if isinstance(key, Tuple):
322
430
  # element indexing m[i,j] = x
323
- return super().__setitem__(key[1] * self._shape_[0] + key[1], value)
431
+ if len(key) != 2:
432
+ raise KeyError(f"Invalid key, expected one or two indices, got {len(key)}")
433
+ if any(isinstance(x, slice) for x in key):
434
+ raise KeyError(f"Slices are not supported when indexing matrices using the `m[i, j]` notation")
435
+ try:
436
+ return super().__setitem__(key[0] * self._shape_[1] + key[1], mat_t.scalar_import(value))
437
+ except (TypeError, ctypes.ArgumentError):
438
+ raise TypeError(
439
+ f"Expected to assign a `{self._wp_scalar_type_.__name__}` value "
440
+ f"but got `{type(value).__name__}` instead"
441
+ ) from None
324
442
  elif isinstance(key, int):
325
443
  # row vector indexing m[r] = v
326
- self.set_row(key, value)
444
+ return self.set_row(key, value)
445
+ elif isinstance(key, slice):
446
+ raise KeyError(f"Slices are not supported when indexing matrices using the `m[start:end]` notation")
447
+ else:
448
+ raise KeyError(f"Invalid key {key}, expected int or pair of ints")
449
+
450
+ @classmethod
451
+ def from_ptr(cls, ptr):
452
+ if ptr:
453
+ # create a new matrix instance and initialize the contents from the binary data
454
+ # this skips float16 conversions, assuming that float16 data is already encoded as uint16
455
+ value = cls()
456
+ ctypes.memmove(ctypes.byref(value), ptr, ctypes.sizeof(cls._type_) * cls._length_)
327
457
  return value
328
458
  else:
329
- # slice etc.
330
- return super().__setitem__(key, value)
459
+ raise RuntimeError("NULL pointer exception")
331
460
 
332
461
  return mat_t
333
462
 
@@ -337,6 +466,23 @@ class void:
337
466
  pass
338
467
 
339
468
 
469
+ class bool:
470
+ _length_ = 1
471
+ _type_ = ctypes.c_bool
472
+
473
+ def __init__(self, x=False):
474
+ self.value = x
475
+
476
+ def __bool__(self) -> bool:
477
+ return self.value != 0
478
+
479
+ def __float__(self) -> float:
480
+ return float(self.value != 0)
481
+
482
+ def __int__(self) -> int:
483
+ return int(self.value != 0)
484
+
485
+
340
486
  class float16:
341
487
  _length_ = 1
342
488
  _type_ = ctypes.c_uint16
@@ -344,6 +490,15 @@ class float16:
344
490
  def __init__(self, x=0.0):
345
491
  self.value = x
346
492
 
493
+ def __bool__(self) -> bool:
494
+ return self.value != 0.0
495
+
496
+ def __float__(self) -> float:
497
+ return float(self.value)
498
+
499
+ def __int__(self) -> int:
500
+ return int(self.value)
501
+
347
502
 
348
503
  class float32:
349
504
  _length_ = 1
@@ -352,6 +507,15 @@ class float32:
352
507
  def __init__(self, x=0.0):
353
508
  self.value = x
354
509
 
510
+ def __bool__(self) -> bool:
511
+ return self.value != 0.0
512
+
513
+ def __float__(self) -> float:
514
+ return float(self.value)
515
+
516
+ def __int__(self) -> int:
517
+ return int(self.value)
518
+
355
519
 
356
520
  class float64:
357
521
  _length_ = 1
@@ -360,6 +524,15 @@ class float64:
360
524
  def __init__(self, x=0.0):
361
525
  self.value = x
362
526
 
527
+ def __bool__(self) -> bool:
528
+ return self.value != 0.0
529
+
530
+ def __float__(self) -> float:
531
+ return float(self.value)
532
+
533
+ def __int__(self) -> int:
534
+ return int(self.value)
535
+
363
536
 
364
537
  class int8:
365
538
  _length_ = 1
@@ -368,6 +541,18 @@ class int8:
368
541
  def __init__(self, x=0):
369
542
  self.value = x
370
543
 
544
+ def __bool__(self) -> bool:
545
+ return self.value != 0
546
+
547
+ def __float__(self) -> float:
548
+ return float(self.value)
549
+
550
+ def __int__(self) -> int:
551
+ return int(self.value)
552
+
553
+ def __index__(self) -> int:
554
+ return int(self.value)
555
+
371
556
 
372
557
  class uint8:
373
558
  _length_ = 1
@@ -376,6 +561,18 @@ class uint8:
376
561
  def __init__(self, x=0):
377
562
  self.value = x
378
563
 
564
+ def __bool__(self) -> bool:
565
+ return self.value != 0
566
+
567
+ def __float__(self) -> float:
568
+ return float(self.value)
569
+
570
+ def __int__(self) -> int:
571
+ return int(self.value)
572
+
573
+ def __index__(self) -> int:
574
+ return int(self.value)
575
+
379
576
 
380
577
  class int16:
381
578
  _length_ = 1
@@ -384,6 +581,18 @@ class int16:
384
581
  def __init__(self, x=0):
385
582
  self.value = x
386
583
 
584
+ def __bool__(self) -> bool:
585
+ return self.value != 0
586
+
587
+ def __float__(self) -> float:
588
+ return float(self.value)
589
+
590
+ def __int__(self) -> int:
591
+ return int(self.value)
592
+
593
+ def __index__(self) -> int:
594
+ return int(self.value)
595
+
387
596
 
388
597
  class uint16:
389
598
  _length_ = 1
@@ -392,6 +601,18 @@ class uint16:
392
601
  def __init__(self, x=0):
393
602
  self.value = x
394
603
 
604
+ def __bool__(self) -> bool:
605
+ return self.value != 0
606
+
607
+ def __float__(self) -> float:
608
+ return float(self.value)
609
+
610
+ def __int__(self) -> int:
611
+ return int(self.value)
612
+
613
+ def __index__(self) -> int:
614
+ return int(self.value)
615
+
395
616
 
396
617
  class int32:
397
618
  _length_ = 1
@@ -400,6 +621,18 @@ class int32:
400
621
  def __init__(self, x=0):
401
622
  self.value = x
402
623
 
624
+ def __bool__(self) -> bool:
625
+ return self.value != 0
626
+
627
+ def __float__(self) -> float:
628
+ return float(self.value)
629
+
630
+ def __int__(self) -> int:
631
+ return int(self.value)
632
+
633
+ def __index__(self) -> int:
634
+ return int(self.value)
635
+
403
636
 
404
637
  class uint32:
405
638
  _length_ = 1
@@ -408,6 +641,18 @@ class uint32:
408
641
  def __init__(self, x=0):
409
642
  self.value = x
410
643
 
644
+ def __bool__(self) -> bool:
645
+ return self.value != 0
646
+
647
+ def __float__(self) -> float:
648
+ return float(self.value)
649
+
650
+ def __int__(self) -> int:
651
+ return int(self.value)
652
+
653
+ def __index__(self) -> int:
654
+ return int(self.value)
655
+
411
656
 
412
657
  class int64:
413
658
  _length_ = 1
@@ -416,6 +661,18 @@ class int64:
416
661
  def __init__(self, x=0):
417
662
  self.value = x
418
663
 
664
+ def __bool__(self) -> bool:
665
+ return self.value != 0
666
+
667
+ def __float__(self) -> float:
668
+ return float(self.value)
669
+
670
+ def __int__(self) -> int:
671
+ return int(self.value)
672
+
673
+ def __index__(self) -> int:
674
+ return int(self.value)
675
+
419
676
 
420
677
  class uint64:
421
678
  _length_ = 1
@@ -424,6 +681,18 @@ class uint64:
424
681
  def __init__(self, x=0):
425
682
  self.value = x
426
683
 
684
+ def __bool__(self) -> bool:
685
+ return self.value != 0
686
+
687
+ def __float__(self) -> float:
688
+ return float(self.value)
689
+
690
+ def __int__(self) -> int:
691
+ return int(self.value)
692
+
693
+ def __index__(self) -> int:
694
+ return int(self.value)
695
+
427
696
 
428
697
  def quaternion(dtype=Any):
429
698
  class quat_t(vector(length=4, dtype=dtype)):
@@ -453,23 +722,63 @@ class quatd(quaternion(dtype=float64)):
453
722
 
454
723
  def transformation(dtype=Any):
455
724
  class transform_t(vector(length=7, dtype=dtype)):
725
+ _wp_init_from_components_sig_ = inspect.Signature(
726
+ (
727
+ inspect.Parameter(
728
+ "p",
729
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
730
+ default=(0.0, 0.0, 0.0),
731
+ ),
732
+ inspect.Parameter(
733
+ "q",
734
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
735
+ default=(0.0, 0.0, 0.0, 1.0),
736
+ ),
737
+ ),
738
+ )
456
739
  _wp_type_params_ = [dtype]
457
740
  _wp_generic_type_str_ = "transform_t"
458
741
  _wp_constructor_ = "transformation"
459
742
 
460
- def __init__(self, p=(0.0, 0.0, 0.0), q=(0.0, 0.0, 0.0, 1.0)):
461
- super().__init__()
743
+ def __init__(self, *args, **kwargs):
744
+ if len(args) == 1 and len(kwargs) == 0:
745
+ if getattr(args[0], "_wp_generic_type_str_") == self._wp_generic_type_str_:
746
+ # Copy constructor.
747
+ super().__init__(*args[0])
748
+ return
749
+
750
+ try:
751
+ # For backward compatibility, try to check if the arguments
752
+ # match the original signature that'd allow initializing
753
+ # the `p` and `q` components separately.
754
+ bound_args = self._wp_init_from_components_sig_.bind(*args, **kwargs)
755
+ bound_args.apply_defaults()
756
+ p, q = bound_args.args
757
+ except (TypeError, ValueError):
758
+ # Fallback to the vector's constructor.
759
+ super().__init__(*args)
760
+ return
761
+
762
+ # Even if the arguments match the original “from components”
763
+ # signature, we still need to make sure that they represent
764
+ # sequences that can be unpacked.
765
+ if hasattr(p, "__len__") and hasattr(q, "__len__"):
766
+ # Initialize from the `p` and `q` components.
767
+ super().__init__()
768
+ self[0:3] = vector(length=3, dtype=dtype)(*p)
769
+ self[3:7] = quaternion(dtype=dtype)(*q)
770
+ return
462
771
 
463
- self[0:3] = vector(length=3, dtype=dtype)(*p)
464
- self[3:7] = quaternion(dtype=dtype)(*q)
772
+ # Fallback to the vector's constructor.
773
+ super().__init__(*args)
465
774
 
466
775
  @property
467
776
  def p(self):
468
- return self[0:3]
777
+ return vec3(self[0:3])
469
778
 
470
779
  @property
471
780
  def q(self):
472
- return self[3:7]
781
+ return quat(self[3:7])
473
782
 
474
783
  return transform_t
475
784
 
@@ -753,6 +1062,7 @@ vector_types = [
753
1062
  ]
754
1063
 
755
1064
  np_dtype_to_warp_type = {
1065
+ np.dtype(np.bool_): bool,
756
1066
  np.dtype(np.int8): int8,
757
1067
  np.dtype(np.uint8): uint8,
758
1068
  np.dtype(np.int16): int16,
@@ -768,6 +1078,21 @@ np_dtype_to_warp_type = {
768
1078
  np.dtype(np.float64): float64,
769
1079
  }
770
1080
 
1081
+ warp_type_to_np_dtype = {
1082
+ bool: np.bool_,
1083
+ int8: np.int8,
1084
+ int16: np.int16,
1085
+ int32: np.int32,
1086
+ int64: np.int64,
1087
+ uint8: np.uint8,
1088
+ uint16: np.uint16,
1089
+ uint32: np.uint32,
1090
+ uint64: np.uint64,
1091
+ float16: np.float16,
1092
+ float32: np.float32,
1093
+ float64: np.float64,
1094
+ }
1095
+
771
1096
 
772
1097
  # represent a Python range iterator
773
1098
  class range_t:
@@ -777,18 +1102,21 @@ class range_t:
777
1102
 
778
1103
  # definition just for kernel type (cannot be a parameter), see bvh.h
779
1104
  class bvh_query_t:
1105
+ """Object used to track state during BVH traversal."""
780
1106
  def __init__(self):
781
1107
  pass
782
1108
 
783
1109
 
784
1110
  # definition just for kernel type (cannot be a parameter), see mesh.h
785
1111
  class mesh_query_aabb_t:
1112
+ """Object used to track state during mesh traversal."""
786
1113
  def __init__(self):
787
1114
  pass
788
1115
 
789
1116
 
790
1117
  # definition just for kernel type (cannot be a parameter), see hash_grid.h
791
1118
  class hash_grid_query_t:
1119
+ """Object used to track state during neighbor traversal."""
792
1120
  def __init__(self):
793
1121
  pass
794
1122
 
@@ -800,6 +1128,8 @@ LAUNCH_MAX_DIMS = 4
800
1128
  # must match array.h
801
1129
  ARRAY_TYPE_REGULAR = 0
802
1130
  ARRAY_TYPE_INDEXED = 1
1131
+ ARRAY_TYPE_FABRIC = 2
1132
+ ARRAY_TYPE_FABRIC_INDEXED = 3
803
1133
 
804
1134
 
805
1135
  # represents bounds for kernel launch (number of threads across multiple dimensions)
@@ -851,6 +1181,30 @@ class array_t(ctypes.Structure):
851
1181
  self.shape[i] = shape[i]
852
1182
  self.strides[i] = strides[i]
853
1183
 
1184
+ # structured type description used when array_t is packed in a struct and shared via numpy structured array.
1185
+ @classmethod
1186
+ def numpy_dtype(cls):
1187
+ return cls._numpy_dtype_
1188
+
1189
+ # structured value used when array_t is packed in a struct and shared via a numpy structured array
1190
+ def numpy_value(self):
1191
+ return (self.data, self.grad, list(self.shape), list(self.strides), self.ndim)
1192
+
1193
+
1194
+ # NOTE: must match array_t._fields_
1195
+ array_t._numpy_dtype_ = {
1196
+ "names": ["data", "grad", "shape", "strides", "ndim"],
1197
+ "formats": ["u8", "u8", f"{ARRAY_MAX_DIMS}i4", f"{ARRAY_MAX_DIMS}i4", "i4"],
1198
+ "offsets": [
1199
+ array_t.data.offset,
1200
+ array_t.grad.offset,
1201
+ array_t.shape.offset,
1202
+ array_t.strides.offset,
1203
+ array_t.ndim.offset,
1204
+ ],
1205
+ "itemsize": ctypes.sizeof(array_t),
1206
+ }
1207
+
854
1208
 
855
1209
  class indexedarray_t(ctypes.Structure):
856
1210
  _fields_ = [
@@ -892,16 +1246,20 @@ def type_length(dtype):
892
1246
  return dtype._length_
893
1247
 
894
1248
 
1249
+ def type_scalar_type(dtype):
1250
+ return getattr(dtype, "_wp_scalar_type_", dtype)
1251
+
1252
+
895
1253
  def type_size_in_bytes(dtype):
896
1254
  if dtype.__module__ == "ctypes":
897
1255
  return ctypes.sizeof(dtype)
898
- elif type_is_struct(dtype):
1256
+ elif isinstance(dtype, warp.codegen.Struct):
899
1257
  return ctypes.sizeof(dtype.ctype)
900
1258
  elif dtype == float or dtype == int:
901
1259
  return 4
902
1260
  elif hasattr(dtype, "_type_"):
903
1261
  return getattr(dtype, "_length_", 1) * ctypes.sizeof(dtype._type_)
904
-
1262
+
905
1263
  else:
906
1264
  return 0
907
1265
 
@@ -916,9 +1274,9 @@ def type_to_warp(dtype):
916
1274
 
917
1275
 
918
1276
  def type_typestr(dtype):
919
- from warp.codegen import Struct
920
-
921
- if dtype == float16:
1277
+ if dtype == bool:
1278
+ return "?"
1279
+ elif dtype == float16:
922
1280
  return "<f2"
923
1281
  elif dtype == float32:
924
1282
  return "<f4"
@@ -940,8 +1298,8 @@ def type_typestr(dtype):
940
1298
  return "<i8"
941
1299
  elif dtype == uint64:
942
1300
  return "<u8"
943
- elif isinstance(dtype, Struct):
944
- return f"|V{ctypes.sizeof(dtype.ctype)}"
1301
+ elif isinstance(dtype, warp.codegen.Struct):
1302
+ return f"|V{ctypes.sizeof(dtype.ctype)}"
945
1303
  elif issubclass(dtype, ctypes.Array):
946
1304
  return type_typestr(dtype._wp_scalar_type_)
947
1305
  else:
@@ -954,9 +1312,16 @@ def type_repr(t):
954
1312
  return str(f"array(ndim={t.ndim}, dtype={t.dtype})")
955
1313
  if type_is_vector(t):
956
1314
  return str(f"vector(length={t._shape_[0]}, dtype={t._wp_scalar_type_})")
957
- elif type_is_matrix(t):
1315
+ if type_is_matrix(t):
958
1316
  return str(f"matrix(shape=({t._shape_[0]}, {t._shape_[1]}), dtype={t._wp_scalar_type_})")
959
- else:
1317
+ if isinstance(t, warp.codegen.Struct):
1318
+ return type_repr(t.cls)
1319
+ if t in scalar_types:
1320
+ return t.__name__
1321
+
1322
+ try:
1323
+ return t.__module__ + "." + t.__qualname__
1324
+ except AttributeError:
960
1325
  return str(t)
961
1326
 
962
1327
 
@@ -974,14 +1339,6 @@ def type_is_float(t):
974
1339
  return t in float_types
975
1340
 
976
1341
 
977
- def type_is_struct(dtype):
978
- from warp.codegen import Struct
979
-
980
- if isinstance(dtype, Struct):
981
- return True
982
- else:
983
- return False
984
-
985
1342
  # returns True if the passed *type* is a vector
986
1343
  def type_is_vector(t):
987
1344
  if hasattr(t, "_wp_generic_type_str_") and t._wp_generic_type_str_ == "vec_t":
@@ -1000,7 +1357,7 @@ def type_is_matrix(t):
1000
1357
 
1001
1358
  # returns true for all value types (int, float, bool, scalars, vectors, matrices)
1002
1359
  def type_is_value(x):
1003
- if (x == int) or (x == float) or (x == bool) or (x in scalar_types) or issubclass(x, ctypes.Array):
1360
+ if (x == int) or (x == float) or (x == builtins.bool) or (x in scalar_types) or issubclass(x, ctypes.Array):
1004
1361
  return True
1005
1362
  else:
1006
1363
  return False
@@ -1028,14 +1385,16 @@ def types_equal(a, b, match_generic=False):
1028
1385
  # convert to canonical types
1029
1386
  if a == float:
1030
1387
  a = float32
1031
- if a == int:
1388
+ elif a == int:
1032
1389
  a = int32
1033
1390
 
1034
1391
  if b == float:
1035
1392
  b = float32
1036
- if b == int:
1393
+ elif b == int:
1037
1394
  b = int32
1038
1395
 
1396
+ compatible_bool_types = [builtins.bool, bool]
1397
+
1039
1398
  def are_equal(p1, p2):
1040
1399
  if match_generic:
1041
1400
  if p1 == Any or p2 == Any:
@@ -1052,7 +1411,22 @@ def types_equal(a, b, match_generic=False):
1052
1411
  return True
1053
1412
  if p1 == Float and p2 == Float:
1054
1413
  return True
1055
- return p1 == p2
1414
+
1415
+ # convert to canonical types
1416
+ if p1 == float:
1417
+ p1 = float32
1418
+ elif p1 == int:
1419
+ p1 = int32
1420
+
1421
+ if p2 == float:
1422
+ p2 = float32
1423
+ elif b == int:
1424
+ p2 = int32
1425
+
1426
+ if p1 in compatible_bool_types and p2 in compatible_bool_types:
1427
+ return True
1428
+ else:
1429
+ return p1 == p2
1056
1430
 
1057
1431
  if (
1058
1432
  hasattr(a, "_wp_generic_type_str_")
@@ -1060,9 +1434,7 @@ def types_equal(a, b, match_generic=False):
1060
1434
  and a._wp_generic_type_str_ == b._wp_generic_type_str_
1061
1435
  ):
1062
1436
  return all([are_equal(p1, p2) for p1, p2 in zip(a._wp_type_params_, b._wp_type_params_)])
1063
- if isinstance(a, array) and isinstance(b, array):
1064
- return True
1065
- if isinstance(a, indexedarray) and isinstance(b, indexedarray):
1437
+ if is_array(a) and type(a) is type(b):
1066
1438
  return True
1067
1439
  else:
1068
1440
  return are_equal(a, b)
@@ -1093,18 +1465,18 @@ class array(Array):
1093
1465
  dtype: DType = Any,
1094
1466
  shape=None,
1095
1467
  strides=None,
1096
- length=0,
1468
+ length=None,
1097
1469
  ptr=None,
1098
- grad_ptr=None,
1099
- capacity=0,
1470
+ capacity=None,
1100
1471
  device=None,
1472
+ pinned=False,
1101
1473
  copy=True,
1102
- owner=True,
1474
+ owner=True, # TODO: replace with deleter=None
1103
1475
  ndim=None,
1476
+ grad=None,
1104
1477
  requires_grad=False,
1105
- pinned=False,
1106
1478
  ):
1107
- """Constructs a new Warp array object from existing data.
1479
+ """Constructs a new Warp array object
1108
1480
 
1109
1481
  When the ``data`` argument is a valid list, tuple, or ndarray the array will be constructed from this object's data.
1110
1482
  For objects that are not stored sequentially in memory (e.g.: a list), then the data will first
@@ -1115,39 +1487,38 @@ class array(Array):
1115
1487
  allocation should reside on the same device given by the device argument, and the user should set the length
1116
1488
  and dtype parameter appropriately.
1117
1489
 
1490
+ If neither ``data`` nor ``ptr`` are specified, the ``shape`` or ``length`` arguments are checked next.
1491
+ This construction path can be used to create new uninitialized arrays, but users are encouraged to call
1492
+ ``wp.empty()``, ``wp.zeros()``, or ``wp.full()`` instead to create new arrays.
1493
+
1494
+ If none of the above arguments are specified, a simple type annotation is constructed. This is used when annotating
1495
+ kernel arguments or struct members (e.g.,``arr: wp.array(dtype=float)``). In this case, only ``dtype`` and ``ndim``
1496
+ are taken into account and no memory is allocated for the array.
1497
+
1118
1498
  Args:
1119
1499
  data (Union[list, tuple, ndarray]) An object to construct the array from, can be a Tuple, List, or generally any type convertible to an np.array
1120
1500
  dtype (Union): One of the built-in types, e.g.: :class:`warp.mat33`, if dtype is Any and data an ndarray then it will be inferred from the array data type
1121
1501
  shape (tuple): Dimensions of the array
1122
1502
  strides (tuple): Number of bytes in each dimension between successive elements of the array
1123
- length (int): Number of elements (rows) of the data type (deprecated, users should use `shape` argument)
1503
+ length (int): Number of elements of the data type (deprecated, users should use `shape` argument)
1124
1504
  ptr (uint64): Address of an external memory address to alias (data should be None)
1125
- grad_ptr (uint64): Address of an external memory address to alias for the gradient array
1126
1505
  capacity (int): Maximum size in bytes of the ptr allocation (data should be None)
1127
1506
  device (Devicelike): Device the array lives on
1128
1507
  copy (bool): Whether the incoming data will be copied or aliased, this is only possible when the incoming `data` already lives on the device specified and types match
1129
1508
  owner (bool): Should the array object try to deallocate memory when it is deleted
1130
1509
  requires_grad (bool): Whether or not gradients will be tracked for this array, see :class:`warp.Tape` for details
1510
+ grad (array): The gradient array to use
1131
1511
  pinned (bool): Whether to allocate pinned host memory, which allows asynchronous host-device transfers (only applicable with device="cpu")
1132
1512
 
1133
1513
  """
1134
1514
 
1135
1515
  self.owner = False
1136
-
1137
- # convert shape to Tuple
1138
- if shape is None:
1139
- shape = tuple(length for _ in range(ndim or 1))
1140
- elif isinstance(shape, int):
1141
- shape = (shape,)
1142
- elif isinstance(shape, List):
1143
- shape = tuple(shape)
1144
-
1145
- self.shape = shape
1146
-
1147
- if len(shape) > ARRAY_MAX_DIMS:
1148
- raise RuntimeError(
1149
- f"Arrays may only have {ARRAY_MAX_DIMS} dimensions maximum, trying to create array with {len(shape)} dims."
1150
- )
1516
+ self.ctype = None
1517
+ self._requires_grad = False
1518
+ self._grad = None
1519
+ # __array_interface__ or __cuda_array_interface__, evaluated lazily and cached
1520
+ self._array_interface = None
1521
+ self.is_transposed = False
1151
1522
 
1152
1523
  # canonicalize dtype
1153
1524
  if dtype == int:
@@ -1155,20 +1526,78 @@ class array(Array):
1155
1526
  elif dtype == float:
1156
1527
  dtype = float32
1157
1528
 
1158
- if data is not None or ptr is not None:
1159
- from .context import runtime
1160
-
1161
- device = runtime.get_device(device)
1529
+ # convert shape to tuple (or leave shape=None if neither shape nor length were specified)
1530
+ if shape is not None:
1531
+ if isinstance(shape, int):
1532
+ shape = (shape,)
1533
+ else:
1534
+ shape = tuple(shape)
1535
+ if len(shape) > ARRAY_MAX_DIMS:
1536
+ raise RuntimeError(
1537
+ f"Failed to create array with shape {shape}, the maximum number of dimensions is {ARRAY_MAX_DIMS}"
1538
+ )
1539
+ elif length is not None:
1540
+ # backward compatibility
1541
+ shape = (length,)
1162
1542
 
1543
+ # determine the construction path from the given arguments
1163
1544
  if data is not None:
1164
- if device.is_capturing:
1165
- raise RuntimeError(f"Cannot allocate memory on device {device} while graph capture is active")
1166
-
1545
+ # data or ptr, not both
1167
1546
  if ptr is not None:
1168
- # data or ptr, not both
1169
- raise RuntimeError("Should only construct arrays with either data or ptr arguments, not both")
1547
+ raise RuntimeError("Can only construct arrays with either `data` or `ptr` arguments, not both")
1548
+ self._init_from_data(data, dtype, shape, device, copy, pinned)
1549
+ elif ptr is not None:
1550
+ self._init_from_ptr(ptr, dtype, shape, strides, capacity, device, owner, pinned)
1551
+ elif shape is not None:
1552
+ self._init_new(dtype, shape, strides, device, pinned)
1553
+ else:
1554
+ self._init_annotation(dtype, ndim or 1)
1170
1555
 
1171
- if isinstance(dtype, warp.codegen.Struct):
1556
+ # initialize gradient, if needed
1557
+ if self.device is not None:
1558
+ if grad is not None:
1559
+ # this will also check whether the gradient array is compatible
1560
+ self.grad = grad
1561
+ else:
1562
+ # allocate gradient if needed
1563
+ self._requires_grad = requires_grad
1564
+ if requires_grad:
1565
+ with warp.ScopedStream(self.device.null_stream):
1566
+ self._alloc_grad()
1567
+
1568
+ def _init_from_data(self, data, dtype, shape, device, copy, pinned):
1569
+ if not hasattr(data, "__len__"):
1570
+ raise RuntimeError(f"Data must be a sequence or array, got scalar {data}")
1571
+
1572
+ if hasattr(dtype, "_wp_scalar_type_"):
1573
+ dtype_shape = dtype._shape_
1574
+ dtype_ndim = len(dtype_shape)
1575
+ scalar_dtype = dtype._wp_scalar_type_
1576
+ else:
1577
+ dtype_shape = ()
1578
+ dtype_ndim = 0
1579
+ scalar_dtype = dtype
1580
+
1581
+ # convert input data to ndarray (handles lists, tuples, etc.) and determine dtype
1582
+ if dtype == Any:
1583
+ # infer dtype from data
1584
+ try:
1585
+ arr = np.array(data, copy=False, ndmin=1)
1586
+ except Exception as e:
1587
+ raise RuntimeError(f"Failed to convert input data to an array: {e}")
1588
+ dtype = np_dtype_to_warp_type.get(arr.dtype)
1589
+ if dtype is None:
1590
+ raise RuntimeError(f"Unsupported input data dtype: {arr.dtype}")
1591
+ elif isinstance(dtype, warp.codegen.Struct):
1592
+ if isinstance(data, np.ndarray):
1593
+ # construct from numpy structured array
1594
+ if data.dtype != dtype.numpy_dtype():
1595
+ raise RuntimeError(
1596
+ f"Invalid source data type for array of structs, expected {dtype.numpy_dtype()}, got {data.dtype}"
1597
+ )
1598
+ arr = data
1599
+ elif isinstance(data, (list, tuple)):
1600
+ # construct from a sequence of structs
1172
1601
  try:
1173
1602
  # convert each struct instance to its corresponding ctype
1174
1603
  ctype_list = [v.__ctype__() for v in data]
@@ -1176,156 +1605,227 @@ class array(Array):
1176
1605
  ctype_arr = (dtype.ctype * len(ctype_list))(*ctype_list)
1177
1606
  # convert to numpy
1178
1607
  arr = np.frombuffer(ctype_arr, dtype=dtype.ctype)
1179
- #arr = np.array(ctype_arr, copy=False)
1180
-
1181
- except Exception as e:
1182
- raise RuntimeError(
1183
- "Error while trying to construct Warp array from a Python list of Warp structs." + str(e))
1184
-
1185
- else:
1186
- try:
1187
- # convert tuples and lists of numeric types to ndarray
1188
- arr = np.array(data, copy=False)
1189
1608
  except Exception as e:
1190
1609
  raise RuntimeError(
1191
- "When constructing an array the data argument must be convertible to ndarray type type. Encountered an error while converting:"
1192
- + str(e)
1193
- )
1194
-
1195
- if dtype == Any:
1196
- # infer dtype from the source data array
1197
- dtype = np_dtype_to_warp_type[arr.dtype]
1198
-
1199
- # try to convert numeric src array to destination type
1200
- if not isinstance(dtype, warp.codegen.Struct):
1201
- try:
1202
- arr = arr.astype(dtype=type_typestr(dtype), copy=False)
1203
- except:
1204
- raise RuntimeError(
1205
- f"Could not convert input data with type {arr.dtype} to array with type {dtype._type_}"
1610
+ f"Error while trying to construct Warp array from a sequence of Warp structs: {e}"
1206
1611
  )
1612
+ else:
1613
+ raise RuntimeError(
1614
+ "Invalid data argument for array of structs, expected a sequence of structs or a NumPy structured array"
1615
+ )
1616
+ else:
1617
+ # convert input data to the given dtype
1618
+ npdtype = warp_type_to_np_dtype.get(scalar_dtype)
1619
+ if npdtype is None:
1620
+ raise RuntimeError(
1621
+ f"Failed to convert input data to an array with Warp type {warp.context.type_str(dtype)}"
1622
+ )
1623
+ try:
1624
+ arr = np.array(data, dtype=npdtype, copy=False, ndmin=1)
1625
+ except Exception as e:
1626
+ raise RuntimeError(f"Failed to convert input data to an array with type {npdtype}: {e}")
1627
+
1628
+ # determine whether the input needs reshaping
1629
+ target_npshape = None
1630
+ if shape is not None:
1631
+ target_npshape = (*shape, *dtype_shape)
1632
+ elif dtype_ndim > 0:
1633
+ # prune inner dimensions of length 1
1634
+ while arr.ndim > 1 and arr.shape[-1] == 1:
1635
+ arr = np.squeeze(arr, axis=-1)
1636
+ # if the inner dims don't match exactly, check if the innermost dim is a multiple of type length
1637
+ if arr.ndim < dtype_ndim or arr.shape[-dtype_ndim:] != dtype_shape:
1638
+ if arr.shape[-1] == dtype._length_:
1639
+ target_npshape = (*arr.shape[:-1], *dtype_shape)
1640
+ elif arr.shape[-1] % dtype._length_ == 0:
1641
+ target_npshape = (*arr.shape[:-1], arr.shape[-1] // dtype._length_, *dtype_shape)
1642
+ else:
1643
+ if dtype_ndim == 1:
1644
+ raise RuntimeError(
1645
+ f"The inner dimensions of the input data are not compatible with the requested vector type {warp.context.type_str(dtype)}: expected an inner dimension that is a multiple of {dtype._length_}"
1646
+ )
1647
+ else:
1648
+ raise RuntimeError(
1649
+ f"The inner dimensions of the input data are not compatible with the requested matrix type {warp.context.type_str(dtype)}: expected inner dimensions {dtype._shape_} or a multiple of {dtype._length_}"
1650
+ )
1207
1651
 
1208
- # ensure contiguous
1209
- arr = np.ascontiguousarray(arr)
1652
+ if target_npshape is not None:
1653
+ try:
1654
+ arr = arr.reshape(target_npshape)
1655
+ except Exception as e:
1656
+ raise RuntimeError(
1657
+ f"Failed to reshape the input data to the given shape {shape} and type {warp.context.type_str(dtype)}: {e}"
1658
+ )
1210
1659
 
1211
- # remove any trailing dimensions of length 1
1212
- if arr.ndim > 1 and arr.shape[-1] == 1:
1213
- arr = np.squeeze(arr, axis=len(arr.shape) - 1)
1660
+ # determine final shape and strides
1661
+ if dtype_ndim > 0:
1662
+ # make sure the inner dims are contiguous for vector/matrix types
1663
+ scalar_size = type_size_in_bytes(dtype._wp_scalar_type_)
1664
+ inner_contiguous = arr.strides[-1] == scalar_size
1665
+ if inner_contiguous and dtype_ndim > 1:
1666
+ inner_contiguous = arr.strides[-2] == scalar_size * dtype_shape[-1]
1214
1667
 
1215
- ptr = arr.__array_interface__["data"][0]
1216
- shape = arr.__array_interface__["shape"]
1217
- strides = arr.__array_interface__.get("strides", None)
1668
+ if not inner_contiguous:
1669
+ arr = np.ascontiguousarray(arr)
1218
1670
 
1219
- # Convert input shape to Warp
1220
- if type_length(dtype) > 1:
1221
- # if we are constructing an array of vectors/matrices, but input
1222
- # is one dimensional (i.e.: flattened) then try and reshape to
1223
- # to match target dtype, inferring the first dimension
1224
- if arr.ndim == 1:
1225
- arr = arr.reshape((-1, *dtype._shape_))
1671
+ shape = arr.shape[:-dtype_ndim] or (1,)
1672
+ strides = arr.strides[:-dtype_ndim] or (type_size_in_bytes(dtype),)
1673
+ else:
1674
+ shape = arr.shape or (1,)
1675
+ strides = arr.strides or (type_size_in_bytes(dtype),)
1226
1676
 
1227
- # last dimension should match dtype shape when using vector types,
1228
- # e.g.: array of mat22 objects should have shape (n, 2, 2)
1229
- dtype_ndim = len(dtype._shape_)
1677
+ device = warp.get_device(device)
1230
1678
 
1231
- trailing_shape = arr.shape[-dtype_ndim:]
1232
- leading_shape = arr.shape[0:-dtype_ndim]
1679
+ if device.is_cpu and not copy and not pinned:
1680
+ # reference numpy memory directly
1681
+ self._init_from_ptr(arr.ctypes.data, dtype, shape, strides, None, device, False, False)
1682
+ # keep a ref to the source array to keep allocation alive
1683
+ self._ref = arr
1684
+ else:
1685
+ # copy data into a new array
1686
+ self._init_new(dtype, shape, None, device, pinned)
1687
+ src = array(
1688
+ ptr=arr.ctypes.data,
1689
+ dtype=dtype,
1690
+ shape=shape,
1691
+ strides=strides,
1692
+ device="cpu",
1693
+ copy=False,
1694
+ owner=False,
1695
+ )
1696
+ warp.copy(self, src)
1233
1697
 
1234
- if dtype._shape_ != trailing_shape:
1235
- raise RuntimeError(
1236
- f"Last dimensions of input array should match the specified data type, given shape {arr.shape}, expected last dimensions to match dtype shape {dtype._shape_}"
1237
- )
1698
+ def _init_from_ptr(self, ptr, dtype, shape, strides, capacity, device, owner, pinned):
1699
+ if dtype == Any:
1700
+ raise RuntimeError("A concrete data type is required to create the array")
1238
1701
 
1239
- shape = leading_shape
1702
+ device = warp.get_device(device)
1240
1703
 
1241
- if strides is not None:
1242
- strides = strides[0:-dtype_ndim]
1704
+ size = 1
1705
+ for d in shape:
1706
+ size *= d
1243
1707
 
1244
- if device.is_cpu and copy is False:
1245
- # ref numpy memory directly
1246
- self.shape = shape
1247
- self.ptr = ptr
1248
- self.grad_ptr = grad_ptr
1249
- self.dtype = dtype
1250
- self.strides = strides
1251
- self.capacity = arr.size * type_size_in_bytes(dtype)
1252
- self.device = device
1253
- self.owner = False
1254
- self.pinned = False
1708
+ contiguous_strides = strides_from_shape(shape, dtype)
1255
1709
 
1256
- # keep a ref to source array to keep allocation alive
1257
- self.ref = arr
1710
+ if strides is None:
1711
+ strides = contiguous_strides
1712
+ is_contiguous = True
1713
+ if capacity is None:
1714
+ capacity = size * type_size_in_bytes(dtype)
1715
+ else:
1716
+ is_contiguous = strides == contiguous_strides
1717
+ if capacity is None:
1718
+ capacity = shape[0] * strides[0]
1719
+
1720
+ self.dtype = dtype
1721
+ self.ndim = len(shape)
1722
+ self.size = size
1723
+ self.capacity = capacity
1724
+ self.shape = shape
1725
+ self.strides = strides
1726
+ self.ptr = ptr
1727
+ self.device = device
1728
+ self.owner = owner
1729
+ self.pinned = pinned if device.is_cpu else False
1730
+ self.is_contiguous = is_contiguous
1258
1731
 
1259
- else:
1260
- # otherwise, we must transfer to device memory
1261
- # create a host wrapper around the numpy array
1262
- # and a new destination array to copy it to
1263
- src = array(
1264
- dtype=dtype,
1265
- shape=shape,
1266
- strides=strides,
1267
- capacity=arr.size * type_size_in_bytes(dtype),
1268
- ptr=ptr,
1269
- device="cpu",
1270
- copy=False,
1271
- owner=False,
1272
- )
1273
- dest = warp.empty(shape, dtype=dtype, device=device, requires_grad=requires_grad, pinned=pinned)
1274
- dest.owner = False
1732
+ def _init_new(self, dtype, shape, strides, device, pinned):
1733
+ if dtype == Any:
1734
+ raise RuntimeError("A concrete data type is required to create the array")
1275
1735
 
1276
- # copy data using the CUDA default stream for synchronous behaviour with other streams
1277
- warp.copy(dest, src, stream=device.null_stream)
1736
+ device = warp.get_device(device)
1278
1737
 
1279
- # object copy to self and transfer data ownership, would probably be cleaner to have _empty, _zero, etc as class methods
1280
- from copy import copy as shallowcopy
1738
+ size = 1
1739
+ for d in shape:
1740
+ size *= d
1281
1741
 
1282
- self.__dict__ = shallowcopy(dest.__dict__)
1283
- self.owner = True
1742
+ contiguous_strides = strides_from_shape(shape, dtype)
1284
1743
 
1744
+ if strides is None:
1745
+ strides = contiguous_strides
1746
+ is_contiguous = True
1747
+ capacity = size * type_size_in_bytes(dtype)
1285
1748
  else:
1286
- # explicit construction from ptr to external memory
1287
- self.shape = shape
1288
- self.strides = strides
1289
- self.capacity = capacity
1290
- self.dtype = dtype
1291
- self.ptr = ptr
1292
- self.grad_ptr = grad_ptr
1293
- self.device = device
1294
- self.owner = owner
1295
- if device is not None and device.is_cpu:
1296
- self.pinned = pinned
1297
- else:
1298
- self.pinned = False
1749
+ is_contiguous = strides == contiguous_strides
1750
+ capacity = shape[0] * strides[0]
1299
1751
 
1300
- self.__name__ = "array<" + type.__name__ + ">"
1301
-
1302
- # update ndim
1303
- if ndim is None:
1304
- self.ndim = len(self.shape)
1752
+ if capacity > 0:
1753
+ ptr = device.allocator.alloc(capacity, pinned=pinned)
1754
+ if ptr is None:
1755
+ raise RuntimeError(f"Array allocation failed on device: {device} for {capacity} bytes")
1305
1756
  else:
1306
- self.ndim = ndim
1757
+ ptr = None
1307
1758
 
1308
- # update size (num elements)
1309
- self.size = 1
1310
- for d in self.shape:
1311
- self.size *= d
1759
+ self.dtype = dtype
1760
+ self.ndim = len(shape)
1761
+ self.size = size
1762
+ self.capacity = capacity
1763
+ self.shape = shape
1764
+ self.strides = strides
1765
+ self.ptr = ptr
1766
+ self.device = device
1767
+ self.owner = True
1768
+ self.pinned = pinned if device.is_cpu else False
1769
+ self.is_contiguous = is_contiguous
1770
+
1771
+ def _init_annotation(self, dtype, ndim):
1772
+ self.dtype = dtype
1773
+ self.ndim = ndim
1774
+ self.size = 0
1775
+ self.capacity = 0
1776
+ self.shape = (0,) * ndim
1777
+ self.strides = (0,) * ndim
1778
+ self.ptr = None
1779
+ self.device = None
1780
+ self.owner = False
1781
+ self.pinned = False
1782
+ self.is_contiguous = False
1312
1783
 
1313
- self._grad = None
1784
+ @property
1785
+ def __array_interface__(self):
1786
+ # raising an AttributeError here makes hasattr() return False
1787
+ if self.device is None or not self.device.is_cpu:
1788
+ raise AttributeError(f"__array_interface__ not supported because device is {self.device}")
1314
1789
 
1315
- # set up array interface access so we can treat this object as a numpy array
1316
- if self.ptr:
1317
- # update byte strides and contiguous flag
1318
- contiguous_strides = strides_from_shape(self.shape, self.dtype)
1319
- if strides is None:
1320
- self.strides = contiguous_strides
1321
- self.is_contiguous = True
1790
+ if self._array_interface is None:
1791
+ # get flat shape (including type shape)
1792
+ if isinstance(self.dtype, warp.codegen.Struct):
1793
+ # struct
1794
+ arr_shape = self.shape
1795
+ arr_strides = self.strides
1796
+ descr = self.dtype.numpy_dtype()
1797
+ elif issubclass(self.dtype, ctypes.Array):
1798
+ # vector type, flatten the dimensions into one tuple
1799
+ arr_shape = (*self.shape, *self.dtype._shape_)
1800
+ dtype_strides = strides_from_shape(self.dtype._shape_, self.dtype._type_)
1801
+ arr_strides = (*self.strides, *dtype_strides)
1802
+ descr = None
1322
1803
  else:
1323
- self.strides = strides
1324
- self.is_contiguous = strides[:ndim] == contiguous_strides[:ndim]
1804
+ # scalar type
1805
+ arr_shape = self.shape
1806
+ arr_strides = self.strides
1807
+ descr = None
1808
+
1809
+ self._array_interface = {
1810
+ "data": (self.ptr if self.ptr is not None else 0, False),
1811
+ "shape": tuple(arr_shape),
1812
+ "strides": tuple(arr_strides),
1813
+ "typestr": type_typestr(self.dtype),
1814
+ "descr": descr, # optional description of structured array layout
1815
+ "version": 3,
1816
+ }
1325
1817
 
1326
- # store flat shape (including type shape)
1818
+ return self._array_interface
1327
1819
 
1328
- if isinstance(dtype, type) and issubclass(dtype, ctypes.Array):
1820
+ @property
1821
+ def __cuda_array_interface__(self):
1822
+ # raising an AttributeError here makes hasattr() return False
1823
+ if self.device is None or not self.device.is_cuda:
1824
+ raise AttributeError(f"__cuda_array_interface__ is not supported because device is {self.device}")
1825
+
1826
+ if self._array_interface is None:
1827
+ # get flat shape (including type shape)
1828
+ if issubclass(self.dtype, ctypes.Array):
1329
1829
  # vector type, flatten the dimensions into one tuple
1330
1830
  arr_shape = (*self.shape, *self.dtype._shape_)
1331
1831
  dtype_strides = strides_from_shape(self.dtype._shape_, self.dtype._type_)
@@ -1335,44 +1835,18 @@ class array(Array):
1335
1835
  arr_shape = self.shape
1336
1836
  arr_strides = self.strides
1337
1837
 
1338
- if device.is_cpu:
1339
- self.__array_interface__ = {
1340
- "data": (self.ptr, False),
1341
- "shape": tuple(arr_shape),
1342
- "strides": tuple(arr_strides),
1343
- "typestr": type_typestr(self.dtype),
1344
- "version": 3,
1345
- }
1346
-
1347
- # set up cuda array interface access so we can treat this object as a Torch tensor
1348
- elif device.is_cuda:
1349
- self.__cuda_array_interface__ = {
1350
- "data": (self.ptr, False),
1351
- "shape": tuple(arr_shape),
1352
- "strides": tuple(arr_strides),
1353
- "typestr": type_typestr(self.dtype),
1354
- "version": 2,
1355
- }
1356
-
1357
- # controls if gradients will be computed by wp.Tape
1358
- # this will trigger allocation of a gradient array if it doesn't exist already
1359
- self.requires_grad = requires_grad
1360
-
1361
- else:
1362
- # array has no data
1363
- self.strides = (0,) * self.ndim
1364
- self.is_contiguous = False
1365
- self.requires_grad = False
1838
+ self._array_interface = {
1839
+ "data": (self.ptr if self.ptr is not None else 0, False),
1840
+ "shape": tuple(arr_shape),
1841
+ "strides": tuple(arr_strides),
1842
+ "typestr": type_typestr(self.dtype),
1843
+ "version": 2,
1844
+ }
1366
1845
 
1367
- self.ctype = None
1846
+ return self._array_interface
1368
1847
 
1369
1848
  def __del__(self):
1370
- if self.owner and self.device is not None and self.ptr is not None:
1371
- # TODO: ill-timed gc could trigger superfluous context switches here
1372
- # Delegate to a separate thread? (e.g., device_free_async)
1373
- if self.device.is_capturing:
1374
- raise RuntimeError(f"Cannot free memory on device {self.device} while graph capture is active")
1375
-
1849
+ if self.owner:
1376
1850
  # use CUDA context guard to avoid side effects during garbage collection
1377
1851
  with self.device.context_guard:
1378
1852
  self.device.allocator.free(self.ptr, self.capacity, self.pinned)
@@ -1385,7 +1859,7 @@ class array(Array):
1385
1859
  # for 'empty' arrays we just return the type information, these are used in kernel function signatures
1386
1860
  return f"array{self.dtype}"
1387
1861
  else:
1388
- return str(self.to("cpu").numpy())
1862
+ return str(self.numpy())
1389
1863
 
1390
1864
  def __getitem__(self, key):
1391
1865
  if isinstance(key, int):
@@ -1436,7 +1910,7 @@ class array(Array):
1436
1910
  if stop < 0:
1437
1911
  stop = self.shape[idx] + stop
1438
1912
 
1439
- if start < 0 or start > self.shape[idx] - 1:
1913
+ if start < 0 or start >= self.shape[idx]:
1440
1914
  raise RuntimeError(f"Invalid indexing in slice: {start}:{stop}:{step}")
1441
1915
  if stop < 1 or stop > self.shape[idx]:
1442
1916
  raise RuntimeError(f"Invalid indexing in slice: {start}:{stop}:{step}")
@@ -1460,23 +1934,37 @@ class array(Array):
1460
1934
  start = k
1461
1935
  if start < 0:
1462
1936
  start = self.shape[idx] + start
1463
- if start < 0 or start > self.shape[idx] - 1:
1937
+ if start < 0 or start >= self.shape[idx]:
1464
1938
  raise RuntimeError(f"Invalid indexing in slice: {k}")
1465
1939
  new_dim -= 1
1466
1940
 
1467
1941
  ptr_offset += self.strides[idx] * start
1468
1942
 
1943
+ # handle grad
1944
+ if self.grad is not None:
1945
+ new_grad = array(
1946
+ ptr=self.grad.ptr + ptr_offset if self.grad.ptr is not None else None,
1947
+ dtype=self.grad.dtype,
1948
+ shape=tuple(new_shape),
1949
+ strides=tuple(new_strides),
1950
+ device=self.grad.device,
1951
+ pinned=self.grad.pinned,
1952
+ owner=False,
1953
+ )
1954
+ # store back-ref to stop data being destroyed
1955
+ new_grad._ref = self.grad
1956
+ else:
1957
+ new_grad = None
1958
+
1469
1959
  a = array(
1960
+ ptr=self.ptr + ptr_offset if self.ptr is not None else None,
1470
1961
  dtype=self.dtype,
1471
1962
  shape=tuple(new_shape),
1472
1963
  strides=tuple(new_strides),
1473
- ptr=self.ptr + ptr_offset,
1474
- grad_ptr=(self.grad_ptr + ptr_offset if self.grad_ptr is not None else None),
1475
- capacity=self.capacity,
1476
1964
  device=self.device,
1965
+ pinned=self.pinned,
1477
1966
  owner=False,
1478
- ndim=new_dim,
1479
- requires_grad=self.requires_grad,
1967
+ grad=new_grad,
1480
1968
  )
1481
1969
 
1482
1970
  # store back-ref to stop data being destroyed
@@ -1494,7 +1982,7 @@ class array(Array):
1494
1982
  def __ctype__(self):
1495
1983
  if self.ctype is None:
1496
1984
  data = 0 if self.ptr is None else ctypes.c_uint64(self.ptr)
1497
- grad = 0 if self.grad_ptr is None else ctypes.c_uint64(self.grad_ptr)
1985
+ grad = 0 if self.grad is None or self.grad.ptr is None else ctypes.c_uint64(self.grad.ptr)
1498
1986
  self.ctype = array_t(data=data, grad=grad, ndim=self.ndim, shape=self.shape, strides=self.strides)
1499
1987
 
1500
1988
  return self.ctype
@@ -1522,25 +2010,31 @@ class array(Array):
1522
2010
  return self._grad
1523
2011
 
1524
2012
  @grad.setter
1525
- def grad(self, value):
1526
- # trigger re-creation of C-representation
1527
- self.ctype = None
1528
- if value is None:
1529
- self.grad_ptr = None
2013
+ def grad(self, grad):
2014
+ if grad is None:
1530
2015
  self._grad = None
1531
- return
1532
- if self._grad is None:
1533
- self.grad_ptr = value.ptr
1534
- self._grad = value
2016
+ self._requires_grad = False
1535
2017
  else:
1536
- self._grad.assign(value)
2018
+ # make sure the given gradient array is compatible
2019
+ if (
2020
+ grad.dtype != self.dtype
2021
+ or grad.shape != self.shape
2022
+ or grad.strides != self.strides
2023
+ or grad.device != self.device
2024
+ ):
2025
+ raise ValueError("The given gradient array is incompatible")
2026
+ self._grad = grad
2027
+ self._requires_grad = True
2028
+
2029
+ # trigger re-creation of C-representation
2030
+ self.ctype = None
1537
2031
 
1538
2032
  @property
1539
2033
  def requires_grad(self):
1540
2034
  return self._requires_grad
1541
2035
 
1542
2036
  @requires_grad.setter
1543
- def requires_grad(self, value: bool):
2037
+ def requires_grad(self, value: builtins.bool):
1544
2038
  if value and self._grad is None:
1545
2039
  self._alloc_grad()
1546
2040
  elif not value:
@@ -1548,18 +2042,15 @@ class array(Array):
1548
2042
 
1549
2043
  self._requires_grad = value
1550
2044
 
1551
- def _alloc_grad(self):
1552
- if self.grad_ptr is None:
1553
- num_bytes = self.size * type_size_in_bytes(self.dtype)
1554
- self.grad_ptr = self.device.allocator.alloc(num_bytes, pinned=self.pinned)
1555
- if self.grad_ptr is None:
1556
- raise RuntimeError("Memory allocation failed on device: {} for {} bytes".format(self.device, num_bytes))
1557
- with warp.ScopedStream(self.device.null_stream):
1558
- self.device.memset(self.grad_ptr, 0, num_bytes)
2045
+ # trigger re-creation of C-representation
2046
+ self.ctype = None
1559
2047
 
2048
+ def _alloc_grad(self):
1560
2049
  self._grad = array(
1561
- ptr=self.grad_ptr, shape=self.shape, dtype=self.dtype, device=self.device, requires_grad=False, owner=False
2050
+ dtype=self.dtype, shape=self.shape, strides=self.strides, device=self.device, pinned=self.pinned
1562
2051
  )
2052
+ self._grad.zero_()
2053
+
1563
2054
  # trigger re-creation of C-representation
1564
2055
  self.ctype = None
1565
2056
 
@@ -1568,171 +2059,195 @@ class array(Array):
1568
2059
  # member attributes available during code-gen (e.g.: d = array.shape[0])
1569
2060
  # Note: we use a shared dict for all array instances
1570
2061
  if array._vars is None:
1571
- from warp.codegen import Var
1572
-
1573
- array._vars = {"shape": Var("shape", shape_t)}
2062
+ array._vars = {"shape": warp.codegen.Var("shape", shape_t)}
1574
2063
  return array._vars
1575
2064
 
1576
2065
  def zero_(self):
1577
- if not self.is_contiguous:
1578
- raise RuntimeError("Assigning to non-contiguous arrays is unsupported.")
1579
-
1580
- if self.device is not None and self.ptr is not None:
1581
- self.device.memset(
1582
- ctypes.c_void_p(self.ptr), ctypes.c_int(0), ctypes.c_size_t(self.size * type_size_in_bytes(self.dtype))
1583
- )
2066
+ """Zeroes-out the array entries."""
2067
+ if self.is_contiguous:
2068
+ # simple memset is usually faster than generic fill
2069
+ self.device.memset(self.ptr, 0, self.size * type_size_in_bytes(self.dtype))
2070
+ else:
2071
+ self.fill_(0)
1584
2072
 
1585
2073
  def fill_(self, value):
1586
- if not self.is_contiguous:
1587
- raise RuntimeError("Assigning to non-contiguous arrays is unsupported.")
1588
-
1589
- if self.device is not None and self.ptr is not None:
1590
- if isinstance(value, ctypes.Array):
1591
- # in this case we're filling the array with a vector or
1592
- # something similar, eg arr.fill_(wp.vec3(1.0,2.0,3.0)).
1593
-
1594
- # check input type:
1595
- value_type_ok = False
1596
- if issubclass(self.dtype, ctypes.Array):
1597
- value_type_ok = (self.dtype._length_ == value._length_) and (self.dtype._type_ == value._type_)
1598
- if not value_type_ok:
1599
- raise RuntimeError(
1600
- "wp.array has Array type elements (eg vec, mat etc). Value type must match element type in wp.array.fill_() method"
1601
- )
1602
-
1603
- src = ctypes.cast(value, ctypes.POINTER(ctypes.c_void_p))
1604
-
1605
- srcsize = value._length_ * ctypes.sizeof(value._type_)
1606
- dst = ctypes.cast(self.ptr, ctypes.POINTER(ctypes.c_int))
1607
- self.device.memtile(dst, src, srcsize, self.size)
2074
+ """Set all array entries to `value`
2075
+
2076
+ args:
2077
+ value: The value to set every array entry to. Must be convertible to the array's ``dtype``.
2078
+
2079
+ Raises:
2080
+ ValueError: If `value` cannot be converted to the array's ``dtype``.
2081
+
2082
+ Examples:
2083
+ ``fill_()`` can take lists or other sequences when filling arrays of vectors or matrices.
2084
+
2085
+ >>> arr = wp.zeros(2, dtype=wp.mat22)
2086
+ >>> arr.numpy()
2087
+ array([[[0., 0.],
2088
+ [0., 0.]],
2089
+ <BLANKLINE>
2090
+ [[0., 0.],
2091
+ [0., 0.]]], dtype=float32)
2092
+ >>> arr.fill_([[1, 2], [3, 4]])
2093
+ >>> arr.numpy()
2094
+ array([[[1., 2.],
2095
+ [3., 4.]],
2096
+ <BLANKLINE>
2097
+ [[1., 2.],
2098
+ [3., 4.]]], dtype=float32)
2099
+ """
2100
+ if self.size == 0:
2101
+ return
1608
2102
 
1609
- else:
1610
- # In this case we're just filling the array with a scalar,
1611
- # eg arr.fill_(1.0). If the elements are scalars, we need to
1612
- # set them all to "value", otherwise we need to set all the
1613
- # components of all the vector elements to "value":
1614
-
1615
- # work out array element type:
1616
- elem_type = self.dtype._type_ if issubclass(self.dtype, ctypes.Array) else type_ctype(self.dtype)
1617
- elem_size = ctypes.sizeof(elem_type)
1618
-
1619
- # convert value to array type
1620
- # we need a special case for float16 because it's annoying...
1621
- if types_equal(self.dtype, float16) or (
1622
- hasattr(self.dtype, "_wp_scalar_type_") and types_equal(self.dtype._wp_scalar_type_, float16)
1623
- ):
1624
- # special case for float16:
1625
- # If you just do elem_type(value), it'll just convert "value"
1626
- # to uint16 then interpret the bits as float16, which will
1627
- # mess the data up. Instead, we use float_to_half_bits() to
1628
- # convert "value" to a float16 and return its bits in a uint16:
1629
-
1630
- from warp.context import runtime
1631
-
1632
- src_value = elem_type(runtime.core.float_to_half_bits(ctypes.c_float(value)))
2103
+ # try to convert the given value to the array dtype
2104
+ try:
2105
+ if isinstance(self.dtype, warp.codegen.Struct):
2106
+ if isinstance(value, self.dtype.cls):
2107
+ cvalue = value.__ctype__()
2108
+ elif value == 0:
2109
+ # allow zero-initializing structs using default constructor
2110
+ cvalue = self.dtype().__ctype__()
1633
2111
  else:
1634
- src_value = elem_type(value)
1635
-
1636
- # use memset for these special cases because it's quicker (probably...):
1637
- total_bytes = self.size * type_size_in_bytes(self.dtype)
1638
- if elem_size in [1, 2, 4] and (total_bytes % 4 == 0):
1639
- # interpret as a 4 byte integer:
1640
- dest_value = ctypes.cast(ctypes.pointer(src_value), ctypes.POINTER(ctypes.c_int)).contents
1641
- if elem_size == 1:
1642
- # need to repeat the bits, otherwise we'll get an array interleaved with zeros:
1643
- dest_value.value = dest_value.value & 0x000000FF
1644
- dest_value.value = (
1645
- dest_value.value
1646
- + (dest_value.value << 8)
1647
- + (dest_value.value << 16)
1648
- + (dest_value.value << 24)
1649
- )
1650
- elif elem_size == 2:
1651
- # need to repeat the bits, otherwise we'll get an array interleaved with zeros:
1652
- dest_value.value = dest_value.value & 0x0000FFFF
1653
- dest_value.value = dest_value.value + (dest_value.value << 16)
1654
-
1655
- self.device.memset(
1656
- ctypes.cast(self.ptr, ctypes.POINTER(ctypes.c_int)), dest_value, ctypes.c_size_t(total_bytes)
2112
+ raise ValueError(
2113
+ f"Invalid initializer value for struct {self.dtype.cls.__name__}, expected struct instance or 0"
1657
2114
  )
2115
+ elif issubclass(self.dtype, ctypes.Array):
2116
+ # vector/matrix
2117
+ cvalue = self.dtype(value)
2118
+ else:
2119
+ # scalar
2120
+ if type(value) in warp.types.scalar_types:
2121
+ value = value.value
2122
+ if self.dtype == float16:
2123
+ cvalue = self.dtype._type_(float_to_half_bits(value))
1658
2124
  else:
1659
- num_elems = self.size * self.dtype._length_ if issubclass(self.dtype, ctypes.Array) else self.size
1660
- dst = ctypes.cast(self.ptr, ctypes.POINTER(ctypes.c_int))
1661
- self.device.memtile(dst, ctypes.pointer(src_value), elem_size, num_elems)
2125
+ cvalue = self.dtype._type_(value)
2126
+ except Exception as e:
2127
+ raise ValueError(f"Failed to convert the value to the array data type: {e}")
2128
+
2129
+ cvalue_ptr = ctypes.pointer(cvalue)
2130
+ cvalue_size = ctypes.sizeof(cvalue)
2131
+
2132
+ # prefer using memtile for contiguous arrays, because it should be faster than generic fill
2133
+ if self.is_contiguous:
2134
+ self.device.memtile(self.ptr, cvalue_ptr, cvalue_size, self.size)
2135
+ else:
2136
+ carr = self.__ctype__()
2137
+ carr_ptr = ctypes.pointer(carr)
2138
+
2139
+ if self.device.is_cuda:
2140
+ warp.context.runtime.core.array_fill_device(
2141
+ self.device.context, carr_ptr, ARRAY_TYPE_REGULAR, cvalue_ptr, cvalue_size
2142
+ )
2143
+ else:
2144
+ warp.context.runtime.core.array_fill_host(carr_ptr, ARRAY_TYPE_REGULAR, cvalue_ptr, cvalue_size)
1662
2145
 
1663
- # equivalent to wrapping src data in an array and copying to self
1664
2146
  def assign(self, src):
1665
- if isinstance(src, array):
2147
+ """Wraps ``src`` in an :class:`warp.array` if it is not already one and copies the contents to ``self``."""
2148
+ if is_array(src):
1666
2149
  warp.copy(self, src)
1667
2150
  else:
1668
- warp.copy(self, array(src, dtype=self.dtype, copy=False, device="cpu"))
2151
+ warp.copy(self, array(data=src, dtype=self.dtype, copy=False, device="cpu"))
1669
2152
 
1670
- # convert array to ndarray (alias memory through array interface)
1671
2153
  def numpy(self):
1672
- # use the CUDA default stream for synchronous behaviour with other streams
1673
- with warp.ScopedStream(self.device.null_stream):
1674
- if self.ptr is None:
1675
- return np.empty(shape=self.shape, dtype=self.dtype)
2154
+ """Converts the array to a :class:`numpy.ndarray` (aliasing memory through the array interface protocol)
2155
+ If the array is on the GPU, a synchronous device-to-host copy (on the CUDA default stream) will be
2156
+ automatically performed to ensure that any outstanding work is completed.
2157
+ """
2158
+ if self.ptr:
2159
+ # use the CUDA default stream for synchronous behaviour with other streams
2160
+ with warp.ScopedStream(self.device.null_stream):
2161
+ a = self.to("cpu", requires_grad=False)
2162
+ # convert through __array_interface__
2163
+ # Note: this handles arrays of structs using `descr`, so the result will be a structured NumPy array
2164
+ return np.array(a, copy=False)
2165
+ else:
2166
+ # return an empty numpy array with the correct dtype and shape
2167
+ if isinstance(self.dtype, warp.codegen.Struct):
2168
+ npdtype = self.dtype.numpy_dtype()
2169
+ npshape = self.shape
2170
+ elif issubclass(self.dtype, ctypes.Array):
2171
+ npdtype = warp_type_to_np_dtype[self.dtype._wp_scalar_type_]
2172
+ npshape = (*self.shape, *self.dtype._shape_)
1676
2173
  else:
1677
- a = self.to("cpu")
2174
+ npdtype = warp_type_to_np_dtype[self.dtype]
2175
+ npshape = self.shape
2176
+ return np.empty(npshape, dtype=npdtype)
1678
2177
 
1679
- if isinstance(self.dtype, warp.codegen.Struct):
1680
- # Note: cptr holds a backref to the source array to avoid it being deallocated
1681
- p = a.cptr()
1682
- return np.ctypeslib.as_array(p, self.shape)
1683
- else:
1684
- # convert through array interface
1685
- return np.array(a, copy=False)
1686
-
1687
- # return a ctypes cast of the array address
1688
- # note that accesses to this object are *not* bounds checked
1689
2178
  def cptr(self):
1690
- if self.device != "cpu":
1691
- raise RuntimeError("Accessing array memory through a ctypes ptr is only supported for CPU arrays.")
1692
-
1693
- p = ctypes.cast(self.ptr, ctypes.POINTER(self.dtype.ctype))
2179
+ """Return a ctypes cast of the array address.
2180
+
2181
+ Notes:
2182
+
2183
+ #. Only CPU arrays support this method.
2184
+ #. The array must be contiguous.
2185
+ #. Accesses to this object are **not** bounds checked.
2186
+ #. For ``float16`` types, a pointer to the internal ``uint16`` representation is returned.
2187
+ """
2188
+ if not self.ptr:
2189
+ return None
2190
+
2191
+ if self.device != "cpu" or not self.is_contiguous:
2192
+ raise RuntimeError(
2193
+ "Accessing array memory through a ctypes ptr is only supported for contiguous CPU arrays."
2194
+ )
2195
+
2196
+ if isinstance(self.dtype, warp.codegen.Struct):
2197
+ p = ctypes.cast(self.ptr, ctypes.POINTER(self.dtype.ctype))
2198
+ else:
2199
+ p = ctypes.cast(self.ptr, ctypes.POINTER(self.dtype._type_))
1694
2200
 
1695
2201
  # store backref to the underlying array to avoid it being deallocated
1696
2202
  p._ref = self
1697
2203
 
1698
2204
  return p
1699
2205
 
1700
- # returns a flattened list of items in the array as a Python list
1701
2206
  def list(self):
1702
- a = self.to("cpu").flatten()
1703
-
1704
- # Note: cptr holds a backref to the source array to avoid it being deallocated
1705
- p = a.cptr()
2207
+ """Returns a flattened list of items in the array as a Python list."""
2208
+ a = self.numpy()
2209
+
2210
+ if isinstance(self.dtype, warp.codegen.Struct):
2211
+ # struct
2212
+ a = a.flatten()
2213
+ data = a.ctypes.data
2214
+ stride = a.strides[0]
2215
+ return [self.dtype.from_ptr(data + i * stride) for i in range(self.size)]
2216
+ elif issubclass(self.dtype, ctypes.Array):
2217
+ # vector/matrix - flatten, but preserve inner vector/matrix dimensions
2218
+ a = a.reshape((self.size, *self.dtype._shape_))
2219
+ data = a.ctypes.data
2220
+ stride = a.strides[0]
2221
+ return [self.dtype.from_ptr(data + i * stride) for i in range(self.size)]
2222
+ else:
2223
+ # scalar
2224
+ return list(a.flatten())
1706
2225
 
1707
- return p[:a.size]
1708
-
1709
- # convert data from one device to another, nop if already on device
1710
- def to(self, device):
2226
+ def to(self, device, requires_grad=None):
2227
+ """Returns a Warp array with this array's data moved to the specified device, no-op if already on device."""
1711
2228
  device = warp.get_device(device)
1712
2229
  if self.device == device:
1713
2230
  return self
1714
2231
  else:
1715
- dest = warp.empty(shape=self.shape, dtype=self.dtype, device=device, requires_grad=self.requires_grad)
1716
- # to copy between devices, array must be contiguous
1717
- warp.copy(dest, self.contiguous())
1718
- return dest
2232
+ return warp.clone(self, device=device, requires_grad=requires_grad)
1719
2233
 
1720
2234
  def flatten(self):
2235
+ """Returns a zero-copy view of the array collapsed to 1-D. Only supported for contiguous arrays."""
2236
+ if self.ndim == 1:
2237
+ return self
2238
+
1721
2239
  if not self.is_contiguous:
1722
2240
  raise RuntimeError("Flattening non-contiguous arrays is unsupported.")
1723
2241
 
1724
2242
  a = array(
2243
+ ptr=self.ptr,
1725
2244
  dtype=self.dtype,
1726
2245
  shape=(self.size,),
1727
- strides=(type_size_in_bytes(self.dtype),),
1728
- ptr=self.ptr,
1729
- grad_ptr=self.grad_ptr,
1730
- capacity=self.capacity,
1731
2246
  device=self.device,
2247
+ pinned=self.pinned,
1732
2248
  copy=False,
1733
2249
  owner=False,
1734
- ndim=1,
1735
- requires_grad=self.requires_grad,
2250
+ grad=None if self.grad is None else self.grad.flatten(),
1736
2251
  )
1737
2252
 
1738
2253
  # store back-ref to stop data being destroyed
@@ -1740,6 +2255,11 @@ class array(Array):
1740
2255
  return a
1741
2256
 
1742
2257
  def reshape(self, shape):
2258
+ """Returns a reshaped array. Only supported for contiguous arrays.
2259
+
2260
+ Args:
2261
+ shape : An int or tuple of ints specifying the shape of the returned array.
2262
+ """
1743
2263
  if not self.is_contiguous:
1744
2264
  raise RuntimeError("Reshaping non-contiguous arrays is unsupported.")
1745
2265
 
@@ -1748,7 +2268,7 @@ class array(Array):
1748
2268
  raise RuntimeError("shape parameter is required.")
1749
2269
  if isinstance(shape, int):
1750
2270
  shape = (shape,)
1751
- elif isinstance(shape, List):
2271
+ elif not isinstance(shape, tuple):
1752
2272
  shape = tuple(shape)
1753
2273
 
1754
2274
  if len(shape) > ARRAY_MAX_DIMS:
@@ -1756,6 +2276,23 @@ class array(Array):
1756
2276
  f"Arrays may only have {ARRAY_MAX_DIMS} dimensions maximum, trying to create array with {len(shape)} dims."
1757
2277
  )
1758
2278
 
2279
+ # check for -1 dimension and reformat
2280
+ if -1 in shape:
2281
+ idx = self.size
2282
+ denom = 1
2283
+ minus_one_count = 0
2284
+ for i, d in enumerate(shape):
2285
+ if d == -1:
2286
+ idx = i
2287
+ minus_one_count += 1
2288
+ else:
2289
+ denom *= d
2290
+ if minus_one_count > 1:
2291
+ raise RuntimeError("Cannot infer shape if more than one index is -1.")
2292
+ new_shape = list(shape)
2293
+ new_shape[idx] = int(self.size / denom)
2294
+ shape = tuple(new_shape)
2295
+
1759
2296
  size = 1
1760
2297
  for d in shape:
1761
2298
  size *= d
@@ -1764,17 +2301,15 @@ class array(Array):
1764
2301
  raise RuntimeError("Reshaped array must have the same total size as the original.")
1765
2302
 
1766
2303
  a = array(
2304
+ ptr=self.ptr,
1767
2305
  dtype=self.dtype,
1768
2306
  shape=shape,
1769
2307
  strides=None,
1770
- ptr=self.ptr,
1771
- grad_ptr=self.grad_ptr,
1772
- capacity=self.capacity,
1773
2308
  device=self.device,
2309
+ pinned=self.pinned,
1774
2310
  copy=False,
1775
2311
  owner=False,
1776
- ndim=len(shape),
1777
- requires_grad=self.requires_grad,
2312
+ grad=None if self.grad is None else self.grad.reshape(shape),
1778
2313
  )
1779
2314
 
1780
2315
  # store back-ref to stop data being destroyed
@@ -1782,49 +2317,55 @@ class array(Array):
1782
2317
  return a
1783
2318
 
1784
2319
  def view(self, dtype):
2320
+ """Returns a zero-copy view of this array's memory with a different data type.
2321
+ ``dtype`` must have the same byte size of the array's native ``dtype``.
2322
+ """
1785
2323
  if type_size_in_bytes(dtype) != type_size_in_bytes(self.dtype):
1786
- raise RuntimeError("cannot reinterpret cast dtypes of unequal byte size")
1787
- else:
1788
- # return an alias of the array memory with different type information
1789
- a = array(
1790
- data=None,
1791
- dtype=dtype,
1792
- shape=self.shape,
1793
- strides=self.strides,
1794
- ptr=self.ptr,
1795
- grad_ptr=self.grad_ptr,
1796
- capacity=self.capacity,
1797
- device=self.device,
1798
- copy=False,
1799
- owner=False,
1800
- ndim=self.ndim,
1801
- requires_grad=self.requires_grad,
1802
- )
2324
+ raise RuntimeError("Cannot cast dtypes of unequal byte size")
1803
2325
 
1804
- a._ref = self
1805
- return a
2326
+ # return an alias of the array memory with different type information
2327
+ a = array(
2328
+ ptr=self.ptr,
2329
+ dtype=dtype,
2330
+ shape=self.shape,
2331
+ strides=self.strides,
2332
+ device=self.device,
2333
+ pinned=self.pinned,
2334
+ copy=False,
2335
+ owner=False,
2336
+ grad=None if self.grad is None else self.grad.view(dtype),
2337
+ )
2338
+
2339
+ a._ref = self
2340
+ return a
1806
2341
 
1807
2342
  def contiguous(self):
2343
+ """Returns a contiguous array with this array's data. No-op if array is already contiguous."""
1808
2344
  if self.is_contiguous:
1809
2345
  return self
1810
2346
 
1811
2347
  a = warp.empty_like(self)
1812
2348
  warp.copy(a, self)
1813
-
1814
2349
  return a
1815
2350
 
1816
- # note: transpose operation will return an array with a non-contiguous access pattern
1817
2351
  def transpose(self, axes=None):
2352
+ """Returns an zero-copy view of the array with axes transposed.
2353
+
2354
+ Note: The transpose operation will return an array with a non-contiguous access pattern.
2355
+
2356
+ Args:
2357
+ axes (optional): Specifies the how the axes are permuted. If not specified, the axes order will be reversed.
2358
+ """
1818
2359
  # noop if 1d array
1819
- if len(self.shape) == 1:
2360
+ if self.ndim == 1:
1820
2361
  return self
1821
2362
 
1822
2363
  if axes is None:
1823
2364
  # reverse the order of the axes
1824
2365
  axes = range(self.ndim)[::-1]
1825
-
1826
- if len(axes) != len(self.shape):
2366
+ elif len(axes) != len(self.shape):
1827
2367
  raise RuntimeError("Length of parameter axes must be equal in length to array shape")
2368
+
1828
2369
  shape = []
1829
2370
  strides = []
1830
2371
  for a in axes:
@@ -1836,20 +2377,19 @@ class array(Array):
1836
2377
  strides.append(self.strides[a])
1837
2378
 
1838
2379
  a = array(
1839
- data=None,
2380
+ ptr=self.ptr,
1840
2381
  dtype=self.dtype,
1841
2382
  shape=tuple(shape),
1842
2383
  strides=tuple(strides),
1843
- ptr=self.ptr,
1844
- grad_ptr=self.grad_ptr,
1845
- capacity=self.capacity,
1846
2384
  device=self.device,
2385
+ pinned=self.pinned,
1847
2386
  copy=False,
1848
2387
  owner=False,
1849
- ndim=self.ndim,
1850
- requires_grad=self.requires_grad,
2388
+ grad=None if self.grad is None else self.grad.transpose(axes=axes),
1851
2389
  )
1852
2390
 
2391
+ a.is_transposed = not self.is_transposed
2392
+
1853
2393
  a._ref = self
1854
2394
  return a
1855
2395
 
@@ -1878,12 +2418,13 @@ def array4d(*args, **kwargs):
1878
2418
  return array(*args, **kwargs)
1879
2419
 
1880
2420
 
2421
+ # TODO: Rewrite so that we take only shape, not length and optional shape
1881
2422
  def from_ptr(ptr, length, dtype=None, shape=None, device=None):
1882
2423
  return array(
1883
2424
  dtype=dtype,
1884
2425
  length=length,
1885
2426
  capacity=length * type_size_in_bytes(dtype),
1886
- ptr=ctypes.cast(ptr, ctypes.POINTER(ctypes.c_size_t)).contents.value,
2427
+ ptr=0 if ptr == 0 else ctypes.cast(ptr, ctypes.POINTER(ctypes.c_size_t)).contents.value,
1887
2428
  shape=shape,
1888
2429
  device=device,
1889
2430
  owner=False,
@@ -1891,12 +2432,113 @@ def from_ptr(ptr, length, dtype=None, shape=None, device=None):
1891
2432
  )
1892
2433
 
1893
2434
 
1894
- class indexedarray(Generic[T]):
2435
+ # A base class for non-contiguous arrays, providing the implementation of common methods like
2436
+ # contiguous(), to(), numpy(), list(), assign(), zero_(), and fill_().
2437
+ class noncontiguous_array_base(Generic[T]):
2438
+ def __init__(self, array_type_id):
2439
+ self.type_id = array_type_id
2440
+ self.is_contiguous = False
2441
+
2442
+ # return a contiguous copy
2443
+ def contiguous(self):
2444
+ a = warp.empty_like(self)
2445
+ warp.copy(a, self)
2446
+ return a
2447
+
2448
+ # copy data from one device to another, nop if already on device
2449
+ def to(self, device):
2450
+ device = warp.get_device(device)
2451
+ if self.device == device:
2452
+ return self
2453
+ else:
2454
+ return warp.clone(self, device=device)
2455
+
2456
+ # return a contiguous numpy copy
2457
+ def numpy(self):
2458
+ # use the CUDA default stream for synchronous behaviour with other streams
2459
+ with warp.ScopedStream(self.device.null_stream):
2460
+ return self.contiguous().numpy()
2461
+
2462
+ # returns a flattened list of items in the array as a Python list
2463
+ def list(self):
2464
+ # use the CUDA default stream for synchronous behaviour with other streams
2465
+ with warp.ScopedStream(self.device.null_stream):
2466
+ return self.contiguous().list()
2467
+
2468
+ # equivalent to wrapping src data in an array and copying to self
2469
+ def assign(self, src):
2470
+ if is_array(src):
2471
+ warp.copy(self, src)
2472
+ else:
2473
+ warp.copy(self, array(data=src, dtype=self.dtype, copy=False, device="cpu"))
2474
+
2475
+ def zero_(self):
2476
+ self.fill_(0)
2477
+
2478
+ def fill_(self, value):
2479
+ if self.size == 0:
2480
+ return
2481
+
2482
+ # try to convert the given value to the array dtype
2483
+ try:
2484
+ if isinstance(self.dtype, warp.codegen.Struct):
2485
+ if isinstance(value, self.dtype.cls):
2486
+ cvalue = value.__ctype__()
2487
+ elif value == 0:
2488
+ # allow zero-initializing structs using default constructor
2489
+ cvalue = self.dtype().__ctype__()
2490
+ else:
2491
+ raise ValueError(
2492
+ f"Invalid initializer value for struct {self.dtype.cls.__name__}, expected struct instance or 0"
2493
+ )
2494
+ elif issubclass(self.dtype, ctypes.Array):
2495
+ # vector/matrix
2496
+ cvalue = self.dtype(value)
2497
+ else:
2498
+ # scalar
2499
+ if type(value) in warp.types.scalar_types:
2500
+ value = value.value
2501
+ if self.dtype == float16:
2502
+ cvalue = self.dtype._type_(float_to_half_bits(value))
2503
+ else:
2504
+ cvalue = self.dtype._type_(value)
2505
+ except Exception as e:
2506
+ raise ValueError(f"Failed to convert the value to the array data type: {e}")
2507
+
2508
+ cvalue_ptr = ctypes.pointer(cvalue)
2509
+ cvalue_size = ctypes.sizeof(cvalue)
2510
+
2511
+ ctype = self.__ctype__()
2512
+ ctype_ptr = ctypes.pointer(ctype)
2513
+
2514
+ if self.device.is_cuda:
2515
+ warp.context.runtime.core.array_fill_device(
2516
+ self.device.context, ctype_ptr, self.type_id, cvalue_ptr, cvalue_size
2517
+ )
2518
+ else:
2519
+ warp.context.runtime.core.array_fill_host(ctype_ptr, self.type_id, cvalue_ptr, cvalue_size)
2520
+
2521
+
2522
+ # helper to check index array properties
2523
+ def check_index_array(indices, expected_device):
2524
+ if not isinstance(indices, array):
2525
+ raise ValueError(f"Indices must be a Warp array, got {type(indices)}")
2526
+ if indices.ndim != 1:
2527
+ raise ValueError(f"Index array must be one-dimensional, got {indices.ndim}")
2528
+ if indices.dtype != int32:
2529
+ raise ValueError(f"Index array must use int32, got dtype {indices.dtype}")
2530
+ if indices.device != expected_device:
2531
+ raise ValueError(f"Index array device ({indices.device} does not match data array device ({expected_device}))")
2532
+
2533
+
2534
+ class indexedarray(noncontiguous_array_base[T]):
1895
2535
  # member attributes available during code-gen (e.g.: d = arr.shape[0])
1896
2536
  # (initialized when needed)
1897
2537
  _vars = None
1898
2538
 
1899
2539
  def __init__(self, data: array = None, indices: Union[array, List[array]] = None, dtype=None, ndim=None):
2540
+ super().__init__(ARRAY_TYPE_INDEXED)
2541
+
1900
2542
  # canonicalize types
1901
2543
  if dtype is not None:
1902
2544
  if dtype == int:
@@ -1926,17 +2568,6 @@ class indexedarray(Generic[T]):
1926
2568
  shape = list(data.shape)
1927
2569
 
1928
2570
  if indices is not None:
1929
- # helper to check index array properties
1930
- def check_index_array(inds, data):
1931
- if inds.ndim != 1:
1932
- raise ValueError(f"Index array must be one-dimensional, got {inds.ndim}")
1933
- if inds.dtype != int32:
1934
- raise ValueError(f"Index array must use int32, got dtype {inds.dtype}")
1935
- if inds.device != data.device:
1936
- raise ValueError(
1937
- f"Index array device ({inds.device} does not match data array device ({data.device}))"
1938
- )
1939
-
1940
2571
  if isinstance(indices, (list, tuple)):
1941
2572
  if len(indices) > self.ndim:
1942
2573
  raise ValueError(
@@ -1944,16 +2575,14 @@ class indexedarray(Generic[T]):
1944
2575
  )
1945
2576
 
1946
2577
  for i in range(len(indices)):
1947
- if isinstance(indices[i], array):
1948
- check_index_array(indices[i], data)
2578
+ if indices[i] is not None:
2579
+ check_index_array(indices[i], data.device)
1949
2580
  self.indices[i] = indices[i]
1950
2581
  shape[i] = len(indices[i])
1951
- elif indices[i] is not None:
1952
- raise TypeError(f"Invalid index array type: {type(indices[i])}")
1953
2582
 
1954
2583
  elif isinstance(indices, array):
1955
2584
  # only a single index array was provided
1956
- check_index_array(indices, data)
2585
+ check_index_array(indices, data.device)
1957
2586
  self.indices[0] = indices
1958
2587
  shape[0] = len(indices)
1959
2588
 
@@ -1975,13 +2604,15 @@ class indexedarray(Generic[T]):
1975
2604
  for d in self.shape:
1976
2605
  self.size *= d
1977
2606
 
1978
- self.is_contiguous = False
1979
-
1980
2607
  def __len__(self):
1981
2608
  return self.shape[0]
1982
2609
 
1983
2610
  def __str__(self):
1984
- return f"indexedarray{self.dtype}"
2611
+ if self.device is None:
2612
+ # type annotation
2613
+ return f"indexedarray{self.dtype}"
2614
+ else:
2615
+ return str(self.numpy())
1985
2616
 
1986
2617
  # construct a C-representation of the array for passing to kernels
1987
2618
  def __ctype__(self):
@@ -1992,48 +2623,9 @@ class indexedarray(Generic[T]):
1992
2623
  # member attributes available during code-gen (e.g.: d = arr.shape[0])
1993
2624
  # Note: we use a shared dict for all indexedarray instances
1994
2625
  if indexedarray._vars is None:
1995
- from warp.codegen import Var
1996
-
1997
- indexedarray._vars = {"shape": Var("shape", shape_t)}
2626
+ indexedarray._vars = {"shape": warp.codegen.Var("shape", shape_t)}
1998
2627
  return indexedarray._vars
1999
2628
 
2000
- def contiguous(self):
2001
- a = warp.empty_like(self)
2002
- warp.copy(a, self)
2003
-
2004
- return a
2005
-
2006
- # convert data from one device to another, nop if already on device
2007
- def to(self, device):
2008
- device = warp.get_device(device)
2009
- if self.device == device:
2010
- return self
2011
- else:
2012
- dest = warp.empty(shape=self.shape, dtype=self.dtype, device=device)
2013
- # to copy between devices, array must be contiguous
2014
- warp.copy(dest, self.contiguous())
2015
- return dest
2016
-
2017
- # convert array to ndarray (alias memory through array interface)
2018
- def numpy(self):
2019
- # use the CUDA default stream for synchronous behaviour with other streams
2020
- with warp.ScopedStream(self.device.null_stream):
2021
-
2022
- a = self.contiguous().to("cpu")
2023
-
2024
- if isinstance(self.dtype, warp.codegen.Struct):
2025
- p = ctypes.cast(a.ptr, ctypes.POINTER(a.dtype.ctype))
2026
- np.ctypeslib.as_array(p, self.shape)
2027
- else:
2028
- # convert through array interface
2029
- return np.array(a, copy=False)
2030
-
2031
- # returns a flattened list of items in the array as a Python list
2032
- def list(self):
2033
- a = self.flatten()
2034
- p = ctypes.cast(a.ptr, ctypes.POINTER(a.dtype.ctype))
2035
- return p[:a.size]
2036
-
2037
2629
 
2038
2630
  # aliases for indexedarrays with small dimensions
2039
2631
  def indexedarray1d(*args, **kwargs):
@@ -2059,7 +2651,22 @@ def indexedarray4d(*args, **kwargs):
2059
2651
  return indexedarray(*args, **kwargs)
2060
2652
 
2061
2653
 
2062
- array_types = (array, indexedarray)
2654
+ from warp.fabric import fabricarray, indexedfabricarray # noqa: E402
2655
+
2656
+ array_types = (array, indexedarray, fabricarray, indexedfabricarray)
2657
+
2658
+
2659
+ def array_type_id(a):
2660
+ if isinstance(a, array):
2661
+ return ARRAY_TYPE_REGULAR
2662
+ elif isinstance(a, indexedarray):
2663
+ return ARRAY_TYPE_INDEXED
2664
+ elif isinstance(a, fabricarray):
2665
+ return ARRAY_TYPE_FABRIC
2666
+ elif isinstance(a, indexedfabricarray):
2667
+ return ARRAY_TYPE_FABRIC_INDEXED
2668
+ else:
2669
+ raise ValueError("Invalid array type")
2063
2670
 
2064
2671
 
2065
2672
  class Bvh:
@@ -2117,11 +2724,11 @@ class Bvh:
2117
2724
  with self.device.context_guard:
2118
2725
  runtime.core.bvh_destroy_device(self.id)
2119
2726
 
2120
- except:
2727
+ except Exception:
2121
2728
  pass
2122
2729
 
2123
2730
  def refit(self):
2124
- """Refit the Bvh. This should be called after users modify the `lowers` and `uppers` arrays."""
2731
+ """Refit the BVH. This should be called after users modify the `lowers` and `uppers` arrays."""
2125
2732
 
2126
2733
  from warp.context import runtime
2127
2734
 
@@ -2141,7 +2748,7 @@ class Mesh:
2141
2748
  "indices": Var("indices", array(dtype=int32)),
2142
2749
  }
2143
2750
 
2144
- def __init__(self, points=None, indices=None, velocities=None):
2751
+ def __init__(self, points=None, indices=None, velocities=None, support_winding_number=False):
2145
2752
  """Class representing a triangle mesh.
2146
2753
 
2147
2754
  Attributes:
@@ -2152,6 +2759,7 @@ class Mesh:
2152
2759
  points (:class:`warp.array`): Array of vertex positions of type :class:`warp.vec3`
2153
2760
  indices (:class:`warp.array`): Array of triangle indices of type :class:`warp.int32`, should be a 1d array with shape (num_tris, 3)
2154
2761
  velocities (:class:`warp.array`): Array of vertex velocities of type :class:`warp.vec3` (optional)
2762
+ support_winding_number (bool): If true the mesh will build additional datastructures to support `wp.mesh_query_point_sign_winding_number()` queries
2155
2763
  """
2156
2764
 
2157
2765
  if points.device != indices.device:
@@ -2183,6 +2791,7 @@ class Mesh:
2183
2791
  indices.__ctype__(),
2184
2792
  int(len(points)),
2185
2793
  int(indices.size / 3),
2794
+ int(support_winding_number),
2186
2795
  )
2187
2796
  else:
2188
2797
  self.id = runtime.core.mesh_create_device(
@@ -2192,6 +2801,7 @@ class Mesh:
2192
2801
  indices.__ctype__(),
2193
2802
  int(len(points)),
2194
2803
  int(indices.size / 3),
2804
+ int(support_winding_number),
2195
2805
  )
2196
2806
 
2197
2807
  def __del__(self):
@@ -2204,7 +2814,7 @@ class Mesh:
2204
2814
  # use CUDA context guard to avoid side effects during garbage collection
2205
2815
  with self.device.context_guard:
2206
2816
  runtime.core.mesh_destroy_device(self.id)
2207
- except:
2817
+ except Exception:
2208
2818
  pass
2209
2819
 
2210
2820
  def refit(self):
@@ -2220,16 +2830,14 @@ class Mesh:
2220
2830
 
2221
2831
 
2222
2832
  class Volume:
2833
+ #: Enum value to specify nearest-neighbor interpolation during sampling
2223
2834
  CLOSEST = constant(0)
2835
+ #: Enum value to specify trilinear interpolation during sampling
2224
2836
  LINEAR = constant(1)
2225
2837
 
2226
2838
  def __init__(self, data: array):
2227
2839
  """Class representing a sparse grid.
2228
2840
 
2229
- Attributes:
2230
- CLOSEST (int): Enum value to specify nearest-neighbor interpolation during sampling
2231
- LINEAR (int): Enum value to specify trilinear interpolation during sampling
2232
-
2233
2841
  Args:
2234
2842
  data (:class:`warp.array`): Array of bytes representing the volume in NanoVDB format
2235
2843
  """
@@ -2271,19 +2879,20 @@ class Volume:
2271
2879
  with self.device.context_guard:
2272
2880
  runtime.core.volume_destroy_device(self.id)
2273
2881
 
2274
- except:
2882
+ except Exception:
2275
2883
  pass
2276
2884
 
2277
- def array(self):
2885
+ def array(self) -> array:
2886
+ """Returns the raw memory buffer of the Volume as an array"""
2278
2887
  buf = ctypes.c_void_p(0)
2279
2888
  size = ctypes.c_uint64(0)
2280
2889
  if self.device.is_cpu:
2281
2890
  self.context.core.volume_get_buffer_info_host(self.id, ctypes.byref(buf), ctypes.byref(size))
2282
2891
  else:
2283
2892
  self.context.core.volume_get_buffer_info_device(self.id, ctypes.byref(buf), ctypes.byref(size))
2284
- return array(ptr=buf.value, dtype=uint8, length=size.value, device=self.device, owner=False)
2893
+ return array(ptr=buf.value, dtype=uint8, shape=size.value, device=self.device, owner=False)
2285
2894
 
2286
- def get_tiles(self):
2895
+ def get_tiles(self) -> array:
2287
2896
  if self.id == 0:
2288
2897
  raise RuntimeError("Invalid Volume")
2289
2898
 
@@ -2294,11 +2903,9 @@ class Volume:
2294
2903
  else:
2295
2904
  self.context.core.volume_get_tiles_device(self.id, ctypes.byref(buf), ctypes.byref(size))
2296
2905
  num_tiles = size.value // (3 * 4)
2297
- return array(
2298
- ptr=buf.value, dtype=int32, shape=(num_tiles, 3), length=size.value, device=self.device, owner=True
2299
- )
2906
+ return array(ptr=buf.value, dtype=int32, shape=(num_tiles, 3), device=self.device, owner=True)
2300
2907
 
2301
- def get_voxel_size(self):
2908
+ def get_voxel_size(self) -> Tuple[float, float, float]:
2302
2909
  if self.id == 0:
2303
2910
  raise RuntimeError("Invalid Volume")
2304
2911
 
@@ -2307,7 +2914,13 @@ class Volume:
2307
2914
  return (dx.value, dy.value, dz.value)
2308
2915
 
2309
2916
  @classmethod
2310
- def load_from_nvdb(cls, file_or_buffer, device=None):
2917
+ def load_from_nvdb(cls, file_or_buffer, device=None) -> Volume:
2918
+ """Creates a Volume object from a NanoVDB file or in-memory buffer.
2919
+
2920
+ Returns:
2921
+
2922
+ A ``warp.Volume`` object.
2923
+ """
2311
2924
  try:
2312
2925
  data = file_or_buffer.read()
2313
2926
  except AttributeError:
@@ -2336,6 +2949,90 @@ class Volume:
2336
2949
  data_array = array(np.frombuffer(grid_data, dtype=np.byte), device=device)
2337
2950
  return cls(data_array)
2338
2951
 
2952
+ @classmethod
2953
+ def load_from_numpy(
2954
+ cls, ndarray: np.array, min_world=(0.0, 0.0, 0.0), voxel_size=1.0, bg_value=0.0, device=None
2955
+ ) -> Volume:
2956
+ """Creates a Volume object from a dense 3D NumPy array.
2957
+
2958
+ This function is only supported for CUDA devices.
2959
+
2960
+ Args:
2961
+ min_world: The 3D coordinate of the lower corner of the volume.
2962
+ voxel_size: The size of each voxel in spatial coordinates.
2963
+ bg_value: Background value
2964
+ device: The CUDA device to create the volume on, e.g.: "cuda" or "cuda:0".
2965
+
2966
+ Returns:
2967
+
2968
+ A ``warp.Volume`` object.
2969
+ """
2970
+
2971
+ import math
2972
+
2973
+ target_shape = (
2974
+ math.ceil(ndarray.shape[0] / 8) * 8,
2975
+ math.ceil(ndarray.shape[1] / 8) * 8,
2976
+ math.ceil(ndarray.shape[2] / 8) * 8,
2977
+ )
2978
+ if hasattr(bg_value, "__len__"):
2979
+ # vec3, assuming the numpy array is 4D
2980
+ padded_array = np.array((target_shape[0], target_shape[1], target_shape[2], 3), dtype=np.single)
2981
+ padded_array[:, :, :, :] = np.array(bg_value)
2982
+ padded_array[0 : ndarray.shape[0], 0 : ndarray.shape[1], 0 : ndarray.shape[2], :] = ndarray
2983
+ else:
2984
+ padded_amount = (
2985
+ math.ceil(ndarray.shape[0] / 8) * 8 - ndarray.shape[0],
2986
+ math.ceil(ndarray.shape[1] / 8) * 8 - ndarray.shape[1],
2987
+ math.ceil(ndarray.shape[2] / 8) * 8 - ndarray.shape[2],
2988
+ )
2989
+ padded_array = np.pad(
2990
+ ndarray,
2991
+ ((0, padded_amount[0]), (0, padded_amount[1]), (0, padded_amount[2])),
2992
+ mode="constant",
2993
+ constant_values=bg_value,
2994
+ )
2995
+
2996
+ shape = padded_array.shape
2997
+ volume = warp.Volume.allocate(
2998
+ min_world,
2999
+ [
3000
+ min_world[0] + (shape[0] - 1) * voxel_size,
3001
+ min_world[1] + (shape[1] - 1) * voxel_size,
3002
+ min_world[2] + (shape[2] - 1) * voxel_size,
3003
+ ],
3004
+ voxel_size,
3005
+ bg_value=bg_value,
3006
+ points_in_world_space=True,
3007
+ translation=min_world,
3008
+ device=device,
3009
+ )
3010
+
3011
+ # Populate volume
3012
+ if hasattr(bg_value, "__len__"):
3013
+ warp.launch(
3014
+ warp.utils.copy_dense_volume_to_nano_vdb_v,
3015
+ dim=(shape[0], shape[1], shape[2]),
3016
+ inputs=[volume.id, warp.array(padded_array, dtype=warp.vec3, device=device)],
3017
+ device=device,
3018
+ )
3019
+ elif isinstance(bg_value, int):
3020
+ warp.launch(
3021
+ warp.utils.copy_dense_volume_to_nano_vdb_i,
3022
+ dim=shape,
3023
+ inputs=[volume.id, warp.array(padded_array, dtype=warp.int32, device=device)],
3024
+ device=device,
3025
+ )
3026
+ else:
3027
+ warp.launch(
3028
+ warp.utils.copy_dense_volume_to_nano_vdb_f,
3029
+ dim=shape,
3030
+ inputs=[volume.id, warp.array(padded_array, dtype=warp.float32, device=device)],
3031
+ device=device,
3032
+ )
3033
+
3034
+ return volume
3035
+
2339
3036
  @classmethod
2340
3037
  def allocate(
2341
3038
  cls,
@@ -2346,9 +3043,11 @@ class Volume:
2346
3043
  translation=(0.0, 0.0, 0.0),
2347
3044
  points_in_world_space=False,
2348
3045
  device=None,
2349
- ):
3046
+ ) -> Volume:
2350
3047
  """Allocate a new Volume based on the bounding box defined by min and max.
2351
3048
 
3049
+ This function is only supported for CUDA devices.
3050
+
2352
3051
  Allocate a volume that is large enough to contain voxels [min[0], min[1], min[2]] - [max[0], max[1], max[2]], inclusive.
2353
3052
  If points_in_world_space is true, then min and max are first converted to index space with the given voxel size and
2354
3053
  translation, and the volume is allocated with those.
@@ -2357,12 +3056,12 @@ class Volume:
2357
3056
  the resulting tiles will be available in the new volume.
2358
3057
 
2359
3058
  Args:
2360
- min (array-like): Lower 3D-coordinates of the bounding box in index space or world space, inclusive
2361
- max (array-like): Upper 3D-coordinates of the bounding box in index space or world space, inclusive
2362
- voxel_size (float): Voxel size of the new volume
3059
+ min (array-like): Lower 3D coordinates of the bounding box in index space or world space, inclusive.
3060
+ max (array-like): Upper 3D coordinates of the bounding box in index space or world space, inclusive.
3061
+ voxel_size (float): Voxel size of the new volume.
2363
3062
  bg_value (float or array-like): Value of unallocated voxels of the volume, also defines the volume's type, a :class:`warp.vec3` volume is created if this is `array-like`, otherwise a float volume is created
2364
- translation (array-like): translation between the index and world spaces
2365
- device (Devicelike): Device the array lives on
3063
+ translation (array-like): translation between the index and world spaces.
3064
+ device (Devicelike): The CUDA device to create the volume on, e.g.: "cuda" or "cuda:0".
2366
3065
 
2367
3066
  """
2368
3067
  if points_in_world_space:
@@ -2387,9 +3086,11 @@ class Volume:
2387
3086
  @classmethod
2388
3087
  def allocate_by_tiles(
2389
3088
  cls, tile_points: array, voxel_size: float, bg_value=0.0, translation=(0.0, 0.0, 0.0), device=None
2390
- ):
3089
+ ) -> Volume:
2391
3090
  """Allocate a new Volume with active tiles for each point tile_points.
2392
3091
 
3092
+ This function is only supported for CUDA devices.
3093
+
2393
3094
  The smallest unit of allocation is a dense tile of 8x8x8 voxels.
2394
3095
  This is the primary method for allocating sparse volumes. It uses an array of points indicating the tiles that must be allocated.
2395
3096
 
@@ -2399,13 +3100,13 @@ class Volume:
2399
3100
 
2400
3101
  Args:
2401
3102
  tile_points (:class:`warp.array`): Array of positions that define the tiles to be allocated.
2402
- The array can be a 2d, N-by-3 array of :class:`warp.int32` values, indicating index space positions,
3103
+ The array can be a 2D, N-by-3 array of :class:`warp.int32` values, indicating index space positions,
2403
3104
  or can be a 1D array of :class:`warp.vec3` values, indicating world space positions.
2404
3105
  Repeated points per tile are allowed and will be efficiently deduplicated.
2405
- voxel_size (float): Voxel size of the new volume
3106
+ voxel_size (float): Voxel size of the new volume.
2406
3107
  bg_value (float or array-like): Value of unallocated voxels of the volume, also defines the volume's type, a :class:`warp.vec3` volume is created if this is `array-like`, otherwise a float volume is created
2407
- translation (array-like): translation between the index and world spaces
2408
- device (Devicelike): Device the array lives on
3108
+ translation (array-like): Translation between the index and world spaces.
3109
+ device (Devicelike): The CUDA device to create the volume on, e.g.: "cuda" or "cuda:0".
2409
3110
 
2410
3111
  """
2411
3112
  from warp.context import runtime
@@ -2442,7 +3143,7 @@ class Volume:
2442
3143
  translation[2],
2443
3144
  in_world_space,
2444
3145
  )
2445
- elif type(bg_value) == int:
3146
+ elif isinstance(bg_value, int):
2446
3147
  volume.id = volume.context.core.volume_i_from_tiles_device(
2447
3148
  volume.device.context,
2448
3149
  ctypes.c_void_p(tile_points.ptr),
@@ -2473,6 +3174,67 @@ class Volume:
2473
3174
  return volume
2474
3175
 
2475
3176
 
3177
+ # definition just for kernel type (cannot be a parameter), see mesh.h
3178
+ # NOTE: its layout must match the corresponding struct defined in C.
3179
+ # NOTE: it needs to be defined after `indexedarray` to workaround a circular import issue.
3180
+ class mesh_query_point_t:
3181
+ """Output for the mesh query point functions.
3182
+
3183
+ Attributes:
3184
+ result (bool): Whether a point is found within the given constraints.
3185
+ sign (float32): A value < 0 if query point is inside the mesh, >=0 otherwise.
3186
+ Note that mesh must be watertight for this to be robust
3187
+ face (int32): Index of the closest face.
3188
+ u (float32): Barycentric u coordinate of the closest point.
3189
+ v (float32): Barycentric v coordinate of the closest point.
3190
+
3191
+ See Also:
3192
+ :func:`mesh_query_point`, :func:`mesh_query_point_no_sign`,
3193
+ :func:`mesh_query_furthest_point_no_sign`,
3194
+ :func:`mesh_query_point_sign_normal`,
3195
+ and :func:`mesh_query_point_sign_winding_number`.
3196
+ """
3197
+ from warp.codegen import Var
3198
+
3199
+ vars = {
3200
+ "result": Var("result", bool),
3201
+ "sign": Var("sign", float32),
3202
+ "face": Var("face", int32),
3203
+ "u": Var("u", float32),
3204
+ "v": Var("v", float32),
3205
+ }
3206
+
3207
+
3208
+ # definition just for kernel type (cannot be a parameter), see mesh.h
3209
+ # NOTE: its layout must match the corresponding struct defined in C.
3210
+ class mesh_query_ray_t:
3211
+ """Output for the mesh query ray functions.
3212
+
3213
+ Attributes:
3214
+ result (bool): Whether a hit is found within the given constraints.
3215
+ sign (float32): A value > 0 if the ray hit in front of the face, returns < 0 otherwise.
3216
+ face (int32): Index of the closest face.
3217
+ t (float32): Distance of the closest hit along the ray.
3218
+ u (float32): Barycentric u coordinate of the closest hit.
3219
+ v (float32): Barycentric v coordinate of the closest hit.
3220
+ normal (vec3f): Face normal.
3221
+
3222
+ See Also:
3223
+ :func:`mesh_query_ray`.
3224
+ """
3225
+ from warp.codegen import Var
3226
+
3227
+ vars = {
3228
+ "result": Var("result", bool),
3229
+ "sign": Var("sign", float32),
3230
+ "face": Var("face", int32),
3231
+ "t": Var("t", float32),
3232
+ "u": Var("u", float32),
3233
+ "v": Var("v", float32),
3234
+ "normal": Var("normal", vec3),
3235
+ }
3236
+
3237
+
2476
3238
  def matmul(
2477
3239
  a: array2d,
2478
3240
  b: array2d,
@@ -2480,7 +3242,7 @@ def matmul(
2480
3242
  d: array2d,
2481
3243
  alpha: float = 1.0,
2482
3244
  beta: float = 0.0,
2483
- allow_tf32x3_arith: bool = False,
3245
+ allow_tf32x3_arith: builtins.bool = False,
2484
3246
  device=None,
2485
3247
  ):
2486
3248
  """Computes a generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
@@ -2509,6 +3271,11 @@ def matmul(
2509
3271
  "wp.matmul currently only supports operation between {A, B, C, D} matrices of the same type."
2510
3272
  )
2511
3273
 
3274
+ if (not a.is_contiguous and not a.is_transposed) or (not b.is_contiguous and not b.is_transposed) or (not c.is_contiguous) or (not d.is_contiguous):
3275
+ raise RuntimeError(
3276
+ "wp.matmul is only valid for contiguous arrays, with the exception that A and/or B may be transposed."
3277
+ )
3278
+
2512
3279
  m = a.shape[0]
2513
3280
  n = b.shape[1]
2514
3281
  k = a.shape[1]
@@ -2543,13 +3310,13 @@ def matmul(
2543
3310
  ctypes.c_void_p(d.ptr),
2544
3311
  alpha,
2545
3312
  beta,
2546
- True,
2547
- True,
3313
+ not a.is_transposed,
3314
+ not b.is_transposed,
2548
3315
  allow_tf32x3_arith,
2549
3316
  1,
2550
3317
  )
2551
3318
  if not ret:
2552
- raise RuntimeError("Matmul failed.")
3319
+ raise RuntimeError("matmul failed.")
2553
3320
 
2554
3321
 
2555
3322
  def adj_matmul(
@@ -2562,7 +3329,7 @@ def adj_matmul(
2562
3329
  adj_d: array2d,
2563
3330
  alpha: float = 1.0,
2564
3331
  beta: float = 0.0,
2565
- allow_tf32x3_arith: bool = False,
3332
+ allow_tf32x3_arith: builtins.bool = False,
2566
3333
  device=None,
2567
3334
  ):
2568
3335
  """Computes the adjoint of a generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
@@ -2613,6 +3380,19 @@ def adj_matmul(
2613
3380
  "wp.adj_matmul currently only supports operation between {A, B, C, adj_D, adj_A, adj_B, adj_C} matrices of the same type."
2614
3381
  )
2615
3382
 
3383
+ if (
3384
+ (not a.is_contiguous and not a.is_transposed)
3385
+ or (not b.is_contiguous and not b.is_transposed)
3386
+ or (not c.is_contiguous)
3387
+ or (not adj_a.is_contiguous and not adj_a.is_transposed)
3388
+ or (not adj_b.is_contiguous and not adj_b.is_transposed)
3389
+ or (not adj_c.is_contiguous)
3390
+ or (not adj_d.is_contiguous)
3391
+ ):
3392
+ raise RuntimeError(
3393
+ "wp.matmul is only valid for contiguous arrays, with the exception that A and/or B and their associated adjoints may be transposed."
3394
+ )
3395
+
2616
3396
  m = a.shape[0]
2617
3397
  n = b.shape[1]
2618
3398
  k = a.shape[1]
@@ -2633,75 +3413,105 @@ def adj_matmul(
2633
3413
 
2634
3414
  # cpu fallback if no cuda devices found
2635
3415
  if device == "cpu":
2636
- adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose()))
2637
- adj_b.assign(alpha * (a.numpy().transpose() @ adj_d.numpy()))
2638
- adj_c.assign(beta * adj_d.numpy())
3416
+ adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose()) + adj_a.numpy())
3417
+ adj_b.assign(alpha * (a.numpy().transpose() @ adj_d.numpy()) + adj_b.numpy())
3418
+ adj_c.assign(beta * adj_d.numpy() + adj_c.numpy())
2639
3419
  return
2640
3420
 
2641
3421
  cc = device.arch
2642
3422
 
2643
3423
  # adj_a
2644
- ret = runtime.core.cutlass_gemm(
2645
- cc,
2646
- m,
2647
- k,
2648
- n,
2649
- type_typestr(a.dtype).encode(),
2650
- ctypes.c_void_p(adj_d.ptr),
2651
- ctypes.c_void_p(b.ptr),
2652
- ctypes.c_void_p(a.ptr),
2653
- ctypes.c_void_p(adj_a.ptr),
2654
- alpha,
2655
- 0.0,
2656
- True,
2657
- False,
2658
- allow_tf32x3_arith,
2659
- 1,
2660
- )
2661
- if not ret:
2662
- raise RuntimeError("adj_matmul failed.")
3424
+ if not a.is_transposed:
3425
+ ret = runtime.core.cutlass_gemm(
3426
+ cc,
3427
+ m,
3428
+ k,
3429
+ n,
3430
+ type_typestr(a.dtype).encode(),
3431
+ ctypes.c_void_p(adj_d.ptr),
3432
+ ctypes.c_void_p(b.ptr),
3433
+ ctypes.c_void_p(adj_a.ptr),
3434
+ ctypes.c_void_p(adj_a.ptr),
3435
+ alpha,
3436
+ 1.0,
3437
+ True,
3438
+ b.is_transposed,
3439
+ allow_tf32x3_arith,
3440
+ 1,
3441
+ )
3442
+ if not ret:
3443
+ raise RuntimeError("adj_matmul failed.")
3444
+ else:
3445
+ ret = runtime.core.cutlass_gemm(
3446
+ cc,
3447
+ k,
3448
+ m,
3449
+ n,
3450
+ type_typestr(a.dtype).encode(),
3451
+ ctypes.c_void_p(b.ptr),
3452
+ ctypes.c_void_p(adj_d.ptr),
3453
+ ctypes.c_void_p(adj_a.ptr),
3454
+ ctypes.c_void_p(adj_a.ptr),
3455
+ alpha,
3456
+ 1.0,
3457
+ not b.is_transposed,
3458
+ False,
3459
+ allow_tf32x3_arith,
3460
+ 1,
3461
+ )
3462
+ if not ret:
3463
+ raise RuntimeError("adj_matmul failed.")
2663
3464
 
2664
3465
  # adj_b
2665
- ret = runtime.core.cutlass_gemm(
2666
- cc,
2667
- k,
2668
- n,
2669
- m,
2670
- type_typestr(a.dtype).encode(),
2671
- ctypes.c_void_p(a.ptr),
2672
- ctypes.c_void_p(adj_d.ptr),
2673
- ctypes.c_void_p(b.ptr),
2674
- ctypes.c_void_p(adj_b.ptr),
2675
- alpha,
2676
- 0.0,
2677
- False,
2678
- True,
2679
- allow_tf32x3_arith,
2680
- 1,
2681
- )
2682
- if not ret:
2683
- raise RuntimeError("adj_matmul failed.")
3466
+ if not b.is_transposed:
3467
+ ret = runtime.core.cutlass_gemm(
3468
+ cc,
3469
+ k,
3470
+ n,
3471
+ m,
3472
+ type_typestr(a.dtype).encode(),
3473
+ ctypes.c_void_p(a.ptr),
3474
+ ctypes.c_void_p(adj_d.ptr),
3475
+ ctypes.c_void_p(adj_b.ptr),
3476
+ ctypes.c_void_p(adj_b.ptr),
3477
+ alpha,
3478
+ 1.0,
3479
+ a.is_transposed,
3480
+ True,
3481
+ allow_tf32x3_arith,
3482
+ 1,
3483
+ )
3484
+ if not ret:
3485
+ raise RuntimeError("adj_matmul failed.")
3486
+ else:
3487
+ ret = runtime.core.cutlass_gemm(
3488
+ cc,
3489
+ n,
3490
+ k,
3491
+ m,
3492
+ type_typestr(a.dtype).encode(),
3493
+ ctypes.c_void_p(adj_d.ptr),
3494
+ ctypes.c_void_p(a.ptr),
3495
+ ctypes.c_void_p(adj_b.ptr),
3496
+ ctypes.c_void_p(adj_b.ptr),
3497
+ alpha,
3498
+ 1.0,
3499
+ False,
3500
+ not a.is_transposed,
3501
+ allow_tf32x3_arith,
3502
+ 1,
3503
+ )
3504
+ if not ret:
3505
+ raise RuntimeError("adj_matmul failed.")
2684
3506
 
2685
3507
  # adj_c
2686
- ret = runtime.core.cutlass_gemm(
2687
- cc,
2688
- m,
2689
- n,
2690
- k,
2691
- type_typestr(a.dtype).encode(),
2692
- ctypes.c_void_p(a.ptr),
2693
- ctypes.c_void_p(b.ptr),
2694
- ctypes.c_void_p(adj_d.ptr),
2695
- ctypes.c_void_p(adj_c.ptr),
2696
- 0.0,
2697
- beta,
2698
- True,
2699
- True,
2700
- allow_tf32x3_arith,
2701
- 1,
3508
+ warp.launch(
3509
+ kernel=warp.utils.add_kernel_2d,
3510
+ dim=adj_c.shape,
3511
+ inputs=[adj_c, adj_d, adj_d.dtype(beta)],
3512
+ device=device,
3513
+ record_tape=False
2702
3514
  )
2703
- if not ret:
2704
- raise RuntimeError("adj_matmul failed.")
2705
3515
 
2706
3516
 
2707
3517
  def batched_matmul(
@@ -2711,7 +3521,7 @@ def batched_matmul(
2711
3521
  d: array3d,
2712
3522
  alpha: float = 1.0,
2713
3523
  beta: float = 0.0,
2714
- allow_tf32x3_arith: bool = False,
3524
+ allow_tf32x3_arith: builtins.bool = False,
2715
3525
  device=None,
2716
3526
  ):
2717
3527
  """Computes a batched generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
@@ -2740,6 +3550,11 @@ def batched_matmul(
2740
3550
  "wp.batched_matmul currently only supports operation between {A, B, C, D} matrices of the same type."
2741
3551
  )
2742
3552
 
3553
+ if (not a.is_contiguous and not a.is_transposed) or (not b.is_contiguous and not b.is_transposed) or (not c.is_contiguous) or (not d.is_contiguous):
3554
+ raise RuntimeError(
3555
+ "wp.matmul is only valid for contiguous arrays, with the exception that A and/or B may be transposed."
3556
+ )
3557
+
2743
3558
  m = a.shape[1]
2744
3559
  n = b.shape[2]
2745
3560
  k = a.shape[2]
@@ -2751,7 +3566,7 @@ def batched_matmul(
2751
3566
 
2752
3567
  if runtime.tape:
2753
3568
  runtime.tape.record_func(
2754
- backward=lambda: adj_matmul(
3569
+ backward=lambda: adj_batched_matmul(
2755
3570
  a, b, c, a.grad, b.grad, c.grad, d.grad, alpha, beta, allow_tf32x3_arith, device
2756
3571
  ),
2757
3572
  arrays=[a, b, c, d],
@@ -2762,26 +3577,55 @@ def batched_matmul(
2762
3577
  d.assign(alpha * np.matmul(a.numpy(), b.numpy()) + beta * c.numpy())
2763
3578
  return
2764
3579
 
3580
+ # handle case in which batch_count exceeds max_batch_count, which is a CUDA array size maximum
3581
+ max_batch_count = 65535
3582
+ iters = int(batch_count / max_batch_count)
3583
+ remainder = batch_count % max_batch_count
3584
+
2765
3585
  cc = device.arch
3586
+ for i in range(iters):
3587
+ idx_start = i * max_batch_count
3588
+ idx_end = (i + 1) * max_batch_count if i < iters - 1 else batch_count
3589
+ ret = runtime.core.cutlass_gemm(
3590
+ cc,
3591
+ m,
3592
+ n,
3593
+ k,
3594
+ type_typestr(a.dtype).encode(),
3595
+ ctypes.c_void_p(a[idx_start:idx_end,:,:].ptr),
3596
+ ctypes.c_void_p(b[idx_start:idx_end,:,:].ptr),
3597
+ ctypes.c_void_p(c[idx_start:idx_end,:,:].ptr),
3598
+ ctypes.c_void_p(d[idx_start:idx_end,:,:].ptr),
3599
+ alpha,
3600
+ beta,
3601
+ not a.is_transposed,
3602
+ not b.is_transposed,
3603
+ allow_tf32x3_arith,
3604
+ max_batch_count,
3605
+ )
3606
+ if not ret:
3607
+ raise RuntimeError("Batched matmul failed.")
3608
+
3609
+ idx_start = iters * max_batch_count
2766
3610
  ret = runtime.core.cutlass_gemm(
2767
3611
  cc,
2768
3612
  m,
2769
3613
  n,
2770
3614
  k,
2771
3615
  type_typestr(a.dtype).encode(),
2772
- ctypes.c_void_p(a.ptr),
2773
- ctypes.c_void_p(b.ptr),
2774
- ctypes.c_void_p(c.ptr),
2775
- ctypes.c_void_p(d.ptr),
3616
+ ctypes.c_void_p(a[idx_start:,:,:].ptr),
3617
+ ctypes.c_void_p(b[idx_start:,:,:].ptr),
3618
+ ctypes.c_void_p(c[idx_start:,:,:].ptr),
3619
+ ctypes.c_void_p(d[idx_start:,:,:].ptr),
2776
3620
  alpha,
2777
3621
  beta,
2778
- True,
2779
- True,
3622
+ not a.is_transposed,
3623
+ not b.is_transposed,
2780
3624
  allow_tf32x3_arith,
2781
- batch_count,
3625
+ remainder,
2782
3626
  )
2783
3627
  if not ret:
2784
- raise RuntimeError("Batched matmul failed.")
3628
+ raise RuntimeError("Batched matmul failed.")
2785
3629
 
2786
3630
 
2787
3631
  def adj_batched_matmul(
@@ -2794,7 +3638,7 @@ def adj_batched_matmul(
2794
3638
  adj_d: array3d,
2795
3639
  alpha: float = 1.0,
2796
3640
  beta: float = 0.0,
2797
- allow_tf32x3_arith: bool = False,
3641
+ allow_tf32x3_arith: builtins.bool = False,
2798
3642
  device=None,
2799
3643
  ):
2800
3644
  """Computes a batched generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
@@ -2861,78 +3705,215 @@ def adj_batched_matmul(
2861
3705
  )
2862
3706
  )
2863
3707
 
3708
+ if (
3709
+ (not a.is_contiguous and not a.is_transposed)
3710
+ or (not b.is_contiguous and not b.is_transposed)
3711
+ or (not c.is_contiguous)
3712
+ or (not adj_a.is_contiguous and not adj_a.is_transposed)
3713
+ or (not adj_b.is_contiguous and not adj_b.is_transposed)
3714
+ or (not adj_c.is_contiguous)
3715
+ or (not adj_d.is_contiguous)
3716
+ ):
3717
+ raise RuntimeError(
3718
+ "wp.matmul is only valid for contiguous arrays, with the exception that A and/or B and their associated adjoints may be transposed."
3719
+ )
3720
+
2864
3721
  # cpu fallback if no cuda devices found
2865
3722
  if device == "cpu":
2866
- adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose((0, 2, 1))))
2867
- adj_b.assign(alpha * np.matmul(a.numpy().transpose((0, 2, 1)), adj_d.numpy()))
2868
- adj_c.assign(beta * adj_d.numpy())
3723
+ adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose((0, 2, 1))) + adj_a.numpy())
3724
+ adj_b.assign(alpha * np.matmul(a.numpy().transpose((0, 2, 1)), adj_d.numpy()) + adj_b.numpy())
3725
+ adj_c.assign(beta * adj_d.numpy() + adj_c.numpy())
2869
3726
  return
2870
3727
 
3728
+ # handle case in which batch_count exceeds max_batch_count, which is a CUDA array size maximum
3729
+ max_batch_count = 65535
3730
+ iters = int(batch_count / max_batch_count)
3731
+ remainder = batch_count % max_batch_count
3732
+
2871
3733
  cc = device.arch
2872
3734
 
3735
+ for i in range(iters):
3736
+ idx_start = i * max_batch_count
3737
+ idx_end = (i + 1) * max_batch_count if i < iters - 1 else batch_count
3738
+
3739
+ # adj_a
3740
+ if not a.is_transposed:
3741
+ ret = runtime.core.cutlass_gemm(
3742
+ cc,
3743
+ m,
3744
+ k,
3745
+ n,
3746
+ type_typestr(a.dtype).encode(),
3747
+ ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
3748
+ ctypes.c_void_p(b[idx_start:idx_end,:,:].ptr),
3749
+ ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
3750
+ ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
3751
+ alpha,
3752
+ 1.0,
3753
+ True,
3754
+ b.is_transposed,
3755
+ allow_tf32x3_arith,
3756
+ max_batch_count,
3757
+ )
3758
+ if not ret:
3759
+ raise RuntimeError("adj_matmul failed.")
3760
+ else:
3761
+ ret = runtime.core.cutlass_gemm(
3762
+ cc,
3763
+ k,
3764
+ m,
3765
+ n,
3766
+ type_typestr(a.dtype).encode(),
3767
+ ctypes.c_void_p(b[idx_start:idx_end,:,:].ptr),
3768
+ ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
3769
+ ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
3770
+ ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
3771
+ alpha,
3772
+ 1.0,
3773
+ not b.is_transposed,
3774
+ False,
3775
+ allow_tf32x3_arith,
3776
+ max_batch_count,
3777
+ )
3778
+ if not ret:
3779
+ raise RuntimeError("adj_matmul failed.")
3780
+
3781
+ # adj_b
3782
+ if not b.is_transposed:
3783
+ ret = runtime.core.cutlass_gemm(
3784
+ cc,
3785
+ k,
3786
+ n,
3787
+ m,
3788
+ type_typestr(a.dtype).encode(),
3789
+ ctypes.c_void_p(a[idx_start:idx_end,:,:].ptr),
3790
+ ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
3791
+ ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
3792
+ ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
3793
+ alpha,
3794
+ 1.0,
3795
+ a.is_transposed,
3796
+ True,
3797
+ allow_tf32x3_arith,
3798
+ max_batch_count,
3799
+ )
3800
+ if not ret:
3801
+ raise RuntimeError("adj_matmul failed.")
3802
+ else:
3803
+ ret = runtime.core.cutlass_gemm(
3804
+ cc,
3805
+ n,
3806
+ k,
3807
+ m,
3808
+ type_typestr(a.dtype).encode(),
3809
+ ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
3810
+ ctypes.c_void_p(a[idx_start:idx_end,:,:].ptr),
3811
+ ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
3812
+ ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
3813
+ alpha,
3814
+ 1.0,
3815
+ False,
3816
+ not a.is_transposed,
3817
+ allow_tf32x3_arith,
3818
+ max_batch_count,
3819
+ )
3820
+ if not ret:
3821
+ raise RuntimeError("adj_matmul failed.")
3822
+
3823
+ idx_start = iters * max_batch_count
3824
+
2873
3825
  # adj_a
2874
- ret = runtime.core.cutlass_gemm(
2875
- cc,
2876
- m,
2877
- k,
2878
- n,
2879
- type_typestr(a.dtype).encode(),
2880
- ctypes.c_void_p(adj_d.ptr),
2881
- ctypes.c_void_p(b.ptr),
2882
- ctypes.c_void_p(a.ptr),
2883
- ctypes.c_void_p(adj_a.ptr),
2884
- alpha,
2885
- 0.0,
2886
- True,
2887
- False,
2888
- allow_tf32x3_arith,
2889
- batch_count,
2890
- )
2891
- if not ret:
2892
- raise RuntimeError("adj_matmul failed.")
3826
+ if not a.is_transposed:
3827
+ ret = runtime.core.cutlass_gemm(
3828
+ cc,
3829
+ m,
3830
+ k,
3831
+ n,
3832
+ type_typestr(a.dtype).encode(),
3833
+ ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
3834
+ ctypes.c_void_p(b[idx_start:,:,:].ptr),
3835
+ ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
3836
+ ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
3837
+ alpha,
3838
+ 1.0,
3839
+ True,
3840
+ b.is_transposed,
3841
+ allow_tf32x3_arith,
3842
+ remainder,
3843
+ )
3844
+ if not ret:
3845
+ raise RuntimeError("adj_matmul failed.")
3846
+ else:
3847
+ ret = runtime.core.cutlass_gemm(
3848
+ cc,
3849
+ k,
3850
+ m,
3851
+ n,
3852
+ type_typestr(a.dtype).encode(),
3853
+ ctypes.c_void_p(b[idx_start:,:,:].ptr),
3854
+ ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
3855
+ ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
3856
+ ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
3857
+ alpha,
3858
+ 1.0,
3859
+ not b.is_transposed,
3860
+ False,
3861
+ allow_tf32x3_arith,
3862
+ remainder,
3863
+ )
3864
+ if not ret:
3865
+ raise RuntimeError("adj_matmul failed.")
2893
3866
 
2894
3867
  # adj_b
2895
- ret = runtime.core.cutlass_gemm(
2896
- cc,
2897
- k,
2898
- n,
2899
- m,
2900
- type_typestr(a.dtype).encode(),
2901
- ctypes.c_void_p(a.ptr),
2902
- ctypes.c_void_p(adj_d.ptr),
2903
- ctypes.c_void_p(b.ptr),
2904
- ctypes.c_void_p(adj_b.ptr),
2905
- alpha,
2906
- 0.0,
2907
- False,
2908
- True,
2909
- allow_tf32x3_arith,
2910
- batch_count,
2911
- )
2912
- if not ret:
2913
- raise RuntimeError("adj_matmul failed.")
3868
+ if not b.is_transposed:
3869
+ ret = runtime.core.cutlass_gemm(
3870
+ cc,
3871
+ k,
3872
+ n,
3873
+ m,
3874
+ type_typestr(a.dtype).encode(),
3875
+ ctypes.c_void_p(a[idx_start:,:,:].ptr),
3876
+ ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
3877
+ ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
3878
+ ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
3879
+ alpha,
3880
+ 1.0,
3881
+ a.is_transposed,
3882
+ True,
3883
+ allow_tf32x3_arith,
3884
+ remainder,
3885
+ )
3886
+ if not ret:
3887
+ raise RuntimeError("adj_matmul failed.")
3888
+ else:
3889
+ ret = runtime.core.cutlass_gemm(
3890
+ cc,
3891
+ n,
3892
+ k,
3893
+ m,
3894
+ type_typestr(a.dtype).encode(),
3895
+ ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
3896
+ ctypes.c_void_p(a[idx_start:,:,:].ptr),
3897
+ ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
3898
+ ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
3899
+ alpha,
3900
+ 1.0,
3901
+ False,
3902
+ not a.is_transposed,
3903
+ allow_tf32x3_arith,
3904
+ remainder,
3905
+ )
3906
+ if not ret:
3907
+ raise RuntimeError("adj_matmul failed.")
2914
3908
 
2915
3909
  # adj_c
2916
- ret = runtime.core.cutlass_gemm(
2917
- cc,
2918
- m,
2919
- n,
2920
- k,
2921
- type_typestr(a.dtype).encode(),
2922
- ctypes.c_void_p(a.ptr),
2923
- ctypes.c_void_p(b.ptr),
2924
- ctypes.c_void_p(adj_d.ptr),
2925
- ctypes.c_void_p(adj_c.ptr),
2926
- 0.0,
2927
- beta,
2928
- True,
2929
- True,
2930
- allow_tf32x3_arith,
2931
- batch_count,
3910
+ warp.launch(
3911
+ kernel=warp.utils.add_kernel_3d,
3912
+ dim=adj_c.shape,
3913
+ inputs=[adj_c, adj_d, adj_d.dtype(beta)],
3914
+ device=device,
3915
+ record_tape=False
2932
3916
  )
2933
- if not ret:
2934
- raise RuntimeError("adj_matmul failed.")
2935
-
2936
3917
 
2937
3918
  class HashGrid:
2938
3919
  def __init__(self, dim_x, dim_y, dim_z, device=None):
@@ -3001,7 +3982,7 @@ class HashGrid:
3001
3982
  with self.device.context_guard:
3002
3983
  runtime.core.hash_grid_destroy_device(self.id)
3003
3984
 
3004
- except:
3985
+ except Exception:
3005
3986
  pass
3006
3987
 
3007
3988
 
@@ -3075,7 +4056,7 @@ class MarchingCubes:
3075
4056
 
3076
4057
  if error:
3077
4058
  raise RuntimeError(
3078
- "Error occured buffers may not be large enough, marching cubes required at least {num_verts} vertices, and {num_tris} triangles."
4059
+ "Buffers may not be large enough, marching cubes required at least {num_verts} vertices, and {num_tris} triangles."
3079
4060
  )
3080
4061
 
3081
4062
  # resize the geometry arrays
@@ -3131,7 +4112,7 @@ def type_matches_template(arg_type, template_type):
3131
4112
  return True
3132
4113
  elif is_array(template_type):
3133
4114
  # ensure the argument type is a non-generic array with matching dtype and dimensionality
3134
- if type(arg_type) != type(template_type):
4115
+ if type(arg_type) is not type(template_type):
3135
4116
  return False
3136
4117
  if not type_matches_template(arg_type.dtype, template_type.dtype):
3137
4118
  return False
@@ -3160,9 +4141,53 @@ def type_matches_template(arg_type, template_type):
3160
4141
  return True
3161
4142
 
3162
4143
 
4144
+ def infer_argument_types(args, template_types, arg_names=None):
4145
+ """Resolve argument types with the given list of template types."""
4146
+
4147
+ if len(args) != len(template_types):
4148
+ raise RuntimeError("Number of arguments must match number of template types.")
4149
+
4150
+ arg_types = []
4151
+
4152
+ for i in range(len(args)):
4153
+ arg = args[i]
4154
+ arg_type = type(arg)
4155
+ arg_name = arg_names[i] if arg_names else str(i)
4156
+ if arg_type in warp.types.array_types:
4157
+ arg_types.append(arg_type(dtype=arg.dtype, ndim=arg.ndim))
4158
+ elif arg_type in warp.types.scalar_types:
4159
+ arg_types.append(arg_type)
4160
+ elif arg_type in [int, float]:
4161
+ # canonicalize type
4162
+ arg_types.append(warp.types.type_to_warp(arg_type))
4163
+ elif hasattr(arg_type, "_wp_scalar_type_"):
4164
+ # vector/matrix type
4165
+ arg_types.append(arg_type)
4166
+ elif issubclass(arg_type, warp.codegen.StructInstance):
4167
+ # a struct
4168
+ arg_types.append(arg._cls)
4169
+ # elif arg_type in [warp.types.launch_bounds_t, warp.types.shape_t, warp.types.range_t]:
4170
+ # arg_types.append(arg_type)
4171
+ # elif arg_type in [warp.hash_grid_query_t, warp.mesh_query_aabb_t, warp.mesh_query_point_t, warp.mesh_query_ray_t, warp.bvh_query_t]:
4172
+ # arg_types.append(arg_type)
4173
+ elif arg is None:
4174
+ # allow passing None for arrays
4175
+ t = template_types[i]
4176
+ if warp.types.is_array(t):
4177
+ arg_types.append(type(t)(dtype=t.dtype, ndim=t.ndim))
4178
+ else:
4179
+ raise TypeError(f"Unable to infer the type of argument '{arg_name}', got None")
4180
+ else:
4181
+ # TODO: attempt to figure out if it's a vector/matrix type given as a numpy array, list, etc.
4182
+ raise TypeError(f"Unable to infer the type of argument '{arg_name}', got {arg_type}")
4183
+
4184
+ return arg_types
4185
+
4186
+
3163
4187
  simple_type_codes = {
3164
4188
  int: "i4",
3165
4189
  float: "f4",
4190
+ builtins.bool: "b",
3166
4191
  bool: "b",
3167
4192
  str: "str", # accepted by print()
3168
4193
  int8: "i1",
@@ -3181,6 +4206,8 @@ simple_type_codes = {
3181
4206
  launch_bounds_t: "lb",
3182
4207
  hash_grid_query_t: "hgq",
3183
4208
  mesh_query_aabb_t: "mqa",
4209
+ mesh_query_point_t: "mqp",
4210
+ mesh_query_ray_t: "mqr",
3184
4211
  bvh_query_t: "bvhq",
3185
4212
  }
3186
4213
 
@@ -3197,14 +4224,14 @@ def get_type_code(arg_type):
3197
4224
  # check for "special" vector/matrix subtypes
3198
4225
  if hasattr(arg_type, "_wp_generic_type_str_"):
3199
4226
  type_str = arg_type._wp_generic_type_str_
3200
- if type_str == "quaternion":
4227
+ if type_str == "quat_t":
3201
4228
  return f"q{dtype_code}"
3202
4229
  elif type_str == "transform_t":
3203
4230
  return f"t{dtype_code}"
3204
- elif type_str == "spatial_vector_t":
3205
- return f"sv{dtype_code}"
3206
- elif type_str == "spatial_matrix_t":
3207
- return f"sm{dtype_code}"
4231
+ # elif type_str == "spatial_vector_t":
4232
+ # return f"sv{dtype_code}"
4233
+ # elif type_str == "spatial_matrix_t":
4234
+ # return f"sm{dtype_code}"
3208
4235
  # generic vector/matrix
3209
4236
  ndim = len(arg_type._shape_)
3210
4237
  if ndim == 1:
@@ -3227,6 +4254,10 @@ def get_type_code(arg_type):
3227
4254
  return f"a{arg_type.ndim}{get_type_code(arg_type.dtype)}"
3228
4255
  elif isinstance(arg_type, indexedarray):
3229
4256
  return f"ia{arg_type.ndim}{get_type_code(arg_type.dtype)}"
4257
+ elif isinstance(arg_type, fabricarray):
4258
+ return f"fa{arg_type.ndim}{get_type_code(arg_type.dtype)}"
4259
+ elif isinstance(arg_type, indexedfabricarray):
4260
+ return f"ifa{arg_type.ndim}{get_type_code(arg_type.dtype)}"
3230
4261
  elif isinstance(arg_type, warp.codegen.Struct):
3231
4262
  return warp.codegen.make_full_qualified_name(arg_type.cls)
3232
4263
  elif arg_type == Scalar: