warp-lang 1.0.0b2__py3-none-win_amd64.whl → 1.0.0b6__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (271) hide show
  1. docs/conf.py +17 -5
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/env/env_usd.py +4 -1
  6. examples/env/environment.py +8 -9
  7. examples/example_dem.py +34 -33
  8. examples/example_diffray.py +364 -337
  9. examples/example_fluid.py +32 -23
  10. examples/example_jacobian_ik.py +97 -93
  11. examples/example_marching_cubes.py +6 -16
  12. examples/example_mesh.py +6 -16
  13. examples/example_mesh_intersect.py +16 -14
  14. examples/example_nvdb.py +14 -16
  15. examples/example_raycast.py +14 -13
  16. examples/example_raymarch.py +16 -23
  17. examples/example_render_opengl.py +19 -10
  18. examples/example_sim_cartpole.py +82 -78
  19. examples/example_sim_cloth.py +45 -48
  20. examples/example_sim_fk_grad.py +51 -44
  21. examples/example_sim_fk_grad_torch.py +47 -40
  22. examples/example_sim_grad_bounce.py +108 -133
  23. examples/example_sim_grad_cloth.py +99 -113
  24. examples/example_sim_granular.py +5 -6
  25. examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
  26. examples/example_sim_neo_hookean.py +51 -55
  27. examples/example_sim_particle_chain.py +4 -4
  28. examples/example_sim_quadruped.py +126 -81
  29. examples/example_sim_rigid_chain.py +54 -61
  30. examples/example_sim_rigid_contact.py +66 -70
  31. examples/example_sim_rigid_fem.py +3 -3
  32. examples/example_sim_rigid_force.py +1 -1
  33. examples/example_sim_rigid_gyroscopic.py +3 -4
  34. examples/example_sim_rigid_kinematics.py +28 -39
  35. examples/example_sim_trajopt.py +112 -110
  36. examples/example_sph.py +9 -8
  37. examples/example_wave.py +7 -7
  38. examples/fem/bsr_utils.py +30 -17
  39. examples/fem/example_apic_fluid.py +85 -69
  40. examples/fem/example_convection_diffusion.py +97 -93
  41. examples/fem/example_convection_diffusion_dg.py +142 -149
  42. examples/fem/example_convection_diffusion_dg0.py +141 -136
  43. examples/fem/example_deformed_geometry.py +146 -0
  44. examples/fem/example_diffusion.py +115 -84
  45. examples/fem/example_diffusion_3d.py +116 -86
  46. examples/fem/example_diffusion_mgpu.py +102 -79
  47. examples/fem/example_mixed_elasticity.py +139 -100
  48. examples/fem/example_navier_stokes.py +175 -162
  49. examples/fem/example_stokes.py +143 -111
  50. examples/fem/example_stokes_transfer.py +186 -157
  51. examples/fem/mesh_utils.py +59 -97
  52. examples/fem/plot_utils.py +138 -17
  53. tools/ci/publishing/build_nodes_info.py +54 -0
  54. warp/__init__.py +4 -3
  55. warp/__init__.pyi +1 -0
  56. warp/bin/warp-clang.dll +0 -0
  57. warp/bin/warp.dll +0 -0
  58. warp/build.py +5 -3
  59. warp/build_dll.py +29 -9
  60. warp/builtins.py +836 -492
  61. warp/codegen.py +864 -553
  62. warp/config.py +3 -1
  63. warp/context.py +389 -172
  64. warp/fem/__init__.py +24 -6
  65. warp/fem/cache.py +318 -25
  66. warp/fem/dirichlet.py +7 -3
  67. warp/fem/domain.py +14 -0
  68. warp/fem/field/__init__.py +30 -38
  69. warp/fem/field/field.py +149 -0
  70. warp/fem/field/nodal_field.py +244 -138
  71. warp/fem/field/restriction.py +8 -6
  72. warp/fem/field/test.py +127 -59
  73. warp/fem/field/trial.py +117 -60
  74. warp/fem/geometry/__init__.py +5 -1
  75. warp/fem/geometry/deformed_geometry.py +271 -0
  76. warp/fem/geometry/element.py +24 -1
  77. warp/fem/geometry/geometry.py +86 -14
  78. warp/fem/geometry/grid_2d.py +112 -54
  79. warp/fem/geometry/grid_3d.py +134 -65
  80. warp/fem/geometry/hexmesh.py +953 -0
  81. warp/fem/geometry/partition.py +85 -33
  82. warp/fem/geometry/quadmesh_2d.py +532 -0
  83. warp/fem/geometry/tetmesh.py +451 -115
  84. warp/fem/geometry/trimesh_2d.py +197 -92
  85. warp/fem/integrate.py +534 -268
  86. warp/fem/operator.py +58 -31
  87. warp/fem/polynomial.py +11 -0
  88. warp/fem/quadrature/__init__.py +1 -1
  89. warp/fem/quadrature/pic_quadrature.py +150 -58
  90. warp/fem/quadrature/quadrature.py +209 -57
  91. warp/fem/space/__init__.py +230 -53
  92. warp/fem/space/basis_space.py +489 -0
  93. warp/fem/space/collocated_function_space.py +105 -0
  94. warp/fem/space/dof_mapper.py +49 -2
  95. warp/fem/space/function_space.py +90 -39
  96. warp/fem/space/grid_2d_function_space.py +149 -496
  97. warp/fem/space/grid_3d_function_space.py +173 -538
  98. warp/fem/space/hexmesh_function_space.py +352 -0
  99. warp/fem/space/partition.py +129 -76
  100. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  101. warp/fem/space/restriction.py +46 -34
  102. warp/fem/space/shape/__init__.py +15 -0
  103. warp/fem/space/shape/cube_shape_function.py +738 -0
  104. warp/fem/space/shape/shape_function.py +103 -0
  105. warp/fem/space/shape/square_shape_function.py +611 -0
  106. warp/fem/space/shape/tet_shape_function.py +567 -0
  107. warp/fem/space/shape/triangle_shape_function.py +429 -0
  108. warp/fem/space/tetmesh_function_space.py +132 -1039
  109. warp/fem/space/topology.py +295 -0
  110. warp/fem/space/trimesh_2d_function_space.py +104 -742
  111. warp/fem/types.py +13 -11
  112. warp/fem/utils.py +335 -60
  113. warp/native/array.h +120 -34
  114. warp/native/builtin.h +101 -72
  115. warp/native/bvh.cpp +73 -325
  116. warp/native/bvh.cu +406 -23
  117. warp/native/bvh.h +22 -40
  118. warp/native/clang/clang.cpp +1 -0
  119. warp/native/crt.h +2 -0
  120. warp/native/cuda_util.cpp +8 -3
  121. warp/native/cuda_util.h +1 -0
  122. warp/native/exports.h +1522 -1243
  123. warp/native/intersect.h +19 -4
  124. warp/native/intersect_adj.h +8 -8
  125. warp/native/mat.h +76 -17
  126. warp/native/mesh.cpp +33 -108
  127. warp/native/mesh.cu +114 -18
  128. warp/native/mesh.h +395 -40
  129. warp/native/noise.h +272 -329
  130. warp/native/quat.h +51 -8
  131. warp/native/rand.h +44 -34
  132. warp/native/reduce.cpp +1 -1
  133. warp/native/sparse.cpp +4 -4
  134. warp/native/sparse.cu +163 -155
  135. warp/native/spatial.h +2 -2
  136. warp/native/temp_buffer.h +18 -14
  137. warp/native/vec.h +103 -21
  138. warp/native/warp.cpp +2 -1
  139. warp/native/warp.cu +28 -3
  140. warp/native/warp.h +4 -3
  141. warp/render/render_opengl.py +261 -109
  142. warp/sim/__init__.py +1 -2
  143. warp/sim/articulation.py +385 -185
  144. warp/sim/import_mjcf.py +59 -48
  145. warp/sim/import_urdf.py +15 -15
  146. warp/sim/import_usd.py +174 -102
  147. warp/sim/inertia.py +17 -18
  148. warp/sim/integrator_xpbd.py +4 -3
  149. warp/sim/model.py +330 -250
  150. warp/sim/render.py +1 -1
  151. warp/sparse.py +625 -152
  152. warp/stubs.py +341 -309
  153. warp/tape.py +9 -6
  154. warp/tests/__main__.py +3 -6
  155. warp/tests/assets/curlnoise_golden.npy +0 -0
  156. warp/tests/assets/pnoise_golden.npy +0 -0
  157. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  158. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  159. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  160. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  161. warp/tests/aux_test_unresolved_func.py +14 -0
  162. warp/tests/aux_test_unresolved_symbol.py +14 -0
  163. warp/tests/disabled_kinematics.py +239 -0
  164. warp/tests/run_coverage_serial.py +31 -0
  165. warp/tests/test_adam.py +103 -106
  166. warp/tests/test_arithmetic.py +94 -74
  167. warp/tests/test_array.py +82 -101
  168. warp/tests/test_array_reduce.py +57 -23
  169. warp/tests/test_atomic.py +64 -28
  170. warp/tests/test_bool.py +22 -12
  171. warp/tests/test_builtins_resolution.py +1292 -0
  172. warp/tests/test_bvh.py +18 -18
  173. warp/tests/test_closest_point_edge_edge.py +54 -57
  174. warp/tests/test_codegen.py +165 -134
  175. warp/tests/test_compile_consts.py +28 -20
  176. warp/tests/test_conditional.py +108 -24
  177. warp/tests/test_copy.py +10 -12
  178. warp/tests/test_ctypes.py +112 -88
  179. warp/tests/test_dense.py +21 -14
  180. warp/tests/test_devices.py +98 -0
  181. warp/tests/test_dlpack.py +75 -75
  182. warp/tests/test_examples.py +237 -0
  183. warp/tests/test_fabricarray.py +22 -24
  184. warp/tests/test_fast_math.py +15 -11
  185. warp/tests/test_fem.py +1034 -124
  186. warp/tests/test_fp16.py +23 -16
  187. warp/tests/test_func.py +187 -86
  188. warp/tests/test_generics.py +194 -49
  189. warp/tests/test_grad.py +123 -181
  190. warp/tests/test_grad_customs.py +176 -0
  191. warp/tests/test_hash_grid.py +35 -34
  192. warp/tests/test_import.py +10 -23
  193. warp/tests/test_indexedarray.py +24 -25
  194. warp/tests/test_intersect.py +18 -9
  195. warp/tests/test_large.py +141 -0
  196. warp/tests/test_launch.py +14 -41
  197. warp/tests/test_lerp.py +64 -65
  198. warp/tests/test_lvalue.py +493 -0
  199. warp/tests/test_marching_cubes.py +12 -13
  200. warp/tests/test_mat.py +517 -2898
  201. warp/tests/test_mat_lite.py +115 -0
  202. warp/tests/test_mat_scalar_ops.py +2889 -0
  203. warp/tests/test_math.py +103 -9
  204. warp/tests/test_matmul.py +304 -69
  205. warp/tests/test_matmul_lite.py +410 -0
  206. warp/tests/test_mesh.py +60 -22
  207. warp/tests/test_mesh_query_aabb.py +21 -25
  208. warp/tests/test_mesh_query_point.py +111 -22
  209. warp/tests/test_mesh_query_ray.py +12 -24
  210. warp/tests/test_mlp.py +30 -22
  211. warp/tests/test_model.py +92 -89
  212. warp/tests/test_modules_lite.py +39 -0
  213. warp/tests/test_multigpu.py +88 -114
  214. warp/tests/test_noise.py +12 -11
  215. warp/tests/test_operators.py +16 -20
  216. warp/tests/test_options.py +11 -11
  217. warp/tests/test_pinned.py +17 -18
  218. warp/tests/test_print.py +32 -11
  219. warp/tests/test_quat.py +275 -129
  220. warp/tests/test_rand.py +18 -16
  221. warp/tests/test_reload.py +38 -34
  222. warp/tests/test_rounding.py +50 -43
  223. warp/tests/test_runlength_encode.py +168 -20
  224. warp/tests/test_smoothstep.py +9 -11
  225. warp/tests/test_snippet.py +143 -0
  226. warp/tests/test_sparse.py +261 -63
  227. warp/tests/test_spatial.py +276 -243
  228. warp/tests/test_streams.py +110 -85
  229. warp/tests/test_struct.py +268 -63
  230. warp/tests/test_tape.py +39 -21
  231. warp/tests/test_torch.py +90 -86
  232. warp/tests/test_transient_module.py +10 -12
  233. warp/tests/test_types.py +363 -0
  234. warp/tests/test_utils.py +451 -0
  235. warp/tests/test_vec.py +354 -2050
  236. warp/tests/test_vec_lite.py +73 -0
  237. warp/tests/test_vec_scalar_ops.py +2099 -0
  238. warp/tests/test_volume.py +418 -376
  239. warp/tests/test_volume_write.py +124 -134
  240. warp/tests/unittest_serial.py +35 -0
  241. warp/tests/unittest_suites.py +291 -0
  242. warp/tests/unittest_utils.py +342 -0
  243. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  244. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  245. warp/thirdparty/appdirs.py +36 -45
  246. warp/thirdparty/unittest_parallel.py +589 -0
  247. warp/types.py +622 -211
  248. warp/utils.py +54 -393
  249. warp_lang-1.0.0b6.dist-info/METADATA +238 -0
  250. warp_lang-1.0.0b6.dist-info/RECORD +409 -0
  251. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  252. examples/example_cache_management.py +0 -40
  253. examples/example_multigpu.py +0 -54
  254. examples/example_struct.py +0 -65
  255. examples/fem/example_stokes_transfer_3d.py +0 -210
  256. warp/bin/warp-clang.so +0 -0
  257. warp/bin/warp.so +0 -0
  258. warp/fem/field/discrete_field.py +0 -80
  259. warp/fem/space/nodal_function_space.py +0 -233
  260. warp/tests/test_all.py +0 -223
  261. warp/tests/test_array_scan.py +0 -60
  262. warp/tests/test_base.py +0 -208
  263. warp/tests/test_unresolved_func.py +0 -7
  264. warp/tests/test_unresolved_symbol.py +0 -7
  265. warp_lang-1.0.0b2.dist-info/METADATA +0 -26
  266. warp_lang-1.0.0b2.dist-info/RECORD +0 -380
  267. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  268. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  269. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  270. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  271. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/types.py CHANGED
@@ -5,9 +5,12 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
+ from __future__ import annotations
9
+
8
10
  import builtins
9
11
  import ctypes
10
12
  import hashlib
13
+ import inspect
11
14
  import struct
12
15
  import zlib
13
16
  from typing import Any, Callable, Generic, List, Tuple, TypeVar, Union
@@ -49,12 +52,14 @@ def constant(x):
49
52
  global _constant_hash
50
53
 
51
54
  # hash the constant value
52
- if isinstance(x, int):
55
+ if isinstance(x, builtins.bool):
56
+ # This needs to come before the check for `int` since all boolean
57
+ # values are also instances of `int`.
58
+ _constant_hash.update(struct.pack("?", x))
59
+ elif isinstance(x, int):
53
60
  _constant_hash.update(struct.pack("<q", x))
54
61
  elif isinstance(x, float):
55
62
  _constant_hash.update(struct.pack("<d", x))
56
- elif isinstance(x, builtins.bool):
57
- _constant_hash.update(struct.pack("?", x))
58
63
  elif isinstance(x, float16):
59
64
  # float16 is a special case
60
65
  p = ctypes.pointer(ctypes.c_float(x.value))
@@ -155,17 +160,31 @@ def vector(length, dtype):
155
160
  else:
156
161
  raise KeyError(f"Invalid key {key}, expected int or slice")
157
162
 
163
+ def __getattr__(self, name):
164
+ idx = "xyzw".find(name)
165
+ if idx != -1:
166
+ return self.__getitem__(idx)
167
+
168
+ return self.__getattribute__(name)
169
+
170
+ def __setattr__(self, name, value):
171
+ idx = "xyzw".find(name)
172
+ if idx != -1:
173
+ return self.__setitem__(idx, value)
174
+
175
+ return super().__setattr__(name, value)
176
+
158
177
  def __add__(self, y):
159
178
  return warp.add(self, y)
160
179
 
161
180
  def __radd__(self, y):
162
- return warp.add(self, y)
181
+ return warp.add(y, self)
163
182
 
164
183
  def __sub__(self, y):
165
184
  return warp.sub(self, y)
166
185
 
167
- def __rsub__(self, x):
168
- return warp.sub(x, self)
186
+ def __rsub__(self, y):
187
+ return warp.sub(y, self)
169
188
 
170
189
  def __mul__(self, y):
171
190
  return warp.mul(self, y)
@@ -173,17 +192,17 @@ def vector(length, dtype):
173
192
  def __rmul__(self, x):
174
193
  return warp.mul(x, self)
175
194
 
176
- def __div__(self, y):
195
+ def __truediv__(self, y):
177
196
  return warp.div(self, y)
178
197
 
179
- def __rdiv__(self, x):
198
+ def __rtruediv__(self, x):
180
199
  return warp.div(x, self)
181
200
 
182
- def __pos__(self, y):
183
- return warp.pos(self, y)
201
+ def __pos__(self):
202
+ return warp.pos(self)
184
203
 
185
- def __neg__(self, y):
186
- return warp.neg(self, y)
204
+ def __neg__(self):
205
+ return warp.neg(self)
187
206
 
188
207
  def __str__(self):
189
208
  return f"[{', '.join(map(str, self))}]"
@@ -275,13 +294,13 @@ def matrix(shape, dtype):
275
294
  return warp.add(self, y)
276
295
 
277
296
  def __radd__(self, y):
278
- return warp.add(self, y)
297
+ return warp.add(y, self)
279
298
 
280
299
  def __sub__(self, y):
281
300
  return warp.sub(self, y)
282
301
 
283
- def __rsub__(self, x):
284
- return warp.sub(x, self)
302
+ def __rsub__(self, y):
303
+ return warp.sub(y, self)
285
304
 
286
305
  def __mul__(self, y):
287
306
  return warp.mul(self, y)
@@ -295,17 +314,17 @@ def matrix(shape, dtype):
295
314
  def __rmatmul__(self, x):
296
315
  return warp.mul(x, self)
297
316
 
298
- def __div__(self, y):
317
+ def __truediv__(self, y):
299
318
  return warp.div(self, y)
300
319
 
301
- def __rdiv__(self, x):
320
+ def __rtruediv__(self, x):
302
321
  return warp.div(x, self)
303
322
 
304
- def __pos__(self, y):
305
- return warp.pos(self, y)
323
+ def __pos__(self):
324
+ return warp.pos(self)
306
325
 
307
- def __neg__(self, y):
308
- return warp.neg(self, y)
326
+ def __neg__(self):
327
+ return warp.neg(self)
309
328
 
310
329
  def __str__(self):
311
330
  row_str = []
@@ -511,23 +530,63 @@ class quatd(quaternion(dtype=float64)):
511
530
 
512
531
  def transformation(dtype=Any):
513
532
  class transform_t(vector(length=7, dtype=dtype)):
533
+ _wp_init_from_components_sig_ = inspect.Signature(
534
+ (
535
+ inspect.Parameter(
536
+ "p",
537
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
538
+ default=(0.0, 0.0, 0.0),
539
+ ),
540
+ inspect.Parameter(
541
+ "q",
542
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
543
+ default=(0.0, 0.0, 0.0, 1.0),
544
+ ),
545
+ ),
546
+ )
514
547
  _wp_type_params_ = [dtype]
515
548
  _wp_generic_type_str_ = "transform_t"
516
549
  _wp_constructor_ = "transformation"
517
550
 
518
- def __init__(self, p=(0.0, 0.0, 0.0), q=(0.0, 0.0, 0.0, 1.0)):
519
- super().__init__()
551
+ def __init__(self, *args, **kwargs):
552
+ if len(args) == 1 and len(kwargs) == 0:
553
+ if getattr(args[0], "_wp_generic_type_str_") == self._wp_generic_type_str_:
554
+ # Copy constructor.
555
+ super().__init__(*args[0])
556
+ return
520
557
 
521
- self[0:3] = vector(length=3, dtype=dtype)(*p)
522
- self[3:7] = quaternion(dtype=dtype)(*q)
558
+ try:
559
+ # For backward compatibility, try to check if the arguments
560
+ # match the original signature that'd allow initializing
561
+ # the `p` and `q` components separately.
562
+ bound_args = self._wp_init_from_components_sig_.bind(*args, **kwargs)
563
+ bound_args.apply_defaults()
564
+ p, q = bound_args.args
565
+ except (TypeError, ValueError):
566
+ # Fallback to the vector's constructor.
567
+ super().__init__(*args)
568
+ return
569
+
570
+ # Even if the arguments match the original “from components”
571
+ # signature, we still need to make sure that they represent
572
+ # sequences that can be unpacked.
573
+ if hasattr(p, "__len__") and hasattr(q, "__len__"):
574
+ # Initialize from the `p` and `q` components.
575
+ super().__init__()
576
+ self[0:3] = vector(length=3, dtype=dtype)(*p)
577
+ self[3:7] = quaternion(dtype=dtype)(*q)
578
+ return
579
+
580
+ # Fallback to the vector's constructor.
581
+ super().__init__(*args)
523
582
 
524
583
  @property
525
584
  def p(self):
526
- return self[0:3]
585
+ return vec3(self[0:3])
527
586
 
528
587
  @property
529
588
  def q(self):
530
- return self[3:7]
589
+ return quat(self[3:7])
531
590
 
532
591
  return transform_t
533
592
 
@@ -851,18 +910,21 @@ class range_t:
851
910
 
852
911
  # definition just for kernel type (cannot be a parameter), see bvh.h
853
912
  class bvh_query_t:
913
+ """Object used to track state during BVH traversal."""
854
914
  def __init__(self):
855
915
  pass
856
916
 
857
917
 
858
918
  # definition just for kernel type (cannot be a parameter), see mesh.h
859
919
  class mesh_query_aabb_t:
920
+ """Object used to track state during mesh traversal."""
860
921
  def __init__(self):
861
922
  pass
862
923
 
863
924
 
864
925
  # definition just for kernel type (cannot be a parameter), see hash_grid.h
865
926
  class hash_grid_query_t:
927
+ """Object used to track state during neighbor traversal."""
866
928
  def __init__(self):
867
929
  pass
868
930
 
@@ -999,7 +1061,7 @@ def type_scalar_type(dtype):
999
1061
  def type_size_in_bytes(dtype):
1000
1062
  if dtype.__module__ == "ctypes":
1001
1063
  return ctypes.sizeof(dtype)
1002
- elif type_is_struct(dtype):
1064
+ elif isinstance(dtype, warp.codegen.Struct):
1003
1065
  return ctypes.sizeof(dtype.ctype)
1004
1066
  elif dtype == float or dtype == int:
1005
1067
  return 4
@@ -1020,8 +1082,6 @@ def type_to_warp(dtype):
1020
1082
 
1021
1083
 
1022
1084
  def type_typestr(dtype):
1023
- from warp.codegen import Struct
1024
-
1025
1085
  if dtype == bool:
1026
1086
  return "?"
1027
1087
  elif dtype == float16:
@@ -1046,7 +1106,7 @@ def type_typestr(dtype):
1046
1106
  return "<i8"
1047
1107
  elif dtype == uint64:
1048
1108
  return "<u8"
1049
- elif isinstance(dtype, Struct):
1109
+ elif isinstance(dtype, warp.codegen.Struct):
1050
1110
  return f"|V{ctypes.sizeof(dtype.ctype)}"
1051
1111
  elif issubclass(dtype, ctypes.Array):
1052
1112
  return type_typestr(dtype._wp_scalar_type_)
@@ -1060,9 +1120,16 @@ def type_repr(t):
1060
1120
  return str(f"array(ndim={t.ndim}, dtype={t.dtype})")
1061
1121
  if type_is_vector(t):
1062
1122
  return str(f"vector(length={t._shape_[0]}, dtype={t._wp_scalar_type_})")
1063
- elif type_is_matrix(t):
1123
+ if type_is_matrix(t):
1064
1124
  return str(f"matrix(shape=({t._shape_[0]}, {t._shape_[1]}), dtype={t._wp_scalar_type_})")
1065
- else:
1125
+ if isinstance(t, warp.codegen.Struct):
1126
+ return type_repr(t.cls)
1127
+ if t in scalar_types:
1128
+ return t.__name__
1129
+
1130
+ try:
1131
+ return t.__module__ + "." + t.__qualname__
1132
+ except AttributeError:
1066
1133
  return str(t)
1067
1134
 
1068
1135
 
@@ -1080,15 +1147,6 @@ def type_is_float(t):
1080
1147
  return t in float_types
1081
1148
 
1082
1149
 
1083
- def type_is_struct(dtype):
1084
- from warp.codegen import Struct
1085
-
1086
- if isinstance(dtype, Struct):
1087
- return True
1088
- else:
1089
- return False
1090
-
1091
-
1092
1150
  # returns True if the passed *type* is a vector
1093
1151
  def type_is_vector(t):
1094
1152
  if hasattr(t, "_wp_generic_type_str_") and t._wp_generic_type_str_ == "vec_t":
@@ -1162,6 +1220,17 @@ def types_equal(a, b, match_generic=False):
1162
1220
  if p1 == Float and p2 == Float:
1163
1221
  return True
1164
1222
 
1223
+ # convert to canonical types
1224
+ if p1 == float:
1225
+ p1 = float32
1226
+ elif p1 == int:
1227
+ p1 = int32
1228
+
1229
+ if p2 == float:
1230
+ p2 = float32
1231
+ elif b == int:
1232
+ p2 = int32
1233
+
1165
1234
  if p1 in compatible_bool_types and p2 in compatible_bool_types:
1166
1235
  return True
1167
1236
  else:
@@ -1173,7 +1242,7 @@ def types_equal(a, b, match_generic=False):
1173
1242
  and a._wp_generic_type_str_ == b._wp_generic_type_str_
1174
1243
  ):
1175
1244
  return all([are_equal(p1, p2) for p1, p2 in zip(a._wp_type_params_, b._wp_type_params_)])
1176
- if is_array(a) and type(a) == type(b):
1245
+ if is_array(a) and type(a) is type(b):
1177
1246
  return True
1178
1247
  else:
1179
1248
  return are_equal(a, b)
@@ -1257,6 +1326,7 @@ class array(Array):
1257
1326
  self._grad = None
1258
1327
  # __array_interface__ or __cuda_array_interface__, evaluated lazily and cached
1259
1328
  self._array_interface = None
1329
+ self.is_transposed = False
1260
1330
 
1261
1331
  # canonicalize dtype
1262
1332
  if dtype == int:
@@ -1801,6 +1871,7 @@ class array(Array):
1801
1871
  return array._vars
1802
1872
 
1803
1873
  def zero_(self):
1874
+ """Zeroes-out the array entires."""
1804
1875
  if self.is_contiguous:
1805
1876
  # simple memset is usually faster than generic fill
1806
1877
  self.device.memset(self.ptr, 0, self.size * type_size_in_bytes(self.dtype))
@@ -1808,6 +1879,32 @@ class array(Array):
1808
1879
  self.fill_(0)
1809
1880
 
1810
1881
  def fill_(self, value):
1882
+ """Set all array entries to `value`
1883
+
1884
+ args:
1885
+ value: The value to set every array entry to. Must be convertible to the array's ``dtype``.
1886
+
1887
+ Raises:
1888
+ ValueError: If `value` cannot be converted to the array's ``dtype``.
1889
+
1890
+ Examples:
1891
+ ``fill_()`` can take lists or other sequences when filling arrays of vectors or matrices.
1892
+
1893
+ >>> arr = wp.zeros(2, dtype=wp.mat22)
1894
+ >>> arr.numpy()
1895
+ array([[[0., 0.],
1896
+ [0., 0.]],
1897
+ <BLANKLINE>
1898
+ [[0., 0.],
1899
+ [0., 0.]]], dtype=float32)
1900
+ >>> arr.fill_([[1, 2], [3, 4]])
1901
+ >>> arr.numpy()
1902
+ array([[[1., 2.],
1903
+ [3., 4.]],
1904
+ <BLANKLINE>
1905
+ [[1., 2.],
1906
+ [3., 4.]]], dtype=float32)
1907
+ """
1811
1908
  if self.size == 0:
1812
1909
  return
1813
1910
 
@@ -1854,15 +1951,18 @@ class array(Array):
1854
1951
  else:
1855
1952
  warp.context.runtime.core.array_fill_host(carr_ptr, ARRAY_TYPE_REGULAR, cvalue_ptr, cvalue_size)
1856
1953
 
1857
- # equivalent to wrapping src data in an array and copying to self
1858
1954
  def assign(self, src):
1955
+ """Wraps ``src`` in an :class:`warp.array` if it is not already one and copies the contents to ``self``."""
1859
1956
  if is_array(src):
1860
1957
  warp.copy(self, src)
1861
1958
  else:
1862
1959
  warp.copy(self, array(data=src, dtype=self.dtype, copy=False, device="cpu"))
1863
1960
 
1864
- # convert array to ndarray (alias memory through array interface)
1865
1961
  def numpy(self):
1962
+ """Converts the array to a :class:`numpy.ndarray` (aliasing memory through the array interface protocol)
1963
+ If the array is on the GPU, a synchronous device-to-host copy (on the CUDA default stream) will be
1964
+ automatically performed to ensure that any outstanding work is completed.
1965
+ """
1866
1966
  if self.ptr:
1867
1967
  # use the CUDA default stream for synchronous behaviour with other streams
1868
1968
  with warp.ScopedStream(self.device.null_stream):
@@ -1883,12 +1983,16 @@ class array(Array):
1883
1983
  npshape = self.shape
1884
1984
  return np.empty(npshape, dtype=npdtype)
1885
1985
 
1886
- # return a ctypes cast of the array address
1887
- # note #1: only CPU arrays support this method
1888
- # note #2: the array must be contiguous
1889
- # note #3: accesses to this object are *not* bounds checked
1890
- # note #4: for float16 types, a pointer to the internal uint16 representation is returned
1891
1986
  def cptr(self):
1987
+ """Return a ctypes cast of the array address.
1988
+
1989
+ Notes:
1990
+
1991
+ #. Only CPU arrays support this method.
1992
+ #. The array must be contiguous.
1993
+ #. Accesses to this object are **not** bounds checked.
1994
+ #. For ``float16`` types, a pointer to the internal ``uint16`` representation is returned.
1995
+ """
1892
1996
  if not self.ptr:
1893
1997
  return None
1894
1998
 
@@ -1907,8 +2011,8 @@ class array(Array):
1907
2011
 
1908
2012
  return p
1909
2013
 
1910
- # returns a flattened list of items in the array as a Python list
1911
2014
  def list(self):
2015
+ """Returns a flattened list of items in the array as a Python list."""
1912
2016
  a = self.numpy()
1913
2017
 
1914
2018
  if isinstance(self.dtype, warp.codegen.Struct):
@@ -1927,8 +2031,8 @@ class array(Array):
1927
2031
  # scalar
1928
2032
  return list(a.flatten())
1929
2033
 
1930
- # convert data from one device to another, nop if already on device
1931
2034
  def to(self, device):
2035
+ """Returns a Warp array with this array's data moved to the specified device, no-op if already on device."""
1932
2036
  device = warp.get_device(device)
1933
2037
  if self.device == device:
1934
2038
  return self
@@ -1936,6 +2040,7 @@ class array(Array):
1936
2040
  return warp.clone(self, device=device)
1937
2041
 
1938
2042
  def flatten(self):
2043
+ """Returns a zero-copy view of the array collapsed to 1-D. Only supported for contiguous arrays."""
1939
2044
  if self.ndim == 1:
1940
2045
  return self
1941
2046
 
@@ -1958,6 +2063,11 @@ class array(Array):
1958
2063
  return a
1959
2064
 
1960
2065
  def reshape(self, shape):
2066
+ """Returns a reshaped array. Only supported for contiguous arrays.
2067
+
2068
+ Args:
2069
+ shape : An int or tuple of ints specifying the shape of the returned array.
2070
+ """
1961
2071
  if not self.is_contiguous:
1962
2072
  raise RuntimeError("Reshaping non-contiguous arrays is unsupported.")
1963
2073
 
@@ -2015,6 +2125,9 @@ class array(Array):
2015
2125
  return a
2016
2126
 
2017
2127
  def view(self, dtype):
2128
+ """Returns a zero-copy view of this array's memory with a different data type.
2129
+ ``dtype`` must have the same byte size of the array's native ``dtype``.
2130
+ """
2018
2131
  if type_size_in_bytes(dtype) != type_size_in_bytes(self.dtype):
2019
2132
  raise RuntimeError("Cannot cast dtypes of unequal byte size")
2020
2133
 
@@ -2035,6 +2148,7 @@ class array(Array):
2035
2148
  return a
2036
2149
 
2037
2150
  def contiguous(self):
2151
+ """Returns a contiguous array with this array's data. No-op if array is already contiguous."""
2038
2152
  if self.is_contiguous:
2039
2153
  return self
2040
2154
 
@@ -2042,8 +2156,14 @@ class array(Array):
2042
2156
  warp.copy(a, self)
2043
2157
  return a
2044
2158
 
2045
- # note: transpose operation will return an array with a non-contiguous access pattern
2046
2159
  def transpose(self, axes=None):
2160
+ """Returns an zero-copy view of the array with axes transposed.
2161
+
2162
+ Note: The transpose operation will return an array with a non-contiguous access pattern.
2163
+
2164
+ Args:
2165
+ axes (optional): Specifies the how the axes are permuted. If not specified, the axes order will be reversed.
2166
+ """
2047
2167
  # noop if 1d array
2048
2168
  if self.ndim == 1:
2049
2169
  return self
@@ -2076,6 +2196,8 @@ class array(Array):
2076
2196
  grad=None if self.grad is None else self.grad.transpose(axes=axes),
2077
2197
  )
2078
2198
 
2199
+ a.is_transposed = not self.is_transposed
2200
+
2079
2201
  a._ref = self
2080
2202
  return a
2081
2203
 
@@ -2516,16 +2638,14 @@ class Mesh:
2516
2638
 
2517
2639
 
2518
2640
  class Volume:
2641
+ #: Enum value to specify nearest-neighbor interpolation during sampling
2519
2642
  CLOSEST = constant(0)
2643
+ #: Enum value to specify trilinear interpolation during sampling
2520
2644
  LINEAR = constant(1)
2521
2645
 
2522
2646
  def __init__(self, data: array):
2523
2647
  """Class representing a sparse grid.
2524
2648
 
2525
- Attributes:
2526
- CLOSEST (int): Enum value to specify nearest-neighbor interpolation during sampling
2527
- LINEAR (int): Enum value to specify trilinear interpolation during sampling
2528
-
2529
2649
  Args:
2530
2650
  data (:class:`warp.array`): Array of bytes representing the volume in NanoVDB format
2531
2651
  """
@@ -2570,7 +2690,8 @@ class Volume:
2570
2690
  except Exception:
2571
2691
  pass
2572
2692
 
2573
- def array(self):
2693
+ def array(self) -> array:
2694
+ """Returns the raw memory buffer of the Volume as an array"""
2574
2695
  buf = ctypes.c_void_p(0)
2575
2696
  size = ctypes.c_uint64(0)
2576
2697
  if self.device.is_cpu:
@@ -2579,7 +2700,7 @@ class Volume:
2579
2700
  self.context.core.volume_get_buffer_info_device(self.id, ctypes.byref(buf), ctypes.byref(size))
2580
2701
  return array(ptr=buf.value, dtype=uint8, shape=size.value, device=self.device, owner=False)
2581
2702
 
2582
- def get_tiles(self):
2703
+ def get_tiles(self) -> array:
2583
2704
  if self.id == 0:
2584
2705
  raise RuntimeError("Invalid Volume")
2585
2706
 
@@ -2592,7 +2713,7 @@ class Volume:
2592
2713
  num_tiles = size.value // (3 * 4)
2593
2714
  return array(ptr=buf.value, dtype=int32, shape=(num_tiles, 3), device=self.device, owner=True)
2594
2715
 
2595
- def get_voxel_size(self):
2716
+ def get_voxel_size(self) -> Tuple[float, float, float]:
2596
2717
  if self.id == 0:
2597
2718
  raise RuntimeError("Invalid Volume")
2598
2719
 
@@ -2601,7 +2722,7 @@ class Volume:
2601
2722
  return (dx.value, dy.value, dz.value)
2602
2723
 
2603
2724
  @classmethod
2604
- def load_from_nvdb(cls, file_or_buffer, device=None):
2725
+ def load_from_nvdb(cls, file_or_buffer, device=None) -> Volume:
2605
2726
  """Creates a Volume object from a NanoVDB file or in-memory buffer.
2606
2727
 
2607
2728
  Returns:
@@ -2637,14 +2758,18 @@ class Volume:
2637
2758
  return cls(data_array)
2638
2759
 
2639
2760
  @classmethod
2640
- def load_from_numpy(cls, ndarray: np.array, min_world=(0.0, 0.0, 0.0), voxel_size=1.0, bg_value=0.0, device=None):
2761
+ def load_from_numpy(
2762
+ cls, ndarray: np.array, min_world=(0.0, 0.0, 0.0), voxel_size=1.0, bg_value=0.0, device=None
2763
+ ) -> Volume:
2641
2764
  """Creates a Volume object from a dense 3D NumPy array.
2642
2765
 
2766
+ This function is only supported for CUDA devices.
2767
+
2643
2768
  Args:
2644
- min_world: The 3D coordinate of the lower corner of the volume
2645
- voxel_size: The size of each voxel in spatial coordinates
2769
+ min_world: The 3D coordinate of the lower corner of the volume.
2770
+ voxel_size: The size of each voxel in spatial coordinates.
2646
2771
  bg_value: Background value
2647
- device: The device to create the volume on, e.g.: "cpu", or "cuda:0"
2772
+ device: The CUDA device to create the volume on, e.g.: "cuda" or "cuda:0".
2648
2773
 
2649
2774
  Returns:
2650
2775
 
@@ -2699,7 +2824,7 @@ class Volume:
2699
2824
  inputs=[volume.id, warp.array(padded_array, dtype=warp.vec3, device=device)],
2700
2825
  device=device,
2701
2826
  )
2702
- elif type(bg_value) == int:
2827
+ elif isinstance(bg_value, int):
2703
2828
  warp.launch(
2704
2829
  warp.utils.copy_dense_volume_to_nano_vdb_i,
2705
2830
  dim=shape,
@@ -2726,9 +2851,11 @@ class Volume:
2726
2851
  translation=(0.0, 0.0, 0.0),
2727
2852
  points_in_world_space=False,
2728
2853
  device=None,
2729
- ):
2854
+ ) -> Volume:
2730
2855
  """Allocate a new Volume based on the bounding box defined by min and max.
2731
2856
 
2857
+ This function is only supported for CUDA devices.
2858
+
2732
2859
  Allocate a volume that is large enough to contain voxels [min[0], min[1], min[2]] - [max[0], max[1], max[2]], inclusive.
2733
2860
  If points_in_world_space is true, then min and max are first converted to index space with the given voxel size and
2734
2861
  translation, and the volume is allocated with those.
@@ -2737,12 +2864,12 @@ class Volume:
2737
2864
  the resulting tiles will be available in the new volume.
2738
2865
 
2739
2866
  Args:
2740
- min (array-like): Lower 3D-coordinates of the bounding box in index space or world space, inclusive
2741
- max (array-like): Upper 3D-coordinates of the bounding box in index space or world space, inclusive
2742
- voxel_size (float): Voxel size of the new volume
2867
+ min (array-like): Lower 3D coordinates of the bounding box in index space or world space, inclusive.
2868
+ max (array-like): Upper 3D coordinates of the bounding box in index space or world space, inclusive.
2869
+ voxel_size (float): Voxel size of the new volume.
2743
2870
  bg_value (float or array-like): Value of unallocated voxels of the volume, also defines the volume's type, a :class:`warp.vec3` volume is created if this is `array-like`, otherwise a float volume is created
2744
- translation (array-like): translation between the index and world spaces
2745
- device (Devicelike): Device the array lives on
2871
+ translation (array-like): translation between the index and world spaces.
2872
+ device (Devicelike): The CUDA device to create the volume on, e.g.: "cuda" or "cuda:0".
2746
2873
 
2747
2874
  """
2748
2875
  if points_in_world_space:
@@ -2767,9 +2894,11 @@ class Volume:
2767
2894
  @classmethod
2768
2895
  def allocate_by_tiles(
2769
2896
  cls, tile_points: array, voxel_size: float, bg_value=0.0, translation=(0.0, 0.0, 0.0), device=None
2770
- ):
2897
+ ) -> Volume:
2771
2898
  """Allocate a new Volume with active tiles for each point tile_points.
2772
2899
 
2900
+ This function is only supported for CUDA devices.
2901
+
2773
2902
  The smallest unit of allocation is a dense tile of 8x8x8 voxels.
2774
2903
  This is the primary method for allocating sparse volumes. It uses an array of points indicating the tiles that must be allocated.
2775
2904
 
@@ -2779,13 +2908,13 @@ class Volume:
2779
2908
 
2780
2909
  Args:
2781
2910
  tile_points (:class:`warp.array`): Array of positions that define the tiles to be allocated.
2782
- The array can be a 2d, N-by-3 array of :class:`warp.int32` values, indicating index space positions,
2911
+ The array can be a 2D, N-by-3 array of :class:`warp.int32` values, indicating index space positions,
2783
2912
  or can be a 1D array of :class:`warp.vec3` values, indicating world space positions.
2784
2913
  Repeated points per tile are allowed and will be efficiently deduplicated.
2785
- voxel_size (float): Voxel size of the new volume
2914
+ voxel_size (float): Voxel size of the new volume.
2786
2915
  bg_value (float or array-like): Value of unallocated voxels of the volume, also defines the volume's type, a :class:`warp.vec3` volume is created if this is `array-like`, otherwise a float volume is created
2787
- translation (array-like): translation between the index and world spaces
2788
- device (Devicelike): Device the array lives on
2916
+ translation (array-like): Translation between the index and world spaces.
2917
+ device (Devicelike): The CUDA device to create the volume on, e.g.: "cuda" or "cuda:0".
2789
2918
 
2790
2919
  """
2791
2920
  from warp.context import runtime
@@ -2822,7 +2951,7 @@ class Volume:
2822
2951
  translation[2],
2823
2952
  in_world_space,
2824
2953
  )
2825
- elif type(bg_value) == int:
2954
+ elif isinstance(bg_value, int):
2826
2955
  volume.id = volume.context.core.volume_i_from_tiles_device(
2827
2956
  volume.device.context,
2828
2957
  ctypes.c_void_p(tile_points.ptr),
@@ -2853,6 +2982,67 @@ class Volume:
2853
2982
  return volume
2854
2983
 
2855
2984
 
2985
+ # definition just for kernel type (cannot be a parameter), see mesh.h
2986
+ # NOTE: its layout must match the corresponding struct defined in C.
2987
+ # NOTE: it needs to be defined after `indexedarray` to workaround a circular import issue.
2988
+ class mesh_query_point_t:
2989
+ """Output for the mesh query point functions.
2990
+
2991
+ Attributes:
2992
+ result (bool): Whether a point is found within the given constraints.
2993
+ sign (float32): A value < 0 if query point is inside the mesh, >=0 otherwise.
2994
+ Note that mesh must be watertight for this to be robust
2995
+ face (int32): Index of the closest face.
2996
+ u (float32): Barycentric u coordinate of the closest point.
2997
+ v (float32): Barycentric v coordinate of the closest point.
2998
+
2999
+ See Also:
3000
+ :func:`mesh_query_point`, :func:`mesh_query_point_no_sign`,
3001
+ :func:`mesh_query_furthest_point_no_sign`,
3002
+ :func:`mesh_query_point_sign_normal`,
3003
+ and :func:`mesh_query_point_sign_winding_number`.
3004
+ """
3005
+ from warp.codegen import Var
3006
+
3007
+ vars = {
3008
+ "result": Var("result", bool),
3009
+ "sign": Var("sign", float32),
3010
+ "face": Var("face", int32),
3011
+ "u": Var("u", float32),
3012
+ "v": Var("v", float32),
3013
+ }
3014
+
3015
+
3016
+ # definition just for kernel type (cannot be a parameter), see mesh.h
3017
+ # NOTE: its layout must match the corresponding struct defined in C.
3018
+ class mesh_query_ray_t:
3019
+ """Output for the mesh query ray functions.
3020
+
3021
+ Attributes:
3022
+ result (bool): Whether a hit is found within the given constraints.
3023
+ sign (float32): A value > 0 if the ray hit in front of the face, returns < 0 otherwise.
3024
+ face (int32): Index of the closest face.
3025
+ t (float32): Distance of the closest hit along the ray.
3026
+ u (float32): Barycentric u coordinate of the closest hit.
3027
+ v (float32): Barycentric v coordinate of the closest hit.
3028
+ normal (vec3f): Face normal.
3029
+
3030
+ See Also:
3031
+ :func:`mesh_query_ray`.
3032
+ """
3033
+ from warp.codegen import Var
3034
+
3035
+ vars = {
3036
+ "result": Var("result", bool),
3037
+ "sign": Var("sign", float32),
3038
+ "face": Var("face", int32),
3039
+ "t": Var("t", float32),
3040
+ "u": Var("u", float32),
3041
+ "v": Var("v", float32),
3042
+ "normal": Var("normal", vec3),
3043
+ }
3044
+
3045
+
2856
3046
  def matmul(
2857
3047
  a: array2d,
2858
3048
  b: array2d,
@@ -2889,6 +3079,11 @@ def matmul(
2889
3079
  "wp.matmul currently only supports operation between {A, B, C, D} matrices of the same type."
2890
3080
  )
2891
3081
 
3082
+ if (not a.is_contiguous and not a.is_transposed) or (not b.is_contiguous and not b.is_transposed) or (not c.is_contiguous) or (not d.is_contiguous):
3083
+ raise RuntimeError(
3084
+ "wp.matmul is only valid for contiguous arrays, with the exception that A and/or B may be transposed."
3085
+ )
3086
+
2892
3087
  m = a.shape[0]
2893
3088
  n = b.shape[1]
2894
3089
  k = a.shape[1]
@@ -2923,13 +3118,13 @@ def matmul(
2923
3118
  ctypes.c_void_p(d.ptr),
2924
3119
  alpha,
2925
3120
  beta,
2926
- True,
2927
- True,
3121
+ not a.is_transposed,
3122
+ not b.is_transposed,
2928
3123
  allow_tf32x3_arith,
2929
3124
  1,
2930
3125
  )
2931
3126
  if not ret:
2932
- raise RuntimeError("Matmul failed.")
3127
+ raise RuntimeError("matmul failed.")
2933
3128
 
2934
3129
 
2935
3130
  def adj_matmul(
@@ -2993,6 +3188,19 @@ def adj_matmul(
2993
3188
  "wp.adj_matmul currently only supports operation between {A, B, C, adj_D, adj_A, adj_B, adj_C} matrices of the same type."
2994
3189
  )
2995
3190
 
3191
+ if (
3192
+ (not a.is_contiguous and not a.is_transposed)
3193
+ or (not b.is_contiguous and not b.is_transposed)
3194
+ or (not c.is_contiguous)
3195
+ or (not adj_a.is_contiguous and not adj_a.is_transposed)
3196
+ or (not adj_b.is_contiguous and not adj_b.is_transposed)
3197
+ or (not adj_c.is_contiguous)
3198
+ or (not adj_d.is_contiguous)
3199
+ ):
3200
+ raise RuntimeError(
3201
+ "wp.matmul is only valid for contiguous arrays, with the exception that A and/or B and their associated adjoints may be transposed."
3202
+ )
3203
+
2996
3204
  m = a.shape[0]
2997
3205
  n = b.shape[1]
2998
3206
  k = a.shape[1]
@@ -3013,75 +3221,105 @@ def adj_matmul(
3013
3221
 
3014
3222
  # cpu fallback if no cuda devices found
3015
3223
  if device == "cpu":
3016
- adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose()))
3017
- adj_b.assign(alpha * (a.numpy().transpose() @ adj_d.numpy()))
3018
- adj_c.assign(beta * adj_d.numpy())
3224
+ adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose()) + adj_a.numpy())
3225
+ adj_b.assign(alpha * (a.numpy().transpose() @ adj_d.numpy()) + adj_b.numpy())
3226
+ adj_c.assign(beta * adj_d.numpy() + adj_c.numpy())
3019
3227
  return
3020
3228
 
3021
3229
  cc = device.arch
3022
3230
 
3023
3231
  # adj_a
3024
- ret = runtime.core.cutlass_gemm(
3025
- cc,
3026
- m,
3027
- k,
3028
- n,
3029
- type_typestr(a.dtype).encode(),
3030
- ctypes.c_void_p(adj_d.ptr),
3031
- ctypes.c_void_p(b.ptr),
3032
- ctypes.c_void_p(a.ptr),
3033
- ctypes.c_void_p(adj_a.ptr),
3034
- alpha,
3035
- 0.0,
3036
- True,
3037
- False,
3038
- allow_tf32x3_arith,
3039
- 1,
3040
- )
3041
- if not ret:
3042
- raise RuntimeError("adj_matmul failed.")
3232
+ if not a.is_transposed:
3233
+ ret = runtime.core.cutlass_gemm(
3234
+ cc,
3235
+ m,
3236
+ k,
3237
+ n,
3238
+ type_typestr(a.dtype).encode(),
3239
+ ctypes.c_void_p(adj_d.ptr),
3240
+ ctypes.c_void_p(b.ptr),
3241
+ ctypes.c_void_p(adj_a.ptr),
3242
+ ctypes.c_void_p(adj_a.ptr),
3243
+ alpha,
3244
+ 1.0,
3245
+ True,
3246
+ b.is_transposed,
3247
+ allow_tf32x3_arith,
3248
+ 1,
3249
+ )
3250
+ if not ret:
3251
+ raise RuntimeError("adj_matmul failed.")
3252
+ else:
3253
+ ret = runtime.core.cutlass_gemm(
3254
+ cc,
3255
+ k,
3256
+ m,
3257
+ n,
3258
+ type_typestr(a.dtype).encode(),
3259
+ ctypes.c_void_p(b.ptr),
3260
+ ctypes.c_void_p(adj_d.ptr),
3261
+ ctypes.c_void_p(adj_a.ptr),
3262
+ ctypes.c_void_p(adj_a.ptr),
3263
+ alpha,
3264
+ 1.0,
3265
+ not b.is_transposed,
3266
+ False,
3267
+ allow_tf32x3_arith,
3268
+ 1,
3269
+ )
3270
+ if not ret:
3271
+ raise RuntimeError("adj_matmul failed.")
3043
3272
 
3044
3273
  # adj_b
3045
- ret = runtime.core.cutlass_gemm(
3046
- cc,
3047
- k,
3048
- n,
3049
- m,
3050
- type_typestr(a.dtype).encode(),
3051
- ctypes.c_void_p(a.ptr),
3052
- ctypes.c_void_p(adj_d.ptr),
3053
- ctypes.c_void_p(b.ptr),
3054
- ctypes.c_void_p(adj_b.ptr),
3055
- alpha,
3056
- 0.0,
3057
- False,
3058
- True,
3059
- allow_tf32x3_arith,
3060
- 1,
3061
- )
3062
- if not ret:
3063
- raise RuntimeError("adj_matmul failed.")
3274
+ if not b.is_transposed:
3275
+ ret = runtime.core.cutlass_gemm(
3276
+ cc,
3277
+ k,
3278
+ n,
3279
+ m,
3280
+ type_typestr(a.dtype).encode(),
3281
+ ctypes.c_void_p(a.ptr),
3282
+ ctypes.c_void_p(adj_d.ptr),
3283
+ ctypes.c_void_p(adj_b.ptr),
3284
+ ctypes.c_void_p(adj_b.ptr),
3285
+ alpha,
3286
+ 1.0,
3287
+ a.is_transposed,
3288
+ True,
3289
+ allow_tf32x3_arith,
3290
+ 1,
3291
+ )
3292
+ if not ret:
3293
+ raise RuntimeError("adj_matmul failed.")
3294
+ else:
3295
+ ret = runtime.core.cutlass_gemm(
3296
+ cc,
3297
+ n,
3298
+ k,
3299
+ m,
3300
+ type_typestr(a.dtype).encode(),
3301
+ ctypes.c_void_p(adj_d.ptr),
3302
+ ctypes.c_void_p(a.ptr),
3303
+ ctypes.c_void_p(adj_b.ptr),
3304
+ ctypes.c_void_p(adj_b.ptr),
3305
+ alpha,
3306
+ 1.0,
3307
+ False,
3308
+ not a.is_transposed,
3309
+ allow_tf32x3_arith,
3310
+ 1,
3311
+ )
3312
+ if not ret:
3313
+ raise RuntimeError("adj_matmul failed.")
3064
3314
 
3065
3315
  # adj_c
3066
- ret = runtime.core.cutlass_gemm(
3067
- cc,
3068
- m,
3069
- n,
3070
- k,
3071
- type_typestr(a.dtype).encode(),
3072
- ctypes.c_void_p(a.ptr),
3073
- ctypes.c_void_p(b.ptr),
3074
- ctypes.c_void_p(adj_d.ptr),
3075
- ctypes.c_void_p(adj_c.ptr),
3076
- 0.0,
3077
- beta,
3078
- True,
3079
- True,
3080
- allow_tf32x3_arith,
3081
- 1,
3316
+ warp.launch(
3317
+ kernel=warp.utils.add_kernel_2d,
3318
+ dim=adj_c.shape,
3319
+ inputs=[adj_c, adj_d, adj_d.dtype(beta)],
3320
+ device=device,
3321
+ record_tape=False
3082
3322
  )
3083
- if not ret:
3084
- raise RuntimeError("adj_matmul failed.")
3085
3323
 
3086
3324
 
3087
3325
  def batched_matmul(
@@ -3120,6 +3358,11 @@ def batched_matmul(
3120
3358
  "wp.batched_matmul currently only supports operation between {A, B, C, D} matrices of the same type."
3121
3359
  )
3122
3360
 
3361
+ if (not a.is_contiguous and not a.is_transposed) or (not b.is_contiguous and not b.is_transposed) or (not c.is_contiguous) or (not d.is_contiguous):
3362
+ raise RuntimeError(
3363
+ "wp.matmul is only valid for contiguous arrays, with the exception that A and/or B may be transposed."
3364
+ )
3365
+
3123
3366
  m = a.shape[1]
3124
3367
  n = b.shape[2]
3125
3368
  k = a.shape[2]
@@ -3131,7 +3374,7 @@ def batched_matmul(
3131
3374
 
3132
3375
  if runtime.tape:
3133
3376
  runtime.tape.record_func(
3134
- backward=lambda: adj_matmul(
3377
+ backward=lambda: adj_batched_matmul(
3135
3378
  a, b, c, a.grad, b.grad, c.grad, d.grad, alpha, beta, allow_tf32x3_arith, device
3136
3379
  ),
3137
3380
  arrays=[a, b, c, d],
@@ -3142,26 +3385,55 @@ def batched_matmul(
3142
3385
  d.assign(alpha * np.matmul(a.numpy(), b.numpy()) + beta * c.numpy())
3143
3386
  return
3144
3387
 
3388
+ # handle case in which batch_count exceeds max_batch_count, which is a CUDA array size maximum
3389
+ max_batch_count = 65535
3390
+ iters = int(batch_count / max_batch_count)
3391
+ remainder = batch_count % max_batch_count
3392
+
3145
3393
  cc = device.arch
3394
+ for i in range(iters):
3395
+ idx_start = i * max_batch_count
3396
+ idx_end = (i + 1) * max_batch_count if i < iters - 1 else batch_count
3397
+ ret = runtime.core.cutlass_gemm(
3398
+ cc,
3399
+ m,
3400
+ n,
3401
+ k,
3402
+ type_typestr(a.dtype).encode(),
3403
+ ctypes.c_void_p(a[idx_start:idx_end,:,:].ptr),
3404
+ ctypes.c_void_p(b[idx_start:idx_end,:,:].ptr),
3405
+ ctypes.c_void_p(c[idx_start:idx_end,:,:].ptr),
3406
+ ctypes.c_void_p(d[idx_start:idx_end,:,:].ptr),
3407
+ alpha,
3408
+ beta,
3409
+ not a.is_transposed,
3410
+ not b.is_transposed,
3411
+ allow_tf32x3_arith,
3412
+ max_batch_count,
3413
+ )
3414
+ if not ret:
3415
+ raise RuntimeError("Batched matmul failed.")
3416
+
3417
+ idx_start = iters * max_batch_count
3146
3418
  ret = runtime.core.cutlass_gemm(
3147
3419
  cc,
3148
3420
  m,
3149
3421
  n,
3150
3422
  k,
3151
3423
  type_typestr(a.dtype).encode(),
3152
- ctypes.c_void_p(a.ptr),
3153
- ctypes.c_void_p(b.ptr),
3154
- ctypes.c_void_p(c.ptr),
3155
- ctypes.c_void_p(d.ptr),
3424
+ ctypes.c_void_p(a[idx_start:,:,:].ptr),
3425
+ ctypes.c_void_p(b[idx_start:,:,:].ptr),
3426
+ ctypes.c_void_p(c[idx_start:,:,:].ptr),
3427
+ ctypes.c_void_p(d[idx_start:,:,:].ptr),
3156
3428
  alpha,
3157
3429
  beta,
3158
- True,
3159
- True,
3430
+ not a.is_transposed,
3431
+ not b.is_transposed,
3160
3432
  allow_tf32x3_arith,
3161
- batch_count,
3433
+ remainder,
3162
3434
  )
3163
3435
  if not ret:
3164
- raise RuntimeError("Batched matmul failed.")
3436
+ raise RuntimeError("Batched matmul failed.")
3165
3437
 
3166
3438
 
3167
3439
  def adj_batched_matmul(
@@ -3241,78 +3513,215 @@ def adj_batched_matmul(
3241
3513
  )
3242
3514
  )
3243
3515
 
3516
+ if (
3517
+ (not a.is_contiguous and not a.is_transposed)
3518
+ or (not b.is_contiguous and not b.is_transposed)
3519
+ or (not c.is_contiguous)
3520
+ or (not adj_a.is_contiguous and not adj_a.is_transposed)
3521
+ or (not adj_b.is_contiguous and not adj_b.is_transposed)
3522
+ or (not adj_c.is_contiguous)
3523
+ or (not adj_d.is_contiguous)
3524
+ ):
3525
+ raise RuntimeError(
3526
+ "wp.matmul is only valid for contiguous arrays, with the exception that A and/or B and their associated adjoints may be transposed."
3527
+ )
3528
+
3244
3529
  # cpu fallback if no cuda devices found
3245
3530
  if device == "cpu":
3246
- adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose((0, 2, 1))))
3247
- adj_b.assign(alpha * np.matmul(a.numpy().transpose((0, 2, 1)), adj_d.numpy()))
3248
- adj_c.assign(beta * adj_d.numpy())
3531
+ adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose((0, 2, 1))) + adj_a.numpy())
3532
+ adj_b.assign(alpha * np.matmul(a.numpy().transpose((0, 2, 1)), adj_d.numpy()) + adj_b.numpy())
3533
+ adj_c.assign(beta * adj_d.numpy() + adj_c.numpy())
3249
3534
  return
3250
3535
 
3536
+ # handle case in which batch_count exceeds max_batch_count, which is a CUDA array size maximum
3537
+ max_batch_count = 65535
3538
+ iters = int(batch_count / max_batch_count)
3539
+ remainder = batch_count % max_batch_count
3540
+
3251
3541
  cc = device.arch
3252
3542
 
3543
+ for i in range(iters):
3544
+ idx_start = i * max_batch_count
3545
+ idx_end = (i + 1) * max_batch_count if i < iters - 1 else batch_count
3546
+
3547
+ # adj_a
3548
+ if not a.is_transposed:
3549
+ ret = runtime.core.cutlass_gemm(
3550
+ cc,
3551
+ m,
3552
+ k,
3553
+ n,
3554
+ type_typestr(a.dtype).encode(),
3555
+ ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
3556
+ ctypes.c_void_p(b[idx_start:idx_end,:,:].ptr),
3557
+ ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
3558
+ ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
3559
+ alpha,
3560
+ 1.0,
3561
+ True,
3562
+ b.is_transposed,
3563
+ allow_tf32x3_arith,
3564
+ max_batch_count,
3565
+ )
3566
+ if not ret:
3567
+ raise RuntimeError("adj_matmul failed.")
3568
+ else:
3569
+ ret = runtime.core.cutlass_gemm(
3570
+ cc,
3571
+ k,
3572
+ m,
3573
+ n,
3574
+ type_typestr(a.dtype).encode(),
3575
+ ctypes.c_void_p(b[idx_start:idx_end,:,:].ptr),
3576
+ ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
3577
+ ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
3578
+ ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
3579
+ alpha,
3580
+ 1.0,
3581
+ not b.is_transposed,
3582
+ False,
3583
+ allow_tf32x3_arith,
3584
+ max_batch_count,
3585
+ )
3586
+ if not ret:
3587
+ raise RuntimeError("adj_matmul failed.")
3588
+
3589
+ # adj_b
3590
+ if not b.is_transposed:
3591
+ ret = runtime.core.cutlass_gemm(
3592
+ cc,
3593
+ k,
3594
+ n,
3595
+ m,
3596
+ type_typestr(a.dtype).encode(),
3597
+ ctypes.c_void_p(a[idx_start:idx_end,:,:].ptr),
3598
+ ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
3599
+ ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
3600
+ ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
3601
+ alpha,
3602
+ 1.0,
3603
+ a.is_transposed,
3604
+ True,
3605
+ allow_tf32x3_arith,
3606
+ max_batch_count,
3607
+ )
3608
+ if not ret:
3609
+ raise RuntimeError("adj_matmul failed.")
3610
+ else:
3611
+ ret = runtime.core.cutlass_gemm(
3612
+ cc,
3613
+ n,
3614
+ k,
3615
+ m,
3616
+ type_typestr(a.dtype).encode(),
3617
+ ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
3618
+ ctypes.c_void_p(a[idx_start:idx_end,:,:].ptr),
3619
+ ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
3620
+ ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
3621
+ alpha,
3622
+ 1.0,
3623
+ False,
3624
+ not a.is_transposed,
3625
+ allow_tf32x3_arith,
3626
+ max_batch_count,
3627
+ )
3628
+ if not ret:
3629
+ raise RuntimeError("adj_matmul failed.")
3630
+
3631
+ idx_start = iters * max_batch_count
3632
+
3253
3633
  # adj_a
3254
- ret = runtime.core.cutlass_gemm(
3255
- cc,
3256
- m,
3257
- k,
3258
- n,
3259
- type_typestr(a.dtype).encode(),
3260
- ctypes.c_void_p(adj_d.ptr),
3261
- ctypes.c_void_p(b.ptr),
3262
- ctypes.c_void_p(a.ptr),
3263
- ctypes.c_void_p(adj_a.ptr),
3264
- alpha,
3265
- 0.0,
3266
- True,
3267
- False,
3268
- allow_tf32x3_arith,
3269
- batch_count,
3270
- )
3271
- if not ret:
3272
- raise RuntimeError("adj_matmul failed.")
3634
+ if not a.is_transposed:
3635
+ ret = runtime.core.cutlass_gemm(
3636
+ cc,
3637
+ m,
3638
+ k,
3639
+ n,
3640
+ type_typestr(a.dtype).encode(),
3641
+ ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
3642
+ ctypes.c_void_p(b[idx_start:,:,:].ptr),
3643
+ ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
3644
+ ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
3645
+ alpha,
3646
+ 1.0,
3647
+ True,
3648
+ b.is_transposed,
3649
+ allow_tf32x3_arith,
3650
+ remainder,
3651
+ )
3652
+ if not ret:
3653
+ raise RuntimeError("adj_matmul failed.")
3654
+ else:
3655
+ ret = runtime.core.cutlass_gemm(
3656
+ cc,
3657
+ k,
3658
+ m,
3659
+ n,
3660
+ type_typestr(a.dtype).encode(),
3661
+ ctypes.c_void_p(b[idx_start:,:,:].ptr),
3662
+ ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
3663
+ ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
3664
+ ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
3665
+ alpha,
3666
+ 1.0,
3667
+ not b.is_transposed,
3668
+ False,
3669
+ allow_tf32x3_arith,
3670
+ remainder,
3671
+ )
3672
+ if not ret:
3673
+ raise RuntimeError("adj_matmul failed.")
3273
3674
 
3274
3675
  # adj_b
3275
- ret = runtime.core.cutlass_gemm(
3276
- cc,
3277
- k,
3278
- n,
3279
- m,
3280
- type_typestr(a.dtype).encode(),
3281
- ctypes.c_void_p(a.ptr),
3282
- ctypes.c_void_p(adj_d.ptr),
3283
- ctypes.c_void_p(b.ptr),
3284
- ctypes.c_void_p(adj_b.ptr),
3285
- alpha,
3286
- 0.0,
3287
- False,
3288
- True,
3289
- allow_tf32x3_arith,
3290
- batch_count,
3291
- )
3292
- if not ret:
3293
- raise RuntimeError("adj_matmul failed.")
3676
+ if not b.is_transposed:
3677
+ ret = runtime.core.cutlass_gemm(
3678
+ cc,
3679
+ k,
3680
+ n,
3681
+ m,
3682
+ type_typestr(a.dtype).encode(),
3683
+ ctypes.c_void_p(a[idx_start:,:,:].ptr),
3684
+ ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
3685
+ ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
3686
+ ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
3687
+ alpha,
3688
+ 1.0,
3689
+ a.is_transposed,
3690
+ True,
3691
+ allow_tf32x3_arith,
3692
+ remainder,
3693
+ )
3694
+ if not ret:
3695
+ raise RuntimeError("adj_matmul failed.")
3696
+ else:
3697
+ ret = runtime.core.cutlass_gemm(
3698
+ cc,
3699
+ n,
3700
+ k,
3701
+ m,
3702
+ type_typestr(a.dtype).encode(),
3703
+ ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
3704
+ ctypes.c_void_p(a[idx_start:,:,:].ptr),
3705
+ ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
3706
+ ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
3707
+ alpha,
3708
+ 1.0,
3709
+ False,
3710
+ not a.is_transposed,
3711
+ allow_tf32x3_arith,
3712
+ remainder,
3713
+ )
3714
+ if not ret:
3715
+ raise RuntimeError("adj_matmul failed.")
3294
3716
 
3295
3717
  # adj_c
3296
- ret = runtime.core.cutlass_gemm(
3297
- cc,
3298
- m,
3299
- n,
3300
- k,
3301
- type_typestr(a.dtype).encode(),
3302
- ctypes.c_void_p(a.ptr),
3303
- ctypes.c_void_p(b.ptr),
3304
- ctypes.c_void_p(adj_d.ptr),
3305
- ctypes.c_void_p(adj_c.ptr),
3306
- 0.0,
3307
- beta,
3308
- True,
3309
- True,
3310
- allow_tf32x3_arith,
3311
- batch_count,
3718
+ warp.launch(
3719
+ kernel=warp.utils.add_kernel_3d,
3720
+ dim=adj_c.shape,
3721
+ inputs=[adj_c, adj_d, adj_d.dtype(beta)],
3722
+ device=device,
3723
+ record_tape=False
3312
3724
  )
3313
- if not ret:
3314
- raise RuntimeError("adj_matmul failed.")
3315
-
3316
3725
 
3317
3726
  class HashGrid:
3318
3727
  def __init__(self, dim_x, dim_y, dim_z, device=None):
@@ -3511,7 +3920,7 @@ def type_matches_template(arg_type, template_type):
3511
3920
  return True
3512
3921
  elif is_array(template_type):
3513
3922
  # ensure the argument type is a non-generic array with matching dtype and dimensionality
3514
- if type(arg_type) != type(template_type):
3923
+ if type(arg_type) is not type(template_type):
3515
3924
  return False
3516
3925
  if not type_matches_template(arg_type.dtype, template_type.dtype):
3517
3926
  return False
@@ -3567,7 +3976,7 @@ def infer_argument_types(args, template_types, arg_names=None):
3567
3976
  arg_types.append(arg._cls)
3568
3977
  # elif arg_type in [warp.types.launch_bounds_t, warp.types.shape_t, warp.types.range_t]:
3569
3978
  # arg_types.append(arg_type)
3570
- # elif arg_type in [warp.hash_grid_query_t, warp.mesh_query_aabb_t, warp.bvh_query_t]:
3979
+ # elif arg_type in [warp.hash_grid_query_t, warp.mesh_query_aabb_t, warp.mesh_query_point_t, warp.mesh_query_ray_t, warp.bvh_query_t]:
3571
3980
  # arg_types.append(arg_type)
3572
3981
  elif arg is None:
3573
3982
  # allow passing None for arrays
@@ -3605,6 +4014,8 @@ simple_type_codes = {
3605
4014
  launch_bounds_t: "lb",
3606
4015
  hash_grid_query_t: "hgq",
3607
4016
  mesh_query_aabb_t: "mqa",
4017
+ mesh_query_point_t: "mqp",
4018
+ mesh_query_ray_t: "mqr",
3608
4019
  bvh_query_t: "bvhq",
3609
4020
  }
3610
4021