warp-lang 1.0.0b5__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/conf.py +3 -4
- examples/env/env_ant.py +1 -1
- examples/env/env_cartpole.py +1 -1
- examples/env/env_humanoid.py +1 -1
- examples/example_dem.py +28 -26
- examples/example_diffray.py +37 -30
- examples/example_fluid.py +7 -3
- examples/example_jacobian_ik.py +1 -1
- examples/example_mesh_intersect.py +10 -7
- examples/example_nvdb.py +3 -3
- examples/example_render_opengl.py +19 -10
- examples/example_sim_cartpole.py +9 -5
- examples/example_sim_cloth.py +29 -25
- examples/example_sim_fk_grad.py +2 -2
- examples/example_sim_fk_grad_torch.py +3 -3
- examples/example_sim_grad_bounce.py +11 -8
- examples/example_sim_grad_cloth.py +12 -9
- examples/example_sim_granular.py +2 -2
- examples/example_sim_granular_collision_sdf.py +13 -13
- examples/example_sim_neo_hookean.py +3 -3
- examples/example_sim_particle_chain.py +2 -2
- examples/example_sim_quadruped.py +8 -5
- examples/example_sim_rigid_chain.py +8 -5
- examples/example_sim_rigid_contact.py +13 -10
- examples/example_sim_rigid_fem.py +2 -2
- examples/example_sim_rigid_gyroscopic.py +2 -2
- examples/example_sim_rigid_kinematics.py +1 -1
- examples/example_sim_trajopt.py +3 -2
- examples/fem/example_apic_fluid.py +5 -7
- examples/fem/example_diffusion_mgpu.py +18 -16
- warp/__init__.py +3 -2
- warp/bin/warp.so +0 -0
- warp/build_dll.py +29 -9
- warp/builtins.py +206 -7
- warp/codegen.py +58 -38
- warp/config.py +3 -1
- warp/context.py +234 -128
- warp/fem/__init__.py +2 -2
- warp/fem/cache.py +2 -1
- warp/fem/field/nodal_field.py +18 -17
- warp/fem/geometry/hexmesh.py +11 -6
- warp/fem/geometry/quadmesh_2d.py +16 -12
- warp/fem/geometry/tetmesh.py +19 -8
- warp/fem/geometry/trimesh_2d.py +18 -7
- warp/fem/integrate.py +341 -196
- warp/fem/quadrature/__init__.py +1 -1
- warp/fem/quadrature/pic_quadrature.py +138 -53
- warp/fem/quadrature/quadrature.py +81 -9
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_space.py +169 -51
- warp/fem/space/grid_2d_function_space.py +2 -2
- warp/fem/space/grid_3d_function_space.py +2 -2
- warp/fem/space/hexmesh_function_space.py +2 -2
- warp/fem/space/partition.py +9 -6
- warp/fem/space/quadmesh_2d_function_space.py +2 -2
- warp/fem/space/shape/cube_shape_function.py +27 -15
- warp/fem/space/shape/square_shape_function.py +29 -18
- warp/fem/space/tetmesh_function_space.py +2 -2
- warp/fem/space/topology.py +10 -0
- warp/fem/space/trimesh_2d_function_space.py +2 -2
- warp/fem/utils.py +10 -5
- warp/native/array.h +49 -8
- warp/native/builtin.h +31 -14
- warp/native/cuda_util.cpp +8 -3
- warp/native/cuda_util.h +1 -0
- warp/native/exports.h +1177 -1108
- warp/native/intersect.h +4 -4
- warp/native/intersect_adj.h +8 -8
- warp/native/mat.h +65 -6
- warp/native/mesh.h +126 -5
- warp/native/quat.h +28 -4
- warp/native/vec.h +76 -14
- warp/native/warp.cu +1 -6
- warp/render/render_opengl.py +261 -109
- warp/sim/import_mjcf.py +13 -7
- warp/sim/import_urdf.py +14 -14
- warp/sim/inertia.py +17 -18
- warp/sim/model.py +67 -67
- warp/sim/render.py +1 -1
- warp/sparse.py +6 -6
- warp/stubs.py +19 -81
- warp/tape.py +1 -1
- warp/tests/__main__.py +3 -6
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/{test_kinematics.py → disabled_kinematics.py} +10 -12
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +102 -106
- warp/tests/test_arithmetic.py +39 -40
- warp/tests/test_array.py +46 -48
- warp/tests/test_array_reduce.py +25 -19
- warp/tests/test_atomic.py +62 -26
- warp/tests/test_bool.py +16 -11
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +9 -12
- warp/tests/test_closest_point_edge_edge.py +53 -57
- warp/tests/test_codegen.py +164 -134
- warp/tests/test_compile_consts.py +13 -19
- warp/tests/test_conditional.py +30 -32
- warp/tests/test_copy.py +9 -12
- warp/tests/test_ctypes.py +90 -98
- warp/tests/test_dense.py +20 -14
- warp/tests/test_devices.py +34 -35
- warp/tests/test_dlpack.py +74 -75
- warp/tests/test_examples.py +215 -97
- warp/tests/test_fabricarray.py +15 -21
- warp/tests/test_fast_math.py +14 -11
- warp/tests/test_fem.py +280 -97
- warp/tests/test_fp16.py +19 -15
- warp/tests/test_func.py +177 -194
- warp/tests/test_generics.py +71 -77
- warp/tests/test_grad.py +83 -32
- warp/tests/test_grad_customs.py +7 -9
- warp/tests/test_hash_grid.py +6 -10
- warp/tests/test_import.py +9 -23
- warp/tests/test_indexedarray.py +19 -21
- warp/tests/test_intersect.py +15 -9
- warp/tests/test_large.py +17 -19
- warp/tests/test_launch.py +14 -17
- warp/tests/test_lerp.py +63 -63
- warp/tests/test_lvalue.py +84 -35
- warp/tests/test_marching_cubes.py +9 -13
- warp/tests/test_mat.py +388 -3004
- warp/tests/test_mat_lite.py +9 -12
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +10 -11
- warp/tests/test_matmul.py +104 -100
- warp/tests/test_matmul_lite.py +72 -98
- warp/tests/test_mesh.py +35 -32
- warp/tests/test_mesh_query_aabb.py +18 -25
- warp/tests/test_mesh_query_point.py +39 -23
- warp/tests/test_mesh_query_ray.py +9 -21
- warp/tests/test_mlp.py +8 -9
- warp/tests/test_model.py +89 -93
- warp/tests/test_modules_lite.py +15 -25
- warp/tests/test_multigpu.py +87 -114
- warp/tests/test_noise.py +10 -12
- warp/tests/test_operators.py +14 -21
- warp/tests/test_options.py +10 -11
- warp/tests/test_pinned.py +16 -18
- warp/tests/test_print.py +16 -20
- warp/tests/test_quat.py +121 -88
- warp/tests/test_rand.py +12 -13
- warp/tests/test_reload.py +27 -32
- warp/tests/test_rounding.py +7 -10
- warp/tests/test_runlength_encode.py +105 -106
- warp/tests/test_smoothstep.py +8 -9
- warp/tests/test_snippet.py +13 -22
- warp/tests/test_sparse.py +30 -29
- warp/tests/test_spatial.py +179 -174
- warp/tests/test_streams.py +100 -107
- warp/tests/test_struct.py +98 -67
- warp/tests/test_tape.py +11 -17
- warp/tests/test_torch.py +89 -86
- warp/tests/test_transient_module.py +9 -12
- warp/tests/test_types.py +328 -50
- warp/tests/test_utils.py +217 -218
- warp/tests/test_vec.py +133 -2133
- warp/tests/test_vec_lite.py +8 -11
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +391 -382
- warp/tests/test_volume_write.py +122 -135
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +291 -0
- warp/tests/{test_base.py → unittest_utils.py} +138 -25
- warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
- warp/tests/{test_debug.py → walkthough_debug.py} +2 -15
- warp/thirdparty/unittest_parallel.py +257 -54
- warp/types.py +119 -98
- warp/utils.py +14 -0
- {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/METADATA +2 -1
- {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/RECORD +182 -178
- {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
- warp/tests/test_all.py +0 -239
- warp/tests/test_conditional_unequal_types_kernels.py +0 -14
- warp/tests/test_coverage.py +0 -38
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/types.py
CHANGED
|
@@ -178,13 +178,13 @@ def vector(length, dtype):
|
|
|
178
178
|
return warp.add(self, y)
|
|
179
179
|
|
|
180
180
|
def __radd__(self, y):
|
|
181
|
-
return warp.add(
|
|
181
|
+
return warp.add(y, self)
|
|
182
182
|
|
|
183
183
|
def __sub__(self, y):
|
|
184
184
|
return warp.sub(self, y)
|
|
185
185
|
|
|
186
|
-
def __rsub__(self,
|
|
187
|
-
return warp.sub(
|
|
186
|
+
def __rsub__(self, y):
|
|
187
|
+
return warp.sub(y, self)
|
|
188
188
|
|
|
189
189
|
def __mul__(self, y):
|
|
190
190
|
return warp.mul(self, y)
|
|
@@ -195,7 +195,7 @@ def vector(length, dtype):
|
|
|
195
195
|
def __truediv__(self, y):
|
|
196
196
|
return warp.div(self, y)
|
|
197
197
|
|
|
198
|
-
def
|
|
198
|
+
def __rtruediv__(self, x):
|
|
199
199
|
return warp.div(x, self)
|
|
200
200
|
|
|
201
201
|
def __pos__(self):
|
|
@@ -294,13 +294,13 @@ def matrix(shape, dtype):
|
|
|
294
294
|
return warp.add(self, y)
|
|
295
295
|
|
|
296
296
|
def __radd__(self, y):
|
|
297
|
-
return warp.add(
|
|
297
|
+
return warp.add(y, self)
|
|
298
298
|
|
|
299
299
|
def __sub__(self, y):
|
|
300
300
|
return warp.sub(self, y)
|
|
301
301
|
|
|
302
|
-
def __rsub__(self,
|
|
303
|
-
return warp.sub(
|
|
302
|
+
def __rsub__(self, y):
|
|
303
|
+
return warp.sub(y, self)
|
|
304
304
|
|
|
305
305
|
def __mul__(self, y):
|
|
306
306
|
return warp.mul(self, y)
|
|
@@ -317,7 +317,7 @@ def matrix(shape, dtype):
|
|
|
317
317
|
def __truediv__(self, y):
|
|
318
318
|
return warp.div(self, y)
|
|
319
319
|
|
|
320
|
-
def
|
|
320
|
+
def __rtruediv__(self, x):
|
|
321
321
|
return warp.div(x, self)
|
|
322
322
|
|
|
323
323
|
def __pos__(self):
|
|
@@ -582,11 +582,11 @@ def transformation(dtype=Any):
|
|
|
582
582
|
|
|
583
583
|
@property
|
|
584
584
|
def p(self):
|
|
585
|
-
return self[0:3]
|
|
585
|
+
return vec3(self[0:3])
|
|
586
586
|
|
|
587
587
|
@property
|
|
588
588
|
def q(self):
|
|
589
|
-
return self[3:7]
|
|
589
|
+
return quat(self[3:7])
|
|
590
590
|
|
|
591
591
|
return transform_t
|
|
592
592
|
|
|
@@ -910,18 +910,21 @@ class range_t:
|
|
|
910
910
|
|
|
911
911
|
# definition just for kernel type (cannot be a parameter), see bvh.h
|
|
912
912
|
class bvh_query_t:
|
|
913
|
+
"""Object used to track state during BVH traversal."""
|
|
913
914
|
def __init__(self):
|
|
914
915
|
pass
|
|
915
916
|
|
|
916
917
|
|
|
917
918
|
# definition just for kernel type (cannot be a parameter), see mesh.h
|
|
918
919
|
class mesh_query_aabb_t:
|
|
920
|
+
"""Object used to track state during mesh traversal."""
|
|
919
921
|
def __init__(self):
|
|
920
922
|
pass
|
|
921
923
|
|
|
922
924
|
|
|
923
925
|
# definition just for kernel type (cannot be a parameter), see hash_grid.h
|
|
924
926
|
class hash_grid_query_t:
|
|
927
|
+
"""Object used to track state during neighbor traversal."""
|
|
925
928
|
def __init__(self):
|
|
926
929
|
pass
|
|
927
930
|
|
|
@@ -2979,6 +2982,67 @@ class Volume:
|
|
|
2979
2982
|
return volume
|
|
2980
2983
|
|
|
2981
2984
|
|
|
2985
|
+
# definition just for kernel type (cannot be a parameter), see mesh.h
|
|
2986
|
+
# NOTE: its layout must match the corresponding struct defined in C.
|
|
2987
|
+
# NOTE: it needs to be defined after `indexedarray` to workaround a circular import issue.
|
|
2988
|
+
class mesh_query_point_t:
|
|
2989
|
+
"""Output for the mesh query point functions.
|
|
2990
|
+
|
|
2991
|
+
Attributes:
|
|
2992
|
+
result (bool): Whether a point is found within the given constraints.
|
|
2993
|
+
sign (float32): A value < 0 if query point is inside the mesh, >=0 otherwise.
|
|
2994
|
+
Note that mesh must be watertight for this to be robust
|
|
2995
|
+
face (int32): Index of the closest face.
|
|
2996
|
+
u (float32): Barycentric u coordinate of the closest point.
|
|
2997
|
+
v (float32): Barycentric v coordinate of the closest point.
|
|
2998
|
+
|
|
2999
|
+
See Also:
|
|
3000
|
+
:func:`mesh_query_point`, :func:`mesh_query_point_no_sign`,
|
|
3001
|
+
:func:`mesh_query_furthest_point_no_sign`,
|
|
3002
|
+
:func:`mesh_query_point_sign_normal`,
|
|
3003
|
+
and :func:`mesh_query_point_sign_winding_number`.
|
|
3004
|
+
"""
|
|
3005
|
+
from warp.codegen import Var
|
|
3006
|
+
|
|
3007
|
+
vars = {
|
|
3008
|
+
"result": Var("result", bool),
|
|
3009
|
+
"sign": Var("sign", float32),
|
|
3010
|
+
"face": Var("face", int32),
|
|
3011
|
+
"u": Var("u", float32),
|
|
3012
|
+
"v": Var("v", float32),
|
|
3013
|
+
}
|
|
3014
|
+
|
|
3015
|
+
|
|
3016
|
+
# definition just for kernel type (cannot be a parameter), see mesh.h
|
|
3017
|
+
# NOTE: its layout must match the corresponding struct defined in C.
|
|
3018
|
+
class mesh_query_ray_t:
|
|
3019
|
+
"""Output for the mesh query ray functions.
|
|
3020
|
+
|
|
3021
|
+
Attributes:
|
|
3022
|
+
result (bool): Whether a hit is found within the given constraints.
|
|
3023
|
+
sign (float32): A value > 0 if the ray hit in front of the face, returns < 0 otherwise.
|
|
3024
|
+
face (int32): Index of the closest face.
|
|
3025
|
+
t (float32): Distance of the closest hit along the ray.
|
|
3026
|
+
u (float32): Barycentric u coordinate of the closest hit.
|
|
3027
|
+
v (float32): Barycentric v coordinate of the closest hit.
|
|
3028
|
+
normal (vec3f): Face normal.
|
|
3029
|
+
|
|
3030
|
+
See Also:
|
|
3031
|
+
:func:`mesh_query_ray`.
|
|
3032
|
+
"""
|
|
3033
|
+
from warp.codegen import Var
|
|
3034
|
+
|
|
3035
|
+
vars = {
|
|
3036
|
+
"result": Var("result", bool),
|
|
3037
|
+
"sign": Var("sign", float32),
|
|
3038
|
+
"face": Var("face", int32),
|
|
3039
|
+
"t": Var("t", float32),
|
|
3040
|
+
"u": Var("u", float32),
|
|
3041
|
+
"v": Var("v", float32),
|
|
3042
|
+
"normal": Var("normal", vec3),
|
|
3043
|
+
}
|
|
3044
|
+
|
|
3045
|
+
|
|
2982
3046
|
def matmul(
|
|
2983
3047
|
a: array2d,
|
|
2984
3048
|
b: array2d,
|
|
@@ -3157,9 +3221,9 @@ def adj_matmul(
|
|
|
3157
3221
|
|
|
3158
3222
|
# cpu fallback if no cuda devices found
|
|
3159
3223
|
if device == "cpu":
|
|
3160
|
-
adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose()))
|
|
3161
|
-
adj_b.assign(alpha * (a.numpy().transpose() @ adj_d.numpy()))
|
|
3162
|
-
adj_c.assign(beta * adj_d.numpy())
|
|
3224
|
+
adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose()) + adj_a.numpy())
|
|
3225
|
+
adj_b.assign(alpha * (a.numpy().transpose() @ adj_d.numpy()) + adj_b.numpy())
|
|
3226
|
+
adj_c.assign(beta * adj_d.numpy() + adj_c.numpy())
|
|
3163
3227
|
return
|
|
3164
3228
|
|
|
3165
3229
|
cc = device.arch
|
|
@@ -3174,10 +3238,10 @@ def adj_matmul(
|
|
|
3174
3238
|
type_typestr(a.dtype).encode(),
|
|
3175
3239
|
ctypes.c_void_p(adj_d.ptr),
|
|
3176
3240
|
ctypes.c_void_p(b.ptr),
|
|
3177
|
-
ctypes.c_void_p(
|
|
3241
|
+
ctypes.c_void_p(adj_a.ptr),
|
|
3178
3242
|
ctypes.c_void_p(adj_a.ptr),
|
|
3179
3243
|
alpha,
|
|
3180
|
-
|
|
3244
|
+
1.0,
|
|
3181
3245
|
True,
|
|
3182
3246
|
b.is_transposed,
|
|
3183
3247
|
allow_tf32x3_arith,
|
|
@@ -3194,10 +3258,10 @@ def adj_matmul(
|
|
|
3194
3258
|
type_typestr(a.dtype).encode(),
|
|
3195
3259
|
ctypes.c_void_p(b.ptr),
|
|
3196
3260
|
ctypes.c_void_p(adj_d.ptr),
|
|
3197
|
-
ctypes.c_void_p(
|
|
3261
|
+
ctypes.c_void_p(adj_a.ptr),
|
|
3198
3262
|
ctypes.c_void_p(adj_a.ptr),
|
|
3199
3263
|
alpha,
|
|
3200
|
-
|
|
3264
|
+
1.0,
|
|
3201
3265
|
not b.is_transposed,
|
|
3202
3266
|
False,
|
|
3203
3267
|
allow_tf32x3_arith,
|
|
@@ -3216,10 +3280,10 @@ def adj_matmul(
|
|
|
3216
3280
|
type_typestr(a.dtype).encode(),
|
|
3217
3281
|
ctypes.c_void_p(a.ptr),
|
|
3218
3282
|
ctypes.c_void_p(adj_d.ptr),
|
|
3219
|
-
ctypes.c_void_p(
|
|
3283
|
+
ctypes.c_void_p(adj_b.ptr),
|
|
3220
3284
|
ctypes.c_void_p(adj_b.ptr),
|
|
3221
3285
|
alpha,
|
|
3222
|
-
|
|
3286
|
+
1.0,
|
|
3223
3287
|
a.is_transposed,
|
|
3224
3288
|
True,
|
|
3225
3289
|
allow_tf32x3_arith,
|
|
@@ -3236,10 +3300,10 @@ def adj_matmul(
|
|
|
3236
3300
|
type_typestr(a.dtype).encode(),
|
|
3237
3301
|
ctypes.c_void_p(adj_d.ptr),
|
|
3238
3302
|
ctypes.c_void_p(a.ptr),
|
|
3239
|
-
ctypes.c_void_p(
|
|
3303
|
+
ctypes.c_void_p(adj_b.ptr),
|
|
3240
3304
|
ctypes.c_void_p(adj_b.ptr),
|
|
3241
3305
|
alpha,
|
|
3242
|
-
|
|
3306
|
+
1.0,
|
|
3243
3307
|
False,
|
|
3244
3308
|
not a.is_transposed,
|
|
3245
3309
|
allow_tf32x3_arith,
|
|
@@ -3249,25 +3313,13 @@ def adj_matmul(
|
|
|
3249
3313
|
raise RuntimeError("adj_matmul failed.")
|
|
3250
3314
|
|
|
3251
3315
|
# adj_c
|
|
3252
|
-
|
|
3253
|
-
|
|
3254
|
-
|
|
3255
|
-
|
|
3256
|
-
|
|
3257
|
-
|
|
3258
|
-
ctypes.c_void_p(a.ptr),
|
|
3259
|
-
ctypes.c_void_p(b.ptr),
|
|
3260
|
-
ctypes.c_void_p(adj_d.ptr),
|
|
3261
|
-
ctypes.c_void_p(adj_c.ptr),
|
|
3262
|
-
0.0,
|
|
3263
|
-
beta,
|
|
3264
|
-
not a.is_transposed,
|
|
3265
|
-
not b.is_transposed,
|
|
3266
|
-
allow_tf32x3_arith,
|
|
3267
|
-
1,
|
|
3316
|
+
warp.launch(
|
|
3317
|
+
kernel=warp.utils.add_kernel_2d,
|
|
3318
|
+
dim=adj_c.shape,
|
|
3319
|
+
inputs=[adj_c, adj_d, adj_d.dtype(beta)],
|
|
3320
|
+
device=device,
|
|
3321
|
+
record_tape=False
|
|
3268
3322
|
)
|
|
3269
|
-
if not ret:
|
|
3270
|
-
raise RuntimeError("adj_matmul failed.")
|
|
3271
3323
|
|
|
3272
3324
|
|
|
3273
3325
|
def batched_matmul(
|
|
@@ -3476,9 +3528,9 @@ def adj_batched_matmul(
|
|
|
3476
3528
|
|
|
3477
3529
|
# cpu fallback if no cuda devices found
|
|
3478
3530
|
if device == "cpu":
|
|
3479
|
-
adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose((0, 2, 1))))
|
|
3480
|
-
adj_b.assign(alpha * np.matmul(a.numpy().transpose((0, 2, 1)), adj_d.numpy()))
|
|
3481
|
-
adj_c.assign(beta * adj_d.numpy())
|
|
3531
|
+
adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose((0, 2, 1))) + adj_a.numpy())
|
|
3532
|
+
adj_b.assign(alpha * np.matmul(a.numpy().transpose((0, 2, 1)), adj_d.numpy()) + adj_b.numpy())
|
|
3533
|
+
adj_c.assign(beta * adj_d.numpy() + adj_c.numpy())
|
|
3482
3534
|
return
|
|
3483
3535
|
|
|
3484
3536
|
# handle case in which batch_count exceeds max_batch_count, which is a CUDA array size maximum
|
|
@@ -3502,10 +3554,10 @@ def adj_batched_matmul(
|
|
|
3502
3554
|
type_typestr(a.dtype).encode(),
|
|
3503
3555
|
ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
|
|
3504
3556
|
ctypes.c_void_p(b[idx_start:idx_end,:,:].ptr),
|
|
3505
|
-
ctypes.c_void_p(
|
|
3557
|
+
ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
|
|
3506
3558
|
ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
|
|
3507
3559
|
alpha,
|
|
3508
|
-
|
|
3560
|
+
1.0,
|
|
3509
3561
|
True,
|
|
3510
3562
|
b.is_transposed,
|
|
3511
3563
|
allow_tf32x3_arith,
|
|
@@ -3522,10 +3574,10 @@ def adj_batched_matmul(
|
|
|
3522
3574
|
type_typestr(a.dtype).encode(),
|
|
3523
3575
|
ctypes.c_void_p(b[idx_start:idx_end,:,:].ptr),
|
|
3524
3576
|
ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
|
|
3525
|
-
ctypes.c_void_p(
|
|
3577
|
+
ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
|
|
3526
3578
|
ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
|
|
3527
3579
|
alpha,
|
|
3528
|
-
|
|
3580
|
+
1.0,
|
|
3529
3581
|
not b.is_transposed,
|
|
3530
3582
|
False,
|
|
3531
3583
|
allow_tf32x3_arith,
|
|
@@ -3544,10 +3596,10 @@ def adj_batched_matmul(
|
|
|
3544
3596
|
type_typestr(a.dtype).encode(),
|
|
3545
3597
|
ctypes.c_void_p(a[idx_start:idx_end,:,:].ptr),
|
|
3546
3598
|
ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
|
|
3547
|
-
ctypes.c_void_p(
|
|
3599
|
+
ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
|
|
3548
3600
|
ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
|
|
3549
3601
|
alpha,
|
|
3550
|
-
|
|
3602
|
+
1.0,
|
|
3551
3603
|
a.is_transposed,
|
|
3552
3604
|
True,
|
|
3553
3605
|
allow_tf32x3_arith,
|
|
@@ -3564,10 +3616,10 @@ def adj_batched_matmul(
|
|
|
3564
3616
|
type_typestr(a.dtype).encode(),
|
|
3565
3617
|
ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
|
|
3566
3618
|
ctypes.c_void_p(a[idx_start:idx_end,:,:].ptr),
|
|
3567
|
-
ctypes.c_void_p(
|
|
3619
|
+
ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
|
|
3568
3620
|
ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
|
|
3569
3621
|
alpha,
|
|
3570
|
-
|
|
3622
|
+
1.0,
|
|
3571
3623
|
False,
|
|
3572
3624
|
not a.is_transposed,
|
|
3573
3625
|
allow_tf32x3_arith,
|
|
@@ -3575,27 +3627,6 @@ def adj_batched_matmul(
|
|
|
3575
3627
|
)
|
|
3576
3628
|
if not ret:
|
|
3577
3629
|
raise RuntimeError("adj_matmul failed.")
|
|
3578
|
-
|
|
3579
|
-
# adj_c
|
|
3580
|
-
ret = runtime.core.cutlass_gemm(
|
|
3581
|
-
cc,
|
|
3582
|
-
m,
|
|
3583
|
-
n,
|
|
3584
|
-
k,
|
|
3585
|
-
type_typestr(a.dtype).encode(),
|
|
3586
|
-
ctypes.c_void_p(a[idx_start:idx_end,:,:].ptr),
|
|
3587
|
-
ctypes.c_void_p(b[idx_start:idx_end,:,:].ptr),
|
|
3588
|
-
ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
|
|
3589
|
-
ctypes.c_void_p(adj_c[idx_start:idx_end,:,:].ptr),
|
|
3590
|
-
0.0,
|
|
3591
|
-
beta,
|
|
3592
|
-
not a.is_transposed,
|
|
3593
|
-
not b.is_transposed,
|
|
3594
|
-
allow_tf32x3_arith,
|
|
3595
|
-
max_batch_count,
|
|
3596
|
-
)
|
|
3597
|
-
if not ret:
|
|
3598
|
-
raise RuntimeError("adj_batched_matmul failed.")
|
|
3599
3630
|
|
|
3600
3631
|
idx_start = iters * max_batch_count
|
|
3601
3632
|
|
|
@@ -3609,10 +3640,10 @@ def adj_batched_matmul(
|
|
|
3609
3640
|
type_typestr(a.dtype).encode(),
|
|
3610
3641
|
ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
|
|
3611
3642
|
ctypes.c_void_p(b[idx_start:,:,:].ptr),
|
|
3612
|
-
ctypes.c_void_p(
|
|
3643
|
+
ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
|
|
3613
3644
|
ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
|
|
3614
3645
|
alpha,
|
|
3615
|
-
|
|
3646
|
+
1.0,
|
|
3616
3647
|
True,
|
|
3617
3648
|
b.is_transposed,
|
|
3618
3649
|
allow_tf32x3_arith,
|
|
@@ -3629,10 +3660,10 @@ def adj_batched_matmul(
|
|
|
3629
3660
|
type_typestr(a.dtype).encode(),
|
|
3630
3661
|
ctypes.c_void_p(b[idx_start:,:,:].ptr),
|
|
3631
3662
|
ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
|
|
3632
|
-
ctypes.c_void_p(
|
|
3663
|
+
ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
|
|
3633
3664
|
ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
|
|
3634
3665
|
alpha,
|
|
3635
|
-
|
|
3666
|
+
1.0,
|
|
3636
3667
|
not b.is_transposed,
|
|
3637
3668
|
False,
|
|
3638
3669
|
allow_tf32x3_arith,
|
|
@@ -3651,10 +3682,10 @@ def adj_batched_matmul(
|
|
|
3651
3682
|
type_typestr(a.dtype).encode(),
|
|
3652
3683
|
ctypes.c_void_p(a[idx_start:,:,:].ptr),
|
|
3653
3684
|
ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
|
|
3654
|
-
ctypes.c_void_p(
|
|
3685
|
+
ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
|
|
3655
3686
|
ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
|
|
3656
3687
|
alpha,
|
|
3657
|
-
|
|
3688
|
+
1.0,
|
|
3658
3689
|
a.is_transposed,
|
|
3659
3690
|
True,
|
|
3660
3691
|
allow_tf32x3_arith,
|
|
@@ -3671,10 +3702,10 @@ def adj_batched_matmul(
|
|
|
3671
3702
|
type_typestr(a.dtype).encode(),
|
|
3672
3703
|
ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
|
|
3673
3704
|
ctypes.c_void_p(a[idx_start:,:,:].ptr),
|
|
3674
|
-
ctypes.c_void_p(
|
|
3705
|
+
ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
|
|
3675
3706
|
ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
|
|
3676
3707
|
alpha,
|
|
3677
|
-
|
|
3708
|
+
1.0,
|
|
3678
3709
|
False,
|
|
3679
3710
|
not a.is_transposed,
|
|
3680
3711
|
allow_tf32x3_arith,
|
|
@@ -3684,25 +3715,13 @@ def adj_batched_matmul(
|
|
|
3684
3715
|
raise RuntimeError("adj_matmul failed.")
|
|
3685
3716
|
|
|
3686
3717
|
# adj_c
|
|
3687
|
-
|
|
3688
|
-
|
|
3689
|
-
|
|
3690
|
-
|
|
3691
|
-
|
|
3692
|
-
|
|
3693
|
-
ctypes.c_void_p(a[idx_start:,:,:].ptr),
|
|
3694
|
-
ctypes.c_void_p(b[idx_start:,:,:].ptr),
|
|
3695
|
-
ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
|
|
3696
|
-
ctypes.c_void_p(adj_c[idx_start:,:,:].ptr),
|
|
3697
|
-
0.0,
|
|
3698
|
-
beta,
|
|
3699
|
-
not a.is_transposed,
|
|
3700
|
-
not b.is_transposed,
|
|
3701
|
-
allow_tf32x3_arith,
|
|
3702
|
-
remainder,
|
|
3718
|
+
warp.launch(
|
|
3719
|
+
kernel=warp.utils.add_kernel_3d,
|
|
3720
|
+
dim=adj_c.shape,
|
|
3721
|
+
inputs=[adj_c, adj_d, adj_d.dtype(beta)],
|
|
3722
|
+
device=device,
|
|
3723
|
+
record_tape=False
|
|
3703
3724
|
)
|
|
3704
|
-
if not ret:
|
|
3705
|
-
raise RuntimeError("adj_batched_matmul failed.")
|
|
3706
3725
|
|
|
3707
3726
|
class HashGrid:
|
|
3708
3727
|
def __init__(self, dim_x, dim_y, dim_z, device=None):
|
|
@@ -3957,7 +3976,7 @@ def infer_argument_types(args, template_types, arg_names=None):
|
|
|
3957
3976
|
arg_types.append(arg._cls)
|
|
3958
3977
|
# elif arg_type in [warp.types.launch_bounds_t, warp.types.shape_t, warp.types.range_t]:
|
|
3959
3978
|
# arg_types.append(arg_type)
|
|
3960
|
-
# elif arg_type in [warp.hash_grid_query_t, warp.mesh_query_aabb_t, warp.bvh_query_t]:
|
|
3979
|
+
# elif arg_type in [warp.hash_grid_query_t, warp.mesh_query_aabb_t, warp.mesh_query_point_t, warp.mesh_query_ray_t, warp.bvh_query_t]:
|
|
3961
3980
|
# arg_types.append(arg_type)
|
|
3962
3981
|
elif arg is None:
|
|
3963
3982
|
# allow passing None for arrays
|
|
@@ -3995,6 +4014,8 @@ simple_type_codes = {
|
|
|
3995
4014
|
launch_bounds_t: "lb",
|
|
3996
4015
|
hash_grid_query_t: "hgq",
|
|
3997
4016
|
mesh_query_aabb_t: "mqa",
|
|
4017
|
+
mesh_query_point_t: "mqp",
|
|
4018
|
+
mesh_query_ray_t: "mqr",
|
|
3998
4019
|
bvh_query_t: "bvhq",
|
|
3999
4020
|
}
|
|
4000
4021
|
|
warp/utils.py
CHANGED
|
@@ -666,3 +666,17 @@ class ScopedTimer:
|
|
|
666
666
|
print("{}{} took {:.2f} ms".format(indent, self.name, self.elapsed))
|
|
667
667
|
|
|
668
668
|
ScopedTimer.indent -= 1
|
|
669
|
+
|
|
670
|
+
|
|
671
|
+
# helper kernels for adj_matmul
|
|
672
|
+
@wp.kernel
|
|
673
|
+
def add_kernel_2d(x: wp.array2d(dtype=Any), acc: wp.array2d(dtype=Any), beta: Any):
|
|
674
|
+
i, j = wp.tid()
|
|
675
|
+
|
|
676
|
+
x[i,j] = x[i,j] + beta * acc[i,j]
|
|
677
|
+
|
|
678
|
+
@wp.kernel
|
|
679
|
+
def add_kernel_3d(x: wp.array3d(dtype=Any), acc: wp.array3d(dtype=Any), beta: Any):
|
|
680
|
+
i, j, k = wp.tid()
|
|
681
|
+
|
|
682
|
+
x[i,j,k] = x[i,j,k] + beta * acc[i,j,k]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: warp-lang
|
|
3
|
-
Version: 1.0.
|
|
3
|
+
Version: 1.0.0b6
|
|
4
4
|
Summary: A Python framework for high-performance simulation and graphics programming
|
|
5
5
|
Author-email: NVIDIA <mmacklin@nvidia.com>
|
|
6
6
|
License: NVSCL
|
|
@@ -25,6 +25,7 @@ Requires-Dist: isort ; extra == 'dev'
|
|
|
25
25
|
Requires-Dist: nvtx ; extra == 'dev'
|
|
26
26
|
Requires-Dist: furo ; extra == 'dev'
|
|
27
27
|
Requires-Dist: sphinx-copybutton ; extra == 'dev'
|
|
28
|
+
Requires-Dist: coverage[toml] ; extra == 'dev'
|
|
28
29
|
|
|
29
30
|
# NVIDIA Warp
|
|
30
31
|
|