warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (300) hide show
  1. warp/__init__.py +10 -4
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +5 -3
  6. warp/build_dll.py +29 -9
  7. warp/builtins.py +868 -507
  8. warp/codegen.py +1074 -638
  9. warp/config.py +3 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +715 -222
  12. warp/fabric.py +326 -0
  13. warp/fem/__init__.py +27 -0
  14. warp/fem/cache.py +389 -0
  15. warp/fem/dirichlet.py +181 -0
  16. warp/fem/domain.py +263 -0
  17. warp/fem/field/__init__.py +101 -0
  18. warp/fem/field/field.py +149 -0
  19. warp/fem/field/nodal_field.py +299 -0
  20. warp/fem/field/restriction.py +21 -0
  21. warp/fem/field/test.py +181 -0
  22. warp/fem/field/trial.py +183 -0
  23. warp/fem/geometry/__init__.py +19 -0
  24. warp/fem/geometry/closest_point.py +70 -0
  25. warp/fem/geometry/deformed_geometry.py +271 -0
  26. warp/fem/geometry/element.py +744 -0
  27. warp/fem/geometry/geometry.py +186 -0
  28. warp/fem/geometry/grid_2d.py +373 -0
  29. warp/fem/geometry/grid_3d.py +435 -0
  30. warp/fem/geometry/hexmesh.py +953 -0
  31. warp/fem/geometry/partition.py +376 -0
  32. warp/fem/geometry/quadmesh_2d.py +532 -0
  33. warp/fem/geometry/tetmesh.py +840 -0
  34. warp/fem/geometry/trimesh_2d.py +577 -0
  35. warp/fem/integrate.py +1616 -0
  36. warp/fem/operator.py +191 -0
  37. warp/fem/polynomial.py +213 -0
  38. warp/fem/quadrature/__init__.py +2 -0
  39. warp/fem/quadrature/pic_quadrature.py +245 -0
  40. warp/fem/quadrature/quadrature.py +294 -0
  41. warp/fem/space/__init__.py +292 -0
  42. warp/fem/space/basis_space.py +489 -0
  43. warp/fem/space/collocated_function_space.py +105 -0
  44. warp/fem/space/dof_mapper.py +236 -0
  45. warp/fem/space/function_space.py +145 -0
  46. warp/fem/space/grid_2d_function_space.py +267 -0
  47. warp/fem/space/grid_3d_function_space.py +306 -0
  48. warp/fem/space/hexmesh_function_space.py +352 -0
  49. warp/fem/space/partition.py +350 -0
  50. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  51. warp/fem/space/restriction.py +160 -0
  52. warp/fem/space/shape/__init__.py +15 -0
  53. warp/fem/space/shape/cube_shape_function.py +738 -0
  54. warp/fem/space/shape/shape_function.py +103 -0
  55. warp/fem/space/shape/square_shape_function.py +611 -0
  56. warp/fem/space/shape/tet_shape_function.py +567 -0
  57. warp/fem/space/shape/triangle_shape_function.py +429 -0
  58. warp/fem/space/tetmesh_function_space.py +292 -0
  59. warp/fem/space/topology.py +295 -0
  60. warp/fem/space/trimesh_2d_function_space.py +221 -0
  61. warp/fem/types.py +77 -0
  62. warp/fem/utils.py +495 -0
  63. warp/native/array.h +147 -44
  64. warp/native/builtin.h +122 -149
  65. warp/native/bvh.cpp +73 -325
  66. warp/native/bvh.cu +406 -23
  67. warp/native/bvh.h +34 -43
  68. warp/native/clang/clang.cpp +13 -8
  69. warp/native/crt.h +2 -0
  70. warp/native/cuda_crt.h +5 -0
  71. warp/native/cuda_util.cpp +15 -3
  72. warp/native/cuda_util.h +3 -1
  73. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  74. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  75. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  76. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  77. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  78. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  79. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  80. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  133. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  134. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  135. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  136. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  137. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  138. warp/native/cutlass_gemm.cu +5 -3
  139. warp/native/exports.h +1240 -952
  140. warp/native/fabric.h +228 -0
  141. warp/native/hashgrid.cpp +4 -4
  142. warp/native/hashgrid.h +22 -2
  143. warp/native/intersect.h +22 -7
  144. warp/native/intersect_adj.h +8 -8
  145. warp/native/intersect_tri.h +1 -1
  146. warp/native/marching.cu +157 -161
  147. warp/native/mat.h +80 -19
  148. warp/native/matnn.h +2 -2
  149. warp/native/mesh.cpp +33 -108
  150. warp/native/mesh.cu +114 -23
  151. warp/native/mesh.h +446 -46
  152. warp/native/noise.h +272 -329
  153. warp/native/quat.h +51 -8
  154. warp/native/rand.h +45 -35
  155. warp/native/range.h +6 -2
  156. warp/native/reduce.cpp +1 -1
  157. warp/native/reduce.cu +10 -12
  158. warp/native/runlength_encode.cu +6 -10
  159. warp/native/scan.cu +8 -11
  160. warp/native/sparse.cpp +4 -4
  161. warp/native/sparse.cu +164 -154
  162. warp/native/spatial.h +2 -2
  163. warp/native/temp_buffer.h +14 -30
  164. warp/native/vec.h +107 -23
  165. warp/native/volume.h +120 -0
  166. warp/native/warp.cpp +560 -30
  167. warp/native/warp.cu +431 -44
  168. warp/native/warp.h +13 -4
  169. warp/optim/__init__.py +1 -0
  170. warp/optim/linear.py +922 -0
  171. warp/optim/sgd.py +92 -0
  172. warp/render/render_opengl.py +335 -119
  173. warp/render/render_usd.py +11 -11
  174. warp/sim/__init__.py +2 -2
  175. warp/sim/articulation.py +385 -185
  176. warp/sim/collide.py +8 -0
  177. warp/sim/import_mjcf.py +297 -106
  178. warp/sim/import_urdf.py +389 -210
  179. warp/sim/import_usd.py +198 -97
  180. warp/sim/inertia.py +17 -18
  181. warp/sim/integrator_euler.py +14 -8
  182. warp/sim/integrator_xpbd.py +158 -16
  183. warp/sim/model.py +795 -291
  184. warp/sim/render.py +3 -3
  185. warp/sim/utils.py +3 -0
  186. warp/sparse.py +640 -150
  187. warp/stubs.py +606 -267
  188. warp/tape.py +61 -10
  189. warp/tests/__main__.py +3 -6
  190. warp/tests/assets/curlnoise_golden.npy +0 -0
  191. warp/tests/assets/pnoise_golden.npy +0 -0
  192. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  193. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  194. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  195. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  196. warp/tests/aux_test_unresolved_func.py +14 -0
  197. warp/tests/aux_test_unresolved_symbol.py +14 -0
  198. warp/tests/disabled_kinematics.py +239 -0
  199. warp/tests/run_coverage_serial.py +31 -0
  200. warp/tests/test_adam.py +103 -106
  201. warp/tests/test_arithmetic.py +128 -74
  202. warp/tests/test_array.py +212 -97
  203. warp/tests/test_array_reduce.py +57 -23
  204. warp/tests/test_atomic.py +64 -28
  205. warp/tests/test_bool.py +99 -0
  206. warp/tests/test_builtins_resolution.py +1292 -0
  207. warp/tests/test_bvh.py +42 -18
  208. warp/tests/test_closest_point_edge_edge.py +54 -57
  209. warp/tests/test_codegen.py +208 -130
  210. warp/tests/test_compile_consts.py +28 -20
  211. warp/tests/test_conditional.py +108 -24
  212. warp/tests/test_copy.py +10 -12
  213. warp/tests/test_ctypes.py +112 -88
  214. warp/tests/test_dense.py +21 -14
  215. warp/tests/test_devices.py +98 -0
  216. warp/tests/test_dlpack.py +75 -75
  217. warp/tests/test_examples.py +277 -0
  218. warp/tests/test_fabricarray.py +955 -0
  219. warp/tests/test_fast_math.py +15 -11
  220. warp/tests/test_fem.py +1271 -0
  221. warp/tests/test_fp16.py +53 -19
  222. warp/tests/test_func.py +187 -86
  223. warp/tests/test_generics.py +194 -49
  224. warp/tests/test_grad.py +178 -109
  225. warp/tests/test_grad_customs.py +176 -0
  226. warp/tests/test_hash_grid.py +52 -37
  227. warp/tests/test_import.py +10 -23
  228. warp/tests/test_indexedarray.py +32 -31
  229. warp/tests/test_intersect.py +18 -9
  230. warp/tests/test_large.py +141 -0
  231. warp/tests/test_launch.py +14 -41
  232. warp/tests/test_lerp.py +64 -65
  233. warp/tests/test_linear_solvers.py +154 -0
  234. warp/tests/test_lvalue.py +493 -0
  235. warp/tests/test_marching_cubes.py +12 -13
  236. warp/tests/test_mat.py +517 -2898
  237. warp/tests/test_mat_lite.py +115 -0
  238. warp/tests/test_mat_scalar_ops.py +2889 -0
  239. warp/tests/test_math.py +103 -9
  240. warp/tests/test_matmul.py +305 -69
  241. warp/tests/test_matmul_lite.py +410 -0
  242. warp/tests/test_mesh.py +71 -14
  243. warp/tests/test_mesh_query_aabb.py +41 -25
  244. warp/tests/test_mesh_query_point.py +140 -22
  245. warp/tests/test_mesh_query_ray.py +39 -22
  246. warp/tests/test_mlp.py +30 -22
  247. warp/tests/test_model.py +92 -89
  248. warp/tests/test_modules_lite.py +39 -0
  249. warp/tests/test_multigpu.py +88 -114
  250. warp/tests/test_noise.py +12 -11
  251. warp/tests/test_operators.py +16 -20
  252. warp/tests/test_options.py +11 -11
  253. warp/tests/test_pinned.py +17 -18
  254. warp/tests/test_print.py +32 -11
  255. warp/tests/test_quat.py +275 -129
  256. warp/tests/test_rand.py +18 -16
  257. warp/tests/test_reload.py +38 -34
  258. warp/tests/test_rounding.py +50 -43
  259. warp/tests/test_runlength_encode.py +168 -20
  260. warp/tests/test_smoothstep.py +9 -11
  261. warp/tests/test_snippet.py +143 -0
  262. warp/tests/test_sparse.py +261 -63
  263. warp/tests/test_spatial.py +276 -243
  264. warp/tests/test_streams.py +110 -85
  265. warp/tests/test_struct.py +268 -63
  266. warp/tests/test_tape.py +39 -21
  267. warp/tests/test_torch.py +118 -89
  268. warp/tests/test_transient_module.py +12 -13
  269. warp/tests/test_types.py +614 -0
  270. warp/tests/test_utils.py +494 -0
  271. warp/tests/test_vec.py +354 -2050
  272. warp/tests/test_vec_lite.py +73 -0
  273. warp/tests/test_vec_scalar_ops.py +2099 -0
  274. warp/tests/test_volume.py +457 -293
  275. warp/tests/test_volume_write.py +124 -134
  276. warp/tests/unittest_serial.py +35 -0
  277. warp/tests/unittest_suites.py +341 -0
  278. warp/tests/unittest_utils.py +568 -0
  279. warp/tests/unused_test_misc.py +71 -0
  280. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  281. warp/thirdparty/appdirs.py +36 -45
  282. warp/thirdparty/unittest_parallel.py +549 -0
  283. warp/torch.py +9 -6
  284. warp/types.py +1089 -366
  285. warp/utils.py +93 -387
  286. warp_lang-0.11.0.dist-info/METADATA +238 -0
  287. warp_lang-0.11.0.dist-info/RECORD +332 -0
  288. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  289. warp/tests/test_all.py +0 -219
  290. warp/tests/test_array_scan.py +0 -60
  291. warp/tests/test_base.py +0 -208
  292. warp/tests/test_unresolved_func.py +0 -7
  293. warp/tests/test_unresolved_symbol.py +0 -7
  294. warp_lang-0.10.1.dist-info/METADATA +0 -21
  295. warp_lang-0.10.1.dist-info/RECORD +0 -188
  296. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  297. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  298. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  299. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
  300. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/types.py CHANGED
@@ -5,19 +5,17 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
+ from __future__ import annotations
9
+
10
+ import builtins
8
11
  import ctypes
9
12
  import hashlib
13
+ import inspect
10
14
  import struct
11
15
  import zlib
12
- import numpy as np
16
+ from typing import Any, Callable, Generic, List, Tuple, TypeVar, Union
13
17
 
14
- from typing import Any
15
- from typing import Tuple
16
- from typing import TypeVar
17
- from typing import Generic
18
- from typing import List
19
- from typing import Callable
20
- from typing import Union
18
+ import numpy as np
21
19
 
22
20
  import warp
23
21
 
@@ -54,12 +52,14 @@ def constant(x):
54
52
  global _constant_hash
55
53
 
56
54
  # hash the constant value
57
- if isinstance(x, int):
55
+ if isinstance(x, builtins.bool):
56
+ # This needs to come before the check for `int` since all boolean
57
+ # values are also instances of `int`.
58
+ _constant_hash.update(struct.pack("?", x))
59
+ elif isinstance(x, int):
58
60
  _constant_hash.update(struct.pack("<q", x))
59
61
  elif isinstance(x, float):
60
62
  _constant_hash.update(struct.pack("<d", x))
61
- elif isinstance(x, bool):
62
- _constant_hash.update(struct.pack("?", x))
63
63
  elif isinstance(x, float16):
64
64
  # float16 is a special case
65
65
  p = ctypes.pointer(ctypes.c_float(x.value))
@@ -149,28 +149,74 @@ def vector(length, dtype):
149
149
 
150
150
  def __setitem__(self, key, value):
151
151
  if isinstance(key, int):
152
- super().__setitem__(key, vec_t.scalar_import(value))
153
- return value
152
+ try:
153
+ return super().__setitem__(key, vec_t.scalar_import(value))
154
+ except (TypeError, ctypes.ArgumentError):
155
+ raise TypeError(
156
+ f"Expected to assign a `{self._wp_scalar_type_.__name__}` value "
157
+ f"but got `{type(value).__name__}` instead"
158
+ ) from None
154
159
  elif isinstance(key, slice):
160
+ try:
161
+ iter(value)
162
+ except TypeError:
163
+ raise TypeError(
164
+ f"Expected to assign a slice from a sequence of values "
165
+ f"but got `{type(value).__name__}` instead"
166
+ ) from None
167
+
155
168
  if self._wp_scalar_type_ == float16:
156
- super().__setitem__(key, [vec_t.scalar_import(x) for x in value])
157
- return value
158
- else:
169
+ converted = []
170
+ try:
171
+ for x in value:
172
+ converted.append(vec_t.scalar_import(x))
173
+ except ctypes.ArgumentError:
174
+ raise TypeError(
175
+ f"Expected to assign a slice from a sequence of `float16` values "
176
+ f"but got `{type(x).__name__}` instead"
177
+ ) from None
178
+
179
+ value = converted
180
+
181
+ try:
159
182
  return super().__setitem__(key, value)
183
+ except TypeError:
184
+ for x in value:
185
+ try:
186
+ self._type_(x)
187
+ except TypeError:
188
+ raise TypeError(
189
+ f"Expected to assign a slice from a sequence of `{self._wp_scalar_type_.__name__}` values "
190
+ f"but got `{type(x).__name__}` instead"
191
+ ) from None
160
192
  else:
161
193
  raise KeyError(f"Invalid key {key}, expected int or slice")
162
194
 
195
+ def __getattr__(self, name):
196
+ idx = "xyzw".find(name)
197
+ if idx != -1:
198
+ return self.__getitem__(idx)
199
+
200
+ return self.__getattribute__(name)
201
+
202
+ def __setattr__(self, name, value):
203
+ idx = "xyzw".find(name)
204
+ if idx != -1:
205
+ return self.__setitem__(idx, value)
206
+
207
+ return super().__setattr__(name, value)
208
+
163
209
  def __add__(self, y):
164
210
  return warp.add(self, y)
165
211
 
166
212
  def __radd__(self, y):
167
- return warp.add(self, y)
213
+ return warp.add(y, self)
168
214
 
169
215
  def __sub__(self, y):
170
216
  return warp.sub(self, y)
171
217
 
172
- def __rsub__(self, x):
173
- return warp.sub(x, self)
218
+ def __rsub__(self, y):
219
+ return warp.sub(y, self)
174
220
 
175
221
  def __mul__(self, y):
176
222
  return warp.mul(self, y)
@@ -178,17 +224,17 @@ def vector(length, dtype):
178
224
  def __rmul__(self, x):
179
225
  return warp.mul(x, self)
180
226
 
181
- def __div__(self, y):
227
+ def __truediv__(self, y):
182
228
  return warp.div(self, y)
183
229
 
184
- def __rdiv__(self, x):
230
+ def __rtruediv__(self, x):
185
231
  return warp.div(x, self)
186
232
 
187
- def __pos__(self, y):
188
- return warp.pos(self, y)
233
+ def __pos__(self):
234
+ return warp.pos(self)
189
235
 
190
- def __neg__(self, y):
191
- return warp.neg(self, y)
236
+ def __neg__(self):
237
+ return warp.neg(self)
192
238
 
193
239
  def __str__(self):
194
240
  return f"[{', '.join(map(str, self))}]"
@@ -280,13 +326,13 @@ def matrix(shape, dtype):
280
326
  return warp.add(self, y)
281
327
 
282
328
  def __radd__(self, y):
283
- return warp.add(self, y)
329
+ return warp.add(y, self)
284
330
 
285
331
  def __sub__(self, y):
286
332
  return warp.sub(self, y)
287
333
 
288
- def __rsub__(self, x):
289
- return warp.sub(x, self)
334
+ def __rsub__(self, y):
335
+ return warp.sub(y, self)
290
336
 
291
337
  def __mul__(self, y):
292
338
  return warp.mul(self, y)
@@ -300,17 +346,17 @@ def matrix(shape, dtype):
300
346
  def __rmatmul__(self, x):
301
347
  return warp.mul(x, self)
302
348
 
303
- def __div__(self, y):
349
+ def __truediv__(self, y):
304
350
  return warp.div(self, y)
305
351
 
306
- def __rdiv__(self, x):
352
+ def __rtruediv__(self, x):
307
353
  return warp.div(x, self)
308
354
 
309
- def __pos__(self, y):
310
- return warp.pos(self, y)
355
+ def __pos__(self):
356
+ return warp.pos(self)
311
357
 
312
- def __neg__(self, y):
313
- return warp.neg(self, y)
358
+ def __neg__(self):
359
+ return warp.neg(self)
314
360
 
315
361
  def __str__(self):
316
362
  row_str = []
@@ -341,10 +387,28 @@ def matrix(shape, dtype):
341
387
  def set_row(self, r, v):
342
388
  if r < 0 or r >= self._shape_[0]:
343
389
  raise IndexError("Invalid row index")
390
+ try:
391
+ iter(v)
392
+ except TypeError:
393
+ raise TypeError(
394
+ f"Expected to assign a slice from a sequence of values "
395
+ f"but got `{type(v).__name__}` instead"
396
+ ) from None
397
+
344
398
  row_start = r * self._shape_[1]
345
399
  row_end = row_start + self._shape_[1]
346
400
  if self._wp_scalar_type_ == float16:
347
- v = [mat_t.scalar_import(x) for x in v]
401
+ converted = []
402
+ try:
403
+ for x in v:
404
+ converted.append(mat_t.scalar_import(x))
405
+ except ctypes.ArgumentError:
406
+ raise TypeError(
407
+ f"Expected to assign a slice from a sequence of `float16` values "
408
+ f"but got `{type(x).__name__}` instead"
409
+ ) from None
410
+
411
+ v = converted
348
412
  super().__setitem__(slice(row_start, row_end), v)
349
413
 
350
414
  def __getitem__(self, key):
@@ -352,6 +416,8 @@ def matrix(shape, dtype):
352
416
  # element indexing m[i,j]
353
417
  if len(key) != 2:
354
418
  raise KeyError(f"Invalid key, expected one or two indices, got {len(key)}")
419
+ if any(isinstance(x, slice) for x in key):
420
+ raise KeyError(f"Slices are not supported when indexing matrices using the `m[i, j]` notation")
355
421
  return mat_t.scalar_export(super().__getitem__(key[0] * self._shape_[1] + key[1]))
356
422
  elif isinstance(key, int):
357
423
  # row vector indexing m[r]
@@ -364,12 +430,20 @@ def matrix(shape, dtype):
364
430
  # element indexing m[i,j] = x
365
431
  if len(key) != 2:
366
432
  raise KeyError(f"Invalid key, expected one or two indices, got {len(key)}")
367
- super().__setitem__(key[0] * self._shape_[1] + key[1], mat_t.scalar_import(value))
368
- return value
433
+ if any(isinstance(x, slice) for x in key):
434
+ raise KeyError(f"Slices are not supported when indexing matrices using the `m[i, j]` notation")
435
+ try:
436
+ return super().__setitem__(key[0] * self._shape_[1] + key[1], mat_t.scalar_import(value))
437
+ except (TypeError, ctypes.ArgumentError):
438
+ raise TypeError(
439
+ f"Expected to assign a `{self._wp_scalar_type_.__name__}` value "
440
+ f"but got `{type(value).__name__}` instead"
441
+ ) from None
369
442
  elif isinstance(key, int):
370
443
  # row vector indexing m[r] = v
371
- self.set_row(key, value)
372
- return value
444
+ return self.set_row(key, value)
445
+ elif isinstance(key, slice):
446
+ raise KeyError(f"Slices are not supported when indexing matrices using the `m[start:end]` notation")
373
447
  else:
374
448
  raise KeyError(f"Invalid key {key}, expected int or pair of ints")
375
449
 
@@ -392,6 +466,23 @@ class void:
392
466
  pass
393
467
 
394
468
 
469
+ class bool:
470
+ _length_ = 1
471
+ _type_ = ctypes.c_bool
472
+
473
+ def __init__(self, x=False):
474
+ self.value = x
475
+
476
+ def __bool__(self) -> bool:
477
+ return self.value != 0
478
+
479
+ def __float__(self) -> float:
480
+ return float(self.value != 0)
481
+
482
+ def __int__(self) -> int:
483
+ return int(self.value != 0)
484
+
485
+
395
486
  class float16:
396
487
  _length_ = 1
397
488
  _type_ = ctypes.c_uint16
@@ -399,6 +490,15 @@ class float16:
399
490
  def __init__(self, x=0.0):
400
491
  self.value = x
401
492
 
493
+ def __bool__(self) -> bool:
494
+ return self.value != 0.0
495
+
496
+ def __float__(self) -> float:
497
+ return float(self.value)
498
+
499
+ def __int__(self) -> int:
500
+ return int(self.value)
501
+
402
502
 
403
503
  class float32:
404
504
  _length_ = 1
@@ -407,6 +507,15 @@ class float32:
407
507
  def __init__(self, x=0.0):
408
508
  self.value = x
409
509
 
510
+ def __bool__(self) -> bool:
511
+ return self.value != 0.0
512
+
513
+ def __float__(self) -> float:
514
+ return float(self.value)
515
+
516
+ def __int__(self) -> int:
517
+ return int(self.value)
518
+
410
519
 
411
520
  class float64:
412
521
  _length_ = 1
@@ -415,6 +524,15 @@ class float64:
415
524
  def __init__(self, x=0.0):
416
525
  self.value = x
417
526
 
527
+ def __bool__(self) -> bool:
528
+ return self.value != 0.0
529
+
530
+ def __float__(self) -> float:
531
+ return float(self.value)
532
+
533
+ def __int__(self) -> int:
534
+ return int(self.value)
535
+
418
536
 
419
537
  class int8:
420
538
  _length_ = 1
@@ -423,6 +541,18 @@ class int8:
423
541
  def __init__(self, x=0):
424
542
  self.value = x
425
543
 
544
+ def __bool__(self) -> bool:
545
+ return self.value != 0
546
+
547
+ def __float__(self) -> float:
548
+ return float(self.value)
549
+
550
+ def __int__(self) -> int:
551
+ return int(self.value)
552
+
553
+ def __index__(self) -> int:
554
+ return int(self.value)
555
+
426
556
 
427
557
  class uint8:
428
558
  _length_ = 1
@@ -431,6 +561,18 @@ class uint8:
431
561
  def __init__(self, x=0):
432
562
  self.value = x
433
563
 
564
+ def __bool__(self) -> bool:
565
+ return self.value != 0
566
+
567
+ def __float__(self) -> float:
568
+ return float(self.value)
569
+
570
+ def __int__(self) -> int:
571
+ return int(self.value)
572
+
573
+ def __index__(self) -> int:
574
+ return int(self.value)
575
+
434
576
 
435
577
  class int16:
436
578
  _length_ = 1
@@ -439,6 +581,18 @@ class int16:
439
581
  def __init__(self, x=0):
440
582
  self.value = x
441
583
 
584
+ def __bool__(self) -> bool:
585
+ return self.value != 0
586
+
587
+ def __float__(self) -> float:
588
+ return float(self.value)
589
+
590
+ def __int__(self) -> int:
591
+ return int(self.value)
592
+
593
+ def __index__(self) -> int:
594
+ return int(self.value)
595
+
442
596
 
443
597
  class uint16:
444
598
  _length_ = 1
@@ -447,6 +601,18 @@ class uint16:
447
601
  def __init__(self, x=0):
448
602
  self.value = x
449
603
 
604
+ def __bool__(self) -> bool:
605
+ return self.value != 0
606
+
607
+ def __float__(self) -> float:
608
+ return float(self.value)
609
+
610
+ def __int__(self) -> int:
611
+ return int(self.value)
612
+
613
+ def __index__(self) -> int:
614
+ return int(self.value)
615
+
450
616
 
451
617
  class int32:
452
618
  _length_ = 1
@@ -455,6 +621,18 @@ class int32:
455
621
  def __init__(self, x=0):
456
622
  self.value = x
457
623
 
624
+ def __bool__(self) -> bool:
625
+ return self.value != 0
626
+
627
+ def __float__(self) -> float:
628
+ return float(self.value)
629
+
630
+ def __int__(self) -> int:
631
+ return int(self.value)
632
+
633
+ def __index__(self) -> int:
634
+ return int(self.value)
635
+
458
636
 
459
637
  class uint32:
460
638
  _length_ = 1
@@ -463,6 +641,18 @@ class uint32:
463
641
  def __init__(self, x=0):
464
642
  self.value = x
465
643
 
644
+ def __bool__(self) -> bool:
645
+ return self.value != 0
646
+
647
+ def __float__(self) -> float:
648
+ return float(self.value)
649
+
650
+ def __int__(self) -> int:
651
+ return int(self.value)
652
+
653
+ def __index__(self) -> int:
654
+ return int(self.value)
655
+
466
656
 
467
657
  class int64:
468
658
  _length_ = 1
@@ -471,6 +661,18 @@ class int64:
471
661
  def __init__(self, x=0):
472
662
  self.value = x
473
663
 
664
+ def __bool__(self) -> bool:
665
+ return self.value != 0
666
+
667
+ def __float__(self) -> float:
668
+ return float(self.value)
669
+
670
+ def __int__(self) -> int:
671
+ return int(self.value)
672
+
673
+ def __index__(self) -> int:
674
+ return int(self.value)
675
+
474
676
 
475
677
  class uint64:
476
678
  _length_ = 1
@@ -479,6 +681,18 @@ class uint64:
479
681
  def __init__(self, x=0):
480
682
  self.value = x
481
683
 
684
+ def __bool__(self) -> bool:
685
+ return self.value != 0
686
+
687
+ def __float__(self) -> float:
688
+ return float(self.value)
689
+
690
+ def __int__(self) -> int:
691
+ return int(self.value)
692
+
693
+ def __index__(self) -> int:
694
+ return int(self.value)
695
+
482
696
 
483
697
  def quaternion(dtype=Any):
484
698
  class quat_t(vector(length=4, dtype=dtype)):
@@ -508,23 +722,63 @@ class quatd(quaternion(dtype=float64)):
508
722
 
509
723
  def transformation(dtype=Any):
510
724
  class transform_t(vector(length=7, dtype=dtype)):
725
+ _wp_init_from_components_sig_ = inspect.Signature(
726
+ (
727
+ inspect.Parameter(
728
+ "p",
729
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
730
+ default=(0.0, 0.0, 0.0),
731
+ ),
732
+ inspect.Parameter(
733
+ "q",
734
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
735
+ default=(0.0, 0.0, 0.0, 1.0),
736
+ ),
737
+ ),
738
+ )
511
739
  _wp_type_params_ = [dtype]
512
740
  _wp_generic_type_str_ = "transform_t"
513
741
  _wp_constructor_ = "transformation"
514
742
 
515
- def __init__(self, p=(0.0, 0.0, 0.0), q=(0.0, 0.0, 0.0, 1.0)):
516
- super().__init__()
743
+ def __init__(self, *args, **kwargs):
744
+ if len(args) == 1 and len(kwargs) == 0:
745
+ if getattr(args[0], "_wp_generic_type_str_") == self._wp_generic_type_str_:
746
+ # Copy constructor.
747
+ super().__init__(*args[0])
748
+ return
749
+
750
+ try:
751
+ # For backward compatibility, try to check if the arguments
752
+ # match the original signature that'd allow initializing
753
+ # the `p` and `q` components separately.
754
+ bound_args = self._wp_init_from_components_sig_.bind(*args, **kwargs)
755
+ bound_args.apply_defaults()
756
+ p, q = bound_args.args
757
+ except (TypeError, ValueError):
758
+ # Fallback to the vector's constructor.
759
+ super().__init__(*args)
760
+ return
761
+
762
+ # Even if the arguments match the original “from components”
763
+ # signature, we still need to make sure that they represent
764
+ # sequences that can be unpacked.
765
+ if hasattr(p, "__len__") and hasattr(q, "__len__"):
766
+ # Initialize from the `p` and `q` components.
767
+ super().__init__()
768
+ self[0:3] = vector(length=3, dtype=dtype)(*p)
769
+ self[3:7] = quaternion(dtype=dtype)(*q)
770
+ return
517
771
 
518
- self[0:3] = vector(length=3, dtype=dtype)(*p)
519
- self[3:7] = quaternion(dtype=dtype)(*q)
772
+ # Fallback to the vector's constructor.
773
+ super().__init__(*args)
520
774
 
521
775
  @property
522
776
  def p(self):
523
- return self[0:3]
777
+ return vec3(self[0:3])
524
778
 
525
779
  @property
526
780
  def q(self):
527
- return self[3:7]
781
+ return quat(self[3:7])
528
782
 
529
783
  return transform_t
530
784
 
@@ -808,6 +1062,7 @@ vector_types = [
808
1062
  ]
809
1063
 
810
1064
  np_dtype_to_warp_type = {
1065
+ np.dtype(np.bool_): bool,
811
1066
  np.dtype(np.int8): int8,
812
1067
  np.dtype(np.uint8): uint8,
813
1068
  np.dtype(np.int16): int16,
@@ -824,6 +1079,7 @@ np_dtype_to_warp_type = {
824
1079
  }
825
1080
 
826
1081
  warp_type_to_np_dtype = {
1082
+ bool: np.bool_,
827
1083
  int8: np.int8,
828
1084
  int16: np.int16,
829
1085
  int32: np.int32,
@@ -846,18 +1102,21 @@ class range_t:
846
1102
 
847
1103
  # definition just for kernel type (cannot be a parameter), see bvh.h
848
1104
  class bvh_query_t:
1105
+ """Object used to track state during BVH traversal."""
849
1106
  def __init__(self):
850
1107
  pass
851
1108
 
852
1109
 
853
1110
  # definition just for kernel type (cannot be a parameter), see mesh.h
854
1111
  class mesh_query_aabb_t:
1112
+ """Object used to track state during mesh traversal."""
855
1113
  def __init__(self):
856
1114
  pass
857
1115
 
858
1116
 
859
1117
  # definition just for kernel type (cannot be a parameter), see hash_grid.h
860
1118
  class hash_grid_query_t:
1119
+ """Object used to track state during neighbor traversal."""
861
1120
  def __init__(self):
862
1121
  pass
863
1122
 
@@ -869,6 +1128,8 @@ LAUNCH_MAX_DIMS = 4
869
1128
  # must match array.h
870
1129
  ARRAY_TYPE_REGULAR = 0
871
1130
  ARRAY_TYPE_INDEXED = 1
1131
+ ARRAY_TYPE_FABRIC = 2
1132
+ ARRAY_TYPE_FABRIC_INDEXED = 3
872
1133
 
873
1134
 
874
1135
  # represents bounds for kernel launch (number of threads across multiple dimensions)
@@ -992,7 +1253,7 @@ def type_scalar_type(dtype):
992
1253
  def type_size_in_bytes(dtype):
993
1254
  if dtype.__module__ == "ctypes":
994
1255
  return ctypes.sizeof(dtype)
995
- elif type_is_struct(dtype):
1256
+ elif isinstance(dtype, warp.codegen.Struct):
996
1257
  return ctypes.sizeof(dtype.ctype)
997
1258
  elif dtype == float or dtype == int:
998
1259
  return 4
@@ -1013,9 +1274,9 @@ def type_to_warp(dtype):
1013
1274
 
1014
1275
 
1015
1276
  def type_typestr(dtype):
1016
- from warp.codegen import Struct
1017
-
1018
- if dtype == float16:
1277
+ if dtype == bool:
1278
+ return "?"
1279
+ elif dtype == float16:
1019
1280
  return "<f2"
1020
1281
  elif dtype == float32:
1021
1282
  return "<f4"
@@ -1037,7 +1298,7 @@ def type_typestr(dtype):
1037
1298
  return "<i8"
1038
1299
  elif dtype == uint64:
1039
1300
  return "<u8"
1040
- elif isinstance(dtype, Struct):
1301
+ elif isinstance(dtype, warp.codegen.Struct):
1041
1302
  return f"|V{ctypes.sizeof(dtype.ctype)}"
1042
1303
  elif issubclass(dtype, ctypes.Array):
1043
1304
  return type_typestr(dtype._wp_scalar_type_)
@@ -1051,9 +1312,16 @@ def type_repr(t):
1051
1312
  return str(f"array(ndim={t.ndim}, dtype={t.dtype})")
1052
1313
  if type_is_vector(t):
1053
1314
  return str(f"vector(length={t._shape_[0]}, dtype={t._wp_scalar_type_})")
1054
- elif type_is_matrix(t):
1315
+ if type_is_matrix(t):
1055
1316
  return str(f"matrix(shape=({t._shape_[0]}, {t._shape_[1]}), dtype={t._wp_scalar_type_})")
1056
- else:
1317
+ if isinstance(t, warp.codegen.Struct):
1318
+ return type_repr(t.cls)
1319
+ if t in scalar_types:
1320
+ return t.__name__
1321
+
1322
+ try:
1323
+ return t.__module__ + "." + t.__qualname__
1324
+ except AttributeError:
1057
1325
  return str(t)
1058
1326
 
1059
1327
 
@@ -1071,15 +1339,6 @@ def type_is_float(t):
1071
1339
  return t in float_types
1072
1340
 
1073
1341
 
1074
- def type_is_struct(dtype):
1075
- from warp.codegen import Struct
1076
-
1077
- if isinstance(dtype, Struct):
1078
- return True
1079
- else:
1080
- return False
1081
-
1082
-
1083
1342
  # returns True if the passed *type* is a vector
1084
1343
  def type_is_vector(t):
1085
1344
  if hasattr(t, "_wp_generic_type_str_") and t._wp_generic_type_str_ == "vec_t":
@@ -1098,7 +1357,7 @@ def type_is_matrix(t):
1098
1357
 
1099
1358
  # returns true for all value types (int, float, bool, scalars, vectors, matrices)
1100
1359
  def type_is_value(x):
1101
- if (x == int) or (x == float) or (x == bool) or (x in scalar_types) or issubclass(x, ctypes.Array):
1360
+ if (x == int) or (x == float) or (x == builtins.bool) or (x in scalar_types) or issubclass(x, ctypes.Array):
1102
1361
  return True
1103
1362
  else:
1104
1363
  return False
@@ -1126,14 +1385,16 @@ def types_equal(a, b, match_generic=False):
1126
1385
  # convert to canonical types
1127
1386
  if a == float:
1128
1387
  a = float32
1129
- if a == int:
1388
+ elif a == int:
1130
1389
  a = int32
1131
1390
 
1132
1391
  if b == float:
1133
1392
  b = float32
1134
- if b == int:
1393
+ elif b == int:
1135
1394
  b = int32
1136
1395
 
1396
+ compatible_bool_types = [builtins.bool, bool]
1397
+
1137
1398
  def are_equal(p1, p2):
1138
1399
  if match_generic:
1139
1400
  if p1 == Any or p2 == Any:
@@ -1150,7 +1411,22 @@ def types_equal(a, b, match_generic=False):
1150
1411
  return True
1151
1412
  if p1 == Float and p2 == Float:
1152
1413
  return True
1153
- return p1 == p2
1414
+
1415
+ # convert to canonical types
1416
+ if p1 == float:
1417
+ p1 = float32
1418
+ elif p1 == int:
1419
+ p1 = int32
1420
+
1421
+ if p2 == float:
1422
+ p2 = float32
1423
+ elif b == int:
1424
+ p2 = int32
1425
+
1426
+ if p1 in compatible_bool_types and p2 in compatible_bool_types:
1427
+ return True
1428
+ else:
1429
+ return p1 == p2
1154
1430
 
1155
1431
  if (
1156
1432
  hasattr(a, "_wp_generic_type_str_")
@@ -1158,9 +1434,7 @@ def types_equal(a, b, match_generic=False):
1158
1434
  and a._wp_generic_type_str_ == b._wp_generic_type_str_
1159
1435
  ):
1160
1436
  return all([are_equal(p1, p2) for p1, p2 in zip(a._wp_type_params_, b._wp_type_params_)])
1161
- if isinstance(a, array) and isinstance(b, array):
1162
- return True
1163
- if isinstance(a, indexedarray) and isinstance(b, indexedarray):
1437
+ if is_array(a) and type(a) is type(b):
1164
1438
  return True
1165
1439
  else:
1166
1440
  return are_equal(a, b)
@@ -1244,6 +1518,7 @@ class array(Array):
1244
1518
  self._grad = None
1245
1519
  # __array_interface__ or __cuda_array_interface__, evaluated lazily and cached
1246
1520
  self._array_interface = None
1521
+ self.is_transposed = False
1247
1522
 
1248
1523
  # canonicalize dtype
1249
1524
  if dtype == int:
@@ -1317,7 +1592,9 @@ class array(Array):
1317
1592
  if isinstance(data, np.ndarray):
1318
1593
  # construct from numpy structured array
1319
1594
  if data.dtype != dtype.numpy_dtype():
1320
- raise RuntimeError(f"Invalid source data type for array of structs, expected {dtype.numpy_dtype()}, got {data.dtype}")
1595
+ raise RuntimeError(
1596
+ f"Invalid source data type for array of structs, expected {dtype.numpy_dtype()}, got {data.dtype}"
1597
+ )
1321
1598
  arr = data
1322
1599
  elif isinstance(data, (list, tuple)):
1323
1600
  # construct from a sequence of structs
@@ -1329,9 +1606,13 @@ class array(Array):
1329
1606
  # convert to numpy
1330
1607
  arr = np.frombuffer(ctype_arr, dtype=dtype.ctype)
1331
1608
  except Exception as e:
1332
- raise RuntimeError(f"Error while trying to construct Warp array from a sequence of Warp structs: {e}")
1609
+ raise RuntimeError(
1610
+ f"Error while trying to construct Warp array from a sequence of Warp structs: {e}"
1611
+ )
1333
1612
  else:
1334
- raise RuntimeError(f"Invalid data argument for array of structs, expected a sequence of structs or a NumPy structured array")
1613
+ raise RuntimeError(
1614
+ "Invalid data argument for array of structs, expected a sequence of structs or a NumPy structured array"
1615
+ )
1335
1616
  else:
1336
1617
  # convert input data to the given dtype
1337
1618
  npdtype = warp_type_to_np_dtype.get(scalar_dtype)
@@ -1416,7 +1697,7 @@ class array(Array):
1416
1697
 
1417
1698
  def _init_from_ptr(self, ptr, dtype, shape, strides, capacity, device, owner, pinned):
1418
1699
  if dtype == Any:
1419
- raise RuntimeError(f"A concrete data type is required to create the array")
1700
+ raise RuntimeError("A concrete data type is required to create the array")
1420
1701
 
1421
1702
  device = warp.get_device(device)
1422
1703
 
@@ -1450,7 +1731,7 @@ class array(Array):
1450
1731
 
1451
1732
  def _init_new(self, dtype, shape, strides, device, pinned):
1452
1733
  if dtype == Any:
1453
- raise RuntimeError(f"A concrete data type is required to create the array")
1734
+ raise RuntimeError("A concrete data type is required to create the array")
1454
1735
 
1455
1736
  device = warp.get_device(device)
1456
1737
 
@@ -1753,7 +2034,7 @@ class array(Array):
1753
2034
  return self._requires_grad
1754
2035
 
1755
2036
  @requires_grad.setter
1756
- def requires_grad(self, value: bool):
2037
+ def requires_grad(self, value: builtins.bool):
1757
2038
  if value and self._grad is None:
1758
2039
  self._alloc_grad()
1759
2040
  elif not value:
@@ -1778,12 +2059,11 @@ class array(Array):
1778
2059
  # member attributes available during code-gen (e.g.: d = array.shape[0])
1779
2060
  # Note: we use a shared dict for all array instances
1780
2061
  if array._vars is None:
1781
- from warp.codegen import Var
1782
-
1783
- array._vars = {"shape": Var("shape", shape_t)}
2062
+ array._vars = {"shape": warp.codegen.Var("shape", shape_t)}
1784
2063
  return array._vars
1785
2064
 
1786
2065
  def zero_(self):
2066
+ """Zeroes-out the array entries."""
1787
2067
  if self.is_contiguous:
1788
2068
  # simple memset is usually faster than generic fill
1789
2069
  self.device.memset(self.ptr, 0, self.size * type_size_in_bytes(self.dtype))
@@ -1791,6 +2071,32 @@ class array(Array):
1791
2071
  self.fill_(0)
1792
2072
 
1793
2073
  def fill_(self, value):
2074
+ """Set all array entries to `value`
2075
+
2076
+ args:
2077
+ value: The value to set every array entry to. Must be convertible to the array's ``dtype``.
2078
+
2079
+ Raises:
2080
+ ValueError: If `value` cannot be converted to the array's ``dtype``.
2081
+
2082
+ Examples:
2083
+ ``fill_()`` can take lists or other sequences when filling arrays of vectors or matrices.
2084
+
2085
+ >>> arr = wp.zeros(2, dtype=wp.mat22)
2086
+ >>> arr.numpy()
2087
+ array([[[0., 0.],
2088
+ [0., 0.]],
2089
+ <BLANKLINE>
2090
+ [[0., 0.],
2091
+ [0., 0.]]], dtype=float32)
2092
+ >>> arr.fill_([[1, 2], [3, 4]])
2093
+ >>> arr.numpy()
2094
+ array([[[1., 2.],
2095
+ [3., 4.]],
2096
+ <BLANKLINE>
2097
+ [[1., 2.],
2098
+ [3., 4.]]], dtype=float32)
2099
+ """
1794
2100
  if self.size == 0:
1795
2101
  return
1796
2102
 
@@ -1837,19 +2143,22 @@ class array(Array):
1837
2143
  else:
1838
2144
  warp.context.runtime.core.array_fill_host(carr_ptr, ARRAY_TYPE_REGULAR, cvalue_ptr, cvalue_size)
1839
2145
 
1840
- # equivalent to wrapping src data in an array and copying to self
1841
2146
  def assign(self, src):
2147
+ """Wraps ``src`` in an :class:`warp.array` if it is not already one and copies the contents to ``self``."""
1842
2148
  if is_array(src):
1843
2149
  warp.copy(self, src)
1844
2150
  else:
1845
2151
  warp.copy(self, array(data=src, dtype=self.dtype, copy=False, device="cpu"))
1846
2152
 
1847
- # convert array to ndarray (alias memory through array interface)
1848
2153
  def numpy(self):
2154
+ """Converts the array to a :class:`numpy.ndarray` (aliasing memory through the array interface protocol)
2155
+ If the array is on the GPU, a synchronous device-to-host copy (on the CUDA default stream) will be
2156
+ automatically performed to ensure that any outstanding work is completed.
2157
+ """
1849
2158
  if self.ptr:
1850
2159
  # use the CUDA default stream for synchronous behaviour with other streams
1851
2160
  with warp.ScopedStream(self.device.null_stream):
1852
- a = self.to("cpu")
2161
+ a = self.to("cpu", requires_grad=False)
1853
2162
  # convert through __array_interface__
1854
2163
  # Note: this handles arrays of structs using `descr`, so the result will be a structured NumPy array
1855
2164
  return np.array(a, copy=False)
@@ -1866,12 +2175,16 @@ class array(Array):
1866
2175
  npshape = self.shape
1867
2176
  return np.empty(npshape, dtype=npdtype)
1868
2177
 
1869
- # return a ctypes cast of the array address
1870
- # note #1: only CPU arrays support this method
1871
- # note #2: the array must be contiguous
1872
- # note #3: accesses to this object are *not* bounds checked
1873
- # note #4: for float16 types, a pointer to the internal uint16 representation is returned
1874
2178
  def cptr(self):
2179
+ """Return a ctypes cast of the array address.
2180
+
2181
+ Notes:
2182
+
2183
+ #. Only CPU arrays support this method.
2184
+ #. The array must be contiguous.
2185
+ #. Accesses to this object are **not** bounds checked.
2186
+ #. For ``float16`` types, a pointer to the internal ``uint16`` representation is returned.
2187
+ """
1875
2188
  if not self.ptr:
1876
2189
  return None
1877
2190
 
@@ -1890,8 +2203,8 @@ class array(Array):
1890
2203
 
1891
2204
  return p
1892
2205
 
1893
- # returns a flattened list of items in the array as a Python list
1894
2206
  def list(self):
2207
+ """Returns a flattened list of items in the array as a Python list."""
1895
2208
  a = self.numpy()
1896
2209
 
1897
2210
  if isinstance(self.dtype, warp.codegen.Struct):
@@ -1910,15 +2223,16 @@ class array(Array):
1910
2223
  # scalar
1911
2224
  return list(a.flatten())
1912
2225
 
1913
- # convert data from one device to another, nop if already on device
1914
- def to(self, device):
2226
+ def to(self, device, requires_grad=None):
2227
+ """Returns a Warp array with this array's data moved to the specified device, no-op if already on device."""
1915
2228
  device = warp.get_device(device)
1916
2229
  if self.device == device:
1917
2230
  return self
1918
2231
  else:
1919
- return warp.clone(self, device=device)
2232
+ return warp.clone(self, device=device, requires_grad=requires_grad)
1920
2233
 
1921
2234
  def flatten(self):
2235
+ """Returns a zero-copy view of the array collapsed to 1-D. Only supported for contiguous arrays."""
1922
2236
  if self.ndim == 1:
1923
2237
  return self
1924
2238
 
@@ -1941,6 +2255,11 @@ class array(Array):
1941
2255
  return a
1942
2256
 
1943
2257
  def reshape(self, shape):
2258
+ """Returns a reshaped array. Only supported for contiguous arrays.
2259
+
2260
+ Args:
2261
+ shape : An int or tuple of ints specifying the shape of the returned array.
2262
+ """
1944
2263
  if not self.is_contiguous:
1945
2264
  raise RuntimeError("Reshaping non-contiguous arrays is unsupported.")
1946
2265
 
@@ -1998,6 +2317,9 @@ class array(Array):
1998
2317
  return a
1999
2318
 
2000
2319
  def view(self, dtype):
2320
+ """Returns a zero-copy view of this array's memory with a different data type.
2321
+ ``dtype`` must have the same byte size of the array's native ``dtype``.
2322
+ """
2001
2323
  if type_size_in_bytes(dtype) != type_size_in_bytes(self.dtype):
2002
2324
  raise RuntimeError("Cannot cast dtypes of unequal byte size")
2003
2325
 
@@ -2018,6 +2340,7 @@ class array(Array):
2018
2340
  return a
2019
2341
 
2020
2342
  def contiguous(self):
2343
+ """Returns a contiguous array with this array's data. No-op if array is already contiguous."""
2021
2344
  if self.is_contiguous:
2022
2345
  return self
2023
2346
 
@@ -2025,8 +2348,14 @@ class array(Array):
2025
2348
  warp.copy(a, self)
2026
2349
  return a
2027
2350
 
2028
- # note: transpose operation will return an array with a non-contiguous access pattern
2029
2351
  def transpose(self, axes=None):
2352
+ """Returns an zero-copy view of the array with axes transposed.
2353
+
2354
+ Note: The transpose operation will return an array with a non-contiguous access pattern.
2355
+
2356
+ Args:
2357
+ axes (optional): Specifies the how the axes are permuted. If not specified, the axes order will be reversed.
2358
+ """
2030
2359
  # noop if 1d array
2031
2360
  if self.ndim == 1:
2032
2361
  return self
@@ -2059,6 +2388,8 @@ class array(Array):
2059
2388
  grad=None if self.grad is None else self.grad.transpose(axes=axes),
2060
2389
  )
2061
2390
 
2391
+ a.is_transposed = not self.is_transposed
2392
+
2062
2393
  a._ref = self
2063
2394
  return a
2064
2395
 
@@ -2093,7 +2424,7 @@ def from_ptr(ptr, length, dtype=None, shape=None, device=None):
2093
2424
  dtype=dtype,
2094
2425
  length=length,
2095
2426
  capacity=length * type_size_in_bytes(dtype),
2096
- ptr=ctypes.cast(ptr, ctypes.POINTER(ctypes.c_size_t)).contents.value,
2427
+ ptr=0 if ptr == 0 else ctypes.cast(ptr, ctypes.POINTER(ctypes.c_size_t)).contents.value,
2097
2428
  shape=shape,
2098
2429
  device=device,
2099
2430
  owner=False,
@@ -2101,12 +2432,113 @@ def from_ptr(ptr, length, dtype=None, shape=None, device=None):
2101
2432
  )
2102
2433
 
2103
2434
 
2104
- class indexedarray(Generic[T]):
2435
+ # A base class for non-contiguous arrays, providing the implementation of common methods like
2436
+ # contiguous(), to(), numpy(), list(), assign(), zero_(), and fill_().
2437
+ class noncontiguous_array_base(Generic[T]):
2438
+ def __init__(self, array_type_id):
2439
+ self.type_id = array_type_id
2440
+ self.is_contiguous = False
2441
+
2442
+ # return a contiguous copy
2443
+ def contiguous(self):
2444
+ a = warp.empty_like(self)
2445
+ warp.copy(a, self)
2446
+ return a
2447
+
2448
+ # copy data from one device to another, nop if already on device
2449
+ def to(self, device):
2450
+ device = warp.get_device(device)
2451
+ if self.device == device:
2452
+ return self
2453
+ else:
2454
+ return warp.clone(self, device=device)
2455
+
2456
+ # return a contiguous numpy copy
2457
+ def numpy(self):
2458
+ # use the CUDA default stream for synchronous behaviour with other streams
2459
+ with warp.ScopedStream(self.device.null_stream):
2460
+ return self.contiguous().numpy()
2461
+
2462
+ # returns a flattened list of items in the array as a Python list
2463
+ def list(self):
2464
+ # use the CUDA default stream for synchronous behaviour with other streams
2465
+ with warp.ScopedStream(self.device.null_stream):
2466
+ return self.contiguous().list()
2467
+
2468
+ # equivalent to wrapping src data in an array and copying to self
2469
+ def assign(self, src):
2470
+ if is_array(src):
2471
+ warp.copy(self, src)
2472
+ else:
2473
+ warp.copy(self, array(data=src, dtype=self.dtype, copy=False, device="cpu"))
2474
+
2475
+ def zero_(self):
2476
+ self.fill_(0)
2477
+
2478
+ def fill_(self, value):
2479
+ if self.size == 0:
2480
+ return
2481
+
2482
+ # try to convert the given value to the array dtype
2483
+ try:
2484
+ if isinstance(self.dtype, warp.codegen.Struct):
2485
+ if isinstance(value, self.dtype.cls):
2486
+ cvalue = value.__ctype__()
2487
+ elif value == 0:
2488
+ # allow zero-initializing structs using default constructor
2489
+ cvalue = self.dtype().__ctype__()
2490
+ else:
2491
+ raise ValueError(
2492
+ f"Invalid initializer value for struct {self.dtype.cls.__name__}, expected struct instance or 0"
2493
+ )
2494
+ elif issubclass(self.dtype, ctypes.Array):
2495
+ # vector/matrix
2496
+ cvalue = self.dtype(value)
2497
+ else:
2498
+ # scalar
2499
+ if type(value) in warp.types.scalar_types:
2500
+ value = value.value
2501
+ if self.dtype == float16:
2502
+ cvalue = self.dtype._type_(float_to_half_bits(value))
2503
+ else:
2504
+ cvalue = self.dtype._type_(value)
2505
+ except Exception as e:
2506
+ raise ValueError(f"Failed to convert the value to the array data type: {e}")
2507
+
2508
+ cvalue_ptr = ctypes.pointer(cvalue)
2509
+ cvalue_size = ctypes.sizeof(cvalue)
2510
+
2511
+ ctype = self.__ctype__()
2512
+ ctype_ptr = ctypes.pointer(ctype)
2513
+
2514
+ if self.device.is_cuda:
2515
+ warp.context.runtime.core.array_fill_device(
2516
+ self.device.context, ctype_ptr, self.type_id, cvalue_ptr, cvalue_size
2517
+ )
2518
+ else:
2519
+ warp.context.runtime.core.array_fill_host(ctype_ptr, self.type_id, cvalue_ptr, cvalue_size)
2520
+
2521
+
2522
+ # helper to check index array properties
2523
+ def check_index_array(indices, expected_device):
2524
+ if not isinstance(indices, array):
2525
+ raise ValueError(f"Indices must be a Warp array, got {type(indices)}")
2526
+ if indices.ndim != 1:
2527
+ raise ValueError(f"Index array must be one-dimensional, got {indices.ndim}")
2528
+ if indices.dtype != int32:
2529
+ raise ValueError(f"Index array must use int32, got dtype {indices.dtype}")
2530
+ if indices.device != expected_device:
2531
+ raise ValueError(f"Index array device ({indices.device} does not match data array device ({expected_device}))")
2532
+
2533
+
2534
+ class indexedarray(noncontiguous_array_base[T]):
2105
2535
  # member attributes available during code-gen (e.g.: d = arr.shape[0])
2106
2536
  # (initialized when needed)
2107
2537
  _vars = None
2108
2538
 
2109
2539
  def __init__(self, data: array = None, indices: Union[array, List[array]] = None, dtype=None, ndim=None):
2540
+ super().__init__(ARRAY_TYPE_INDEXED)
2541
+
2110
2542
  # canonicalize types
2111
2543
  if dtype is not None:
2112
2544
  if dtype == int:
@@ -2136,17 +2568,6 @@ class indexedarray(Generic[T]):
2136
2568
  shape = list(data.shape)
2137
2569
 
2138
2570
  if indices is not None:
2139
- # helper to check index array properties
2140
- def check_index_array(inds, data):
2141
- if inds.ndim != 1:
2142
- raise ValueError(f"Index array must be one-dimensional, got {inds.ndim}")
2143
- if inds.dtype != int32:
2144
- raise ValueError(f"Index array must use int32, got dtype {inds.dtype}")
2145
- if inds.device != data.device:
2146
- raise ValueError(
2147
- f"Index array device ({inds.device} does not match data array device ({data.device}))"
2148
- )
2149
-
2150
2571
  if isinstance(indices, (list, tuple)):
2151
2572
  if len(indices) > self.ndim:
2152
2573
  raise ValueError(
@@ -2154,16 +2575,14 @@ class indexedarray(Generic[T]):
2154
2575
  )
2155
2576
 
2156
2577
  for i in range(len(indices)):
2157
- if isinstance(indices[i], array):
2158
- check_index_array(indices[i], data)
2578
+ if indices[i] is not None:
2579
+ check_index_array(indices[i], data.device)
2159
2580
  self.indices[i] = indices[i]
2160
2581
  shape[i] = len(indices[i])
2161
- elif indices[i] is not None:
2162
- raise TypeError(f"Invalid index array type: {type(indices[i])}")
2163
2582
 
2164
2583
  elif isinstance(indices, array):
2165
2584
  # only a single index array was provided
2166
- check_index_array(indices, data)
2585
+ check_index_array(indices, data.device)
2167
2586
  self.indices[0] = indices
2168
2587
  shape[0] = len(indices)
2169
2588
 
@@ -2185,8 +2604,6 @@ class indexedarray(Generic[T]):
2185
2604
  for d in self.shape:
2186
2605
  self.size *= d
2187
2606
 
2188
- self.is_contiguous = False
2189
-
2190
2607
  def __len__(self):
2191
2608
  return self.shape[0]
2192
2609
 
@@ -2206,89 +2623,9 @@ class indexedarray(Generic[T]):
2206
2623
  # member attributes available during code-gen (e.g.: d = arr.shape[0])
2207
2624
  # Note: we use a shared dict for all indexedarray instances
2208
2625
  if indexedarray._vars is None:
2209
- from warp.codegen import Var
2210
-
2211
- indexedarray._vars = {"shape": Var("shape", shape_t)}
2626
+ indexedarray._vars = {"shape": warp.codegen.Var("shape", shape_t)}
2212
2627
  return indexedarray._vars
2213
2628
 
2214
- def contiguous(self):
2215
- a = warp.empty_like(self)
2216
- warp.copy(a, self)
2217
- return a
2218
-
2219
- # convert data from one device to another, nop if already on device
2220
- def to(self, device):
2221
- device = warp.get_device(device)
2222
- if self.device == device:
2223
- return self
2224
- else:
2225
- return warp.clone(self, device=device)
2226
-
2227
- # return a contiguous numpy copy
2228
- def numpy(self):
2229
- # use the CUDA default stream for synchronous behaviour with other streams
2230
- with warp.ScopedStream(self.device.null_stream):
2231
- return self.contiguous().numpy()
2232
-
2233
- # returns a flattened list of items in the array as a Python list
2234
- def list(self):
2235
- # use the CUDA default stream for synchronous behaviour with other streams
2236
- with warp.ScopedStream(self.device.null_stream):
2237
- return self.contiguous().list()
2238
-
2239
- def zero_(self):
2240
- self.fill_(0)
2241
-
2242
- def fill_(self, value):
2243
- if self.size == 0:
2244
- return
2245
-
2246
- # try to convert the given value to the array dtype
2247
- try:
2248
- if isinstance(self.dtype, warp.codegen.Struct):
2249
- if isinstance(value, self.dtype.cls):
2250
- cvalue = value.__ctype__()
2251
- elif value == 0:
2252
- # allow zero-initializing structs using default constructor
2253
- cvalue = self.dtype().__ctype__()
2254
- else:
2255
- raise ValueError(
2256
- f"Invalid initializer value for struct {self.dtype.cls.__name__}, expected struct instance or 0"
2257
- )
2258
- elif issubclass(self.dtype, ctypes.Array):
2259
- # vector/matrix
2260
- cvalue = self.dtype(value)
2261
- else:
2262
- # scalar
2263
- if type(value) in warp.types.scalar_types:
2264
- value = value.value
2265
- if self.dtype == float16:
2266
- cvalue = self.dtype._type_(float_to_half_bits(value))
2267
- else:
2268
- cvalue = self.dtype._type_(value)
2269
- except Exception as e:
2270
- raise ValueError(f"Failed to convert the value to the array data type: {e}")
2271
-
2272
- cvalue_ptr = ctypes.pointer(cvalue)
2273
- cvalue_size = ctypes.sizeof(cvalue)
2274
-
2275
- ctype = self.__ctype__()
2276
- ctype_ptr = ctypes.pointer(ctype)
2277
-
2278
- if self.device.is_cuda:
2279
- warp.context.runtime.core.array_fill_device(
2280
- self.device.context, ctype_ptr, ARRAY_TYPE_INDEXED, cvalue_ptr, cvalue_size
2281
- )
2282
- else:
2283
- warp.context.runtime.core.array_fill_host(ctype_ptr, ARRAY_TYPE_INDEXED, cvalue_ptr, cvalue_size)
2284
-
2285
- # equivalent to wrapping src data in an array and copying to self
2286
- def assign(self, src):
2287
- if is_array(src):
2288
- warp.copy(self, src)
2289
- else:
2290
- warp.copy(self, array(data=src, dtype=self.dtype, copy=False, device="cpu"))
2291
-
2292
2629
 
2293
2630
  # aliases for indexedarrays with small dimensions
2294
2631
  def indexedarray1d(*args, **kwargs):
@@ -2314,16 +2651,22 @@ def indexedarray4d(*args, **kwargs):
2314
2651
  return indexedarray(*args, **kwargs)
2315
2652
 
2316
2653
 
2317
- array_types = (array, indexedarray)
2654
+ from warp.fabric import fabricarray, indexedfabricarray # noqa: E402
2655
+
2656
+ array_types = (array, indexedarray, fabricarray, indexedfabricarray)
2318
2657
 
2319
2658
 
2320
2659
  def array_type_id(a):
2321
- if isinstance(a, warp.array):
2322
- return warp.types.ARRAY_TYPE_REGULAR
2323
- elif isinstance(a, warp.indexedarray):
2324
- return warp.types.ARRAY_TYPE_INDEXED
2660
+ if isinstance(a, array):
2661
+ return ARRAY_TYPE_REGULAR
2662
+ elif isinstance(a, indexedarray):
2663
+ return ARRAY_TYPE_INDEXED
2664
+ elif isinstance(a, fabricarray):
2665
+ return ARRAY_TYPE_FABRIC
2666
+ elif isinstance(a, indexedfabricarray):
2667
+ return ARRAY_TYPE_FABRIC_INDEXED
2325
2668
  else:
2326
- raise ValueError(f"Invalid array")
2669
+ raise ValueError("Invalid array type")
2327
2670
 
2328
2671
 
2329
2672
  class Bvh:
@@ -2381,11 +2724,11 @@ class Bvh:
2381
2724
  with self.device.context_guard:
2382
2725
  runtime.core.bvh_destroy_device(self.id)
2383
2726
 
2384
- except:
2727
+ except Exception:
2385
2728
  pass
2386
2729
 
2387
2730
  def refit(self):
2388
- """Refit the Bvh. This should be called after users modify the `lowers` and `uppers` arrays."""
2731
+ """Refit the BVH. This should be called after users modify the `lowers` and `uppers` arrays."""
2389
2732
 
2390
2733
  from warp.context import runtime
2391
2734
 
@@ -2471,7 +2814,7 @@ class Mesh:
2471
2814
  # use CUDA context guard to avoid side effects during garbage collection
2472
2815
  with self.device.context_guard:
2473
2816
  runtime.core.mesh_destroy_device(self.id)
2474
- except:
2817
+ except Exception:
2475
2818
  pass
2476
2819
 
2477
2820
  def refit(self):
@@ -2487,16 +2830,14 @@ class Mesh:
2487
2830
 
2488
2831
 
2489
2832
  class Volume:
2833
+ #: Enum value to specify nearest-neighbor interpolation during sampling
2490
2834
  CLOSEST = constant(0)
2835
+ #: Enum value to specify trilinear interpolation during sampling
2491
2836
  LINEAR = constant(1)
2492
2837
 
2493
2838
  def __init__(self, data: array):
2494
2839
  """Class representing a sparse grid.
2495
2840
 
2496
- Attributes:
2497
- CLOSEST (int): Enum value to specify nearest-neighbor interpolation during sampling
2498
- LINEAR (int): Enum value to specify trilinear interpolation during sampling
2499
-
2500
2841
  Args:
2501
2842
  data (:class:`warp.array`): Array of bytes representing the volume in NanoVDB format
2502
2843
  """
@@ -2538,10 +2879,11 @@ class Volume:
2538
2879
  with self.device.context_guard:
2539
2880
  runtime.core.volume_destroy_device(self.id)
2540
2881
 
2541
- except:
2882
+ except Exception:
2542
2883
  pass
2543
2884
 
2544
- def array(self):
2885
+ def array(self) -> array:
2886
+ """Returns the raw memory buffer of the Volume as an array"""
2545
2887
  buf = ctypes.c_void_p(0)
2546
2888
  size = ctypes.c_uint64(0)
2547
2889
  if self.device.is_cpu:
@@ -2550,7 +2892,7 @@ class Volume:
2550
2892
  self.context.core.volume_get_buffer_info_device(self.id, ctypes.byref(buf), ctypes.byref(size))
2551
2893
  return array(ptr=buf.value, dtype=uint8, shape=size.value, device=self.device, owner=False)
2552
2894
 
2553
- def get_tiles(self):
2895
+ def get_tiles(self) -> array:
2554
2896
  if self.id == 0:
2555
2897
  raise RuntimeError("Invalid Volume")
2556
2898
 
@@ -2563,7 +2905,7 @@ class Volume:
2563
2905
  num_tiles = size.value // (3 * 4)
2564
2906
  return array(ptr=buf.value, dtype=int32, shape=(num_tiles, 3), device=self.device, owner=True)
2565
2907
 
2566
- def get_voxel_size(self):
2908
+ def get_voxel_size(self) -> Tuple[float, float, float]:
2567
2909
  if self.id == 0:
2568
2910
  raise RuntimeError("Invalid Volume")
2569
2911
 
@@ -2572,7 +2914,13 @@ class Volume:
2572
2914
  return (dx.value, dy.value, dz.value)
2573
2915
 
2574
2916
  @classmethod
2575
- def load_from_nvdb(cls, file_or_buffer, device=None):
2917
+ def load_from_nvdb(cls, file_or_buffer, device=None) -> Volume:
2918
+ """Creates a Volume object from a NanoVDB file or in-memory buffer.
2919
+
2920
+ Returns:
2921
+
2922
+ A ``warp.Volume`` object.
2923
+ """
2576
2924
  try:
2577
2925
  data = file_or_buffer.read()
2578
2926
  except AttributeError:
@@ -2601,6 +2949,90 @@ class Volume:
2601
2949
  data_array = array(np.frombuffer(grid_data, dtype=np.byte), device=device)
2602
2950
  return cls(data_array)
2603
2951
 
2952
+ @classmethod
2953
+ def load_from_numpy(
2954
+ cls, ndarray: np.array, min_world=(0.0, 0.0, 0.0), voxel_size=1.0, bg_value=0.0, device=None
2955
+ ) -> Volume:
2956
+ """Creates a Volume object from a dense 3D NumPy array.
2957
+
2958
+ This function is only supported for CUDA devices.
2959
+
2960
+ Args:
2961
+ min_world: The 3D coordinate of the lower corner of the volume.
2962
+ voxel_size: The size of each voxel in spatial coordinates.
2963
+ bg_value: Background value
2964
+ device: The CUDA device to create the volume on, e.g.: "cuda" or "cuda:0".
2965
+
2966
+ Returns:
2967
+
2968
+ A ``warp.Volume`` object.
2969
+ """
2970
+
2971
+ import math
2972
+
2973
+ target_shape = (
2974
+ math.ceil(ndarray.shape[0] / 8) * 8,
2975
+ math.ceil(ndarray.shape[1] / 8) * 8,
2976
+ math.ceil(ndarray.shape[2] / 8) * 8,
2977
+ )
2978
+ if hasattr(bg_value, "__len__"):
2979
+ # vec3, assuming the numpy array is 4D
2980
+ padded_array = np.array((target_shape[0], target_shape[1], target_shape[2], 3), dtype=np.single)
2981
+ padded_array[:, :, :, :] = np.array(bg_value)
2982
+ padded_array[0 : ndarray.shape[0], 0 : ndarray.shape[1], 0 : ndarray.shape[2], :] = ndarray
2983
+ else:
2984
+ padded_amount = (
2985
+ math.ceil(ndarray.shape[0] / 8) * 8 - ndarray.shape[0],
2986
+ math.ceil(ndarray.shape[1] / 8) * 8 - ndarray.shape[1],
2987
+ math.ceil(ndarray.shape[2] / 8) * 8 - ndarray.shape[2],
2988
+ )
2989
+ padded_array = np.pad(
2990
+ ndarray,
2991
+ ((0, padded_amount[0]), (0, padded_amount[1]), (0, padded_amount[2])),
2992
+ mode="constant",
2993
+ constant_values=bg_value,
2994
+ )
2995
+
2996
+ shape = padded_array.shape
2997
+ volume = warp.Volume.allocate(
2998
+ min_world,
2999
+ [
3000
+ min_world[0] + (shape[0] - 1) * voxel_size,
3001
+ min_world[1] + (shape[1] - 1) * voxel_size,
3002
+ min_world[2] + (shape[2] - 1) * voxel_size,
3003
+ ],
3004
+ voxel_size,
3005
+ bg_value=bg_value,
3006
+ points_in_world_space=True,
3007
+ translation=min_world,
3008
+ device=device,
3009
+ )
3010
+
3011
+ # Populate volume
3012
+ if hasattr(bg_value, "__len__"):
3013
+ warp.launch(
3014
+ warp.utils.copy_dense_volume_to_nano_vdb_v,
3015
+ dim=(shape[0], shape[1], shape[2]),
3016
+ inputs=[volume.id, warp.array(padded_array, dtype=warp.vec3, device=device)],
3017
+ device=device,
3018
+ )
3019
+ elif isinstance(bg_value, int):
3020
+ warp.launch(
3021
+ warp.utils.copy_dense_volume_to_nano_vdb_i,
3022
+ dim=shape,
3023
+ inputs=[volume.id, warp.array(padded_array, dtype=warp.int32, device=device)],
3024
+ device=device,
3025
+ )
3026
+ else:
3027
+ warp.launch(
3028
+ warp.utils.copy_dense_volume_to_nano_vdb_f,
3029
+ dim=shape,
3030
+ inputs=[volume.id, warp.array(padded_array, dtype=warp.float32, device=device)],
3031
+ device=device,
3032
+ )
3033
+
3034
+ return volume
3035
+
2604
3036
  @classmethod
2605
3037
  def allocate(
2606
3038
  cls,
@@ -2611,9 +3043,11 @@ class Volume:
2611
3043
  translation=(0.0, 0.0, 0.0),
2612
3044
  points_in_world_space=False,
2613
3045
  device=None,
2614
- ):
3046
+ ) -> Volume:
2615
3047
  """Allocate a new Volume based on the bounding box defined by min and max.
2616
3048
 
3049
+ This function is only supported for CUDA devices.
3050
+
2617
3051
  Allocate a volume that is large enough to contain voxels [min[0], min[1], min[2]] - [max[0], max[1], max[2]], inclusive.
2618
3052
  If points_in_world_space is true, then min and max are first converted to index space with the given voxel size and
2619
3053
  translation, and the volume is allocated with those.
@@ -2622,12 +3056,12 @@ class Volume:
2622
3056
  the resulting tiles will be available in the new volume.
2623
3057
 
2624
3058
  Args:
2625
- min (array-like): Lower 3D-coordinates of the bounding box in index space or world space, inclusive
2626
- max (array-like): Upper 3D-coordinates of the bounding box in index space or world space, inclusive
2627
- voxel_size (float): Voxel size of the new volume
3059
+ min (array-like): Lower 3D coordinates of the bounding box in index space or world space, inclusive.
3060
+ max (array-like): Upper 3D coordinates of the bounding box in index space or world space, inclusive.
3061
+ voxel_size (float): Voxel size of the new volume.
2628
3062
  bg_value (float or array-like): Value of unallocated voxels of the volume, also defines the volume's type, a :class:`warp.vec3` volume is created if this is `array-like`, otherwise a float volume is created
2629
- translation (array-like): translation between the index and world spaces
2630
- device (Devicelike): Device the array lives on
3063
+ translation (array-like): translation between the index and world spaces.
3064
+ device (Devicelike): The CUDA device to create the volume on, e.g.: "cuda" or "cuda:0".
2631
3065
 
2632
3066
  """
2633
3067
  if points_in_world_space:
@@ -2652,9 +3086,11 @@ class Volume:
2652
3086
  @classmethod
2653
3087
  def allocate_by_tiles(
2654
3088
  cls, tile_points: array, voxel_size: float, bg_value=0.0, translation=(0.0, 0.0, 0.0), device=None
2655
- ):
3089
+ ) -> Volume:
2656
3090
  """Allocate a new Volume with active tiles for each point tile_points.
2657
3091
 
3092
+ This function is only supported for CUDA devices.
3093
+
2658
3094
  The smallest unit of allocation is a dense tile of 8x8x8 voxels.
2659
3095
  This is the primary method for allocating sparse volumes. It uses an array of points indicating the tiles that must be allocated.
2660
3096
 
@@ -2664,13 +3100,13 @@ class Volume:
2664
3100
 
2665
3101
  Args:
2666
3102
  tile_points (:class:`warp.array`): Array of positions that define the tiles to be allocated.
2667
- The array can be a 2d, N-by-3 array of :class:`warp.int32` values, indicating index space positions,
3103
+ The array can be a 2D, N-by-3 array of :class:`warp.int32` values, indicating index space positions,
2668
3104
  or can be a 1D array of :class:`warp.vec3` values, indicating world space positions.
2669
3105
  Repeated points per tile are allowed and will be efficiently deduplicated.
2670
- voxel_size (float): Voxel size of the new volume
3106
+ voxel_size (float): Voxel size of the new volume.
2671
3107
  bg_value (float or array-like): Value of unallocated voxels of the volume, also defines the volume's type, a :class:`warp.vec3` volume is created if this is `array-like`, otherwise a float volume is created
2672
- translation (array-like): translation between the index and world spaces
2673
- device (Devicelike): Device the array lives on
3108
+ translation (array-like): Translation between the index and world spaces.
3109
+ device (Devicelike): The CUDA device to create the volume on, e.g.: "cuda" or "cuda:0".
2674
3110
 
2675
3111
  """
2676
3112
  from warp.context import runtime
@@ -2707,7 +3143,7 @@ class Volume:
2707
3143
  translation[2],
2708
3144
  in_world_space,
2709
3145
  )
2710
- elif type(bg_value) == int:
3146
+ elif isinstance(bg_value, int):
2711
3147
  volume.id = volume.context.core.volume_i_from_tiles_device(
2712
3148
  volume.device.context,
2713
3149
  ctypes.c_void_p(tile_points.ptr),
@@ -2738,6 +3174,67 @@ class Volume:
2738
3174
  return volume
2739
3175
 
2740
3176
 
3177
+ # definition just for kernel type (cannot be a parameter), see mesh.h
3178
+ # NOTE: its layout must match the corresponding struct defined in C.
3179
+ # NOTE: it needs to be defined after `indexedarray` to workaround a circular import issue.
3180
+ class mesh_query_point_t:
3181
+ """Output for the mesh query point functions.
3182
+
3183
+ Attributes:
3184
+ result (bool): Whether a point is found within the given constraints.
3185
+ sign (float32): A value < 0 if query point is inside the mesh, >=0 otherwise.
3186
+ Note that mesh must be watertight for this to be robust
3187
+ face (int32): Index of the closest face.
3188
+ u (float32): Barycentric u coordinate of the closest point.
3189
+ v (float32): Barycentric v coordinate of the closest point.
3190
+
3191
+ See Also:
3192
+ :func:`mesh_query_point`, :func:`mesh_query_point_no_sign`,
3193
+ :func:`mesh_query_furthest_point_no_sign`,
3194
+ :func:`mesh_query_point_sign_normal`,
3195
+ and :func:`mesh_query_point_sign_winding_number`.
3196
+ """
3197
+ from warp.codegen import Var
3198
+
3199
+ vars = {
3200
+ "result": Var("result", bool),
3201
+ "sign": Var("sign", float32),
3202
+ "face": Var("face", int32),
3203
+ "u": Var("u", float32),
3204
+ "v": Var("v", float32),
3205
+ }
3206
+
3207
+
3208
+ # definition just for kernel type (cannot be a parameter), see mesh.h
3209
+ # NOTE: its layout must match the corresponding struct defined in C.
3210
+ class mesh_query_ray_t:
3211
+ """Output for the mesh query ray functions.
3212
+
3213
+ Attributes:
3214
+ result (bool): Whether a hit is found within the given constraints.
3215
+ sign (float32): A value > 0 if the ray hit in front of the face, returns < 0 otherwise.
3216
+ face (int32): Index of the closest face.
3217
+ t (float32): Distance of the closest hit along the ray.
3218
+ u (float32): Barycentric u coordinate of the closest hit.
3219
+ v (float32): Barycentric v coordinate of the closest hit.
3220
+ normal (vec3f): Face normal.
3221
+
3222
+ See Also:
3223
+ :func:`mesh_query_ray`.
3224
+ """
3225
+ from warp.codegen import Var
3226
+
3227
+ vars = {
3228
+ "result": Var("result", bool),
3229
+ "sign": Var("sign", float32),
3230
+ "face": Var("face", int32),
3231
+ "t": Var("t", float32),
3232
+ "u": Var("u", float32),
3233
+ "v": Var("v", float32),
3234
+ "normal": Var("normal", vec3),
3235
+ }
3236
+
3237
+
2741
3238
  def matmul(
2742
3239
  a: array2d,
2743
3240
  b: array2d,
@@ -2745,7 +3242,7 @@ def matmul(
2745
3242
  d: array2d,
2746
3243
  alpha: float = 1.0,
2747
3244
  beta: float = 0.0,
2748
- allow_tf32x3_arith: bool = False,
3245
+ allow_tf32x3_arith: builtins.bool = False,
2749
3246
  device=None,
2750
3247
  ):
2751
3248
  """Computes a generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
@@ -2774,6 +3271,11 @@ def matmul(
2774
3271
  "wp.matmul currently only supports operation between {A, B, C, D} matrices of the same type."
2775
3272
  )
2776
3273
 
3274
+ if (not a.is_contiguous and not a.is_transposed) or (not b.is_contiguous and not b.is_transposed) or (not c.is_contiguous) or (not d.is_contiguous):
3275
+ raise RuntimeError(
3276
+ "wp.matmul is only valid for contiguous arrays, with the exception that A and/or B may be transposed."
3277
+ )
3278
+
2777
3279
  m = a.shape[0]
2778
3280
  n = b.shape[1]
2779
3281
  k = a.shape[1]
@@ -2808,13 +3310,13 @@ def matmul(
2808
3310
  ctypes.c_void_p(d.ptr),
2809
3311
  alpha,
2810
3312
  beta,
2811
- True,
2812
- True,
3313
+ not a.is_transposed,
3314
+ not b.is_transposed,
2813
3315
  allow_tf32x3_arith,
2814
3316
  1,
2815
3317
  )
2816
3318
  if not ret:
2817
- raise RuntimeError("Matmul failed.")
3319
+ raise RuntimeError("matmul failed.")
2818
3320
 
2819
3321
 
2820
3322
  def adj_matmul(
@@ -2827,7 +3329,7 @@ def adj_matmul(
2827
3329
  adj_d: array2d,
2828
3330
  alpha: float = 1.0,
2829
3331
  beta: float = 0.0,
2830
- allow_tf32x3_arith: bool = False,
3332
+ allow_tf32x3_arith: builtins.bool = False,
2831
3333
  device=None,
2832
3334
  ):
2833
3335
  """Computes the adjoint of a generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
@@ -2878,6 +3380,19 @@ def adj_matmul(
2878
3380
  "wp.adj_matmul currently only supports operation between {A, B, C, adj_D, adj_A, adj_B, adj_C} matrices of the same type."
2879
3381
  )
2880
3382
 
3383
+ if (
3384
+ (not a.is_contiguous and not a.is_transposed)
3385
+ or (not b.is_contiguous and not b.is_transposed)
3386
+ or (not c.is_contiguous)
3387
+ or (not adj_a.is_contiguous and not adj_a.is_transposed)
3388
+ or (not adj_b.is_contiguous and not adj_b.is_transposed)
3389
+ or (not adj_c.is_contiguous)
3390
+ or (not adj_d.is_contiguous)
3391
+ ):
3392
+ raise RuntimeError(
3393
+ "wp.matmul is only valid for contiguous arrays, with the exception that A and/or B and their associated adjoints may be transposed."
3394
+ )
3395
+
2881
3396
  m = a.shape[0]
2882
3397
  n = b.shape[1]
2883
3398
  k = a.shape[1]
@@ -2898,75 +3413,105 @@ def adj_matmul(
2898
3413
 
2899
3414
  # cpu fallback if no cuda devices found
2900
3415
  if device == "cpu":
2901
- adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose()))
2902
- adj_b.assign(alpha * (a.numpy().transpose() @ adj_d.numpy()))
2903
- adj_c.assign(beta * adj_d.numpy())
3416
+ adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose()) + adj_a.numpy())
3417
+ adj_b.assign(alpha * (a.numpy().transpose() @ adj_d.numpy()) + adj_b.numpy())
3418
+ adj_c.assign(beta * adj_d.numpy() + adj_c.numpy())
2904
3419
  return
2905
3420
 
2906
3421
  cc = device.arch
2907
3422
 
2908
3423
  # adj_a
2909
- ret = runtime.core.cutlass_gemm(
2910
- cc,
2911
- m,
2912
- k,
2913
- n,
2914
- type_typestr(a.dtype).encode(),
2915
- ctypes.c_void_p(adj_d.ptr),
2916
- ctypes.c_void_p(b.ptr),
2917
- ctypes.c_void_p(a.ptr),
2918
- ctypes.c_void_p(adj_a.ptr),
2919
- alpha,
2920
- 0.0,
2921
- True,
2922
- False,
2923
- allow_tf32x3_arith,
2924
- 1,
2925
- )
2926
- if not ret:
2927
- raise RuntimeError("adj_matmul failed.")
3424
+ if not a.is_transposed:
3425
+ ret = runtime.core.cutlass_gemm(
3426
+ cc,
3427
+ m,
3428
+ k,
3429
+ n,
3430
+ type_typestr(a.dtype).encode(),
3431
+ ctypes.c_void_p(adj_d.ptr),
3432
+ ctypes.c_void_p(b.ptr),
3433
+ ctypes.c_void_p(adj_a.ptr),
3434
+ ctypes.c_void_p(adj_a.ptr),
3435
+ alpha,
3436
+ 1.0,
3437
+ True,
3438
+ b.is_transposed,
3439
+ allow_tf32x3_arith,
3440
+ 1,
3441
+ )
3442
+ if not ret:
3443
+ raise RuntimeError("adj_matmul failed.")
3444
+ else:
3445
+ ret = runtime.core.cutlass_gemm(
3446
+ cc,
3447
+ k,
3448
+ m,
3449
+ n,
3450
+ type_typestr(a.dtype).encode(),
3451
+ ctypes.c_void_p(b.ptr),
3452
+ ctypes.c_void_p(adj_d.ptr),
3453
+ ctypes.c_void_p(adj_a.ptr),
3454
+ ctypes.c_void_p(adj_a.ptr),
3455
+ alpha,
3456
+ 1.0,
3457
+ not b.is_transposed,
3458
+ False,
3459
+ allow_tf32x3_arith,
3460
+ 1,
3461
+ )
3462
+ if not ret:
3463
+ raise RuntimeError("adj_matmul failed.")
2928
3464
 
2929
3465
  # adj_b
2930
- ret = runtime.core.cutlass_gemm(
2931
- cc,
2932
- k,
2933
- n,
2934
- m,
2935
- type_typestr(a.dtype).encode(),
2936
- ctypes.c_void_p(a.ptr),
2937
- ctypes.c_void_p(adj_d.ptr),
2938
- ctypes.c_void_p(b.ptr),
2939
- ctypes.c_void_p(adj_b.ptr),
2940
- alpha,
2941
- 0.0,
2942
- False,
2943
- True,
2944
- allow_tf32x3_arith,
2945
- 1,
2946
- )
2947
- if not ret:
2948
- raise RuntimeError("adj_matmul failed.")
3466
+ if not b.is_transposed:
3467
+ ret = runtime.core.cutlass_gemm(
3468
+ cc,
3469
+ k,
3470
+ n,
3471
+ m,
3472
+ type_typestr(a.dtype).encode(),
3473
+ ctypes.c_void_p(a.ptr),
3474
+ ctypes.c_void_p(adj_d.ptr),
3475
+ ctypes.c_void_p(adj_b.ptr),
3476
+ ctypes.c_void_p(adj_b.ptr),
3477
+ alpha,
3478
+ 1.0,
3479
+ a.is_transposed,
3480
+ True,
3481
+ allow_tf32x3_arith,
3482
+ 1,
3483
+ )
3484
+ if not ret:
3485
+ raise RuntimeError("adj_matmul failed.")
3486
+ else:
3487
+ ret = runtime.core.cutlass_gemm(
3488
+ cc,
3489
+ n,
3490
+ k,
3491
+ m,
3492
+ type_typestr(a.dtype).encode(),
3493
+ ctypes.c_void_p(adj_d.ptr),
3494
+ ctypes.c_void_p(a.ptr),
3495
+ ctypes.c_void_p(adj_b.ptr),
3496
+ ctypes.c_void_p(adj_b.ptr),
3497
+ alpha,
3498
+ 1.0,
3499
+ False,
3500
+ not a.is_transposed,
3501
+ allow_tf32x3_arith,
3502
+ 1,
3503
+ )
3504
+ if not ret:
3505
+ raise RuntimeError("adj_matmul failed.")
2949
3506
 
2950
3507
  # adj_c
2951
- ret = runtime.core.cutlass_gemm(
2952
- cc,
2953
- m,
2954
- n,
2955
- k,
2956
- type_typestr(a.dtype).encode(),
2957
- ctypes.c_void_p(a.ptr),
2958
- ctypes.c_void_p(b.ptr),
2959
- ctypes.c_void_p(adj_d.ptr),
2960
- ctypes.c_void_p(adj_c.ptr),
2961
- 0.0,
2962
- beta,
2963
- True,
2964
- True,
2965
- allow_tf32x3_arith,
2966
- 1,
3508
+ warp.launch(
3509
+ kernel=warp.utils.add_kernel_2d,
3510
+ dim=adj_c.shape,
3511
+ inputs=[adj_c, adj_d, adj_d.dtype(beta)],
3512
+ device=device,
3513
+ record_tape=False
2967
3514
  )
2968
- if not ret:
2969
- raise RuntimeError("adj_matmul failed.")
2970
3515
 
2971
3516
 
2972
3517
  def batched_matmul(
@@ -2976,7 +3521,7 @@ def batched_matmul(
2976
3521
  d: array3d,
2977
3522
  alpha: float = 1.0,
2978
3523
  beta: float = 0.0,
2979
- allow_tf32x3_arith: bool = False,
3524
+ allow_tf32x3_arith: builtins.bool = False,
2980
3525
  device=None,
2981
3526
  ):
2982
3527
  """Computes a batched generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
@@ -3005,6 +3550,11 @@ def batched_matmul(
3005
3550
  "wp.batched_matmul currently only supports operation between {A, B, C, D} matrices of the same type."
3006
3551
  )
3007
3552
 
3553
+ if (not a.is_contiguous and not a.is_transposed) or (not b.is_contiguous and not b.is_transposed) or (not c.is_contiguous) or (not d.is_contiguous):
3554
+ raise RuntimeError(
3555
+ "wp.matmul is only valid for contiguous arrays, with the exception that A and/or B may be transposed."
3556
+ )
3557
+
3008
3558
  m = a.shape[1]
3009
3559
  n = b.shape[2]
3010
3560
  k = a.shape[2]
@@ -3016,7 +3566,7 @@ def batched_matmul(
3016
3566
 
3017
3567
  if runtime.tape:
3018
3568
  runtime.tape.record_func(
3019
- backward=lambda: adj_matmul(
3569
+ backward=lambda: adj_batched_matmul(
3020
3570
  a, b, c, a.grad, b.grad, c.grad, d.grad, alpha, beta, allow_tf32x3_arith, device
3021
3571
  ),
3022
3572
  arrays=[a, b, c, d],
@@ -3027,26 +3577,55 @@ def batched_matmul(
3027
3577
  d.assign(alpha * np.matmul(a.numpy(), b.numpy()) + beta * c.numpy())
3028
3578
  return
3029
3579
 
3580
+ # handle case in which batch_count exceeds max_batch_count, which is a CUDA array size maximum
3581
+ max_batch_count = 65535
3582
+ iters = int(batch_count / max_batch_count)
3583
+ remainder = batch_count % max_batch_count
3584
+
3030
3585
  cc = device.arch
3586
+ for i in range(iters):
3587
+ idx_start = i * max_batch_count
3588
+ idx_end = (i + 1) * max_batch_count if i < iters - 1 else batch_count
3589
+ ret = runtime.core.cutlass_gemm(
3590
+ cc,
3591
+ m,
3592
+ n,
3593
+ k,
3594
+ type_typestr(a.dtype).encode(),
3595
+ ctypes.c_void_p(a[idx_start:idx_end,:,:].ptr),
3596
+ ctypes.c_void_p(b[idx_start:idx_end,:,:].ptr),
3597
+ ctypes.c_void_p(c[idx_start:idx_end,:,:].ptr),
3598
+ ctypes.c_void_p(d[idx_start:idx_end,:,:].ptr),
3599
+ alpha,
3600
+ beta,
3601
+ not a.is_transposed,
3602
+ not b.is_transposed,
3603
+ allow_tf32x3_arith,
3604
+ max_batch_count,
3605
+ )
3606
+ if not ret:
3607
+ raise RuntimeError("Batched matmul failed.")
3608
+
3609
+ idx_start = iters * max_batch_count
3031
3610
  ret = runtime.core.cutlass_gemm(
3032
3611
  cc,
3033
3612
  m,
3034
3613
  n,
3035
3614
  k,
3036
3615
  type_typestr(a.dtype).encode(),
3037
- ctypes.c_void_p(a.ptr),
3038
- ctypes.c_void_p(b.ptr),
3039
- ctypes.c_void_p(c.ptr),
3040
- ctypes.c_void_p(d.ptr),
3616
+ ctypes.c_void_p(a[idx_start:,:,:].ptr),
3617
+ ctypes.c_void_p(b[idx_start:,:,:].ptr),
3618
+ ctypes.c_void_p(c[idx_start:,:,:].ptr),
3619
+ ctypes.c_void_p(d[idx_start:,:,:].ptr),
3041
3620
  alpha,
3042
3621
  beta,
3043
- True,
3044
- True,
3622
+ not a.is_transposed,
3623
+ not b.is_transposed,
3045
3624
  allow_tf32x3_arith,
3046
- batch_count,
3625
+ remainder,
3047
3626
  )
3048
3627
  if not ret:
3049
- raise RuntimeError("Batched matmul failed.")
3628
+ raise RuntimeError("Batched matmul failed.")
3050
3629
 
3051
3630
 
3052
3631
  def adj_batched_matmul(
@@ -3059,7 +3638,7 @@ def adj_batched_matmul(
3059
3638
  adj_d: array3d,
3060
3639
  alpha: float = 1.0,
3061
3640
  beta: float = 0.0,
3062
- allow_tf32x3_arith: bool = False,
3641
+ allow_tf32x3_arith: builtins.bool = False,
3063
3642
  device=None,
3064
3643
  ):
3065
3644
  """Computes a batched generic matrix-matrix multiplication (GEMM) of the form: `d = alpha * (a @ b) + beta * c`.
@@ -3126,78 +3705,215 @@ def adj_batched_matmul(
3126
3705
  )
3127
3706
  )
3128
3707
 
3708
+ if (
3709
+ (not a.is_contiguous and not a.is_transposed)
3710
+ or (not b.is_contiguous and not b.is_transposed)
3711
+ or (not c.is_contiguous)
3712
+ or (not adj_a.is_contiguous and not adj_a.is_transposed)
3713
+ or (not adj_b.is_contiguous and not adj_b.is_transposed)
3714
+ or (not adj_c.is_contiguous)
3715
+ or (not adj_d.is_contiguous)
3716
+ ):
3717
+ raise RuntimeError(
3718
+ "wp.matmul is only valid for contiguous arrays, with the exception that A and/or B and their associated adjoints may be transposed."
3719
+ )
3720
+
3129
3721
  # cpu fallback if no cuda devices found
3130
3722
  if device == "cpu":
3131
- adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose((0, 2, 1))))
3132
- adj_b.assign(alpha * np.matmul(a.numpy().transpose((0, 2, 1)), adj_d.numpy()))
3133
- adj_c.assign(beta * adj_d.numpy())
3723
+ adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose((0, 2, 1))) + adj_a.numpy())
3724
+ adj_b.assign(alpha * np.matmul(a.numpy().transpose((0, 2, 1)), adj_d.numpy()) + adj_b.numpy())
3725
+ adj_c.assign(beta * adj_d.numpy() + adj_c.numpy())
3134
3726
  return
3135
3727
 
3728
+ # handle case in which batch_count exceeds max_batch_count, which is a CUDA array size maximum
3729
+ max_batch_count = 65535
3730
+ iters = int(batch_count / max_batch_count)
3731
+ remainder = batch_count % max_batch_count
3732
+
3136
3733
  cc = device.arch
3137
3734
 
3735
+ for i in range(iters):
3736
+ idx_start = i * max_batch_count
3737
+ idx_end = (i + 1) * max_batch_count if i < iters - 1 else batch_count
3738
+
3739
+ # adj_a
3740
+ if not a.is_transposed:
3741
+ ret = runtime.core.cutlass_gemm(
3742
+ cc,
3743
+ m,
3744
+ k,
3745
+ n,
3746
+ type_typestr(a.dtype).encode(),
3747
+ ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
3748
+ ctypes.c_void_p(b[idx_start:idx_end,:,:].ptr),
3749
+ ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
3750
+ ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
3751
+ alpha,
3752
+ 1.0,
3753
+ True,
3754
+ b.is_transposed,
3755
+ allow_tf32x3_arith,
3756
+ max_batch_count,
3757
+ )
3758
+ if not ret:
3759
+ raise RuntimeError("adj_matmul failed.")
3760
+ else:
3761
+ ret = runtime.core.cutlass_gemm(
3762
+ cc,
3763
+ k,
3764
+ m,
3765
+ n,
3766
+ type_typestr(a.dtype).encode(),
3767
+ ctypes.c_void_p(b[idx_start:idx_end,:,:].ptr),
3768
+ ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
3769
+ ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
3770
+ ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
3771
+ alpha,
3772
+ 1.0,
3773
+ not b.is_transposed,
3774
+ False,
3775
+ allow_tf32x3_arith,
3776
+ max_batch_count,
3777
+ )
3778
+ if not ret:
3779
+ raise RuntimeError("adj_matmul failed.")
3780
+
3781
+ # adj_b
3782
+ if not b.is_transposed:
3783
+ ret = runtime.core.cutlass_gemm(
3784
+ cc,
3785
+ k,
3786
+ n,
3787
+ m,
3788
+ type_typestr(a.dtype).encode(),
3789
+ ctypes.c_void_p(a[idx_start:idx_end,:,:].ptr),
3790
+ ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
3791
+ ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
3792
+ ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
3793
+ alpha,
3794
+ 1.0,
3795
+ a.is_transposed,
3796
+ True,
3797
+ allow_tf32x3_arith,
3798
+ max_batch_count,
3799
+ )
3800
+ if not ret:
3801
+ raise RuntimeError("adj_matmul failed.")
3802
+ else:
3803
+ ret = runtime.core.cutlass_gemm(
3804
+ cc,
3805
+ n,
3806
+ k,
3807
+ m,
3808
+ type_typestr(a.dtype).encode(),
3809
+ ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
3810
+ ctypes.c_void_p(a[idx_start:idx_end,:,:].ptr),
3811
+ ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
3812
+ ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
3813
+ alpha,
3814
+ 1.0,
3815
+ False,
3816
+ not a.is_transposed,
3817
+ allow_tf32x3_arith,
3818
+ max_batch_count,
3819
+ )
3820
+ if not ret:
3821
+ raise RuntimeError("adj_matmul failed.")
3822
+
3823
+ idx_start = iters * max_batch_count
3824
+
3138
3825
  # adj_a
3139
- ret = runtime.core.cutlass_gemm(
3140
- cc,
3141
- m,
3142
- k,
3143
- n,
3144
- type_typestr(a.dtype).encode(),
3145
- ctypes.c_void_p(adj_d.ptr),
3146
- ctypes.c_void_p(b.ptr),
3147
- ctypes.c_void_p(a.ptr),
3148
- ctypes.c_void_p(adj_a.ptr),
3149
- alpha,
3150
- 0.0,
3151
- True,
3152
- False,
3153
- allow_tf32x3_arith,
3154
- batch_count,
3155
- )
3156
- if not ret:
3157
- raise RuntimeError("adj_matmul failed.")
3826
+ if not a.is_transposed:
3827
+ ret = runtime.core.cutlass_gemm(
3828
+ cc,
3829
+ m,
3830
+ k,
3831
+ n,
3832
+ type_typestr(a.dtype).encode(),
3833
+ ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
3834
+ ctypes.c_void_p(b[idx_start:,:,:].ptr),
3835
+ ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
3836
+ ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
3837
+ alpha,
3838
+ 1.0,
3839
+ True,
3840
+ b.is_transposed,
3841
+ allow_tf32x3_arith,
3842
+ remainder,
3843
+ )
3844
+ if not ret:
3845
+ raise RuntimeError("adj_matmul failed.")
3846
+ else:
3847
+ ret = runtime.core.cutlass_gemm(
3848
+ cc,
3849
+ k,
3850
+ m,
3851
+ n,
3852
+ type_typestr(a.dtype).encode(),
3853
+ ctypes.c_void_p(b[idx_start:,:,:].ptr),
3854
+ ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
3855
+ ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
3856
+ ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
3857
+ alpha,
3858
+ 1.0,
3859
+ not b.is_transposed,
3860
+ False,
3861
+ allow_tf32x3_arith,
3862
+ remainder,
3863
+ )
3864
+ if not ret:
3865
+ raise RuntimeError("adj_matmul failed.")
3158
3866
 
3159
3867
  # adj_b
3160
- ret = runtime.core.cutlass_gemm(
3161
- cc,
3162
- k,
3163
- n,
3164
- m,
3165
- type_typestr(a.dtype).encode(),
3166
- ctypes.c_void_p(a.ptr),
3167
- ctypes.c_void_p(adj_d.ptr),
3168
- ctypes.c_void_p(b.ptr),
3169
- ctypes.c_void_p(adj_b.ptr),
3170
- alpha,
3171
- 0.0,
3172
- False,
3173
- True,
3174
- allow_tf32x3_arith,
3175
- batch_count,
3176
- )
3177
- if not ret:
3178
- raise RuntimeError("adj_matmul failed.")
3868
+ if not b.is_transposed:
3869
+ ret = runtime.core.cutlass_gemm(
3870
+ cc,
3871
+ k,
3872
+ n,
3873
+ m,
3874
+ type_typestr(a.dtype).encode(),
3875
+ ctypes.c_void_p(a[idx_start:,:,:].ptr),
3876
+ ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
3877
+ ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
3878
+ ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
3879
+ alpha,
3880
+ 1.0,
3881
+ a.is_transposed,
3882
+ True,
3883
+ allow_tf32x3_arith,
3884
+ remainder,
3885
+ )
3886
+ if not ret:
3887
+ raise RuntimeError("adj_matmul failed.")
3888
+ else:
3889
+ ret = runtime.core.cutlass_gemm(
3890
+ cc,
3891
+ n,
3892
+ k,
3893
+ m,
3894
+ type_typestr(a.dtype).encode(),
3895
+ ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
3896
+ ctypes.c_void_p(a[idx_start:,:,:].ptr),
3897
+ ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
3898
+ ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
3899
+ alpha,
3900
+ 1.0,
3901
+ False,
3902
+ not a.is_transposed,
3903
+ allow_tf32x3_arith,
3904
+ remainder,
3905
+ )
3906
+ if not ret:
3907
+ raise RuntimeError("adj_matmul failed.")
3179
3908
 
3180
3909
  # adj_c
3181
- ret = runtime.core.cutlass_gemm(
3182
- cc,
3183
- m,
3184
- n,
3185
- k,
3186
- type_typestr(a.dtype).encode(),
3187
- ctypes.c_void_p(a.ptr),
3188
- ctypes.c_void_p(b.ptr),
3189
- ctypes.c_void_p(adj_d.ptr),
3190
- ctypes.c_void_p(adj_c.ptr),
3191
- 0.0,
3192
- beta,
3193
- True,
3194
- True,
3195
- allow_tf32x3_arith,
3196
- batch_count,
3910
+ warp.launch(
3911
+ kernel=warp.utils.add_kernel_3d,
3912
+ dim=adj_c.shape,
3913
+ inputs=[adj_c, adj_d, adj_d.dtype(beta)],
3914
+ device=device,
3915
+ record_tape=False
3197
3916
  )
3198
- if not ret:
3199
- raise RuntimeError("adj_matmul failed.")
3200
-
3201
3917
 
3202
3918
  class HashGrid:
3203
3919
  def __init__(self, dim_x, dim_y, dim_z, device=None):
@@ -3266,7 +3982,7 @@ class HashGrid:
3266
3982
  with self.device.context_guard:
3267
3983
  runtime.core.hash_grid_destroy_device(self.id)
3268
3984
 
3269
- except:
3985
+ except Exception:
3270
3986
  pass
3271
3987
 
3272
3988
 
@@ -3340,7 +4056,7 @@ class MarchingCubes:
3340
4056
 
3341
4057
  if error:
3342
4058
  raise RuntimeError(
3343
- "Error occured buffers may not be large enough, marching cubes required at least {num_verts} vertices, and {num_tris} triangles."
4059
+ "Buffers may not be large enough, marching cubes required at least {num_verts} vertices, and {num_tris} triangles."
3344
4060
  )
3345
4061
 
3346
4062
  # resize the geometry arrays
@@ -3396,7 +4112,7 @@ def type_matches_template(arg_type, template_type):
3396
4112
  return True
3397
4113
  elif is_array(template_type):
3398
4114
  # ensure the argument type is a non-generic array with matching dtype and dimensionality
3399
- if type(arg_type) != type(template_type):
4115
+ if type(arg_type) is not type(template_type):
3400
4116
  return False
3401
4117
  if not type_matches_template(arg_type.dtype, template_type.dtype):
3402
4118
  return False
@@ -3429,7 +4145,7 @@ def infer_argument_types(args, template_types, arg_names=None):
3429
4145
  """Resolve argument types with the given list of template types."""
3430
4146
 
3431
4147
  if len(args) != len(template_types):
3432
- raise RuntimeError(f"Number of arguments must match number of template types.")
4148
+ raise RuntimeError("Number of arguments must match number of template types.")
3433
4149
 
3434
4150
  arg_types = []
3435
4151
 
@@ -3452,7 +4168,7 @@ def infer_argument_types(args, template_types, arg_names=None):
3452
4168
  arg_types.append(arg._cls)
3453
4169
  # elif arg_type in [warp.types.launch_bounds_t, warp.types.shape_t, warp.types.range_t]:
3454
4170
  # arg_types.append(arg_type)
3455
- # elif arg_type in [warp.hash_grid_query_t, warp.mesh_query_aabb_t, warp.bvh_query_t]:
4171
+ # elif arg_type in [warp.hash_grid_query_t, warp.mesh_query_aabb_t, warp.mesh_query_point_t, warp.mesh_query_ray_t, warp.bvh_query_t]:
3456
4172
  # arg_types.append(arg_type)
3457
4173
  elif arg is None:
3458
4174
  # allow passing None for arrays
@@ -3471,6 +4187,7 @@ def infer_argument_types(args, template_types, arg_names=None):
3471
4187
  simple_type_codes = {
3472
4188
  int: "i4",
3473
4189
  float: "f4",
4190
+ builtins.bool: "b",
3474
4191
  bool: "b",
3475
4192
  str: "str", # accepted by print()
3476
4193
  int8: "i1",
@@ -3489,6 +4206,8 @@ simple_type_codes = {
3489
4206
  launch_bounds_t: "lb",
3490
4207
  hash_grid_query_t: "hgq",
3491
4208
  mesh_query_aabb_t: "mqa",
4209
+ mesh_query_point_t: "mqp",
4210
+ mesh_query_ray_t: "mqr",
3492
4211
  bvh_query_t: "bvhq",
3493
4212
  }
3494
4213
 
@@ -3505,14 +4224,14 @@ def get_type_code(arg_type):
3505
4224
  # check for "special" vector/matrix subtypes
3506
4225
  if hasattr(arg_type, "_wp_generic_type_str_"):
3507
4226
  type_str = arg_type._wp_generic_type_str_
3508
- if type_str == "quaternion":
4227
+ if type_str == "quat_t":
3509
4228
  return f"q{dtype_code}"
3510
4229
  elif type_str == "transform_t":
3511
4230
  return f"t{dtype_code}"
3512
- elif type_str == "spatial_vector_t":
3513
- return f"sv{dtype_code}"
3514
- elif type_str == "spatial_matrix_t":
3515
- return f"sm{dtype_code}"
4231
+ # elif type_str == "spatial_vector_t":
4232
+ # return f"sv{dtype_code}"
4233
+ # elif type_str == "spatial_matrix_t":
4234
+ # return f"sm{dtype_code}"
3516
4235
  # generic vector/matrix
3517
4236
  ndim = len(arg_type._shape_)
3518
4237
  if ndim == 1:
@@ -3535,6 +4254,10 @@ def get_type_code(arg_type):
3535
4254
  return f"a{arg_type.ndim}{get_type_code(arg_type.dtype)}"
3536
4255
  elif isinstance(arg_type, indexedarray):
3537
4256
  return f"ia{arg_type.ndim}{get_type_code(arg_type.dtype)}"
4257
+ elif isinstance(arg_type, fabricarray):
4258
+ return f"fa{arg_type.ndim}{get_type_code(arg_type.dtype)}"
4259
+ elif isinstance(arg_type, indexedfabricarray):
4260
+ return f"ifa{arg_type.ndim}{get_type_code(arg_type.dtype)}"
3538
4261
  elif isinstance(arg_type, warp.codegen.Struct):
3539
4262
  return warp.codegen.make_full_qualified_name(arg_type.cls)
3540
4263
  elif arg_type == Scalar: