warp-lang 1.0.0b2__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docs/conf.py +17 -5
- examples/env/env_ant.py +1 -1
- examples/env/env_cartpole.py +1 -1
- examples/env/env_humanoid.py +1 -1
- examples/env/env_usd.py +4 -1
- examples/env/environment.py +8 -9
- examples/example_dem.py +34 -33
- examples/example_diffray.py +364 -337
- examples/example_fluid.py +32 -23
- examples/example_jacobian_ik.py +97 -93
- examples/example_marching_cubes.py +6 -16
- examples/example_mesh.py +6 -16
- examples/example_mesh_intersect.py +16 -14
- examples/example_nvdb.py +14 -16
- examples/example_raycast.py +14 -13
- examples/example_raymarch.py +16 -23
- examples/example_render_opengl.py +19 -10
- examples/example_sim_cartpole.py +82 -78
- examples/example_sim_cloth.py +45 -48
- examples/example_sim_fk_grad.py +51 -44
- examples/example_sim_fk_grad_torch.py +47 -40
- examples/example_sim_grad_bounce.py +108 -133
- examples/example_sim_grad_cloth.py +99 -113
- examples/example_sim_granular.py +5 -6
- examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
- examples/example_sim_neo_hookean.py +51 -55
- examples/example_sim_particle_chain.py +4 -4
- examples/example_sim_quadruped.py +126 -81
- examples/example_sim_rigid_chain.py +54 -61
- examples/example_sim_rigid_contact.py +66 -70
- examples/example_sim_rigid_fem.py +3 -3
- examples/example_sim_rigid_force.py +1 -1
- examples/example_sim_rigid_gyroscopic.py +3 -4
- examples/example_sim_rigid_kinematics.py +28 -39
- examples/example_sim_trajopt.py +112 -110
- examples/example_sph.py +9 -8
- examples/example_wave.py +7 -7
- examples/fem/bsr_utils.py +30 -17
- examples/fem/example_apic_fluid.py +85 -69
- examples/fem/example_convection_diffusion.py +97 -93
- examples/fem/example_convection_diffusion_dg.py +142 -149
- examples/fem/example_convection_diffusion_dg0.py +141 -136
- examples/fem/example_deformed_geometry.py +146 -0
- examples/fem/example_diffusion.py +115 -84
- examples/fem/example_diffusion_3d.py +116 -86
- examples/fem/example_diffusion_mgpu.py +102 -79
- examples/fem/example_mixed_elasticity.py +139 -100
- examples/fem/example_navier_stokes.py +175 -162
- examples/fem/example_stokes.py +143 -111
- examples/fem/example_stokes_transfer.py +186 -157
- examples/fem/mesh_utils.py +59 -97
- examples/fem/plot_utils.py +138 -17
- tools/ci/publishing/build_nodes_info.py +54 -0
- warp/__init__.py +4 -3
- warp/__init__.pyi +1 -0
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +5 -3
- warp/build_dll.py +29 -9
- warp/builtins.py +836 -492
- warp/codegen.py +864 -553
- warp/config.py +3 -1
- warp/context.py +389 -172
- warp/fem/__init__.py +24 -6
- warp/fem/cache.py +318 -25
- warp/fem/dirichlet.py +7 -3
- warp/fem/domain.py +14 -0
- warp/fem/field/__init__.py +30 -38
- warp/fem/field/field.py +149 -0
- warp/fem/field/nodal_field.py +244 -138
- warp/fem/field/restriction.py +8 -6
- warp/fem/field/test.py +127 -59
- warp/fem/field/trial.py +117 -60
- warp/fem/geometry/__init__.py +5 -1
- warp/fem/geometry/deformed_geometry.py +271 -0
- warp/fem/geometry/element.py +24 -1
- warp/fem/geometry/geometry.py +86 -14
- warp/fem/geometry/grid_2d.py +112 -54
- warp/fem/geometry/grid_3d.py +134 -65
- warp/fem/geometry/hexmesh.py +953 -0
- warp/fem/geometry/partition.py +85 -33
- warp/fem/geometry/quadmesh_2d.py +532 -0
- warp/fem/geometry/tetmesh.py +451 -115
- warp/fem/geometry/trimesh_2d.py +197 -92
- warp/fem/integrate.py +534 -268
- warp/fem/operator.py +58 -31
- warp/fem/polynomial.py +11 -0
- warp/fem/quadrature/__init__.py +1 -1
- warp/fem/quadrature/pic_quadrature.py +150 -58
- warp/fem/quadrature/quadrature.py +209 -57
- warp/fem/space/__init__.py +230 -53
- warp/fem/space/basis_space.py +489 -0
- warp/fem/space/collocated_function_space.py +105 -0
- warp/fem/space/dof_mapper.py +49 -2
- warp/fem/space/function_space.py +90 -39
- warp/fem/space/grid_2d_function_space.py +149 -496
- warp/fem/space/grid_3d_function_space.py +173 -538
- warp/fem/space/hexmesh_function_space.py +352 -0
- warp/fem/space/partition.py +129 -76
- warp/fem/space/quadmesh_2d_function_space.py +369 -0
- warp/fem/space/restriction.py +46 -34
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +738 -0
- warp/fem/space/shape/shape_function.py +103 -0
- warp/fem/space/shape/square_shape_function.py +611 -0
- warp/fem/space/shape/tet_shape_function.py +567 -0
- warp/fem/space/shape/triangle_shape_function.py +429 -0
- warp/fem/space/tetmesh_function_space.py +132 -1039
- warp/fem/space/topology.py +295 -0
- warp/fem/space/trimesh_2d_function_space.py +104 -742
- warp/fem/types.py +13 -11
- warp/fem/utils.py +335 -60
- warp/native/array.h +120 -34
- warp/native/builtin.h +101 -72
- warp/native/bvh.cpp +73 -325
- warp/native/bvh.cu +406 -23
- warp/native/bvh.h +22 -40
- warp/native/clang/clang.cpp +1 -0
- warp/native/crt.h +2 -0
- warp/native/cuda_util.cpp +8 -3
- warp/native/cuda_util.h +1 -0
- warp/native/exports.h +1522 -1243
- warp/native/intersect.h +19 -4
- warp/native/intersect_adj.h +8 -8
- warp/native/mat.h +76 -17
- warp/native/mesh.cpp +33 -108
- warp/native/mesh.cu +114 -18
- warp/native/mesh.h +395 -40
- warp/native/noise.h +272 -329
- warp/native/quat.h +51 -8
- warp/native/rand.h +44 -34
- warp/native/reduce.cpp +1 -1
- warp/native/sparse.cpp +4 -4
- warp/native/sparse.cu +163 -155
- warp/native/spatial.h +2 -2
- warp/native/temp_buffer.h +18 -14
- warp/native/vec.h +103 -21
- warp/native/warp.cpp +2 -1
- warp/native/warp.cu +28 -3
- warp/native/warp.h +4 -3
- warp/render/render_opengl.py +261 -109
- warp/sim/__init__.py +1 -2
- warp/sim/articulation.py +385 -185
- warp/sim/import_mjcf.py +59 -48
- warp/sim/import_urdf.py +15 -15
- warp/sim/import_usd.py +174 -102
- warp/sim/inertia.py +17 -18
- warp/sim/integrator_xpbd.py +4 -3
- warp/sim/model.py +330 -250
- warp/sim/render.py +1 -1
- warp/sparse.py +625 -152
- warp/stubs.py +341 -309
- warp/tape.py +9 -6
- warp/tests/__main__.py +3 -6
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/disabled_kinematics.py +239 -0
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +103 -106
- warp/tests/test_arithmetic.py +94 -74
- warp/tests/test_array.py +82 -101
- warp/tests/test_array_reduce.py +57 -23
- warp/tests/test_atomic.py +64 -28
- warp/tests/test_bool.py +22 -12
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +18 -18
- warp/tests/test_closest_point_edge_edge.py +54 -57
- warp/tests/test_codegen.py +165 -134
- warp/tests/test_compile_consts.py +28 -20
- warp/tests/test_conditional.py +108 -24
- warp/tests/test_copy.py +10 -12
- warp/tests/test_ctypes.py +112 -88
- warp/tests/test_dense.py +21 -14
- warp/tests/test_devices.py +98 -0
- warp/tests/test_dlpack.py +75 -75
- warp/tests/test_examples.py +237 -0
- warp/tests/test_fabricarray.py +22 -24
- warp/tests/test_fast_math.py +15 -11
- warp/tests/test_fem.py +1034 -124
- warp/tests/test_fp16.py +23 -16
- warp/tests/test_func.py +187 -86
- warp/tests/test_generics.py +194 -49
- warp/tests/test_grad.py +123 -181
- warp/tests/test_grad_customs.py +176 -0
- warp/tests/test_hash_grid.py +35 -34
- warp/tests/test_import.py +10 -23
- warp/tests/test_indexedarray.py +24 -25
- warp/tests/test_intersect.py +18 -9
- warp/tests/test_large.py +141 -0
- warp/tests/test_launch.py +14 -41
- warp/tests/test_lerp.py +64 -65
- warp/tests/test_lvalue.py +493 -0
- warp/tests/test_marching_cubes.py +12 -13
- warp/tests/test_mat.py +517 -2898
- warp/tests/test_mat_lite.py +115 -0
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +103 -9
- warp/tests/test_matmul.py +304 -69
- warp/tests/test_matmul_lite.py +410 -0
- warp/tests/test_mesh.py +60 -22
- warp/tests/test_mesh_query_aabb.py +21 -25
- warp/tests/test_mesh_query_point.py +111 -22
- warp/tests/test_mesh_query_ray.py +12 -24
- warp/tests/test_mlp.py +30 -22
- warp/tests/test_model.py +92 -89
- warp/tests/test_modules_lite.py +39 -0
- warp/tests/test_multigpu.py +88 -114
- warp/tests/test_noise.py +12 -11
- warp/tests/test_operators.py +16 -20
- warp/tests/test_options.py +11 -11
- warp/tests/test_pinned.py +17 -18
- warp/tests/test_print.py +32 -11
- warp/tests/test_quat.py +275 -129
- warp/tests/test_rand.py +18 -16
- warp/tests/test_reload.py +38 -34
- warp/tests/test_rounding.py +50 -43
- warp/tests/test_runlength_encode.py +168 -20
- warp/tests/test_smoothstep.py +9 -11
- warp/tests/test_snippet.py +143 -0
- warp/tests/test_sparse.py +261 -63
- warp/tests/test_spatial.py +276 -243
- warp/tests/test_streams.py +110 -85
- warp/tests/test_struct.py +268 -63
- warp/tests/test_tape.py +39 -21
- warp/tests/test_torch.py +90 -86
- warp/tests/test_transient_module.py +10 -12
- warp/tests/test_types.py +363 -0
- warp/tests/test_utils.py +451 -0
- warp/tests/test_vec.py +354 -2050
- warp/tests/test_vec_lite.py +73 -0
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +418 -376
- warp/tests/test_volume_write.py +124 -134
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +291 -0
- warp/tests/unittest_utils.py +342 -0
- warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
- warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
- warp/thirdparty/appdirs.py +36 -45
- warp/thirdparty/unittest_parallel.py +589 -0
- warp/types.py +622 -211
- warp/utils.py +54 -393
- warp_lang-1.0.0b6.dist-info/METADATA +238 -0
- warp_lang-1.0.0b6.dist-info/RECORD +409 -0
- {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
- examples/example_cache_management.py +0 -40
- examples/example_multigpu.py +0 -54
- examples/example_struct.py +0 -65
- examples/fem/example_stokes_transfer_3d.py +0 -210
- warp/fem/field/discrete_field.py +0 -80
- warp/fem/space/nodal_function_space.py +0 -233
- warp/tests/test_all.py +0 -223
- warp/tests/test_array_scan.py +0 -60
- warp/tests/test_base.py +0 -208
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- warp_lang-1.0.0b2.dist-info/METADATA +0 -26
- warp_lang-1.0.0b2.dist-info/RECORD +0 -378
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/types.py
CHANGED
|
@@ -5,9 +5,12 @@
|
|
|
5
5
|
# distribution of this software and related documentation without an express
|
|
6
6
|
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
7
|
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
8
10
|
import builtins
|
|
9
11
|
import ctypes
|
|
10
12
|
import hashlib
|
|
13
|
+
import inspect
|
|
11
14
|
import struct
|
|
12
15
|
import zlib
|
|
13
16
|
from typing import Any, Callable, Generic, List, Tuple, TypeVar, Union
|
|
@@ -49,12 +52,14 @@ def constant(x):
|
|
|
49
52
|
global _constant_hash
|
|
50
53
|
|
|
51
54
|
# hash the constant value
|
|
52
|
-
if isinstance(x,
|
|
55
|
+
if isinstance(x, builtins.bool):
|
|
56
|
+
# This needs to come before the check for `int` since all boolean
|
|
57
|
+
# values are also instances of `int`.
|
|
58
|
+
_constant_hash.update(struct.pack("?", x))
|
|
59
|
+
elif isinstance(x, int):
|
|
53
60
|
_constant_hash.update(struct.pack("<q", x))
|
|
54
61
|
elif isinstance(x, float):
|
|
55
62
|
_constant_hash.update(struct.pack("<d", x))
|
|
56
|
-
elif isinstance(x, builtins.bool):
|
|
57
|
-
_constant_hash.update(struct.pack("?", x))
|
|
58
63
|
elif isinstance(x, float16):
|
|
59
64
|
# float16 is a special case
|
|
60
65
|
p = ctypes.pointer(ctypes.c_float(x.value))
|
|
@@ -155,17 +160,31 @@ def vector(length, dtype):
|
|
|
155
160
|
else:
|
|
156
161
|
raise KeyError(f"Invalid key {key}, expected int or slice")
|
|
157
162
|
|
|
163
|
+
def __getattr__(self, name):
|
|
164
|
+
idx = "xyzw".find(name)
|
|
165
|
+
if idx != -1:
|
|
166
|
+
return self.__getitem__(idx)
|
|
167
|
+
|
|
168
|
+
return self.__getattribute__(name)
|
|
169
|
+
|
|
170
|
+
def __setattr__(self, name, value):
|
|
171
|
+
idx = "xyzw".find(name)
|
|
172
|
+
if idx != -1:
|
|
173
|
+
return self.__setitem__(idx, value)
|
|
174
|
+
|
|
175
|
+
return super().__setattr__(name, value)
|
|
176
|
+
|
|
158
177
|
def __add__(self, y):
|
|
159
178
|
return warp.add(self, y)
|
|
160
179
|
|
|
161
180
|
def __radd__(self, y):
|
|
162
|
-
return warp.add(
|
|
181
|
+
return warp.add(y, self)
|
|
163
182
|
|
|
164
183
|
def __sub__(self, y):
|
|
165
184
|
return warp.sub(self, y)
|
|
166
185
|
|
|
167
|
-
def __rsub__(self,
|
|
168
|
-
return warp.sub(
|
|
186
|
+
def __rsub__(self, y):
|
|
187
|
+
return warp.sub(y, self)
|
|
169
188
|
|
|
170
189
|
def __mul__(self, y):
|
|
171
190
|
return warp.mul(self, y)
|
|
@@ -173,17 +192,17 @@ def vector(length, dtype):
|
|
|
173
192
|
def __rmul__(self, x):
|
|
174
193
|
return warp.mul(x, self)
|
|
175
194
|
|
|
176
|
-
def
|
|
195
|
+
def __truediv__(self, y):
|
|
177
196
|
return warp.div(self, y)
|
|
178
197
|
|
|
179
|
-
def
|
|
198
|
+
def __rtruediv__(self, x):
|
|
180
199
|
return warp.div(x, self)
|
|
181
200
|
|
|
182
|
-
def __pos__(self
|
|
183
|
-
return warp.pos(self
|
|
201
|
+
def __pos__(self):
|
|
202
|
+
return warp.pos(self)
|
|
184
203
|
|
|
185
|
-
def __neg__(self
|
|
186
|
-
return warp.neg(self
|
|
204
|
+
def __neg__(self):
|
|
205
|
+
return warp.neg(self)
|
|
187
206
|
|
|
188
207
|
def __str__(self):
|
|
189
208
|
return f"[{', '.join(map(str, self))}]"
|
|
@@ -275,13 +294,13 @@ def matrix(shape, dtype):
|
|
|
275
294
|
return warp.add(self, y)
|
|
276
295
|
|
|
277
296
|
def __radd__(self, y):
|
|
278
|
-
return warp.add(
|
|
297
|
+
return warp.add(y, self)
|
|
279
298
|
|
|
280
299
|
def __sub__(self, y):
|
|
281
300
|
return warp.sub(self, y)
|
|
282
301
|
|
|
283
|
-
def __rsub__(self,
|
|
284
|
-
return warp.sub(
|
|
302
|
+
def __rsub__(self, y):
|
|
303
|
+
return warp.sub(y, self)
|
|
285
304
|
|
|
286
305
|
def __mul__(self, y):
|
|
287
306
|
return warp.mul(self, y)
|
|
@@ -295,17 +314,17 @@ def matrix(shape, dtype):
|
|
|
295
314
|
def __rmatmul__(self, x):
|
|
296
315
|
return warp.mul(x, self)
|
|
297
316
|
|
|
298
|
-
def
|
|
317
|
+
def __truediv__(self, y):
|
|
299
318
|
return warp.div(self, y)
|
|
300
319
|
|
|
301
|
-
def
|
|
320
|
+
def __rtruediv__(self, x):
|
|
302
321
|
return warp.div(x, self)
|
|
303
322
|
|
|
304
|
-
def __pos__(self
|
|
305
|
-
return warp.pos(self
|
|
323
|
+
def __pos__(self):
|
|
324
|
+
return warp.pos(self)
|
|
306
325
|
|
|
307
|
-
def __neg__(self
|
|
308
|
-
return warp.neg(self
|
|
326
|
+
def __neg__(self):
|
|
327
|
+
return warp.neg(self)
|
|
309
328
|
|
|
310
329
|
def __str__(self):
|
|
311
330
|
row_str = []
|
|
@@ -511,23 +530,63 @@ class quatd(quaternion(dtype=float64)):
|
|
|
511
530
|
|
|
512
531
|
def transformation(dtype=Any):
|
|
513
532
|
class transform_t(vector(length=7, dtype=dtype)):
|
|
533
|
+
_wp_init_from_components_sig_ = inspect.Signature(
|
|
534
|
+
(
|
|
535
|
+
inspect.Parameter(
|
|
536
|
+
"p",
|
|
537
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
538
|
+
default=(0.0, 0.0, 0.0),
|
|
539
|
+
),
|
|
540
|
+
inspect.Parameter(
|
|
541
|
+
"q",
|
|
542
|
+
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
|
543
|
+
default=(0.0, 0.0, 0.0, 1.0),
|
|
544
|
+
),
|
|
545
|
+
),
|
|
546
|
+
)
|
|
514
547
|
_wp_type_params_ = [dtype]
|
|
515
548
|
_wp_generic_type_str_ = "transform_t"
|
|
516
549
|
_wp_constructor_ = "transformation"
|
|
517
550
|
|
|
518
|
-
def __init__(self,
|
|
519
|
-
|
|
551
|
+
def __init__(self, *args, **kwargs):
|
|
552
|
+
if len(args) == 1 and len(kwargs) == 0:
|
|
553
|
+
if getattr(args[0], "_wp_generic_type_str_") == self._wp_generic_type_str_:
|
|
554
|
+
# Copy constructor.
|
|
555
|
+
super().__init__(*args[0])
|
|
556
|
+
return
|
|
520
557
|
|
|
521
|
-
|
|
522
|
-
|
|
558
|
+
try:
|
|
559
|
+
# For backward compatibility, try to check if the arguments
|
|
560
|
+
# match the original signature that'd allow initializing
|
|
561
|
+
# the `p` and `q` components separately.
|
|
562
|
+
bound_args = self._wp_init_from_components_sig_.bind(*args, **kwargs)
|
|
563
|
+
bound_args.apply_defaults()
|
|
564
|
+
p, q = bound_args.args
|
|
565
|
+
except (TypeError, ValueError):
|
|
566
|
+
# Fallback to the vector's constructor.
|
|
567
|
+
super().__init__(*args)
|
|
568
|
+
return
|
|
569
|
+
|
|
570
|
+
# Even if the arguments match the original “from components”
|
|
571
|
+
# signature, we still need to make sure that they represent
|
|
572
|
+
# sequences that can be unpacked.
|
|
573
|
+
if hasattr(p, "__len__") and hasattr(q, "__len__"):
|
|
574
|
+
# Initialize from the `p` and `q` components.
|
|
575
|
+
super().__init__()
|
|
576
|
+
self[0:3] = vector(length=3, dtype=dtype)(*p)
|
|
577
|
+
self[3:7] = quaternion(dtype=dtype)(*q)
|
|
578
|
+
return
|
|
579
|
+
|
|
580
|
+
# Fallback to the vector's constructor.
|
|
581
|
+
super().__init__(*args)
|
|
523
582
|
|
|
524
583
|
@property
|
|
525
584
|
def p(self):
|
|
526
|
-
return self[0:3]
|
|
585
|
+
return vec3(self[0:3])
|
|
527
586
|
|
|
528
587
|
@property
|
|
529
588
|
def q(self):
|
|
530
|
-
return self[3:7]
|
|
589
|
+
return quat(self[3:7])
|
|
531
590
|
|
|
532
591
|
return transform_t
|
|
533
592
|
|
|
@@ -851,18 +910,21 @@ class range_t:
|
|
|
851
910
|
|
|
852
911
|
# definition just for kernel type (cannot be a parameter), see bvh.h
|
|
853
912
|
class bvh_query_t:
|
|
913
|
+
"""Object used to track state during BVH traversal."""
|
|
854
914
|
def __init__(self):
|
|
855
915
|
pass
|
|
856
916
|
|
|
857
917
|
|
|
858
918
|
# definition just for kernel type (cannot be a parameter), see mesh.h
|
|
859
919
|
class mesh_query_aabb_t:
|
|
920
|
+
"""Object used to track state during mesh traversal."""
|
|
860
921
|
def __init__(self):
|
|
861
922
|
pass
|
|
862
923
|
|
|
863
924
|
|
|
864
925
|
# definition just for kernel type (cannot be a parameter), see hash_grid.h
|
|
865
926
|
class hash_grid_query_t:
|
|
927
|
+
"""Object used to track state during neighbor traversal."""
|
|
866
928
|
def __init__(self):
|
|
867
929
|
pass
|
|
868
930
|
|
|
@@ -999,7 +1061,7 @@ def type_scalar_type(dtype):
|
|
|
999
1061
|
def type_size_in_bytes(dtype):
|
|
1000
1062
|
if dtype.__module__ == "ctypes":
|
|
1001
1063
|
return ctypes.sizeof(dtype)
|
|
1002
|
-
elif
|
|
1064
|
+
elif isinstance(dtype, warp.codegen.Struct):
|
|
1003
1065
|
return ctypes.sizeof(dtype.ctype)
|
|
1004
1066
|
elif dtype == float or dtype == int:
|
|
1005
1067
|
return 4
|
|
@@ -1020,8 +1082,6 @@ def type_to_warp(dtype):
|
|
|
1020
1082
|
|
|
1021
1083
|
|
|
1022
1084
|
def type_typestr(dtype):
|
|
1023
|
-
from warp.codegen import Struct
|
|
1024
|
-
|
|
1025
1085
|
if dtype == bool:
|
|
1026
1086
|
return "?"
|
|
1027
1087
|
elif dtype == float16:
|
|
@@ -1046,7 +1106,7 @@ def type_typestr(dtype):
|
|
|
1046
1106
|
return "<i8"
|
|
1047
1107
|
elif dtype == uint64:
|
|
1048
1108
|
return "<u8"
|
|
1049
|
-
elif isinstance(dtype, Struct):
|
|
1109
|
+
elif isinstance(dtype, warp.codegen.Struct):
|
|
1050
1110
|
return f"|V{ctypes.sizeof(dtype.ctype)}"
|
|
1051
1111
|
elif issubclass(dtype, ctypes.Array):
|
|
1052
1112
|
return type_typestr(dtype._wp_scalar_type_)
|
|
@@ -1060,9 +1120,16 @@ def type_repr(t):
|
|
|
1060
1120
|
return str(f"array(ndim={t.ndim}, dtype={t.dtype})")
|
|
1061
1121
|
if type_is_vector(t):
|
|
1062
1122
|
return str(f"vector(length={t._shape_[0]}, dtype={t._wp_scalar_type_})")
|
|
1063
|
-
|
|
1123
|
+
if type_is_matrix(t):
|
|
1064
1124
|
return str(f"matrix(shape=({t._shape_[0]}, {t._shape_[1]}), dtype={t._wp_scalar_type_})")
|
|
1065
|
-
|
|
1125
|
+
if isinstance(t, warp.codegen.Struct):
|
|
1126
|
+
return type_repr(t.cls)
|
|
1127
|
+
if t in scalar_types:
|
|
1128
|
+
return t.__name__
|
|
1129
|
+
|
|
1130
|
+
try:
|
|
1131
|
+
return t.__module__ + "." + t.__qualname__
|
|
1132
|
+
except AttributeError:
|
|
1066
1133
|
return str(t)
|
|
1067
1134
|
|
|
1068
1135
|
|
|
@@ -1080,15 +1147,6 @@ def type_is_float(t):
|
|
|
1080
1147
|
return t in float_types
|
|
1081
1148
|
|
|
1082
1149
|
|
|
1083
|
-
def type_is_struct(dtype):
|
|
1084
|
-
from warp.codegen import Struct
|
|
1085
|
-
|
|
1086
|
-
if isinstance(dtype, Struct):
|
|
1087
|
-
return True
|
|
1088
|
-
else:
|
|
1089
|
-
return False
|
|
1090
|
-
|
|
1091
|
-
|
|
1092
1150
|
# returns True if the passed *type* is a vector
|
|
1093
1151
|
def type_is_vector(t):
|
|
1094
1152
|
if hasattr(t, "_wp_generic_type_str_") and t._wp_generic_type_str_ == "vec_t":
|
|
@@ -1162,6 +1220,17 @@ def types_equal(a, b, match_generic=False):
|
|
|
1162
1220
|
if p1 == Float and p2 == Float:
|
|
1163
1221
|
return True
|
|
1164
1222
|
|
|
1223
|
+
# convert to canonical types
|
|
1224
|
+
if p1 == float:
|
|
1225
|
+
p1 = float32
|
|
1226
|
+
elif p1 == int:
|
|
1227
|
+
p1 = int32
|
|
1228
|
+
|
|
1229
|
+
if p2 == float:
|
|
1230
|
+
p2 = float32
|
|
1231
|
+
elif b == int:
|
|
1232
|
+
p2 = int32
|
|
1233
|
+
|
|
1165
1234
|
if p1 in compatible_bool_types and p2 in compatible_bool_types:
|
|
1166
1235
|
return True
|
|
1167
1236
|
else:
|
|
@@ -1173,7 +1242,7 @@ def types_equal(a, b, match_generic=False):
|
|
|
1173
1242
|
and a._wp_generic_type_str_ == b._wp_generic_type_str_
|
|
1174
1243
|
):
|
|
1175
1244
|
return all([are_equal(p1, p2) for p1, p2 in zip(a._wp_type_params_, b._wp_type_params_)])
|
|
1176
|
-
if is_array(a) and type(a)
|
|
1245
|
+
if is_array(a) and type(a) is type(b):
|
|
1177
1246
|
return True
|
|
1178
1247
|
else:
|
|
1179
1248
|
return are_equal(a, b)
|
|
@@ -1257,6 +1326,7 @@ class array(Array):
|
|
|
1257
1326
|
self._grad = None
|
|
1258
1327
|
# __array_interface__ or __cuda_array_interface__, evaluated lazily and cached
|
|
1259
1328
|
self._array_interface = None
|
|
1329
|
+
self.is_transposed = False
|
|
1260
1330
|
|
|
1261
1331
|
# canonicalize dtype
|
|
1262
1332
|
if dtype == int:
|
|
@@ -1801,6 +1871,7 @@ class array(Array):
|
|
|
1801
1871
|
return array._vars
|
|
1802
1872
|
|
|
1803
1873
|
def zero_(self):
|
|
1874
|
+
"""Zeroes-out the array entires."""
|
|
1804
1875
|
if self.is_contiguous:
|
|
1805
1876
|
# simple memset is usually faster than generic fill
|
|
1806
1877
|
self.device.memset(self.ptr, 0, self.size * type_size_in_bytes(self.dtype))
|
|
@@ -1808,6 +1879,32 @@ class array(Array):
|
|
|
1808
1879
|
self.fill_(0)
|
|
1809
1880
|
|
|
1810
1881
|
def fill_(self, value):
|
|
1882
|
+
"""Set all array entries to `value`
|
|
1883
|
+
|
|
1884
|
+
args:
|
|
1885
|
+
value: The value to set every array entry to. Must be convertible to the array's ``dtype``.
|
|
1886
|
+
|
|
1887
|
+
Raises:
|
|
1888
|
+
ValueError: If `value` cannot be converted to the array's ``dtype``.
|
|
1889
|
+
|
|
1890
|
+
Examples:
|
|
1891
|
+
``fill_()`` can take lists or other sequences when filling arrays of vectors or matrices.
|
|
1892
|
+
|
|
1893
|
+
>>> arr = wp.zeros(2, dtype=wp.mat22)
|
|
1894
|
+
>>> arr.numpy()
|
|
1895
|
+
array([[[0., 0.],
|
|
1896
|
+
[0., 0.]],
|
|
1897
|
+
<BLANKLINE>
|
|
1898
|
+
[[0., 0.],
|
|
1899
|
+
[0., 0.]]], dtype=float32)
|
|
1900
|
+
>>> arr.fill_([[1, 2], [3, 4]])
|
|
1901
|
+
>>> arr.numpy()
|
|
1902
|
+
array([[[1., 2.],
|
|
1903
|
+
[3., 4.]],
|
|
1904
|
+
<BLANKLINE>
|
|
1905
|
+
[[1., 2.],
|
|
1906
|
+
[3., 4.]]], dtype=float32)
|
|
1907
|
+
"""
|
|
1811
1908
|
if self.size == 0:
|
|
1812
1909
|
return
|
|
1813
1910
|
|
|
@@ -1854,15 +1951,18 @@ class array(Array):
|
|
|
1854
1951
|
else:
|
|
1855
1952
|
warp.context.runtime.core.array_fill_host(carr_ptr, ARRAY_TYPE_REGULAR, cvalue_ptr, cvalue_size)
|
|
1856
1953
|
|
|
1857
|
-
# equivalent to wrapping src data in an array and copying to self
|
|
1858
1954
|
def assign(self, src):
|
|
1955
|
+
"""Wraps ``src`` in an :class:`warp.array` if it is not already one and copies the contents to ``self``."""
|
|
1859
1956
|
if is_array(src):
|
|
1860
1957
|
warp.copy(self, src)
|
|
1861
1958
|
else:
|
|
1862
1959
|
warp.copy(self, array(data=src, dtype=self.dtype, copy=False, device="cpu"))
|
|
1863
1960
|
|
|
1864
|
-
# convert array to ndarray (alias memory through array interface)
|
|
1865
1961
|
def numpy(self):
|
|
1962
|
+
"""Converts the array to a :class:`numpy.ndarray` (aliasing memory through the array interface protocol)
|
|
1963
|
+
If the array is on the GPU, a synchronous device-to-host copy (on the CUDA default stream) will be
|
|
1964
|
+
automatically performed to ensure that any outstanding work is completed.
|
|
1965
|
+
"""
|
|
1866
1966
|
if self.ptr:
|
|
1867
1967
|
# use the CUDA default stream for synchronous behaviour with other streams
|
|
1868
1968
|
with warp.ScopedStream(self.device.null_stream):
|
|
@@ -1883,12 +1983,16 @@ class array(Array):
|
|
|
1883
1983
|
npshape = self.shape
|
|
1884
1984
|
return np.empty(npshape, dtype=npdtype)
|
|
1885
1985
|
|
|
1886
|
-
# return a ctypes cast of the array address
|
|
1887
|
-
# note #1: only CPU arrays support this method
|
|
1888
|
-
# note #2: the array must be contiguous
|
|
1889
|
-
# note #3: accesses to this object are *not* bounds checked
|
|
1890
|
-
# note #4: for float16 types, a pointer to the internal uint16 representation is returned
|
|
1891
1986
|
def cptr(self):
|
|
1987
|
+
"""Return a ctypes cast of the array address.
|
|
1988
|
+
|
|
1989
|
+
Notes:
|
|
1990
|
+
|
|
1991
|
+
#. Only CPU arrays support this method.
|
|
1992
|
+
#. The array must be contiguous.
|
|
1993
|
+
#. Accesses to this object are **not** bounds checked.
|
|
1994
|
+
#. For ``float16`` types, a pointer to the internal ``uint16`` representation is returned.
|
|
1995
|
+
"""
|
|
1892
1996
|
if not self.ptr:
|
|
1893
1997
|
return None
|
|
1894
1998
|
|
|
@@ -1907,8 +2011,8 @@ class array(Array):
|
|
|
1907
2011
|
|
|
1908
2012
|
return p
|
|
1909
2013
|
|
|
1910
|
-
# returns a flattened list of items in the array as a Python list
|
|
1911
2014
|
def list(self):
|
|
2015
|
+
"""Returns a flattened list of items in the array as a Python list."""
|
|
1912
2016
|
a = self.numpy()
|
|
1913
2017
|
|
|
1914
2018
|
if isinstance(self.dtype, warp.codegen.Struct):
|
|
@@ -1927,8 +2031,8 @@ class array(Array):
|
|
|
1927
2031
|
# scalar
|
|
1928
2032
|
return list(a.flatten())
|
|
1929
2033
|
|
|
1930
|
-
# convert data from one device to another, nop if already on device
|
|
1931
2034
|
def to(self, device):
|
|
2035
|
+
"""Returns a Warp array with this array's data moved to the specified device, no-op if already on device."""
|
|
1932
2036
|
device = warp.get_device(device)
|
|
1933
2037
|
if self.device == device:
|
|
1934
2038
|
return self
|
|
@@ -1936,6 +2040,7 @@ class array(Array):
|
|
|
1936
2040
|
return warp.clone(self, device=device)
|
|
1937
2041
|
|
|
1938
2042
|
def flatten(self):
|
|
2043
|
+
"""Returns a zero-copy view of the array collapsed to 1-D. Only supported for contiguous arrays."""
|
|
1939
2044
|
if self.ndim == 1:
|
|
1940
2045
|
return self
|
|
1941
2046
|
|
|
@@ -1958,6 +2063,11 @@ class array(Array):
|
|
|
1958
2063
|
return a
|
|
1959
2064
|
|
|
1960
2065
|
def reshape(self, shape):
|
|
2066
|
+
"""Returns a reshaped array. Only supported for contiguous arrays.
|
|
2067
|
+
|
|
2068
|
+
Args:
|
|
2069
|
+
shape : An int or tuple of ints specifying the shape of the returned array.
|
|
2070
|
+
"""
|
|
1961
2071
|
if not self.is_contiguous:
|
|
1962
2072
|
raise RuntimeError("Reshaping non-contiguous arrays is unsupported.")
|
|
1963
2073
|
|
|
@@ -2015,6 +2125,9 @@ class array(Array):
|
|
|
2015
2125
|
return a
|
|
2016
2126
|
|
|
2017
2127
|
def view(self, dtype):
|
|
2128
|
+
"""Returns a zero-copy view of this array's memory with a different data type.
|
|
2129
|
+
``dtype`` must have the same byte size of the array's native ``dtype``.
|
|
2130
|
+
"""
|
|
2018
2131
|
if type_size_in_bytes(dtype) != type_size_in_bytes(self.dtype):
|
|
2019
2132
|
raise RuntimeError("Cannot cast dtypes of unequal byte size")
|
|
2020
2133
|
|
|
@@ -2035,6 +2148,7 @@ class array(Array):
|
|
|
2035
2148
|
return a
|
|
2036
2149
|
|
|
2037
2150
|
def contiguous(self):
|
|
2151
|
+
"""Returns a contiguous array with this array's data. No-op if array is already contiguous."""
|
|
2038
2152
|
if self.is_contiguous:
|
|
2039
2153
|
return self
|
|
2040
2154
|
|
|
@@ -2042,8 +2156,14 @@ class array(Array):
|
|
|
2042
2156
|
warp.copy(a, self)
|
|
2043
2157
|
return a
|
|
2044
2158
|
|
|
2045
|
-
# note: transpose operation will return an array with a non-contiguous access pattern
|
|
2046
2159
|
def transpose(self, axes=None):
|
|
2160
|
+
"""Returns an zero-copy view of the array with axes transposed.
|
|
2161
|
+
|
|
2162
|
+
Note: The transpose operation will return an array with a non-contiguous access pattern.
|
|
2163
|
+
|
|
2164
|
+
Args:
|
|
2165
|
+
axes (optional): Specifies the how the axes are permuted. If not specified, the axes order will be reversed.
|
|
2166
|
+
"""
|
|
2047
2167
|
# noop if 1d array
|
|
2048
2168
|
if self.ndim == 1:
|
|
2049
2169
|
return self
|
|
@@ -2076,6 +2196,8 @@ class array(Array):
|
|
|
2076
2196
|
grad=None if self.grad is None else self.grad.transpose(axes=axes),
|
|
2077
2197
|
)
|
|
2078
2198
|
|
|
2199
|
+
a.is_transposed = not self.is_transposed
|
|
2200
|
+
|
|
2079
2201
|
a._ref = self
|
|
2080
2202
|
return a
|
|
2081
2203
|
|
|
@@ -2516,16 +2638,14 @@ class Mesh:
|
|
|
2516
2638
|
|
|
2517
2639
|
|
|
2518
2640
|
class Volume:
|
|
2641
|
+
#: Enum value to specify nearest-neighbor interpolation during sampling
|
|
2519
2642
|
CLOSEST = constant(0)
|
|
2643
|
+
#: Enum value to specify trilinear interpolation during sampling
|
|
2520
2644
|
LINEAR = constant(1)
|
|
2521
2645
|
|
|
2522
2646
|
def __init__(self, data: array):
|
|
2523
2647
|
"""Class representing a sparse grid.
|
|
2524
2648
|
|
|
2525
|
-
Attributes:
|
|
2526
|
-
CLOSEST (int): Enum value to specify nearest-neighbor interpolation during sampling
|
|
2527
|
-
LINEAR (int): Enum value to specify trilinear interpolation during sampling
|
|
2528
|
-
|
|
2529
2649
|
Args:
|
|
2530
2650
|
data (:class:`warp.array`): Array of bytes representing the volume in NanoVDB format
|
|
2531
2651
|
"""
|
|
@@ -2570,7 +2690,8 @@ class Volume:
|
|
|
2570
2690
|
except Exception:
|
|
2571
2691
|
pass
|
|
2572
2692
|
|
|
2573
|
-
def array(self):
|
|
2693
|
+
def array(self) -> array:
|
|
2694
|
+
"""Returns the raw memory buffer of the Volume as an array"""
|
|
2574
2695
|
buf = ctypes.c_void_p(0)
|
|
2575
2696
|
size = ctypes.c_uint64(0)
|
|
2576
2697
|
if self.device.is_cpu:
|
|
@@ -2579,7 +2700,7 @@ class Volume:
|
|
|
2579
2700
|
self.context.core.volume_get_buffer_info_device(self.id, ctypes.byref(buf), ctypes.byref(size))
|
|
2580
2701
|
return array(ptr=buf.value, dtype=uint8, shape=size.value, device=self.device, owner=False)
|
|
2581
2702
|
|
|
2582
|
-
def get_tiles(self):
|
|
2703
|
+
def get_tiles(self) -> array:
|
|
2583
2704
|
if self.id == 0:
|
|
2584
2705
|
raise RuntimeError("Invalid Volume")
|
|
2585
2706
|
|
|
@@ -2592,7 +2713,7 @@ class Volume:
|
|
|
2592
2713
|
num_tiles = size.value // (3 * 4)
|
|
2593
2714
|
return array(ptr=buf.value, dtype=int32, shape=(num_tiles, 3), device=self.device, owner=True)
|
|
2594
2715
|
|
|
2595
|
-
def get_voxel_size(self):
|
|
2716
|
+
def get_voxel_size(self) -> Tuple[float, float, float]:
|
|
2596
2717
|
if self.id == 0:
|
|
2597
2718
|
raise RuntimeError("Invalid Volume")
|
|
2598
2719
|
|
|
@@ -2601,7 +2722,7 @@ class Volume:
|
|
|
2601
2722
|
return (dx.value, dy.value, dz.value)
|
|
2602
2723
|
|
|
2603
2724
|
@classmethod
|
|
2604
|
-
def load_from_nvdb(cls, file_or_buffer, device=None):
|
|
2725
|
+
def load_from_nvdb(cls, file_or_buffer, device=None) -> Volume:
|
|
2605
2726
|
"""Creates a Volume object from a NanoVDB file or in-memory buffer.
|
|
2606
2727
|
|
|
2607
2728
|
Returns:
|
|
@@ -2637,14 +2758,18 @@ class Volume:
|
|
|
2637
2758
|
return cls(data_array)
|
|
2638
2759
|
|
|
2639
2760
|
@classmethod
|
|
2640
|
-
def load_from_numpy(
|
|
2761
|
+
def load_from_numpy(
|
|
2762
|
+
cls, ndarray: np.array, min_world=(0.0, 0.0, 0.0), voxel_size=1.0, bg_value=0.0, device=None
|
|
2763
|
+
) -> Volume:
|
|
2641
2764
|
"""Creates a Volume object from a dense 3D NumPy array.
|
|
2642
2765
|
|
|
2766
|
+
This function is only supported for CUDA devices.
|
|
2767
|
+
|
|
2643
2768
|
Args:
|
|
2644
|
-
min_world: The 3D coordinate of the lower corner of the volume
|
|
2645
|
-
voxel_size: The size of each voxel in spatial coordinates
|
|
2769
|
+
min_world: The 3D coordinate of the lower corner of the volume.
|
|
2770
|
+
voxel_size: The size of each voxel in spatial coordinates.
|
|
2646
2771
|
bg_value: Background value
|
|
2647
|
-
device: The device to create the volume on, e.g.: "
|
|
2772
|
+
device: The CUDA device to create the volume on, e.g.: "cuda" or "cuda:0".
|
|
2648
2773
|
|
|
2649
2774
|
Returns:
|
|
2650
2775
|
|
|
@@ -2699,7 +2824,7 @@ class Volume:
|
|
|
2699
2824
|
inputs=[volume.id, warp.array(padded_array, dtype=warp.vec3, device=device)],
|
|
2700
2825
|
device=device,
|
|
2701
2826
|
)
|
|
2702
|
-
elif
|
|
2827
|
+
elif isinstance(bg_value, int):
|
|
2703
2828
|
warp.launch(
|
|
2704
2829
|
warp.utils.copy_dense_volume_to_nano_vdb_i,
|
|
2705
2830
|
dim=shape,
|
|
@@ -2726,9 +2851,11 @@ class Volume:
|
|
|
2726
2851
|
translation=(0.0, 0.0, 0.0),
|
|
2727
2852
|
points_in_world_space=False,
|
|
2728
2853
|
device=None,
|
|
2729
|
-
):
|
|
2854
|
+
) -> Volume:
|
|
2730
2855
|
"""Allocate a new Volume based on the bounding box defined by min and max.
|
|
2731
2856
|
|
|
2857
|
+
This function is only supported for CUDA devices.
|
|
2858
|
+
|
|
2732
2859
|
Allocate a volume that is large enough to contain voxels [min[0], min[1], min[2]] - [max[0], max[1], max[2]], inclusive.
|
|
2733
2860
|
If points_in_world_space is true, then min and max are first converted to index space with the given voxel size and
|
|
2734
2861
|
translation, and the volume is allocated with those.
|
|
@@ -2737,12 +2864,12 @@ class Volume:
|
|
|
2737
2864
|
the resulting tiles will be available in the new volume.
|
|
2738
2865
|
|
|
2739
2866
|
Args:
|
|
2740
|
-
min (array-like): Lower 3D
|
|
2741
|
-
max (array-like): Upper 3D
|
|
2742
|
-
voxel_size (float): Voxel size of the new volume
|
|
2867
|
+
min (array-like): Lower 3D coordinates of the bounding box in index space or world space, inclusive.
|
|
2868
|
+
max (array-like): Upper 3D coordinates of the bounding box in index space or world space, inclusive.
|
|
2869
|
+
voxel_size (float): Voxel size of the new volume.
|
|
2743
2870
|
bg_value (float or array-like): Value of unallocated voxels of the volume, also defines the volume's type, a :class:`warp.vec3` volume is created if this is `array-like`, otherwise a float volume is created
|
|
2744
|
-
translation (array-like): translation between the index and world spaces
|
|
2745
|
-
device (Devicelike):
|
|
2871
|
+
translation (array-like): translation between the index and world spaces.
|
|
2872
|
+
device (Devicelike): The CUDA device to create the volume on, e.g.: "cuda" or "cuda:0".
|
|
2746
2873
|
|
|
2747
2874
|
"""
|
|
2748
2875
|
if points_in_world_space:
|
|
@@ -2767,9 +2894,11 @@ class Volume:
|
|
|
2767
2894
|
@classmethod
|
|
2768
2895
|
def allocate_by_tiles(
|
|
2769
2896
|
cls, tile_points: array, voxel_size: float, bg_value=0.0, translation=(0.0, 0.0, 0.0), device=None
|
|
2770
|
-
):
|
|
2897
|
+
) -> Volume:
|
|
2771
2898
|
"""Allocate a new Volume with active tiles for each point tile_points.
|
|
2772
2899
|
|
|
2900
|
+
This function is only supported for CUDA devices.
|
|
2901
|
+
|
|
2773
2902
|
The smallest unit of allocation is a dense tile of 8x8x8 voxels.
|
|
2774
2903
|
This is the primary method for allocating sparse volumes. It uses an array of points indicating the tiles that must be allocated.
|
|
2775
2904
|
|
|
@@ -2779,13 +2908,13 @@ class Volume:
|
|
|
2779
2908
|
|
|
2780
2909
|
Args:
|
|
2781
2910
|
tile_points (:class:`warp.array`): Array of positions that define the tiles to be allocated.
|
|
2782
|
-
The array can be a
|
|
2911
|
+
The array can be a 2D, N-by-3 array of :class:`warp.int32` values, indicating index space positions,
|
|
2783
2912
|
or can be a 1D array of :class:`warp.vec3` values, indicating world space positions.
|
|
2784
2913
|
Repeated points per tile are allowed and will be efficiently deduplicated.
|
|
2785
|
-
voxel_size (float): Voxel size of the new volume
|
|
2914
|
+
voxel_size (float): Voxel size of the new volume.
|
|
2786
2915
|
bg_value (float or array-like): Value of unallocated voxels of the volume, also defines the volume's type, a :class:`warp.vec3` volume is created if this is `array-like`, otherwise a float volume is created
|
|
2787
|
-
translation (array-like):
|
|
2788
|
-
device (Devicelike):
|
|
2916
|
+
translation (array-like): Translation between the index and world spaces.
|
|
2917
|
+
device (Devicelike): The CUDA device to create the volume on, e.g.: "cuda" or "cuda:0".
|
|
2789
2918
|
|
|
2790
2919
|
"""
|
|
2791
2920
|
from warp.context import runtime
|
|
@@ -2822,7 +2951,7 @@ class Volume:
|
|
|
2822
2951
|
translation[2],
|
|
2823
2952
|
in_world_space,
|
|
2824
2953
|
)
|
|
2825
|
-
elif
|
|
2954
|
+
elif isinstance(bg_value, int):
|
|
2826
2955
|
volume.id = volume.context.core.volume_i_from_tiles_device(
|
|
2827
2956
|
volume.device.context,
|
|
2828
2957
|
ctypes.c_void_p(tile_points.ptr),
|
|
@@ -2853,6 +2982,67 @@ class Volume:
|
|
|
2853
2982
|
return volume
|
|
2854
2983
|
|
|
2855
2984
|
|
|
2985
|
+
# definition just for kernel type (cannot be a parameter), see mesh.h
|
|
2986
|
+
# NOTE: its layout must match the corresponding struct defined in C.
|
|
2987
|
+
# NOTE: it needs to be defined after `indexedarray` to workaround a circular import issue.
|
|
2988
|
+
class mesh_query_point_t:
|
|
2989
|
+
"""Output for the mesh query point functions.
|
|
2990
|
+
|
|
2991
|
+
Attributes:
|
|
2992
|
+
result (bool): Whether a point is found within the given constraints.
|
|
2993
|
+
sign (float32): A value < 0 if query point is inside the mesh, >=0 otherwise.
|
|
2994
|
+
Note that mesh must be watertight for this to be robust
|
|
2995
|
+
face (int32): Index of the closest face.
|
|
2996
|
+
u (float32): Barycentric u coordinate of the closest point.
|
|
2997
|
+
v (float32): Barycentric v coordinate of the closest point.
|
|
2998
|
+
|
|
2999
|
+
See Also:
|
|
3000
|
+
:func:`mesh_query_point`, :func:`mesh_query_point_no_sign`,
|
|
3001
|
+
:func:`mesh_query_furthest_point_no_sign`,
|
|
3002
|
+
:func:`mesh_query_point_sign_normal`,
|
|
3003
|
+
and :func:`mesh_query_point_sign_winding_number`.
|
|
3004
|
+
"""
|
|
3005
|
+
from warp.codegen import Var
|
|
3006
|
+
|
|
3007
|
+
vars = {
|
|
3008
|
+
"result": Var("result", bool),
|
|
3009
|
+
"sign": Var("sign", float32),
|
|
3010
|
+
"face": Var("face", int32),
|
|
3011
|
+
"u": Var("u", float32),
|
|
3012
|
+
"v": Var("v", float32),
|
|
3013
|
+
}
|
|
3014
|
+
|
|
3015
|
+
|
|
3016
|
+
# definition just for kernel type (cannot be a parameter), see mesh.h
|
|
3017
|
+
# NOTE: its layout must match the corresponding struct defined in C.
|
|
3018
|
+
class mesh_query_ray_t:
|
|
3019
|
+
"""Output for the mesh query ray functions.
|
|
3020
|
+
|
|
3021
|
+
Attributes:
|
|
3022
|
+
result (bool): Whether a hit is found within the given constraints.
|
|
3023
|
+
sign (float32): A value > 0 if the ray hit in front of the face, returns < 0 otherwise.
|
|
3024
|
+
face (int32): Index of the closest face.
|
|
3025
|
+
t (float32): Distance of the closest hit along the ray.
|
|
3026
|
+
u (float32): Barycentric u coordinate of the closest hit.
|
|
3027
|
+
v (float32): Barycentric v coordinate of the closest hit.
|
|
3028
|
+
normal (vec3f): Face normal.
|
|
3029
|
+
|
|
3030
|
+
See Also:
|
|
3031
|
+
:func:`mesh_query_ray`.
|
|
3032
|
+
"""
|
|
3033
|
+
from warp.codegen import Var
|
|
3034
|
+
|
|
3035
|
+
vars = {
|
|
3036
|
+
"result": Var("result", bool),
|
|
3037
|
+
"sign": Var("sign", float32),
|
|
3038
|
+
"face": Var("face", int32),
|
|
3039
|
+
"t": Var("t", float32),
|
|
3040
|
+
"u": Var("u", float32),
|
|
3041
|
+
"v": Var("v", float32),
|
|
3042
|
+
"normal": Var("normal", vec3),
|
|
3043
|
+
}
|
|
3044
|
+
|
|
3045
|
+
|
|
2856
3046
|
def matmul(
|
|
2857
3047
|
a: array2d,
|
|
2858
3048
|
b: array2d,
|
|
@@ -2889,6 +3079,11 @@ def matmul(
|
|
|
2889
3079
|
"wp.matmul currently only supports operation between {A, B, C, D} matrices of the same type."
|
|
2890
3080
|
)
|
|
2891
3081
|
|
|
3082
|
+
if (not a.is_contiguous and not a.is_transposed) or (not b.is_contiguous and not b.is_transposed) or (not c.is_contiguous) or (not d.is_contiguous):
|
|
3083
|
+
raise RuntimeError(
|
|
3084
|
+
"wp.matmul is only valid for contiguous arrays, with the exception that A and/or B may be transposed."
|
|
3085
|
+
)
|
|
3086
|
+
|
|
2892
3087
|
m = a.shape[0]
|
|
2893
3088
|
n = b.shape[1]
|
|
2894
3089
|
k = a.shape[1]
|
|
@@ -2923,13 +3118,13 @@ def matmul(
|
|
|
2923
3118
|
ctypes.c_void_p(d.ptr),
|
|
2924
3119
|
alpha,
|
|
2925
3120
|
beta,
|
|
2926
|
-
|
|
2927
|
-
|
|
3121
|
+
not a.is_transposed,
|
|
3122
|
+
not b.is_transposed,
|
|
2928
3123
|
allow_tf32x3_arith,
|
|
2929
3124
|
1,
|
|
2930
3125
|
)
|
|
2931
3126
|
if not ret:
|
|
2932
|
-
raise RuntimeError("
|
|
3127
|
+
raise RuntimeError("matmul failed.")
|
|
2933
3128
|
|
|
2934
3129
|
|
|
2935
3130
|
def adj_matmul(
|
|
@@ -2993,6 +3188,19 @@ def adj_matmul(
|
|
|
2993
3188
|
"wp.adj_matmul currently only supports operation between {A, B, C, adj_D, adj_A, adj_B, adj_C} matrices of the same type."
|
|
2994
3189
|
)
|
|
2995
3190
|
|
|
3191
|
+
if (
|
|
3192
|
+
(not a.is_contiguous and not a.is_transposed)
|
|
3193
|
+
or (not b.is_contiguous and not b.is_transposed)
|
|
3194
|
+
or (not c.is_contiguous)
|
|
3195
|
+
or (not adj_a.is_contiguous and not adj_a.is_transposed)
|
|
3196
|
+
or (not adj_b.is_contiguous and not adj_b.is_transposed)
|
|
3197
|
+
or (not adj_c.is_contiguous)
|
|
3198
|
+
or (not adj_d.is_contiguous)
|
|
3199
|
+
):
|
|
3200
|
+
raise RuntimeError(
|
|
3201
|
+
"wp.matmul is only valid for contiguous arrays, with the exception that A and/or B and their associated adjoints may be transposed."
|
|
3202
|
+
)
|
|
3203
|
+
|
|
2996
3204
|
m = a.shape[0]
|
|
2997
3205
|
n = b.shape[1]
|
|
2998
3206
|
k = a.shape[1]
|
|
@@ -3013,75 +3221,105 @@ def adj_matmul(
|
|
|
3013
3221
|
|
|
3014
3222
|
# cpu fallback if no cuda devices found
|
|
3015
3223
|
if device == "cpu":
|
|
3016
|
-
adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose()))
|
|
3017
|
-
adj_b.assign(alpha * (a.numpy().transpose() @ adj_d.numpy()))
|
|
3018
|
-
adj_c.assign(beta * adj_d.numpy())
|
|
3224
|
+
adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose()) + adj_a.numpy())
|
|
3225
|
+
adj_b.assign(alpha * (a.numpy().transpose() @ adj_d.numpy()) + adj_b.numpy())
|
|
3226
|
+
adj_c.assign(beta * adj_d.numpy() + adj_c.numpy())
|
|
3019
3227
|
return
|
|
3020
3228
|
|
|
3021
3229
|
cc = device.arch
|
|
3022
3230
|
|
|
3023
3231
|
# adj_a
|
|
3024
|
-
|
|
3025
|
-
|
|
3026
|
-
|
|
3027
|
-
|
|
3028
|
-
|
|
3029
|
-
|
|
3030
|
-
|
|
3031
|
-
|
|
3032
|
-
|
|
3033
|
-
|
|
3034
|
-
|
|
3035
|
-
|
|
3036
|
-
|
|
3037
|
-
|
|
3038
|
-
|
|
3039
|
-
|
|
3040
|
-
|
|
3041
|
-
|
|
3042
|
-
|
|
3232
|
+
if not a.is_transposed:
|
|
3233
|
+
ret = runtime.core.cutlass_gemm(
|
|
3234
|
+
cc,
|
|
3235
|
+
m,
|
|
3236
|
+
k,
|
|
3237
|
+
n,
|
|
3238
|
+
type_typestr(a.dtype).encode(),
|
|
3239
|
+
ctypes.c_void_p(adj_d.ptr),
|
|
3240
|
+
ctypes.c_void_p(b.ptr),
|
|
3241
|
+
ctypes.c_void_p(adj_a.ptr),
|
|
3242
|
+
ctypes.c_void_p(adj_a.ptr),
|
|
3243
|
+
alpha,
|
|
3244
|
+
1.0,
|
|
3245
|
+
True,
|
|
3246
|
+
b.is_transposed,
|
|
3247
|
+
allow_tf32x3_arith,
|
|
3248
|
+
1,
|
|
3249
|
+
)
|
|
3250
|
+
if not ret:
|
|
3251
|
+
raise RuntimeError("adj_matmul failed.")
|
|
3252
|
+
else:
|
|
3253
|
+
ret = runtime.core.cutlass_gemm(
|
|
3254
|
+
cc,
|
|
3255
|
+
k,
|
|
3256
|
+
m,
|
|
3257
|
+
n,
|
|
3258
|
+
type_typestr(a.dtype).encode(),
|
|
3259
|
+
ctypes.c_void_p(b.ptr),
|
|
3260
|
+
ctypes.c_void_p(adj_d.ptr),
|
|
3261
|
+
ctypes.c_void_p(adj_a.ptr),
|
|
3262
|
+
ctypes.c_void_p(adj_a.ptr),
|
|
3263
|
+
alpha,
|
|
3264
|
+
1.0,
|
|
3265
|
+
not b.is_transposed,
|
|
3266
|
+
False,
|
|
3267
|
+
allow_tf32x3_arith,
|
|
3268
|
+
1,
|
|
3269
|
+
)
|
|
3270
|
+
if not ret:
|
|
3271
|
+
raise RuntimeError("adj_matmul failed.")
|
|
3043
3272
|
|
|
3044
3273
|
# adj_b
|
|
3045
|
-
|
|
3046
|
-
|
|
3047
|
-
|
|
3048
|
-
|
|
3049
|
-
|
|
3050
|
-
|
|
3051
|
-
|
|
3052
|
-
|
|
3053
|
-
|
|
3054
|
-
|
|
3055
|
-
|
|
3056
|
-
|
|
3057
|
-
|
|
3058
|
-
|
|
3059
|
-
|
|
3060
|
-
|
|
3061
|
-
|
|
3062
|
-
|
|
3063
|
-
|
|
3274
|
+
if not b.is_transposed:
|
|
3275
|
+
ret = runtime.core.cutlass_gemm(
|
|
3276
|
+
cc,
|
|
3277
|
+
k,
|
|
3278
|
+
n,
|
|
3279
|
+
m,
|
|
3280
|
+
type_typestr(a.dtype).encode(),
|
|
3281
|
+
ctypes.c_void_p(a.ptr),
|
|
3282
|
+
ctypes.c_void_p(adj_d.ptr),
|
|
3283
|
+
ctypes.c_void_p(adj_b.ptr),
|
|
3284
|
+
ctypes.c_void_p(adj_b.ptr),
|
|
3285
|
+
alpha,
|
|
3286
|
+
1.0,
|
|
3287
|
+
a.is_transposed,
|
|
3288
|
+
True,
|
|
3289
|
+
allow_tf32x3_arith,
|
|
3290
|
+
1,
|
|
3291
|
+
)
|
|
3292
|
+
if not ret:
|
|
3293
|
+
raise RuntimeError("adj_matmul failed.")
|
|
3294
|
+
else:
|
|
3295
|
+
ret = runtime.core.cutlass_gemm(
|
|
3296
|
+
cc,
|
|
3297
|
+
n,
|
|
3298
|
+
k,
|
|
3299
|
+
m,
|
|
3300
|
+
type_typestr(a.dtype).encode(),
|
|
3301
|
+
ctypes.c_void_p(adj_d.ptr),
|
|
3302
|
+
ctypes.c_void_p(a.ptr),
|
|
3303
|
+
ctypes.c_void_p(adj_b.ptr),
|
|
3304
|
+
ctypes.c_void_p(adj_b.ptr),
|
|
3305
|
+
alpha,
|
|
3306
|
+
1.0,
|
|
3307
|
+
False,
|
|
3308
|
+
not a.is_transposed,
|
|
3309
|
+
allow_tf32x3_arith,
|
|
3310
|
+
1,
|
|
3311
|
+
)
|
|
3312
|
+
if not ret:
|
|
3313
|
+
raise RuntimeError("adj_matmul failed.")
|
|
3064
3314
|
|
|
3065
3315
|
# adj_c
|
|
3066
|
-
|
|
3067
|
-
|
|
3068
|
-
|
|
3069
|
-
|
|
3070
|
-
|
|
3071
|
-
|
|
3072
|
-
ctypes.c_void_p(a.ptr),
|
|
3073
|
-
ctypes.c_void_p(b.ptr),
|
|
3074
|
-
ctypes.c_void_p(adj_d.ptr),
|
|
3075
|
-
ctypes.c_void_p(adj_c.ptr),
|
|
3076
|
-
0.0,
|
|
3077
|
-
beta,
|
|
3078
|
-
True,
|
|
3079
|
-
True,
|
|
3080
|
-
allow_tf32x3_arith,
|
|
3081
|
-
1,
|
|
3316
|
+
warp.launch(
|
|
3317
|
+
kernel=warp.utils.add_kernel_2d,
|
|
3318
|
+
dim=adj_c.shape,
|
|
3319
|
+
inputs=[adj_c, adj_d, adj_d.dtype(beta)],
|
|
3320
|
+
device=device,
|
|
3321
|
+
record_tape=False
|
|
3082
3322
|
)
|
|
3083
|
-
if not ret:
|
|
3084
|
-
raise RuntimeError("adj_matmul failed.")
|
|
3085
3323
|
|
|
3086
3324
|
|
|
3087
3325
|
def batched_matmul(
|
|
@@ -3120,6 +3358,11 @@ def batched_matmul(
|
|
|
3120
3358
|
"wp.batched_matmul currently only supports operation between {A, B, C, D} matrices of the same type."
|
|
3121
3359
|
)
|
|
3122
3360
|
|
|
3361
|
+
if (not a.is_contiguous and not a.is_transposed) or (not b.is_contiguous and not b.is_transposed) or (not c.is_contiguous) or (not d.is_contiguous):
|
|
3362
|
+
raise RuntimeError(
|
|
3363
|
+
"wp.matmul is only valid for contiguous arrays, with the exception that A and/or B may be transposed."
|
|
3364
|
+
)
|
|
3365
|
+
|
|
3123
3366
|
m = a.shape[1]
|
|
3124
3367
|
n = b.shape[2]
|
|
3125
3368
|
k = a.shape[2]
|
|
@@ -3131,7 +3374,7 @@ def batched_matmul(
|
|
|
3131
3374
|
|
|
3132
3375
|
if runtime.tape:
|
|
3133
3376
|
runtime.tape.record_func(
|
|
3134
|
-
backward=lambda:
|
|
3377
|
+
backward=lambda: adj_batched_matmul(
|
|
3135
3378
|
a, b, c, a.grad, b.grad, c.grad, d.grad, alpha, beta, allow_tf32x3_arith, device
|
|
3136
3379
|
),
|
|
3137
3380
|
arrays=[a, b, c, d],
|
|
@@ -3142,26 +3385,55 @@ def batched_matmul(
|
|
|
3142
3385
|
d.assign(alpha * np.matmul(a.numpy(), b.numpy()) + beta * c.numpy())
|
|
3143
3386
|
return
|
|
3144
3387
|
|
|
3388
|
+
# handle case in which batch_count exceeds max_batch_count, which is a CUDA array size maximum
|
|
3389
|
+
max_batch_count = 65535
|
|
3390
|
+
iters = int(batch_count / max_batch_count)
|
|
3391
|
+
remainder = batch_count % max_batch_count
|
|
3392
|
+
|
|
3145
3393
|
cc = device.arch
|
|
3394
|
+
for i in range(iters):
|
|
3395
|
+
idx_start = i * max_batch_count
|
|
3396
|
+
idx_end = (i + 1) * max_batch_count if i < iters - 1 else batch_count
|
|
3397
|
+
ret = runtime.core.cutlass_gemm(
|
|
3398
|
+
cc,
|
|
3399
|
+
m,
|
|
3400
|
+
n,
|
|
3401
|
+
k,
|
|
3402
|
+
type_typestr(a.dtype).encode(),
|
|
3403
|
+
ctypes.c_void_p(a[idx_start:idx_end,:,:].ptr),
|
|
3404
|
+
ctypes.c_void_p(b[idx_start:idx_end,:,:].ptr),
|
|
3405
|
+
ctypes.c_void_p(c[idx_start:idx_end,:,:].ptr),
|
|
3406
|
+
ctypes.c_void_p(d[idx_start:idx_end,:,:].ptr),
|
|
3407
|
+
alpha,
|
|
3408
|
+
beta,
|
|
3409
|
+
not a.is_transposed,
|
|
3410
|
+
not b.is_transposed,
|
|
3411
|
+
allow_tf32x3_arith,
|
|
3412
|
+
max_batch_count,
|
|
3413
|
+
)
|
|
3414
|
+
if not ret:
|
|
3415
|
+
raise RuntimeError("Batched matmul failed.")
|
|
3416
|
+
|
|
3417
|
+
idx_start = iters * max_batch_count
|
|
3146
3418
|
ret = runtime.core.cutlass_gemm(
|
|
3147
3419
|
cc,
|
|
3148
3420
|
m,
|
|
3149
3421
|
n,
|
|
3150
3422
|
k,
|
|
3151
3423
|
type_typestr(a.dtype).encode(),
|
|
3152
|
-
ctypes.c_void_p(a.ptr),
|
|
3153
|
-
ctypes.c_void_p(b.ptr),
|
|
3154
|
-
ctypes.c_void_p(c.ptr),
|
|
3155
|
-
ctypes.c_void_p(d.ptr),
|
|
3424
|
+
ctypes.c_void_p(a[idx_start:,:,:].ptr),
|
|
3425
|
+
ctypes.c_void_p(b[idx_start:,:,:].ptr),
|
|
3426
|
+
ctypes.c_void_p(c[idx_start:,:,:].ptr),
|
|
3427
|
+
ctypes.c_void_p(d[idx_start:,:,:].ptr),
|
|
3156
3428
|
alpha,
|
|
3157
3429
|
beta,
|
|
3158
|
-
|
|
3159
|
-
|
|
3430
|
+
not a.is_transposed,
|
|
3431
|
+
not b.is_transposed,
|
|
3160
3432
|
allow_tf32x3_arith,
|
|
3161
|
-
|
|
3433
|
+
remainder,
|
|
3162
3434
|
)
|
|
3163
3435
|
if not ret:
|
|
3164
|
-
raise RuntimeError("Batched matmul failed.")
|
|
3436
|
+
raise RuntimeError("Batched matmul failed.")
|
|
3165
3437
|
|
|
3166
3438
|
|
|
3167
3439
|
def adj_batched_matmul(
|
|
@@ -3241,78 +3513,215 @@ def adj_batched_matmul(
|
|
|
3241
3513
|
)
|
|
3242
3514
|
)
|
|
3243
3515
|
|
|
3516
|
+
if (
|
|
3517
|
+
(not a.is_contiguous and not a.is_transposed)
|
|
3518
|
+
or (not b.is_contiguous and not b.is_transposed)
|
|
3519
|
+
or (not c.is_contiguous)
|
|
3520
|
+
or (not adj_a.is_contiguous and not adj_a.is_transposed)
|
|
3521
|
+
or (not adj_b.is_contiguous and not adj_b.is_transposed)
|
|
3522
|
+
or (not adj_c.is_contiguous)
|
|
3523
|
+
or (not adj_d.is_contiguous)
|
|
3524
|
+
):
|
|
3525
|
+
raise RuntimeError(
|
|
3526
|
+
"wp.matmul is only valid for contiguous arrays, with the exception that A and/or B and their associated adjoints may be transposed."
|
|
3527
|
+
)
|
|
3528
|
+
|
|
3244
3529
|
# cpu fallback if no cuda devices found
|
|
3245
3530
|
if device == "cpu":
|
|
3246
|
-
adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose((0, 2, 1))))
|
|
3247
|
-
adj_b.assign(alpha * np.matmul(a.numpy().transpose((0, 2, 1)), adj_d.numpy()))
|
|
3248
|
-
adj_c.assign(beta * adj_d.numpy())
|
|
3531
|
+
adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose((0, 2, 1))) + adj_a.numpy())
|
|
3532
|
+
adj_b.assign(alpha * np.matmul(a.numpy().transpose((0, 2, 1)), adj_d.numpy()) + adj_b.numpy())
|
|
3533
|
+
adj_c.assign(beta * adj_d.numpy() + adj_c.numpy())
|
|
3249
3534
|
return
|
|
3250
3535
|
|
|
3536
|
+
# handle case in which batch_count exceeds max_batch_count, which is a CUDA array size maximum
|
|
3537
|
+
max_batch_count = 65535
|
|
3538
|
+
iters = int(batch_count / max_batch_count)
|
|
3539
|
+
remainder = batch_count % max_batch_count
|
|
3540
|
+
|
|
3251
3541
|
cc = device.arch
|
|
3252
3542
|
|
|
3543
|
+
for i in range(iters):
|
|
3544
|
+
idx_start = i * max_batch_count
|
|
3545
|
+
idx_end = (i + 1) * max_batch_count if i < iters - 1 else batch_count
|
|
3546
|
+
|
|
3547
|
+
# adj_a
|
|
3548
|
+
if not a.is_transposed:
|
|
3549
|
+
ret = runtime.core.cutlass_gemm(
|
|
3550
|
+
cc,
|
|
3551
|
+
m,
|
|
3552
|
+
k,
|
|
3553
|
+
n,
|
|
3554
|
+
type_typestr(a.dtype).encode(),
|
|
3555
|
+
ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
|
|
3556
|
+
ctypes.c_void_p(b[idx_start:idx_end,:,:].ptr),
|
|
3557
|
+
ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
|
|
3558
|
+
ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
|
|
3559
|
+
alpha,
|
|
3560
|
+
1.0,
|
|
3561
|
+
True,
|
|
3562
|
+
b.is_transposed,
|
|
3563
|
+
allow_tf32x3_arith,
|
|
3564
|
+
max_batch_count,
|
|
3565
|
+
)
|
|
3566
|
+
if not ret:
|
|
3567
|
+
raise RuntimeError("adj_matmul failed.")
|
|
3568
|
+
else:
|
|
3569
|
+
ret = runtime.core.cutlass_gemm(
|
|
3570
|
+
cc,
|
|
3571
|
+
k,
|
|
3572
|
+
m,
|
|
3573
|
+
n,
|
|
3574
|
+
type_typestr(a.dtype).encode(),
|
|
3575
|
+
ctypes.c_void_p(b[idx_start:idx_end,:,:].ptr),
|
|
3576
|
+
ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
|
|
3577
|
+
ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
|
|
3578
|
+
ctypes.c_void_p(adj_a[idx_start:idx_end,:,:].ptr),
|
|
3579
|
+
alpha,
|
|
3580
|
+
1.0,
|
|
3581
|
+
not b.is_transposed,
|
|
3582
|
+
False,
|
|
3583
|
+
allow_tf32x3_arith,
|
|
3584
|
+
max_batch_count,
|
|
3585
|
+
)
|
|
3586
|
+
if not ret:
|
|
3587
|
+
raise RuntimeError("adj_matmul failed.")
|
|
3588
|
+
|
|
3589
|
+
# adj_b
|
|
3590
|
+
if not b.is_transposed:
|
|
3591
|
+
ret = runtime.core.cutlass_gemm(
|
|
3592
|
+
cc,
|
|
3593
|
+
k,
|
|
3594
|
+
n,
|
|
3595
|
+
m,
|
|
3596
|
+
type_typestr(a.dtype).encode(),
|
|
3597
|
+
ctypes.c_void_p(a[idx_start:idx_end,:,:].ptr),
|
|
3598
|
+
ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
|
|
3599
|
+
ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
|
|
3600
|
+
ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
|
|
3601
|
+
alpha,
|
|
3602
|
+
1.0,
|
|
3603
|
+
a.is_transposed,
|
|
3604
|
+
True,
|
|
3605
|
+
allow_tf32x3_arith,
|
|
3606
|
+
max_batch_count,
|
|
3607
|
+
)
|
|
3608
|
+
if not ret:
|
|
3609
|
+
raise RuntimeError("adj_matmul failed.")
|
|
3610
|
+
else:
|
|
3611
|
+
ret = runtime.core.cutlass_gemm(
|
|
3612
|
+
cc,
|
|
3613
|
+
n,
|
|
3614
|
+
k,
|
|
3615
|
+
m,
|
|
3616
|
+
type_typestr(a.dtype).encode(),
|
|
3617
|
+
ctypes.c_void_p(adj_d[idx_start:idx_end,:,:].ptr),
|
|
3618
|
+
ctypes.c_void_p(a[idx_start:idx_end,:,:].ptr),
|
|
3619
|
+
ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
|
|
3620
|
+
ctypes.c_void_p(adj_b[idx_start:idx_end,:,:].ptr),
|
|
3621
|
+
alpha,
|
|
3622
|
+
1.0,
|
|
3623
|
+
False,
|
|
3624
|
+
not a.is_transposed,
|
|
3625
|
+
allow_tf32x3_arith,
|
|
3626
|
+
max_batch_count,
|
|
3627
|
+
)
|
|
3628
|
+
if not ret:
|
|
3629
|
+
raise RuntimeError("adj_matmul failed.")
|
|
3630
|
+
|
|
3631
|
+
idx_start = iters * max_batch_count
|
|
3632
|
+
|
|
3253
3633
|
# adj_a
|
|
3254
|
-
|
|
3255
|
-
|
|
3256
|
-
|
|
3257
|
-
|
|
3258
|
-
|
|
3259
|
-
|
|
3260
|
-
|
|
3261
|
-
|
|
3262
|
-
|
|
3263
|
-
|
|
3264
|
-
|
|
3265
|
-
|
|
3266
|
-
|
|
3267
|
-
|
|
3268
|
-
|
|
3269
|
-
|
|
3270
|
-
|
|
3271
|
-
|
|
3272
|
-
|
|
3634
|
+
if not a.is_transposed:
|
|
3635
|
+
ret = runtime.core.cutlass_gemm(
|
|
3636
|
+
cc,
|
|
3637
|
+
m,
|
|
3638
|
+
k,
|
|
3639
|
+
n,
|
|
3640
|
+
type_typestr(a.dtype).encode(),
|
|
3641
|
+
ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
|
|
3642
|
+
ctypes.c_void_p(b[idx_start:,:,:].ptr),
|
|
3643
|
+
ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
|
|
3644
|
+
ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
|
|
3645
|
+
alpha,
|
|
3646
|
+
1.0,
|
|
3647
|
+
True,
|
|
3648
|
+
b.is_transposed,
|
|
3649
|
+
allow_tf32x3_arith,
|
|
3650
|
+
remainder,
|
|
3651
|
+
)
|
|
3652
|
+
if not ret:
|
|
3653
|
+
raise RuntimeError("adj_matmul failed.")
|
|
3654
|
+
else:
|
|
3655
|
+
ret = runtime.core.cutlass_gemm(
|
|
3656
|
+
cc,
|
|
3657
|
+
k,
|
|
3658
|
+
m,
|
|
3659
|
+
n,
|
|
3660
|
+
type_typestr(a.dtype).encode(),
|
|
3661
|
+
ctypes.c_void_p(b[idx_start:,:,:].ptr),
|
|
3662
|
+
ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
|
|
3663
|
+
ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
|
|
3664
|
+
ctypes.c_void_p(adj_a[idx_start:,:,:].ptr),
|
|
3665
|
+
alpha,
|
|
3666
|
+
1.0,
|
|
3667
|
+
not b.is_transposed,
|
|
3668
|
+
False,
|
|
3669
|
+
allow_tf32x3_arith,
|
|
3670
|
+
remainder,
|
|
3671
|
+
)
|
|
3672
|
+
if not ret:
|
|
3673
|
+
raise RuntimeError("adj_matmul failed.")
|
|
3273
3674
|
|
|
3274
3675
|
# adj_b
|
|
3275
|
-
|
|
3276
|
-
|
|
3277
|
-
|
|
3278
|
-
|
|
3279
|
-
|
|
3280
|
-
|
|
3281
|
-
|
|
3282
|
-
|
|
3283
|
-
|
|
3284
|
-
|
|
3285
|
-
|
|
3286
|
-
|
|
3287
|
-
|
|
3288
|
-
|
|
3289
|
-
|
|
3290
|
-
|
|
3291
|
-
|
|
3292
|
-
|
|
3293
|
-
|
|
3676
|
+
if not b.is_transposed:
|
|
3677
|
+
ret = runtime.core.cutlass_gemm(
|
|
3678
|
+
cc,
|
|
3679
|
+
k,
|
|
3680
|
+
n,
|
|
3681
|
+
m,
|
|
3682
|
+
type_typestr(a.dtype).encode(),
|
|
3683
|
+
ctypes.c_void_p(a[idx_start:,:,:].ptr),
|
|
3684
|
+
ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
|
|
3685
|
+
ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
|
|
3686
|
+
ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
|
|
3687
|
+
alpha,
|
|
3688
|
+
1.0,
|
|
3689
|
+
a.is_transposed,
|
|
3690
|
+
True,
|
|
3691
|
+
allow_tf32x3_arith,
|
|
3692
|
+
remainder,
|
|
3693
|
+
)
|
|
3694
|
+
if not ret:
|
|
3695
|
+
raise RuntimeError("adj_matmul failed.")
|
|
3696
|
+
else:
|
|
3697
|
+
ret = runtime.core.cutlass_gemm(
|
|
3698
|
+
cc,
|
|
3699
|
+
n,
|
|
3700
|
+
k,
|
|
3701
|
+
m,
|
|
3702
|
+
type_typestr(a.dtype).encode(),
|
|
3703
|
+
ctypes.c_void_p(adj_d[idx_start:,:,:].ptr),
|
|
3704
|
+
ctypes.c_void_p(a[idx_start:,:,:].ptr),
|
|
3705
|
+
ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
|
|
3706
|
+
ctypes.c_void_p(adj_b[idx_start:,:,:].ptr),
|
|
3707
|
+
alpha,
|
|
3708
|
+
1.0,
|
|
3709
|
+
False,
|
|
3710
|
+
not a.is_transposed,
|
|
3711
|
+
allow_tf32x3_arith,
|
|
3712
|
+
remainder,
|
|
3713
|
+
)
|
|
3714
|
+
if not ret:
|
|
3715
|
+
raise RuntimeError("adj_matmul failed.")
|
|
3294
3716
|
|
|
3295
3717
|
# adj_c
|
|
3296
|
-
|
|
3297
|
-
|
|
3298
|
-
|
|
3299
|
-
|
|
3300
|
-
|
|
3301
|
-
|
|
3302
|
-
ctypes.c_void_p(a.ptr),
|
|
3303
|
-
ctypes.c_void_p(b.ptr),
|
|
3304
|
-
ctypes.c_void_p(adj_d.ptr),
|
|
3305
|
-
ctypes.c_void_p(adj_c.ptr),
|
|
3306
|
-
0.0,
|
|
3307
|
-
beta,
|
|
3308
|
-
True,
|
|
3309
|
-
True,
|
|
3310
|
-
allow_tf32x3_arith,
|
|
3311
|
-
batch_count,
|
|
3718
|
+
warp.launch(
|
|
3719
|
+
kernel=warp.utils.add_kernel_3d,
|
|
3720
|
+
dim=adj_c.shape,
|
|
3721
|
+
inputs=[adj_c, adj_d, adj_d.dtype(beta)],
|
|
3722
|
+
device=device,
|
|
3723
|
+
record_tape=False
|
|
3312
3724
|
)
|
|
3313
|
-
if not ret:
|
|
3314
|
-
raise RuntimeError("adj_matmul failed.")
|
|
3315
|
-
|
|
3316
3725
|
|
|
3317
3726
|
class HashGrid:
|
|
3318
3727
|
def __init__(self, dim_x, dim_y, dim_z, device=None):
|
|
@@ -3511,7 +3920,7 @@ def type_matches_template(arg_type, template_type):
|
|
|
3511
3920
|
return True
|
|
3512
3921
|
elif is_array(template_type):
|
|
3513
3922
|
# ensure the argument type is a non-generic array with matching dtype and dimensionality
|
|
3514
|
-
if type(arg_type)
|
|
3923
|
+
if type(arg_type) is not type(template_type):
|
|
3515
3924
|
return False
|
|
3516
3925
|
if not type_matches_template(arg_type.dtype, template_type.dtype):
|
|
3517
3926
|
return False
|
|
@@ -3567,7 +3976,7 @@ def infer_argument_types(args, template_types, arg_names=None):
|
|
|
3567
3976
|
arg_types.append(arg._cls)
|
|
3568
3977
|
# elif arg_type in [warp.types.launch_bounds_t, warp.types.shape_t, warp.types.range_t]:
|
|
3569
3978
|
# arg_types.append(arg_type)
|
|
3570
|
-
# elif arg_type in [warp.hash_grid_query_t, warp.mesh_query_aabb_t, warp.bvh_query_t]:
|
|
3979
|
+
# elif arg_type in [warp.hash_grid_query_t, warp.mesh_query_aabb_t, warp.mesh_query_point_t, warp.mesh_query_ray_t, warp.bvh_query_t]:
|
|
3571
3980
|
# arg_types.append(arg_type)
|
|
3572
3981
|
elif arg is None:
|
|
3573
3982
|
# allow passing None for arrays
|
|
@@ -3605,6 +4014,8 @@ simple_type_codes = {
|
|
|
3605
4014
|
launch_bounds_t: "lb",
|
|
3606
4015
|
hash_grid_query_t: "hgq",
|
|
3607
4016
|
mesh_query_aabb_t: "mqa",
|
|
4017
|
+
mesh_query_point_t: "mqp",
|
|
4018
|
+
mesh_query_ray_t: "mqr",
|
|
3608
4019
|
bvh_query_t: "bvhq",
|
|
3609
4020
|
}
|
|
3610
4021
|
|