warp-lang 1.0.0b5__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. docs/conf.py +3 -4
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/example_dem.py +28 -26
  6. examples/example_diffray.py +37 -30
  7. examples/example_fluid.py +7 -3
  8. examples/example_jacobian_ik.py +1 -1
  9. examples/example_mesh_intersect.py +10 -7
  10. examples/example_nvdb.py +3 -3
  11. examples/example_render_opengl.py +19 -10
  12. examples/example_sim_cartpole.py +9 -5
  13. examples/example_sim_cloth.py +29 -25
  14. examples/example_sim_fk_grad.py +2 -2
  15. examples/example_sim_fk_grad_torch.py +3 -3
  16. examples/example_sim_grad_bounce.py +11 -8
  17. examples/example_sim_grad_cloth.py +12 -9
  18. examples/example_sim_granular.py +2 -2
  19. examples/example_sim_granular_collision_sdf.py +13 -13
  20. examples/example_sim_neo_hookean.py +3 -3
  21. examples/example_sim_particle_chain.py +2 -2
  22. examples/example_sim_quadruped.py +8 -5
  23. examples/example_sim_rigid_chain.py +8 -5
  24. examples/example_sim_rigid_contact.py +13 -10
  25. examples/example_sim_rigid_fem.py +2 -2
  26. examples/example_sim_rigid_gyroscopic.py +2 -2
  27. examples/example_sim_rigid_kinematics.py +1 -1
  28. examples/example_sim_trajopt.py +3 -2
  29. examples/fem/example_apic_fluid.py +5 -7
  30. examples/fem/example_diffusion_mgpu.py +18 -16
  31. warp/__init__.py +3 -2
  32. warp/bin/warp.so +0 -0
  33. warp/build_dll.py +29 -9
  34. warp/builtins.py +206 -7
  35. warp/codegen.py +58 -38
  36. warp/config.py +3 -1
  37. warp/context.py +234 -128
  38. warp/fem/__init__.py +2 -2
  39. warp/fem/cache.py +2 -1
  40. warp/fem/field/nodal_field.py +18 -17
  41. warp/fem/geometry/hexmesh.py +11 -6
  42. warp/fem/geometry/quadmesh_2d.py +16 -12
  43. warp/fem/geometry/tetmesh.py +19 -8
  44. warp/fem/geometry/trimesh_2d.py +18 -7
  45. warp/fem/integrate.py +341 -196
  46. warp/fem/quadrature/__init__.py +1 -1
  47. warp/fem/quadrature/pic_quadrature.py +138 -53
  48. warp/fem/quadrature/quadrature.py +81 -9
  49. warp/fem/space/__init__.py +1 -1
  50. warp/fem/space/basis_space.py +169 -51
  51. warp/fem/space/grid_2d_function_space.py +2 -2
  52. warp/fem/space/grid_3d_function_space.py +2 -2
  53. warp/fem/space/hexmesh_function_space.py +2 -2
  54. warp/fem/space/partition.py +9 -6
  55. warp/fem/space/quadmesh_2d_function_space.py +2 -2
  56. warp/fem/space/shape/cube_shape_function.py +27 -15
  57. warp/fem/space/shape/square_shape_function.py +29 -18
  58. warp/fem/space/tetmesh_function_space.py +2 -2
  59. warp/fem/space/topology.py +10 -0
  60. warp/fem/space/trimesh_2d_function_space.py +2 -2
  61. warp/fem/utils.py +10 -5
  62. warp/native/array.h +49 -8
  63. warp/native/builtin.h +31 -14
  64. warp/native/cuda_util.cpp +8 -3
  65. warp/native/cuda_util.h +1 -0
  66. warp/native/exports.h +1177 -1108
  67. warp/native/intersect.h +4 -4
  68. warp/native/intersect_adj.h +8 -8
  69. warp/native/mat.h +65 -6
  70. warp/native/mesh.h +126 -5
  71. warp/native/quat.h +28 -4
  72. warp/native/vec.h +76 -14
  73. warp/native/warp.cu +1 -6
  74. warp/render/render_opengl.py +261 -109
  75. warp/sim/import_mjcf.py +13 -7
  76. warp/sim/import_urdf.py +14 -14
  77. warp/sim/inertia.py +17 -18
  78. warp/sim/model.py +67 -67
  79. warp/sim/render.py +1 -1
  80. warp/sparse.py +6 -6
  81. warp/stubs.py +19 -81
  82. warp/tape.py +1 -1
  83. warp/tests/__main__.py +3 -6
  84. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  85. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  86. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  87. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  88. warp/tests/aux_test_unresolved_func.py +14 -0
  89. warp/tests/aux_test_unresolved_symbol.py +14 -0
  90. warp/tests/{test_kinematics.py → disabled_kinematics.py} +10 -12
  91. warp/tests/run_coverage_serial.py +31 -0
  92. warp/tests/test_adam.py +102 -106
  93. warp/tests/test_arithmetic.py +39 -40
  94. warp/tests/test_array.py +46 -48
  95. warp/tests/test_array_reduce.py +25 -19
  96. warp/tests/test_atomic.py +62 -26
  97. warp/tests/test_bool.py +16 -11
  98. warp/tests/test_builtins_resolution.py +1292 -0
  99. warp/tests/test_bvh.py +9 -12
  100. warp/tests/test_closest_point_edge_edge.py +53 -57
  101. warp/tests/test_codegen.py +164 -134
  102. warp/tests/test_compile_consts.py +13 -19
  103. warp/tests/test_conditional.py +30 -32
  104. warp/tests/test_copy.py +9 -12
  105. warp/tests/test_ctypes.py +90 -98
  106. warp/tests/test_dense.py +20 -14
  107. warp/tests/test_devices.py +34 -35
  108. warp/tests/test_dlpack.py +74 -75
  109. warp/tests/test_examples.py +215 -97
  110. warp/tests/test_fabricarray.py +15 -21
  111. warp/tests/test_fast_math.py +14 -11
  112. warp/tests/test_fem.py +280 -97
  113. warp/tests/test_fp16.py +19 -15
  114. warp/tests/test_func.py +177 -194
  115. warp/tests/test_generics.py +71 -77
  116. warp/tests/test_grad.py +83 -32
  117. warp/tests/test_grad_customs.py +7 -9
  118. warp/tests/test_hash_grid.py +6 -10
  119. warp/tests/test_import.py +9 -23
  120. warp/tests/test_indexedarray.py +19 -21
  121. warp/tests/test_intersect.py +15 -9
  122. warp/tests/test_large.py +17 -19
  123. warp/tests/test_launch.py +14 -17
  124. warp/tests/test_lerp.py +63 -63
  125. warp/tests/test_lvalue.py +84 -35
  126. warp/tests/test_marching_cubes.py +9 -13
  127. warp/tests/test_mat.py +388 -3004
  128. warp/tests/test_mat_lite.py +9 -12
  129. warp/tests/test_mat_scalar_ops.py +2889 -0
  130. warp/tests/test_math.py +10 -11
  131. warp/tests/test_matmul.py +104 -100
  132. warp/tests/test_matmul_lite.py +72 -98
  133. warp/tests/test_mesh.py +35 -32
  134. warp/tests/test_mesh_query_aabb.py +18 -25
  135. warp/tests/test_mesh_query_point.py +39 -23
  136. warp/tests/test_mesh_query_ray.py +9 -21
  137. warp/tests/test_mlp.py +8 -9
  138. warp/tests/test_model.py +89 -93
  139. warp/tests/test_modules_lite.py +15 -25
  140. warp/tests/test_multigpu.py +87 -114
  141. warp/tests/test_noise.py +10 -12
  142. warp/tests/test_operators.py +14 -21
  143. warp/tests/test_options.py +10 -11
  144. warp/tests/test_pinned.py +16 -18
  145. warp/tests/test_print.py +16 -20
  146. warp/tests/test_quat.py +121 -88
  147. warp/tests/test_rand.py +12 -13
  148. warp/tests/test_reload.py +27 -32
  149. warp/tests/test_rounding.py +7 -10
  150. warp/tests/test_runlength_encode.py +105 -106
  151. warp/tests/test_smoothstep.py +8 -9
  152. warp/tests/test_snippet.py +13 -22
  153. warp/tests/test_sparse.py +30 -29
  154. warp/tests/test_spatial.py +179 -174
  155. warp/tests/test_streams.py +100 -107
  156. warp/tests/test_struct.py +98 -67
  157. warp/tests/test_tape.py +11 -17
  158. warp/tests/test_torch.py +89 -86
  159. warp/tests/test_transient_module.py +9 -12
  160. warp/tests/test_types.py +328 -50
  161. warp/tests/test_utils.py +217 -218
  162. warp/tests/test_vec.py +133 -2133
  163. warp/tests/test_vec_lite.py +8 -11
  164. warp/tests/test_vec_scalar_ops.py +2099 -0
  165. warp/tests/test_volume.py +391 -382
  166. warp/tests/test_volume_write.py +122 -135
  167. warp/tests/unittest_serial.py +35 -0
  168. warp/tests/unittest_suites.py +291 -0
  169. warp/tests/{test_base.py → unittest_utils.py} +138 -25
  170. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  171. warp/tests/{test_debug.py → walkthough_debug.py} +2 -15
  172. warp/thirdparty/unittest_parallel.py +257 -54
  173. warp/types.py +119 -98
  174. warp/utils.py +14 -0
  175. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/METADATA +2 -1
  176. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/RECORD +182 -178
  177. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  178. warp/tests/test_all.py +0 -239
  179. warp/tests/test_conditional_unequal_types_kernels.py +0 -14
  180. warp/tests/test_coverage.py +0 -38
  181. warp/tests/test_unresolved_func.py +0 -7
  182. warp/tests/test_unresolved_symbol.py +0 -7
  183. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  184. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  185. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  186. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  187. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/types.py CHANGED
@@ -178,13 +178,13 @@ def vector(length, dtype):
178
178
  return warp.add(self, y)
179
179
 
180
180
  def __radd__(self, y):
181
- return warp.add(self, y)
181
+ return warp.add(y, self)
182
182
 
183
183
  def __sub__(self, y):
184
184
  return warp.sub(self, y)
185
185
 
186
- def __rsub__(self, x):
187
- return warp.sub(x, self)
186
+ def __rsub__(self, y):
187
+ return warp.sub(y, self)
188
188
 
189
189
  def __mul__(self, y):
190
190
  return warp.mul(self, y)
@@ -195,7 +195,7 @@ def vector(length, dtype):
195
195
  def __truediv__(self, y):
196
196
  return warp.div(self, y)
197
197
 
198
- def __rdiv__(self, x):
198
+ def __rtruediv__(self, x):
199
199
  return warp.div(x, self)
200
200
 
201
201
  def __pos__(self):
@@ -294,13 +294,13 @@ def matrix(shape, dtype):
294
294
  return warp.add(self, y)
295
295
 
296
296
  def __radd__(self, y):
297
- return warp.add(self, y)
297
+ return warp.add(y, self)
298
298
 
299
299
  def __sub__(self, y):
300
300
  return warp.sub(self, y)
301
301
 
302
- def __rsub__(self, x):
303
- return warp.sub(x, self)
302
+ def __rsub__(self, y):
303
+ return warp.sub(y, self)
304
304
 
305
305
  def __mul__(self, y):
306
306
  return warp.mul(self, y)
@@ -317,7 +317,7 @@ def matrix(shape, dtype):
317
317
  def __truediv__(self, y):
318
318
  return warp.div(self, y)
319
319
 
320
- def __rdiv__(self, x):
320
+ def __rtruediv__(self, x):
321
321
  return warp.div(x, self)
322
322
 
323
323
  def __pos__(self):
@@ -582,11 +582,11 @@ def transformation(dtype=Any):
582
582
 
583
583
  @property
584
584
  def p(self):
585
- return self[0:3]
585
+ return vec3(self[0:3])
586
586
 
587
587
  @property
588
588
  def q(self):
589
- return self[3:7]
589
+ return quat(self[3:7])
590
590
 
591
591
  return transform_t
592
592
 
@@ -910,18 +910,21 @@ class range_t:
910
910
 
911
911
  # definition just for kernel type (cannot be a parameter), see bvh.h
912
912
  class bvh_query_t:
913
+ """Object used to track state during BVH traversal."""
913
914
  def __init__(self):
914
915
  pass
915
916
 
916
917
 
917
918
  # definition just for kernel type (cannot be a parameter), see mesh.h
918
919
  class mesh_query_aabb_t:
920
+ """Object used to track state during mesh traversal."""
919
921
  def __init__(self):
920
922
  pass
921
923
 
922
924
 
923
925
  # definition just for kernel type (cannot be a parameter), see hash_grid.h
924
926
  class hash_grid_query_t:
927
+ """Object used to track state during neighbor traversal."""
925
928
  def __init__(self):
926
929
  pass
927
930
 
@@ -2979,6 +2982,67 @@ class Volume:
2979
2982
  return volume
2980
2983
 
2981
2984
 
2985
+ # definition just for kernel type (cannot be a parameter), see mesh.h
2986
+ # NOTE: its layout must match the corresponding struct defined in C.
2987
+ # NOTE: it needs to be defined after `indexedarray` to workaround a circular import issue.
2988
+ class mesh_query_point_t:
2989
+ """Output for the mesh query point functions.
2990
+
2991
+ Attributes:
2992
+ result (bool): Whether a point is found within the given constraints.
2993
+ sign (float32): A value < 0 if query point is inside the mesh, >=0 otherwise.
2994
+ Note that mesh must be watertight for this to be robust
2995
+ face (int32): Index of the closest face.
2996
+ u (float32): Barycentric u coordinate of the closest point.
2997
+ v (float32): Barycentric v coordinate of the closest point.
2998
+
2999
+ See Also:
3000
+ :func:`mesh_query_point`, :func:`mesh_query_point_no_sign`,
3001
+ :func:`mesh_query_furthest_point_no_sign`,
3002
+ :func:`mesh_query_point_sign_normal`,
3003
+ and :func:`mesh_query_point_sign_winding_number`.
3004
+ """
3005
+ from warp.codegen import Var
3006
+
3007
+ vars = {
3008
+ "result": Var("result", bool),
3009
+ "sign": Var("sign", float32),
3010
+ "face": Var("face", int32),
3011
+ "u": Var("u", float32),
3012
+ "v": Var("v", float32),
3013
+ }
3014
+
3015
+
3016
+ # definition just for kernel type (cannot be a parameter), see mesh.h
3017
+ # NOTE: its layout must match the corresponding struct defined in C.
3018
+ class mesh_query_ray_t:
3019
+ """Output for the mesh query ray functions.
3020
+
3021
+ Attributes:
3022
+ result (bool): Whether a hit is found within the given constraints.
3023
+ sign (float32): A value > 0 if the ray hit in front of the face, returns < 0 otherwise.
3024
+ face (int32): Index of the closest face.
3025
+ t (float32): Distance of the closest hit along the ray.
3026
+ u (float32): Barycentric u coordinate of the closest hit.
3027
+ v (float32): Barycentric v coordinate of the closest hit.
3028
+ normal (vec3f): Face normal.
3029
+
3030
+ See Also:
3031
+ :func:`mesh_query_ray`.
3032
+ """
3033
+ from warp.codegen import Var
3034
+
3035
+ vars = {
3036
+ "result": Var("result", bool),
3037
+ "sign": Var("sign", float32),
3038
+ "face": Var("face", int32),
3039
+ "t": Var("t", float32),
3040
+ "u": Var("u", float32),
3041
+ "v": Var("v", float32),
3042
+ "normal": Var("normal", vec3),
3043
+ }
3044
+
3045
+
2982
3046
  def matmul(
2983
3047
  a: array2d,
2984
3048
  b: array2d,
@@ -3157,9 +3221,9 @@ def adj_matmul(
3157
3221
 
3158
3222
  # cpu fallback if no cuda devices found
3159
3223
  if device == "cpu":
3160
- adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose()))
3161
- adj_b.assign(alpha * (a.numpy().transpose() @ adj_d.numpy()))
3162
- adj_c.assign(beta * adj_d.numpy())
3224
+ adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose()) + adj_a.numpy())
3225
+ adj_b.assign(alpha * (a.numpy().transpose() @ adj_d.numpy()) + adj_b.numpy())
3226
+ adj_c.assign(beta * adj_d.numpy() + adj_c.numpy())
3163
3227
  return
3164
3228
 
3165
3229
  cc = device.arch
@@ -3174,10 +3238,10 @@ def adj_matmul(
3174
3238
  type_typestr(a.dtype).encode(),
3175
3239
  ctypes.c_void_p(adj_d.ptr),
3176
3240
  ctypes.c_void_p(b.ptr),
3177
- ctypes.c_void_p(a.ptr),
3241
+ ctypes.c_void_p(adj_a.ptr),
3178
3242
  ctypes.c_void_p(adj_a.ptr),
3179
3243
  alpha,
3180
- 0.0,
3244
+ 1.0,
3181
3245
  True,
3182
3246
  b.is_transposed,
3183
3247
  allow_tf32x3_arith,
@@ -3194,10 +3258,10 @@ def adj_matmul(
3194
3258
  type_typestr(a.dtype).encode(),
3195
3259
  ctypes.c_void_p(b.ptr),
3196
3260
  ctypes.c_void_p(adj_d.ptr),
3197
- ctypes.c_void_p(a.ptr),
3261
+ ctypes.c_void_p(adj_a.ptr),
3198
3262
  ctypes.c_void_p(adj_a.ptr),
3199
3263
  alpha,
3200
- 0.0,
3264
+ 1.0,
3201
3265
  not b.is_transposed,
3202
3266
  False,
3203
3267
  allow_tf32x3_arith,
@@ -3216,10 +3280,10 @@ def adj_matmul(
3216
3280
  type_typestr(a.dtype).encode(),
3217
3281
  ctypes.c_void_p(a.ptr),
3218
3282
  ctypes.c_void_p(adj_d.ptr),
3219
- ctypes.c_void_p(b.ptr),
3283
+ ctypes.c_void_p(adj_b.ptr),
3220
3284
  ctypes.c_void_p(adj_b.ptr),
3221
3285
  alpha,
3222
- 0.0,
3286
+ 1.0,
3223
3287
  a.is_transposed,
3224
3288
  True,
3225
3289
  allow_tf32x3_arith,
@@ -3236,10 +3300,10 @@ def adj_matmul(
3236
3300
  type_typestr(a.dtype).encode(),
3237
3301
  ctypes.c_void_p(adj_d.ptr),
3238
3302
  ctypes.c_void_p(a.ptr),
3239
- ctypes.c_void_p(b.ptr),
3303
+ ctypes.c_void_p(adj_b.ptr),
3240
3304
  ctypes.c_void_p(adj_b.ptr),
3241
3305
  alpha,
3242
- 0.0,
3306
+ 1.0,
3243
3307
  False,
3244
3308
  not a.is_transposed,
3245
3309
  allow_tf32x3_arith,
@@ -3249,25 +3313,13 @@ def adj_matmul(
3249
3313
  raise RuntimeError("adj_matmul failed.")
3250
3314
 
3251
3315
  # adj_c
3252
- ret = runtime.core.cutlass_gemm(
3253
- cc,
3254
- m,
3255
- n,
3256
- k,
3257
- type_typestr(a.dtype).encode(),
3258
- ctypes.c_void_p(a.ptr),
3259
- ctypes.c_void_p(b.ptr),
3260
- ctypes.c_void_p(adj_d.ptr),
3261
- ctypes.c_void_p(adj_c.ptr),
3262
- 0.0,
3263
- beta,
3264
- not a.is_transposed,
3265
- not b.is_transposed,
3266
- allow_tf32x3_arith,
3267
- 1,
3316
+ warp.launch(
3317
+ kernel=warp.utils.add_kernel_2d,
3318
+ dim=adj_c.shape,
3319
+ inputs=[adj_c, adj_d, adj_d.dtype(beta)],
3320
+ device=device,
3321
+ record_tape=False
3268
3322
  )
3269
- if not ret:
3270
- raise RuntimeError("adj_matmul failed.")
3271
3323
 
3272
3324
 
3273
3325
  def batched_matmul(
@@ -3476,9 +3528,9 @@ def adj_batched_matmul(
3476
3528
 
3477
3529
  # cpu fallback if no cuda devices found
3478
3530
  if device == "cpu":
3479
- adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose((0, 2, 1))))
3480
- adj_b.assign(alpha * np.matmul(a.numpy().transpose((0, 2, 1)), adj_d.numpy()))
3481
- adj_c.assign(beta * adj_d.numpy())
3531
+ adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose((0, 2, 1))) + adj_a.numpy())
3532
+ adj_b.assign(alpha * np.matmul(a.numpy().transpose((0, 2, 1)), adj_d.numpy()) + adj_b.numpy())
3533
+ adj_c.assign(beta * adj_d.numpy() + adj_c.numpy())
3482
3534
  return
3483
3535
 
3484
3536
  # handle case in which batch_count exceeds max_batch_count, which is a CUDA array size maximum
@@ -3502,10 +3554,10 @@ def adj_batched_matmul(
3502
3554
  type_typestr(a.dtype).encode(),
3503
3555
  ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
3504
3556
  ctypes.c_void_p(b[idx_start:idx_end,:,:].ptr),
3505
- ctypes.c_void_p(a[idx_start:idx_end,:,:].ptr),
3557
+ ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
3506
3558
  ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
3507
3559
  alpha,
3508
- 0.0,
3560
+ 1.0,
3509
3561
  True,
3510
3562
  b.is_transposed,
3511
3563
  allow_tf32x3_arith,
@@ -3522,10 +3574,10 @@ def adj_batched_matmul(
3522
3574
  type_typestr(a.dtype).encode(),
3523
3575
  ctypes.c_void_p(b[idx_start:idx_end,:,:].ptr),
3524
3576
  ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
3525
- ctypes.c_void_p(a[idx_start:idx_end,:,:].ptr),
3577
+ ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
3526
3578
  ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
3527
3579
  alpha,
3528
- 0.0,
3580
+ 1.0,
3529
3581
  not b.is_transposed,
3530
3582
  False,
3531
3583
  allow_tf32x3_arith,
@@ -3544,10 +3596,10 @@ def adj_batched_matmul(
3544
3596
  type_typestr(a.dtype).encode(),
3545
3597
  ctypes.c_void_p(a[idx_start:idx_end,:,:].ptr),
3546
3598
  ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
3547
- ctypes.c_void_p(b[idx_start:idx_end,:,:].ptr),
3599
+ ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
3548
3600
  ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
3549
3601
  alpha,
3550
- 0.0,
3602
+ 1.0,
3551
3603
  a.is_transposed,
3552
3604
  True,
3553
3605
  allow_tf32x3_arith,
@@ -3564,10 +3616,10 @@ def adj_batched_matmul(
3564
3616
  type_typestr(a.dtype).encode(),
3565
3617
  ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
3566
3618
  ctypes.c_void_p(a[idx_start:idx_end,:,:].ptr),
3567
- ctypes.c_void_p(b[idx_start:idx_end,:,:].ptr),
3619
+ ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
3568
3620
  ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
3569
3621
  alpha,
3570
- 0.0,
3622
+ 1.0,
3571
3623
  False,
3572
3624
  not a.is_transposed,
3573
3625
  allow_tf32x3_arith,
@@ -3575,27 +3627,6 @@ def adj_batched_matmul(
3575
3627
  )
3576
3628
  if not ret:
3577
3629
  raise RuntimeError("adj_matmul failed.")
3578
-
3579
- # adj_c
3580
- ret = runtime.core.cutlass_gemm(
3581
- cc,
3582
- m,
3583
- n,
3584
- k,
3585
- type_typestr(a.dtype).encode(),
3586
- ctypes.c_void_p(a[idx_start:idx_end,:,:].ptr),
3587
- ctypes.c_void_p(b[idx_start:idx_end,:,:].ptr),
3588
- ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
3589
- ctypes.c_void_p(adj_c[idx_start:idx_end,:,:].ptr),
3590
- 0.0,
3591
- beta,
3592
- not a.is_transposed,
3593
- not b.is_transposed,
3594
- allow_tf32x3_arith,
3595
- max_batch_count,
3596
- )
3597
- if not ret:
3598
- raise RuntimeError("adj_batched_matmul failed.")
3599
3630
 
3600
3631
  idx_start = iters * max_batch_count
3601
3632
 
@@ -3609,10 +3640,10 @@ def adj_batched_matmul(
3609
3640
  type_typestr(a.dtype).encode(),
3610
3641
  ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
3611
3642
  ctypes.c_void_p(b[idx_start:,:,:].ptr),
3612
- ctypes.c_void_p(a[idx_start:,:,:].ptr),
3643
+ ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
3613
3644
  ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
3614
3645
  alpha,
3615
- 0.0,
3646
+ 1.0,
3616
3647
  True,
3617
3648
  b.is_transposed,
3618
3649
  allow_tf32x3_arith,
@@ -3629,10 +3660,10 @@ def adj_batched_matmul(
3629
3660
  type_typestr(a.dtype).encode(),
3630
3661
  ctypes.c_void_p(b[idx_start:,:,:].ptr),
3631
3662
  ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
3632
- ctypes.c_void_p(a[idx_start:,:,:].ptr),
3663
+ ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
3633
3664
  ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
3634
3665
  alpha,
3635
- 0.0,
3666
+ 1.0,
3636
3667
  not b.is_transposed,
3637
3668
  False,
3638
3669
  allow_tf32x3_arith,
@@ -3651,10 +3682,10 @@ def adj_batched_matmul(
3651
3682
  type_typestr(a.dtype).encode(),
3652
3683
  ctypes.c_void_p(a[idx_start:,:,:].ptr),
3653
3684
  ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
3654
- ctypes.c_void_p(b[idx_start:,:,:].ptr),
3685
+ ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
3655
3686
  ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
3656
3687
  alpha,
3657
- 0.0,
3688
+ 1.0,
3658
3689
  a.is_transposed,
3659
3690
  True,
3660
3691
  allow_tf32x3_arith,
@@ -3671,10 +3702,10 @@ def adj_batched_matmul(
3671
3702
  type_typestr(a.dtype).encode(),
3672
3703
  ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
3673
3704
  ctypes.c_void_p(a[idx_start:,:,:].ptr),
3674
- ctypes.c_void_p(b[idx_start:,:,:].ptr),
3705
+ ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
3675
3706
  ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
3676
3707
  alpha,
3677
- 0.0,
3708
+ 1.0,
3678
3709
  False,
3679
3710
  not a.is_transposed,
3680
3711
  allow_tf32x3_arith,
@@ -3684,25 +3715,13 @@ def adj_batched_matmul(
3684
3715
  raise RuntimeError("adj_matmul failed.")
3685
3716
 
3686
3717
  # adj_c
3687
- ret = runtime.core.cutlass_gemm(
3688
- cc,
3689
- m,
3690
- n,
3691
- k,
3692
- type_typestr(a.dtype).encode(),
3693
- ctypes.c_void_p(a[idx_start:,:,:].ptr),
3694
- ctypes.c_void_p(b[idx_start:,:,:].ptr),
3695
- ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
3696
- ctypes.c_void_p(adj_c[idx_start:,:,:].ptr),
3697
- 0.0,
3698
- beta,
3699
- not a.is_transposed,
3700
- not b.is_transposed,
3701
- allow_tf32x3_arith,
3702
- remainder,
3718
+ warp.launch(
3719
+ kernel=warp.utils.add_kernel_3d,
3720
+ dim=adj_c.shape,
3721
+ inputs=[adj_c, adj_d, adj_d.dtype(beta)],
3722
+ device=device,
3723
+ record_tape=False
3703
3724
  )
3704
- if not ret:
3705
- raise RuntimeError("adj_batched_matmul failed.")
3706
3725
 
3707
3726
  class HashGrid:
3708
3727
  def __init__(self, dim_x, dim_y, dim_z, device=None):
@@ -3957,7 +3976,7 @@ def infer_argument_types(args, template_types, arg_names=None):
3957
3976
  arg_types.append(arg._cls)
3958
3977
  # elif arg_type in [warp.types.launch_bounds_t, warp.types.shape_t, warp.types.range_t]:
3959
3978
  # arg_types.append(arg_type)
3960
- # elif arg_type in [warp.hash_grid_query_t, warp.mesh_query_aabb_t, warp.bvh_query_t]:
3979
+ # elif arg_type in [warp.hash_grid_query_t, warp.mesh_query_aabb_t, warp.mesh_query_point_t, warp.mesh_query_ray_t, warp.bvh_query_t]:
3961
3980
  # arg_types.append(arg_type)
3962
3981
  elif arg is None:
3963
3982
  # allow passing None for arrays
@@ -3995,6 +4014,8 @@ simple_type_codes = {
3995
4014
  launch_bounds_t: "lb",
3996
4015
  hash_grid_query_t: "hgq",
3997
4016
  mesh_query_aabb_t: "mqa",
4017
+ mesh_query_point_t: "mqp",
4018
+ mesh_query_ray_t: "mqr",
3998
4019
  bvh_query_t: "bvhq",
3999
4020
  }
4000
4021
 
warp/utils.py CHANGED
@@ -666,3 +666,17 @@ class ScopedTimer:
666
666
  print("{}{} took {:.2f} ms".format(indent, self.name, self.elapsed))
667
667
 
668
668
  ScopedTimer.indent -= 1
669
+
670
+
671
+ # helper kernels for adj_matmul
672
+ @wp.kernel
673
+ def add_kernel_2d(x: wp.array2d(dtype=Any), acc: wp.array2d(dtype=Any), beta: Any):
674
+ i, j = wp.tid()
675
+
676
+ x[i,j] = x[i,j] + beta * acc[i,j]
677
+
678
+ @wp.kernel
679
+ def add_kernel_3d(x: wp.array3d(dtype=Any), acc: wp.array3d(dtype=Any), beta: Any):
680
+ i, j, k = wp.tid()
681
+
682
+ x[i,j,k] = x[i,j,k] + beta * acc[i,j,k]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: warp-lang
3
- Version: 1.0.0b5
3
+ Version: 1.0.0b6
4
4
  Summary: A Python framework for high-performance simulation and graphics programming
5
5
  Author-email: NVIDIA <mmacklin@nvidia.com>
6
6
  License: NVSCL
@@ -25,6 +25,7 @@ Requires-Dist: isort ; extra == 'dev'
25
25
  Requires-Dist: nvtx ; extra == 'dev'
26
26
  Requires-Dist: furo ; extra == 'dev'
27
27
  Requires-Dist: sphinx-copybutton ; extra == 'dev'
28
+ Requires-Dist: coverage[toml] ; extra == 'dev'
28
29
 
29
30
  # NVIDIA Warp
30
31