warp-lang 1.9.0__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +301 -287
- warp/__init__.pyi +2220 -313
- warp/_src/__init__.py +14 -0
- warp/_src/autograd.py +1075 -0
- warp/_src/build.py +618 -0
- warp/_src/build_dll.py +640 -0
- warp/{builtins.py → _src/builtins.py} +1497 -226
- warp/_src/codegen.py +4359 -0
- warp/{config.py → _src/config.py} +178 -169
- warp/_src/constants.py +57 -0
- warp/_src/context.py +8294 -0
- warp/_src/dlpack.py +462 -0
- warp/_src/fabric.py +355 -0
- warp/_src/fem/__init__.py +14 -0
- warp/_src/fem/adaptivity.py +508 -0
- warp/_src/fem/cache.py +687 -0
- warp/_src/fem/dirichlet.py +188 -0
- warp/{fem → _src/fem}/domain.py +40 -30
- warp/_src/fem/field/__init__.py +131 -0
- warp/_src/fem/field/field.py +701 -0
- warp/{fem → _src/fem}/field/nodal_field.py +30 -15
- warp/{fem → _src/fem}/field/restriction.py +1 -1
- warp/{fem → _src/fem}/field/virtual.py +53 -27
- warp/_src/fem/geometry/__init__.py +32 -0
- warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
- warp/_src/fem/geometry/closest_point.py +97 -0
- warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
- warp/{fem → _src/fem}/geometry/element.py +32 -10
- warp/{fem → _src/fem}/geometry/geometry.py +48 -20
- warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
- warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
- warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
- warp/{fem → _src/fem}/geometry/partition.py +121 -63
- warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
- warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
- warp/{fem → _src/fem}/integrate.py +164 -158
- warp/_src/fem/linalg.py +383 -0
- warp/_src/fem/operator.py +396 -0
- warp/_src/fem/polynomial.py +229 -0
- warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
- warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
- warp/_src/fem/space/__init__.py +248 -0
- warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
- warp/_src/fem/space/basis_space.py +679 -0
- warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
- warp/{fem → _src/fem}/space/function_space.py +14 -13
- warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
- warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
- warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
- warp/{fem → _src/fem}/space/partition.py +117 -60
- warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/restriction.py +66 -33
- warp/_src/fem/space/shape/__init__.py +152 -0
- warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
- warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
- warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
- warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
- warp/_src/fem/space/topology.py +459 -0
- warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
- warp/_src/fem/types.py +112 -0
- warp/_src/fem/utils.py +486 -0
- warp/_src/jax.py +186 -0
- warp/_src/jax_experimental/__init__.py +14 -0
- warp/_src/jax_experimental/custom_call.py +387 -0
- warp/_src/jax_experimental/ffi.py +1284 -0
- warp/_src/jax_experimental/xla_ffi.py +656 -0
- warp/_src/marching_cubes.py +708 -0
- warp/_src/math.py +414 -0
- warp/_src/optim/__init__.py +14 -0
- warp/_src/optim/adam.py +163 -0
- warp/_src/optim/linear.py +1606 -0
- warp/_src/optim/sgd.py +112 -0
- warp/_src/paddle.py +406 -0
- warp/_src/render/__init__.py +14 -0
- warp/_src/render/imgui_manager.py +289 -0
- warp/_src/render/render_opengl.py +3636 -0
- warp/_src/render/render_usd.py +937 -0
- warp/_src/render/utils.py +160 -0
- warp/_src/sparse.py +2716 -0
- warp/_src/tape.py +1206 -0
- warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
- warp/_src/torch.py +391 -0
- warp/_src/types.py +5870 -0
- warp/_src/utils.py +1693 -0
- warp/autograd.py +12 -1054
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +8 -588
- warp/build_dll.py +6 -471
- warp/codegen.py +6 -4246
- warp/constants.py +6 -39
- warp/context.py +12 -7851
- warp/dlpack.py +6 -444
- warp/examples/distributed/example_jacobi_mpi.py +4 -5
- warp/examples/fem/example_adaptive_grid.py +1 -1
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +8 -8
- warp/examples/fem/example_diffusion.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_mixed_elasticity.py +2 -2
- warp/examples/fem/example_navier_stokes.py +1 -1
- warp/examples/fem/example_nonconforming_contact.py +7 -7
- warp/examples/fem/example_stokes.py +1 -1
- warp/examples/fem/example_stokes_transfer.py +1 -1
- warp/examples/fem/utils.py +2 -2
- warp/examples/interop/example_jax_callable.py +1 -1
- warp/examples/interop/example_jax_ffi_callback.py +1 -1
- warp/examples/interop/example_jax_kernel.py +3 -2
- warp/examples/tile/example_tile_mcgp.py +191 -0
- warp/fabric.py +6 -337
- warp/fem/__init__.py +159 -97
- warp/fem/adaptivity.py +7 -489
- warp/fem/cache.py +9 -648
- warp/fem/dirichlet.py +6 -184
- warp/fem/field/__init__.py +8 -109
- warp/fem/field/field.py +7 -652
- warp/fem/geometry/__init__.py +7 -18
- warp/fem/geometry/closest_point.py +11 -77
- warp/fem/linalg.py +18 -366
- warp/fem/operator.py +11 -369
- warp/fem/polynomial.py +9 -209
- warp/fem/space/__init__.py +5 -211
- warp/fem/space/basis_space.py +6 -662
- warp/fem/space/shape/__init__.py +41 -118
- warp/fem/space/topology.py +6 -437
- warp/fem/types.py +6 -81
- warp/fem/utils.py +11 -444
- warp/jax.py +8 -165
- warp/jax_experimental/__init__.py +14 -1
- warp/jax_experimental/custom_call.py +8 -342
- warp/jax_experimental/ffi.py +17 -853
- warp/jax_experimental/xla_ffi.py +5 -596
- warp/marching_cubes.py +5 -689
- warp/math.py +16 -393
- warp/native/array.h +385 -37
- warp/native/builtin.h +316 -39
- warp/native/bvh.cpp +43 -9
- warp/native/bvh.cu +62 -27
- warp/native/bvh.h +310 -309
- warp/native/clang/clang.cpp +102 -97
- warp/native/coloring.cpp +0 -1
- warp/native/crt.h +208 -0
- warp/native/exports.h +156 -0
- warp/native/hashgrid.cu +2 -0
- warp/native/intersect.h +24 -1
- warp/native/intersect_tri.h +44 -35
- warp/native/mat.h +1456 -276
- warp/native/mesh.cpp +4 -4
- warp/native/mesh.cu +4 -2
- warp/native/mesh.h +176 -61
- warp/native/quat.h +0 -52
- warp/native/scan.cu +2 -0
- warp/native/sort.cu +22 -13
- warp/native/sort.h +2 -0
- warp/native/sparse.cu +7 -3
- warp/native/spatial.h +12 -0
- warp/native/tile.h +837 -70
- warp/native/tile_radix_sort.h +1 -1
- warp/native/tile_reduce.h +394 -46
- warp/native/tile_scan.h +4 -4
- warp/native/vec.h +469 -53
- warp/native/version.h +23 -0
- warp/native/volume.cpp +1 -1
- warp/native/volume.cu +1 -0
- warp/native/volume.h +1 -1
- warp/native/volume_builder.cu +2 -0
- warp/native/warp.cpp +60 -32
- warp/native/warp.cu +313 -201
- warp/native/warp.h +14 -11
- warp/optim/__init__.py +6 -3
- warp/optim/adam.py +6 -145
- warp/optim/linear.py +14 -1585
- warp/optim/sgd.py +6 -94
- warp/paddle.py +6 -388
- warp/render/__init__.py +8 -4
- warp/render/imgui_manager.py +7 -267
- warp/render/render_opengl.py +6 -3616
- warp/render/render_usd.py +6 -918
- warp/render/utils.py +6 -142
- warp/sparse.py +37 -2563
- warp/tape.py +6 -1188
- warp/tests/__main__.py +1 -1
- warp/tests/cuda/test_async.py +4 -4
- warp/tests/cuda/test_conditional_captures.py +1 -1
- warp/tests/cuda/test_multigpu.py +1 -1
- warp/tests/cuda/test_streams.py +58 -1
- warp/tests/geometry/test_bvh.py +157 -22
- warp/tests/geometry/test_hash_grid.py +38 -0
- warp/tests/geometry/test_marching_cubes.py +0 -1
- warp/tests/geometry/test_mesh.py +5 -3
- warp/tests/geometry/test_mesh_query_aabb.py +5 -12
- warp/tests/geometry/test_mesh_query_point.py +5 -2
- warp/tests/geometry/test_mesh_query_ray.py +15 -3
- warp/tests/geometry/test_volume_write.py +5 -5
- warp/tests/interop/test_dlpack.py +14 -14
- warp/tests/interop/test_jax.py +1382 -79
- warp/tests/interop/test_paddle.py +1 -1
- warp/tests/test_adam.py +0 -1
- warp/tests/test_arithmetic.py +9 -9
- warp/tests/test_array.py +529 -100
- warp/tests/test_array_reduce.py +3 -3
- warp/tests/test_atomic.py +12 -8
- warp/tests/test_atomic_bitwise.py +209 -0
- warp/tests/test_atomic_cas.py +4 -4
- warp/tests/test_bool.py +2 -2
- warp/tests/test_builtins_resolution.py +5 -571
- warp/tests/test_codegen.py +34 -15
- warp/tests/test_conditional.py +1 -1
- warp/tests/test_context.py +6 -6
- warp/tests/test_copy.py +242 -161
- warp/tests/test_ctypes.py +3 -3
- warp/tests/test_devices.py +24 -2
- warp/tests/test_examples.py +16 -84
- warp/tests/test_fabricarray.py +35 -35
- warp/tests/test_fast_math.py +0 -2
- warp/tests/test_fem.py +60 -14
- warp/tests/test_fixedarray.py +3 -3
- warp/tests/test_func.py +8 -5
- warp/tests/test_generics.py +1 -1
- warp/tests/test_indexedarray.py +24 -24
- warp/tests/test_intersect.py +39 -9
- warp/tests/test_large.py +1 -1
- warp/tests/test_lerp.py +3 -1
- warp/tests/test_linear_solvers.py +1 -1
- warp/tests/test_map.py +49 -4
- warp/tests/test_mat.py +52 -62
- warp/tests/test_mat_constructors.py +4 -5
- warp/tests/test_mat_lite.py +1 -1
- warp/tests/test_mat_scalar_ops.py +121 -121
- warp/tests/test_math.py +34 -0
- warp/tests/test_module_aot.py +4 -4
- warp/tests/test_modules_lite.py +28 -2
- warp/tests/test_print.py +11 -11
- warp/tests/test_quat.py +93 -58
- warp/tests/test_runlength_encode.py +1 -1
- warp/tests/test_scalar_ops.py +38 -10
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +126 -15
- warp/tests/test_spatial.py +105 -87
- warp/tests/test_special_values.py +6 -6
- warp/tests/test_static.py +7 -7
- warp/tests/test_struct.py +13 -2
- warp/tests/test_triangle_closest_point.py +48 -1
- warp/tests/test_tuple.py +96 -0
- warp/tests/test_types.py +82 -9
- warp/tests/test_utils.py +52 -52
- warp/tests/test_vec.py +29 -29
- warp/tests/test_vec_constructors.py +5 -5
- warp/tests/test_vec_scalar_ops.py +97 -97
- warp/tests/test_version.py +75 -0
- warp/tests/tile/test_tile.py +239 -0
- warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
- warp/tests/tile/test_tile_cholesky.py +7 -4
- warp/tests/tile/test_tile_load.py +26 -2
- warp/tests/tile/test_tile_mathdx.py +3 -3
- warp/tests/tile/test_tile_matmul.py +1 -1
- warp/tests/tile/test_tile_mlp.py +2 -4
- warp/tests/tile/test_tile_reduce.py +214 -13
- warp/tests/unittest_suites.py +6 -14
- warp/tests/unittest_utils.py +10 -9
- warp/tests/walkthrough_debug.py +3 -1
- warp/torch.py +6 -373
- warp/types.py +29 -5750
- warp/utils.py +10 -1659
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +47 -103
- warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
- warp/examples/assets/cartpole.urdf +0 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/nv_ant.xml +0 -92
- warp/examples/assets/nv_humanoid.xml +0 -183
- warp/examples/assets/quadruped.urdf +0 -268
- warp/examples/optim/example_bounce.py +0 -266
- warp/examples/optim/example_cloth_throw.py +0 -228
- warp/examples/optim/example_drone.py +0 -870
- warp/examples/optim/example_inverse_kinematics.py +0 -182
- warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
- warp/examples/optim/example_softbody_properties.py +0 -400
- warp/examples/optim/example_spring_cage.py +0 -245
- warp/examples/optim/example_trajectory.py +0 -227
- warp/examples/sim/example_cartpole.py +0 -143
- warp/examples/sim/example_cloth.py +0 -225
- warp/examples/sim/example_cloth_self_contact.py +0 -316
- warp/examples/sim/example_granular.py +0 -130
- warp/examples/sim/example_granular_collision_sdf.py +0 -202
- warp/examples/sim/example_jacobian_ik.py +0 -244
- warp/examples/sim/example_particle_chain.py +0 -124
- warp/examples/sim/example_quadruped.py +0 -203
- warp/examples/sim/example_rigid_chain.py +0 -203
- warp/examples/sim/example_rigid_contact.py +0 -195
- warp/examples/sim/example_rigid_force.py +0 -133
- warp/examples/sim/example_rigid_gyroscopic.py +0 -115
- warp/examples/sim/example_rigid_soft_contact.py +0 -140
- warp/examples/sim/example_soft_body.py +0 -196
- warp/examples/tile/example_tile_walker.py +0 -327
- warp/sim/__init__.py +0 -74
- warp/sim/articulation.py +0 -793
- warp/sim/collide.py +0 -2570
- warp/sim/graph_coloring.py +0 -307
- warp/sim/import_mjcf.py +0 -791
- warp/sim/import_snu.py +0 -227
- warp/sim/import_urdf.py +0 -579
- warp/sim/import_usd.py +0 -898
- warp/sim/inertia.py +0 -357
- warp/sim/integrator.py +0 -245
- warp/sim/integrator_euler.py +0 -2000
- warp/sim/integrator_featherstone.py +0 -2101
- warp/sim/integrator_vbd.py +0 -2487
- warp/sim/integrator_xpbd.py +0 -3295
- warp/sim/model.py +0 -4821
- warp/sim/particles.py +0 -121
- warp/sim/render.py +0 -431
- warp/sim/utils.py +0 -431
- warp/tests/sim/disabled_kinematics.py +0 -244
- warp/tests/sim/test_cloth.py +0 -863
- warp/tests/sim/test_collision.py +0 -743
- warp/tests/sim/test_coloring.py +0 -347
- warp/tests/sim/test_inertia.py +0 -161
- warp/tests/sim/test_model.py +0 -226
- warp/tests/sim/test_sim_grad.py +0 -287
- warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
- warp/tests/sim/test_sim_kinematics.py +0 -98
- warp/thirdparty/__init__.py +0 -0
- warp_lang-1.9.0.dist-info/RECORD +0 -456
- /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
- /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.9.0.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
warp/native/warp.cu
CHANGED
|
@@ -19,6 +19,7 @@
|
|
|
19
19
|
#include "scan.h"
|
|
20
20
|
#include "cuda_util.h"
|
|
21
21
|
#include "error.h"
|
|
22
|
+
#include "sort.h"
|
|
22
23
|
|
|
23
24
|
#include <cstdlib>
|
|
24
25
|
#include <fstream>
|
|
@@ -221,6 +222,14 @@ struct ModuleInfo
|
|
|
221
222
|
void* module = NULL;
|
|
222
223
|
};
|
|
223
224
|
|
|
225
|
+
// Information used when deferring graph destruction.
|
|
226
|
+
struct GraphDestroyInfo
|
|
227
|
+
{
|
|
228
|
+
void* context = NULL;
|
|
229
|
+
void* graph = NULL;
|
|
230
|
+
void* graph_exec = NULL;
|
|
231
|
+
};
|
|
232
|
+
|
|
224
233
|
static std::unordered_map<CUfunction, std::string> g_kernel_names;
|
|
225
234
|
|
|
226
235
|
// cached info for all devices, indexed by ordinal
|
|
@@ -252,6 +261,11 @@ static std::vector<FreeInfo> g_deferred_free_list;
|
|
|
252
261
|
// Call unload_deferred_modules() to release.
|
|
253
262
|
static std::vector<ModuleInfo> g_deferred_module_list;
|
|
254
263
|
|
|
264
|
+
// Graphs that cannot be destroyed immediately get queued here.
|
|
265
|
+
// Call destroy_deferred_graphs() to release.
|
|
266
|
+
static std::vector<GraphDestroyInfo> g_deferred_graph_list;
|
|
267
|
+
|
|
268
|
+
|
|
255
269
|
void wp_cuda_set_context_restore_policy(bool always_restore)
|
|
256
270
|
{
|
|
257
271
|
ContextGuard::always_restore = always_restore;
|
|
@@ -337,7 +351,7 @@ int cuda_init()
|
|
|
337
351
|
}
|
|
338
352
|
|
|
339
353
|
|
|
340
|
-
|
|
354
|
+
CUcontext get_current_context()
|
|
341
355
|
{
|
|
342
356
|
CUcontext ctx;
|
|
343
357
|
if (check_cu(cuCtxGetCurrent_f(&ctx)))
|
|
@@ -494,6 +508,38 @@ static int unload_deferred_modules(void* context = NULL)
|
|
|
494
508
|
return num_unloaded_modules;
|
|
495
509
|
}
|
|
496
510
|
|
|
511
|
+
static int destroy_deferred_graphs(void* context = NULL)
|
|
512
|
+
{
|
|
513
|
+
if (g_deferred_graph_list.empty() || !g_captures.empty())
|
|
514
|
+
return 0;
|
|
515
|
+
|
|
516
|
+
int num_destroyed_graphs = 0;
|
|
517
|
+
for (auto it = g_deferred_graph_list.begin(); it != g_deferred_graph_list.end(); /*noop*/)
|
|
518
|
+
{
|
|
519
|
+
// destroy the graph if it matches the given context or if the context is unspecified
|
|
520
|
+
const GraphDestroyInfo& graph_info = *it;
|
|
521
|
+
if (graph_info.context == context || !context)
|
|
522
|
+
{
|
|
523
|
+
if (graph_info.graph)
|
|
524
|
+
{
|
|
525
|
+
check_cuda(cudaGraphDestroy((cudaGraph_t)graph_info.graph));
|
|
526
|
+
}
|
|
527
|
+
if (graph_info.graph_exec)
|
|
528
|
+
{
|
|
529
|
+
check_cuda(cudaGraphExecDestroy((cudaGraphExec_t)graph_info.graph_exec));
|
|
530
|
+
}
|
|
531
|
+
++num_destroyed_graphs;
|
|
532
|
+
it = g_deferred_graph_list.erase(it);
|
|
533
|
+
}
|
|
534
|
+
else
|
|
535
|
+
{
|
|
536
|
+
++it;
|
|
537
|
+
}
|
|
538
|
+
}
|
|
539
|
+
|
|
540
|
+
return num_destroyed_graphs;
|
|
541
|
+
}
|
|
542
|
+
|
|
497
543
|
static void CUDART_CB on_graph_destroy(void* user_data)
|
|
498
544
|
{
|
|
499
545
|
if (!user_data)
|
|
@@ -988,15 +1034,15 @@ void wp_memtile_device(void* context, void* dst, const void* src, size_t srcsize
|
|
|
988
1034
|
|
|
989
1035
|
|
|
990
1036
|
static __global__ void array_copy_1d_kernel(void* dst, const void* src,
|
|
991
|
-
|
|
1037
|
+
size_t dst_stride, size_t src_stride,
|
|
992
1038
|
const int* dst_indices, const int* src_indices,
|
|
993
|
-
|
|
1039
|
+
size_t n, size_t elem_size)
|
|
994
1040
|
{
|
|
995
|
-
|
|
1041
|
+
size_t i = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
996
1042
|
if (i < n)
|
|
997
1043
|
{
|
|
998
|
-
|
|
999
|
-
|
|
1044
|
+
size_t src_idx = src_indices ? src_indices[i] : i;
|
|
1045
|
+
size_t dst_idx = dst_indices ? dst_indices[i] : i;
|
|
1000
1046
|
const char* p = (const char*)src + src_idx * src_stride;
|
|
1001
1047
|
char* q = (char*)dst + dst_idx * dst_stride;
|
|
1002
1048
|
memcpy(q, p, elem_size);
|
|
@@ -1004,20 +1050,20 @@ static __global__ void array_copy_1d_kernel(void* dst, const void* src,
|
|
|
1004
1050
|
}
|
|
1005
1051
|
|
|
1006
1052
|
static __global__ void array_copy_2d_kernel(void* dst, const void* src,
|
|
1007
|
-
wp::vec_t<2,
|
|
1053
|
+
wp::vec_t<2, size_t> dst_strides, wp::vec_t<2, size_t> src_strides,
|
|
1008
1054
|
wp::vec_t<2, const int*> dst_indices, wp::vec_t<2, const int*> src_indices,
|
|
1009
|
-
wp::vec_t<2,
|
|
1055
|
+
wp::vec_t<2, size_t> shape, size_t elem_size)
|
|
1010
1056
|
{
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1057
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1058
|
+
size_t n = shape[1];
|
|
1059
|
+
size_t i = tid / n;
|
|
1060
|
+
size_t j = tid % n;
|
|
1015
1061
|
if (i < shape[0] /*&& j < shape[1]*/)
|
|
1016
1062
|
{
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1063
|
+
size_t src_idx0 = src_indices[0] ? src_indices[0][i] : i;
|
|
1064
|
+
size_t dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
|
|
1065
|
+
size_t src_idx1 = src_indices[1] ? src_indices[1][j] : j;
|
|
1066
|
+
size_t dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
|
|
1021
1067
|
const char* p = (const char*)src + src_idx0 * src_strides[0] + src_idx1 * src_strides[1];
|
|
1022
1068
|
char* q = (char*)dst + dst_idx0 * dst_strides[0] + dst_idx1 * dst_strides[1];
|
|
1023
1069
|
memcpy(q, p, elem_size);
|
|
@@ -1025,24 +1071,24 @@ static __global__ void array_copy_2d_kernel(void* dst, const void* src,
|
|
|
1025
1071
|
}
|
|
1026
1072
|
|
|
1027
1073
|
static __global__ void array_copy_3d_kernel(void* dst, const void* src,
|
|
1028
|
-
wp::vec_t<3,
|
|
1074
|
+
wp::vec_t<3, size_t> dst_strides, wp::vec_t<3, size_t> src_strides,
|
|
1029
1075
|
wp::vec_t<3, const int*> dst_indices, wp::vec_t<3, const int*> src_indices,
|
|
1030
|
-
wp::vec_t<3,
|
|
1031
|
-
{
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1076
|
+
wp::vec_t<3, size_t> shape, size_t elem_size)
|
|
1077
|
+
{
|
|
1078
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1079
|
+
size_t n = shape[1];
|
|
1080
|
+
size_t o = shape[2];
|
|
1081
|
+
size_t i = tid / (n * o);
|
|
1082
|
+
size_t j = tid % (n * o) / o;
|
|
1083
|
+
size_t k = tid % o;
|
|
1038
1084
|
if (i < shape[0] && j < shape[1] /*&& k < shape[2]*/)
|
|
1039
1085
|
{
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1086
|
+
size_t src_idx0 = src_indices[0] ? src_indices[0][i] : i;
|
|
1087
|
+
size_t dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
|
|
1088
|
+
size_t src_idx1 = src_indices[1] ? src_indices[1][j] : j;
|
|
1089
|
+
size_t dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
|
|
1090
|
+
size_t src_idx2 = src_indices[2] ? src_indices[2][k] : k;
|
|
1091
|
+
size_t dst_idx2 = dst_indices[2] ? dst_indices[2][k] : k;
|
|
1046
1092
|
const char* p = (const char*)src + src_idx0 * src_strides[0]
|
|
1047
1093
|
+ src_idx1 * src_strides[1]
|
|
1048
1094
|
+ src_idx2 * src_strides[2];
|
|
@@ -1054,28 +1100,28 @@ static __global__ void array_copy_3d_kernel(void* dst, const void* src,
|
|
|
1054
1100
|
}
|
|
1055
1101
|
|
|
1056
1102
|
static __global__ void array_copy_4d_kernel(void* dst, const void* src,
|
|
1057
|
-
wp::vec_t<4,
|
|
1103
|
+
wp::vec_t<4, size_t> dst_strides, wp::vec_t<4, size_t> src_strides,
|
|
1058
1104
|
wp::vec_t<4, const int*> dst_indices, wp::vec_t<4, const int*> src_indices,
|
|
1059
|
-
wp::vec_t<4,
|
|
1060
|
-
{
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1105
|
+
wp::vec_t<4, size_t> shape, size_t elem_size)
|
|
1106
|
+
{
|
|
1107
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1108
|
+
size_t n = shape[1];
|
|
1109
|
+
size_t o = shape[2];
|
|
1110
|
+
size_t p = shape[3];
|
|
1111
|
+
size_t i = tid / (n * o * p);
|
|
1112
|
+
size_t j = tid % (n * o * p) / (o * p);
|
|
1113
|
+
size_t k = tid % (o * p) / p;
|
|
1114
|
+
size_t l = tid % p;
|
|
1069
1115
|
if (i < shape[0] && j < shape[1] && k < shape[2] /*&& l < shape[3]*/)
|
|
1070
1116
|
{
|
|
1071
|
-
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1117
|
+
size_t src_idx0 = src_indices[0] ? src_indices[0][i] : i;
|
|
1118
|
+
size_t dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
|
|
1119
|
+
size_t src_idx1 = src_indices[1] ? src_indices[1][j] : j;
|
|
1120
|
+
size_t dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
|
|
1121
|
+
size_t src_idx2 = src_indices[2] ? src_indices[2][k] : k;
|
|
1122
|
+
size_t dst_idx2 = dst_indices[2] ? dst_indices[2][k] : k;
|
|
1123
|
+
size_t src_idx3 = src_indices[3] ? src_indices[3][l] : l;
|
|
1124
|
+
size_t dst_idx3 = dst_indices[3] ? dst_indices[3][l] : l;
|
|
1079
1125
|
const char* p = (const char*)src + src_idx0 * src_strides[0]
|
|
1080
1126
|
+ src_idx1 * src_strides[1]
|
|
1081
1127
|
+ src_idx2 * src_strides[2]
|
|
@@ -1090,14 +1136,14 @@ static __global__ void array_copy_4d_kernel(void* dst, const void* src,
|
|
|
1090
1136
|
|
|
1091
1137
|
|
|
1092
1138
|
static __global__ void array_copy_from_fabric_kernel(wp::fabricarray_t<void> src,
|
|
1093
|
-
void* dst_data,
|
|
1094
|
-
|
|
1139
|
+
void* dst_data, size_t dst_stride, const int* dst_indices,
|
|
1140
|
+
size_t elem_size)
|
|
1095
1141
|
{
|
|
1096
|
-
|
|
1142
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1097
1143
|
|
|
1098
1144
|
if (tid < src.size)
|
|
1099
1145
|
{
|
|
1100
|
-
|
|
1146
|
+
size_t dst_idx = dst_indices ? dst_indices[tid] : tid;
|
|
1101
1147
|
void* dst_ptr = (char*)dst_data + dst_idx * dst_stride;
|
|
1102
1148
|
const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
|
|
1103
1149
|
memcpy(dst_ptr, src_ptr, elem_size);
|
|
@@ -1105,15 +1151,15 @@ static __global__ void array_copy_from_fabric_kernel(wp::fabricarray_t<void> src
|
|
|
1105
1151
|
}
|
|
1106
1152
|
|
|
1107
1153
|
static __global__ void array_copy_from_fabric_indexed_kernel(wp::indexedfabricarray_t<void> src,
|
|
1108
|
-
void* dst_data,
|
|
1109
|
-
|
|
1154
|
+
void* dst_data, size_t dst_stride, const int* dst_indices,
|
|
1155
|
+
size_t elem_size)
|
|
1110
1156
|
{
|
|
1111
|
-
|
|
1157
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1112
1158
|
|
|
1113
1159
|
if (tid < src.size)
|
|
1114
1160
|
{
|
|
1115
|
-
|
|
1116
|
-
|
|
1161
|
+
size_t src_index = src.indices[tid];
|
|
1162
|
+
size_t dst_idx = dst_indices ? dst_indices[tid] : tid;
|
|
1117
1163
|
void* dst_ptr = (char*)dst_data + dst_idx * dst_stride;
|
|
1118
1164
|
const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
|
|
1119
1165
|
memcpy(dst_ptr, src_ptr, elem_size);
|
|
@@ -1121,14 +1167,14 @@ static __global__ void array_copy_from_fabric_indexed_kernel(wp::indexedfabricar
|
|
|
1121
1167
|
}
|
|
1122
1168
|
|
|
1123
1169
|
static __global__ void array_copy_to_fabric_kernel(wp::fabricarray_t<void> dst,
|
|
1124
|
-
const void* src_data,
|
|
1125
|
-
|
|
1170
|
+
const void* src_data, size_t src_stride, const int* src_indices,
|
|
1171
|
+
size_t elem_size)
|
|
1126
1172
|
{
|
|
1127
|
-
|
|
1173
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1128
1174
|
|
|
1129
1175
|
if (tid < dst.size)
|
|
1130
1176
|
{
|
|
1131
|
-
|
|
1177
|
+
size_t src_idx = src_indices ? src_indices[tid] : tid;
|
|
1132
1178
|
const void* src_ptr = (const char*)src_data + src_idx * src_stride;
|
|
1133
1179
|
void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
|
|
1134
1180
|
memcpy(dst_ptr, src_ptr, elem_size);
|
|
@@ -1136,25 +1182,25 @@ static __global__ void array_copy_to_fabric_kernel(wp::fabricarray_t<void> dst,
|
|
|
1136
1182
|
}
|
|
1137
1183
|
|
|
1138
1184
|
static __global__ void array_copy_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst,
|
|
1139
|
-
const void* src_data,
|
|
1140
|
-
|
|
1185
|
+
const void* src_data, size_t src_stride, const int* src_indices,
|
|
1186
|
+
size_t elem_size)
|
|
1141
1187
|
{
|
|
1142
|
-
|
|
1188
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1143
1189
|
|
|
1144
1190
|
if (tid < dst.size)
|
|
1145
1191
|
{
|
|
1146
|
-
|
|
1192
|
+
size_t src_idx = src_indices ? src_indices[tid] : tid;
|
|
1147
1193
|
const void* src_ptr = (const char*)src_data + src_idx * src_stride;
|
|
1148
|
-
|
|
1194
|
+
size_t dst_idx = dst.indices[tid];
|
|
1149
1195
|
void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_idx, elem_size);
|
|
1150
1196
|
memcpy(dst_ptr, src_ptr, elem_size);
|
|
1151
1197
|
}
|
|
1152
1198
|
}
|
|
1153
1199
|
|
|
1154
1200
|
|
|
1155
|
-
static __global__ void array_copy_fabric_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::fabricarray_t<void> src,
|
|
1201
|
+
static __global__ void array_copy_fabric_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::fabricarray_t<void> src, size_t elem_size)
|
|
1156
1202
|
{
|
|
1157
|
-
|
|
1203
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1158
1204
|
|
|
1159
1205
|
if (tid < dst.size)
|
|
1160
1206
|
{
|
|
@@ -1165,27 +1211,27 @@ static __global__ void array_copy_fabric_to_fabric_kernel(wp::fabricarray_t<void
|
|
|
1165
1211
|
}
|
|
1166
1212
|
|
|
1167
1213
|
|
|
1168
|
-
static __global__ void array_copy_fabric_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::fabricarray_t<void> src,
|
|
1214
|
+
static __global__ void array_copy_fabric_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::fabricarray_t<void> src, size_t elem_size)
|
|
1169
1215
|
{
|
|
1170
|
-
|
|
1216
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1171
1217
|
|
|
1172
1218
|
if (tid < dst.size)
|
|
1173
1219
|
{
|
|
1174
1220
|
const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
|
|
1175
|
-
|
|
1221
|
+
size_t dst_index = dst.indices[tid];
|
|
1176
1222
|
void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_index, elem_size);
|
|
1177
1223
|
memcpy(dst_ptr, src_ptr, elem_size);
|
|
1178
1224
|
}
|
|
1179
1225
|
}
|
|
1180
1226
|
|
|
1181
1227
|
|
|
1182
|
-
static __global__ void array_copy_fabric_indexed_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::indexedfabricarray_t<void> src,
|
|
1228
|
+
static __global__ void array_copy_fabric_indexed_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, size_t elem_size)
|
|
1183
1229
|
{
|
|
1184
|
-
|
|
1230
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1185
1231
|
|
|
1186
1232
|
if (tid < dst.size)
|
|
1187
1233
|
{
|
|
1188
|
-
|
|
1234
|
+
size_t src_index = src.indices[tid];
|
|
1189
1235
|
const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
|
|
1190
1236
|
void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
|
|
1191
1237
|
memcpy(dst_ptr, src_ptr, elem_size);
|
|
@@ -1193,14 +1239,14 @@ static __global__ void array_copy_fabric_indexed_to_fabric_kernel(wp::fabricarra
|
|
|
1193
1239
|
}
|
|
1194
1240
|
|
|
1195
1241
|
|
|
1196
|
-
static __global__ void array_copy_fabric_indexed_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::indexedfabricarray_t<void> src,
|
|
1242
|
+
static __global__ void array_copy_fabric_indexed_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, size_t elem_size)
|
|
1197
1243
|
{
|
|
1198
|
-
|
|
1244
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1199
1245
|
|
|
1200
1246
|
if (tid < dst.size)
|
|
1201
1247
|
{
|
|
1202
|
-
|
|
1203
|
-
|
|
1248
|
+
size_t src_index = src.indices[tid];
|
|
1249
|
+
size_t dst_index = dst.indices[tid];
|
|
1204
1250
|
const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
|
|
1205
1251
|
void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_index, elem_size);
|
|
1206
1252
|
memcpy(dst_ptr, src_ptr, elem_size);
|
|
@@ -1439,9 +1485,9 @@ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_ty
|
|
|
1439
1485
|
}
|
|
1440
1486
|
case 2:
|
|
1441
1487
|
{
|
|
1442
|
-
wp::vec_t<2,
|
|
1443
|
-
wp::vec_t<2,
|
|
1444
|
-
wp::vec_t<2,
|
|
1488
|
+
wp::vec_t<2, size_t> shape_v(src_shape[0], src_shape[1]);
|
|
1489
|
+
wp::vec_t<2, size_t> src_strides_v(src_strides[0], src_strides[1]);
|
|
1490
|
+
wp::vec_t<2, size_t> dst_strides_v(dst_strides[0], dst_strides[1]);
|
|
1445
1491
|
wp::vec_t<2, const int*> src_indices_v(src_indices[0], src_indices[1]);
|
|
1446
1492
|
wp::vec_t<2, const int*> dst_indices_v(dst_indices[0], dst_indices[1]);
|
|
1447
1493
|
|
|
@@ -1453,9 +1499,9 @@ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_ty
|
|
|
1453
1499
|
}
|
|
1454
1500
|
case 3:
|
|
1455
1501
|
{
|
|
1456
|
-
wp::vec_t<3,
|
|
1457
|
-
wp::vec_t<3,
|
|
1458
|
-
wp::vec_t<3,
|
|
1502
|
+
wp::vec_t<3, size_t> shape_v(src_shape[0], src_shape[1], src_shape[2]);
|
|
1503
|
+
wp::vec_t<3, size_t> src_strides_v(src_strides[0], src_strides[1], src_strides[2]);
|
|
1504
|
+
wp::vec_t<3, size_t> dst_strides_v(dst_strides[0], dst_strides[1], dst_strides[2]);
|
|
1459
1505
|
wp::vec_t<3, const int*> src_indices_v(src_indices[0], src_indices[1], src_indices[2]);
|
|
1460
1506
|
wp::vec_t<3, const int*> dst_indices_v(dst_indices[0], dst_indices[1], dst_indices[2]);
|
|
1461
1507
|
|
|
@@ -1467,9 +1513,9 @@ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_ty
|
|
|
1467
1513
|
}
|
|
1468
1514
|
case 4:
|
|
1469
1515
|
{
|
|
1470
|
-
wp::vec_t<4,
|
|
1471
|
-
wp::vec_t<4,
|
|
1472
|
-
wp::vec_t<4,
|
|
1516
|
+
wp::vec_t<4, size_t> shape_v(src_shape[0], src_shape[1], src_shape[2], src_shape[3]);
|
|
1517
|
+
wp::vec_t<4, size_t> src_strides_v(src_strides[0], src_strides[1], src_strides[2], src_strides[3]);
|
|
1518
|
+
wp::vec_t<4, size_t> dst_strides_v(dst_strides[0], dst_strides[1], dst_strides[2], dst_strides[3]);
|
|
1473
1519
|
wp::vec_t<4, const int*> src_indices_v(src_indices[0], src_indices[1], src_indices[2], src_indices[3]);
|
|
1474
1520
|
wp::vec_t<4, const int*> dst_indices_v(dst_indices[0], dst_indices[1], dst_indices[2], dst_indices[3]);
|
|
1475
1521
|
|
|
@@ -1489,94 +1535,94 @@ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_ty
|
|
|
1489
1535
|
|
|
1490
1536
|
|
|
1491
1537
|
static __global__ void array_fill_1d_kernel(void* data,
|
|
1492
|
-
|
|
1493
|
-
|
|
1538
|
+
size_t n,
|
|
1539
|
+
size_t stride,
|
|
1494
1540
|
const int* indices,
|
|
1495
1541
|
const void* value,
|
|
1496
|
-
|
|
1542
|
+
size_t value_size)
|
|
1497
1543
|
{
|
|
1498
|
-
|
|
1544
|
+
size_t i = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1499
1545
|
if (i < n)
|
|
1500
1546
|
{
|
|
1501
|
-
|
|
1547
|
+
size_t idx = indices ? indices[i] : i;
|
|
1502
1548
|
char* p = (char*)data + idx * stride;
|
|
1503
1549
|
memcpy(p, value, value_size);
|
|
1504
1550
|
}
|
|
1505
1551
|
}
|
|
1506
1552
|
|
|
1507
1553
|
static __global__ void array_fill_2d_kernel(void* data,
|
|
1508
|
-
wp::vec_t<2,
|
|
1509
|
-
wp::vec_t<2,
|
|
1554
|
+
wp::vec_t<2, size_t> shape,
|
|
1555
|
+
wp::vec_t<2, size_t> strides,
|
|
1510
1556
|
wp::vec_t<2, const int*> indices,
|
|
1511
1557
|
const void* value,
|
|
1512
|
-
|
|
1558
|
+
size_t value_size)
|
|
1513
1559
|
{
|
|
1514
|
-
|
|
1515
|
-
|
|
1516
|
-
|
|
1517
|
-
|
|
1560
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1561
|
+
size_t n = shape[1];
|
|
1562
|
+
size_t i = tid / n;
|
|
1563
|
+
size_t j = tid % n;
|
|
1518
1564
|
if (i < shape[0] /*&& j < shape[1]*/)
|
|
1519
1565
|
{
|
|
1520
|
-
|
|
1521
|
-
|
|
1566
|
+
size_t idx0 = indices[0] ? indices[0][i] : i;
|
|
1567
|
+
size_t idx1 = indices[1] ? indices[1][j] : j;
|
|
1522
1568
|
char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1];
|
|
1523
1569
|
memcpy(p, value, value_size);
|
|
1524
1570
|
}
|
|
1525
1571
|
}
|
|
1526
1572
|
|
|
1527
1573
|
static __global__ void array_fill_3d_kernel(void* data,
|
|
1528
|
-
wp::vec_t<3,
|
|
1529
|
-
wp::vec_t<3,
|
|
1574
|
+
wp::vec_t<3, size_t> shape,
|
|
1575
|
+
wp::vec_t<3, size_t> strides,
|
|
1530
1576
|
wp::vec_t<3, const int*> indices,
|
|
1531
1577
|
const void* value,
|
|
1532
|
-
|
|
1533
|
-
{
|
|
1534
|
-
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
|
|
1539
|
-
|
|
1578
|
+
size_t value_size)
|
|
1579
|
+
{
|
|
1580
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1581
|
+
size_t n = shape[1];
|
|
1582
|
+
size_t o = shape[2];
|
|
1583
|
+
size_t i = tid / (n * o);
|
|
1584
|
+
size_t j = tid % (n * o) / o;
|
|
1585
|
+
size_t k = tid % o;
|
|
1540
1586
|
if (i < shape[0] && j < shape[1] /*&& k < shape[2]*/)
|
|
1541
1587
|
{
|
|
1542
|
-
|
|
1543
|
-
|
|
1544
|
-
|
|
1588
|
+
size_t idx0 = indices[0] ? indices[0][i] : i;
|
|
1589
|
+
size_t idx1 = indices[1] ? indices[1][j] : j;
|
|
1590
|
+
size_t idx2 = indices[2] ? indices[2][k] : k;
|
|
1545
1591
|
char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1] + idx2 * strides[2];
|
|
1546
1592
|
memcpy(p, value, value_size);
|
|
1547
1593
|
}
|
|
1548
1594
|
}
|
|
1549
1595
|
|
|
1550
1596
|
static __global__ void array_fill_4d_kernel(void* data,
|
|
1551
|
-
wp::vec_t<4,
|
|
1552
|
-
wp::vec_t<4,
|
|
1597
|
+
wp::vec_t<4, size_t> shape,
|
|
1598
|
+
wp::vec_t<4, size_t> strides,
|
|
1553
1599
|
wp::vec_t<4, const int*> indices,
|
|
1554
1600
|
const void* value,
|
|
1555
|
-
|
|
1556
|
-
{
|
|
1557
|
-
|
|
1558
|
-
|
|
1559
|
-
|
|
1560
|
-
|
|
1561
|
-
|
|
1562
|
-
|
|
1563
|
-
|
|
1564
|
-
|
|
1601
|
+
size_t value_size)
|
|
1602
|
+
{
|
|
1603
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1604
|
+
size_t n = shape[1];
|
|
1605
|
+
size_t o = shape[2];
|
|
1606
|
+
size_t p = shape[3];
|
|
1607
|
+
size_t i = tid / (n * o * p);
|
|
1608
|
+
size_t j = tid % (n * o * p) / (o * p);
|
|
1609
|
+
size_t k = tid % (o * p) / p;
|
|
1610
|
+
size_t l = tid % p;
|
|
1565
1611
|
if (i < shape[0] && j < shape[1] && k < shape[2] /*&& l < shape[3]*/)
|
|
1566
1612
|
{
|
|
1567
|
-
|
|
1568
|
-
|
|
1569
|
-
|
|
1570
|
-
|
|
1613
|
+
size_t idx0 = indices[0] ? indices[0][i] : i;
|
|
1614
|
+
size_t idx1 = indices[1] ? indices[1][j] : j;
|
|
1615
|
+
size_t idx2 = indices[2] ? indices[2][k] : k;
|
|
1616
|
+
size_t idx3 = indices[3] ? indices[3][l] : l;
|
|
1571
1617
|
char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1] + idx2 * strides[2] + idx3 * strides[3];
|
|
1572
1618
|
memcpy(p, value, value_size);
|
|
1573
1619
|
}
|
|
1574
1620
|
}
|
|
1575
1621
|
|
|
1576
1622
|
|
|
1577
|
-
static __global__ void array_fill_fabric_kernel(wp::fabricarray_t<void> fa, const void* value,
|
|
1623
|
+
static __global__ void array_fill_fabric_kernel(wp::fabricarray_t<void> fa, const void* value, size_t value_size)
|
|
1578
1624
|
{
|
|
1579
|
-
|
|
1625
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1580
1626
|
if (tid < fa.size)
|
|
1581
1627
|
{
|
|
1582
1628
|
void* dst_ptr = fabricarray_element_ptr(fa, tid, value_size);
|
|
@@ -1585,9 +1631,9 @@ static __global__ void array_fill_fabric_kernel(wp::fabricarray_t<void> fa, cons
|
|
|
1585
1631
|
}
|
|
1586
1632
|
|
|
1587
1633
|
|
|
1588
|
-
static __global__ void array_fill_fabric_indexed_kernel(wp::indexedfabricarray_t<void> ifa, const void* value,
|
|
1634
|
+
static __global__ void array_fill_fabric_indexed_kernel(wp::indexedfabricarray_t<void> ifa, const void* value, size_t value_size)
|
|
1589
1635
|
{
|
|
1590
|
-
|
|
1636
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1591
1637
|
if (tid < ifa.size)
|
|
1592
1638
|
{
|
|
1593
1639
|
size_t idx = size_t(ifa.indices[tid]);
|
|
@@ -1684,8 +1730,8 @@ WP_API void wp_array_fill_device(void* context, void* arr_ptr, int arr_type, con
|
|
|
1684
1730
|
}
|
|
1685
1731
|
case 2:
|
|
1686
1732
|
{
|
|
1687
|
-
wp::vec_t<2,
|
|
1688
|
-
wp::vec_t<2,
|
|
1733
|
+
wp::vec_t<2, size_t> shape_v(shape[0], shape[1]);
|
|
1734
|
+
wp::vec_t<2, size_t> strides_v(strides[0], strides[1]);
|
|
1689
1735
|
wp::vec_t<2, const int*> indices_v(indices[0], indices[1]);
|
|
1690
1736
|
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_2d_kernel, n,
|
|
1691
1737
|
(data, shape_v, strides_v, indices_v, value_devptr, value_size));
|
|
@@ -1693,8 +1739,8 @@ WP_API void wp_array_fill_device(void* context, void* arr_ptr, int arr_type, con
|
|
|
1693
1739
|
}
|
|
1694
1740
|
case 3:
|
|
1695
1741
|
{
|
|
1696
|
-
wp::vec_t<3,
|
|
1697
|
-
wp::vec_t<3,
|
|
1742
|
+
wp::vec_t<3, size_t> shape_v(shape[0], shape[1], shape[2]);
|
|
1743
|
+
wp::vec_t<3, size_t> strides_v(strides[0], strides[1], strides[2]);
|
|
1698
1744
|
wp::vec_t<3, const int*> indices_v(indices[0], indices[1], indices[2]);
|
|
1699
1745
|
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_3d_kernel, n,
|
|
1700
1746
|
(data, shape_v, strides_v, indices_v, value_devptr, value_size));
|
|
@@ -1702,8 +1748,8 @@ WP_API void wp_array_fill_device(void* context, void* arr_ptr, int arr_type, con
|
|
|
1702
1748
|
}
|
|
1703
1749
|
case 4:
|
|
1704
1750
|
{
|
|
1705
|
-
wp::vec_t<4,
|
|
1706
|
-
wp::vec_t<4,
|
|
1751
|
+
wp::vec_t<4, size_t> shape_v(shape[0], shape[1], shape[2], shape[3]);
|
|
1752
|
+
wp::vec_t<4, size_t> strides_v(strides[0], strides[1], strides[2], strides[3]);
|
|
1707
1753
|
wp::vec_t<4, const int*> indices_v(indices[0], indices[1], indices[2], indices[3]);
|
|
1708
1754
|
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_4d_kernel, n,
|
|
1709
1755
|
(data, shape_v, strides_v, indices_v, value_devptr, value_size));
|
|
@@ -2071,13 +2117,17 @@ void wp_cuda_context_synchronize(void* context)
|
|
|
2071
2117
|
|
|
2072
2118
|
check_cu(cuCtxSynchronize_f());
|
|
2073
2119
|
|
|
2074
|
-
if (
|
|
2120
|
+
if (!context)
|
|
2121
|
+
context = get_current_context();
|
|
2122
|
+
|
|
2123
|
+
if (free_deferred_allocs(context) > 0)
|
|
2075
2124
|
{
|
|
2076
2125
|
// ensure deferred asynchronous deallocations complete
|
|
2077
2126
|
check_cu(cuCtxSynchronize_f());
|
|
2078
2127
|
}
|
|
2079
2128
|
|
|
2080
2129
|
unload_deferred_modules(context);
|
|
2130
|
+
destroy_deferred_graphs(context);
|
|
2081
2131
|
|
|
2082
2132
|
// check_cuda(cudaDeviceGraphMemTrim(wp_cuda_context_get_device_ordinal(context)));
|
|
2083
2133
|
}
|
|
@@ -2448,6 +2498,9 @@ void wp_cuda_stream_destroy(void* context, void* stream)
|
|
|
2448
2498
|
|
|
2449
2499
|
wp_cuda_stream_unregister(context, stream);
|
|
2450
2500
|
|
|
2501
|
+
// release temporary radix sort buffer associated with this stream
|
|
2502
|
+
radix_sort_release(context, stream);
|
|
2503
|
+
|
|
2451
2504
|
check_cu(cuStreamDestroy_f(static_cast<CUstream>(stream)));
|
|
2452
2505
|
}
|
|
2453
2506
|
|
|
@@ -2510,15 +2563,36 @@ void wp_cuda_stream_synchronize(void* stream)
|
|
|
2510
2563
|
check_cu(cuStreamSynchronize_f(static_cast<CUstream>(stream)));
|
|
2511
2564
|
}
|
|
2512
2565
|
|
|
2513
|
-
void wp_cuda_stream_wait_event(void* stream, void* event)
|
|
2566
|
+
void wp_cuda_stream_wait_event(void* stream, void* event, bool external)
|
|
2514
2567
|
{
|
|
2515
|
-
|
|
2568
|
+
// the external flag can only be used during graph capture
|
|
2569
|
+
if (external && !g_captures.empty() && wp_cuda_stream_is_capturing(stream))
|
|
2570
|
+
{
|
|
2571
|
+
// wait for an external event during graph capture
|
|
2572
|
+
check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), CU_EVENT_WAIT_EXTERNAL));
|
|
2573
|
+
}
|
|
2574
|
+
else
|
|
2575
|
+
{
|
|
2576
|
+
check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), CU_EVENT_WAIT_DEFAULT));
|
|
2577
|
+
}
|
|
2516
2578
|
}
|
|
2517
2579
|
|
|
2518
|
-
void wp_cuda_stream_wait_stream(void* stream, void* other_stream, void* event)
|
|
2580
|
+
void wp_cuda_stream_wait_stream(void* stream, void* other_stream, void* event, bool external)
|
|
2519
2581
|
{
|
|
2520
|
-
|
|
2521
|
-
|
|
2582
|
+
unsigned record_flags = CU_EVENT_RECORD_DEFAULT;
|
|
2583
|
+
unsigned wait_flags = CU_EVENT_WAIT_DEFAULT;
|
|
2584
|
+
|
|
2585
|
+
// the external flag can only be used during graph capture
|
|
2586
|
+
if (external && !g_captures.empty())
|
|
2587
|
+
{
|
|
2588
|
+
if (wp_cuda_stream_is_capturing(other_stream))
|
|
2589
|
+
record_flags = CU_EVENT_RECORD_EXTERNAL;
|
|
2590
|
+
if (wp_cuda_stream_is_capturing(stream))
|
|
2591
|
+
wait_flags = CU_EVENT_WAIT_EXTERNAL;
|
|
2592
|
+
}
|
|
2593
|
+
|
|
2594
|
+
check_cu(cuEventRecordWithFlags_f(static_cast<CUevent>(event), static_cast<CUstream>(other_stream), record_flags));
|
|
2595
|
+
check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), wait_flags));
|
|
2522
2596
|
}
|
|
2523
2597
|
|
|
2524
2598
|
int wp_cuda_stream_is_capturing(void* stream)
|
|
@@ -2571,11 +2645,12 @@ int wp_cuda_event_query(void* event)
|
|
|
2571
2645
|
return res;
|
|
2572
2646
|
}
|
|
2573
2647
|
|
|
2574
|
-
void wp_cuda_event_record(void* event, void* stream, bool
|
|
2648
|
+
void wp_cuda_event_record(void* event, void* stream, bool external)
|
|
2575
2649
|
{
|
|
2576
|
-
|
|
2650
|
+
// the external flag can only be used during graph capture
|
|
2651
|
+
if (external && !g_captures.empty() && wp_cuda_stream_is_capturing(stream))
|
|
2577
2652
|
{
|
|
2578
|
-
// record
|
|
2653
|
+
// record external event during graph capture (e.g., for timing or when explicitly specified by the user)
|
|
2579
2654
|
check_cu(cuEventRecordWithFlags_f(static_cast<CUevent>(event), static_cast<CUstream>(stream), CU_EVENT_RECORD_EXTERNAL));
|
|
2580
2655
|
}
|
|
2581
2656
|
else
|
|
@@ -2625,7 +2700,7 @@ bool wp_cuda_graph_begin_capture(void* context, void* stream, int external)
|
|
|
2625
2700
|
else
|
|
2626
2701
|
{
|
|
2627
2702
|
// start the capture
|
|
2628
|
-
if (!check_cuda(cudaStreamBeginCapture(cuda_stream,
|
|
2703
|
+
if (!check_cuda(cudaStreamBeginCapture(cuda_stream, cudaStreamCaptureModeThreadLocal)))
|
|
2629
2704
|
return false;
|
|
2630
2705
|
}
|
|
2631
2706
|
|
|
@@ -2772,6 +2847,7 @@ bool wp_cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
|
|
|
2772
2847
|
{
|
|
2773
2848
|
free_deferred_allocs();
|
|
2774
2849
|
unload_deferred_modules();
|
|
2850
|
+
destroy_deferred_graphs();
|
|
2775
2851
|
}
|
|
2776
2852
|
|
|
2777
2853
|
if (graph_ret)
|
|
@@ -2811,11 +2887,12 @@ bool wp_cuda_graph_create_exec(void* context, void* stream, void* graph, void**
|
|
|
2811
2887
|
// Support for conditional graph nodes available with CUDA 12.4+.
|
|
2812
2888
|
#if CUDA_VERSION >= 12040
|
|
2813
2889
|
|
|
2814
|
-
// CUBIN data for compiled conditional modules, loaded on demand, keyed on device architecture
|
|
2815
|
-
|
|
2890
|
+
// CUBIN or PTX data for compiled conditional modules, loaded on demand, keyed on device architecture
|
|
2891
|
+
using ModuleKey = std::pair<int, bool>; // <arch, use_ptx>
|
|
2892
|
+
static std::map<ModuleKey, void*> g_conditional_modules;
|
|
2816
2893
|
|
|
2817
2894
|
// Compile module with conditional helper kernels
|
|
2818
|
-
static void* compile_conditional_module(int arch)
|
|
2895
|
+
static void* compile_conditional_module(int arch, bool use_ptx)
|
|
2819
2896
|
{
|
|
2820
2897
|
static const char* kernel_source = R"(
|
|
2821
2898
|
typedef __device_builtin__ unsigned long long cudaGraphConditionalHandle;
|
|
@@ -2844,8 +2921,9 @@ static void* compile_conditional_module(int arch)
|
|
|
2844
2921
|
)";
|
|
2845
2922
|
|
|
2846
2923
|
// avoid recompilation
|
|
2847
|
-
|
|
2848
|
-
|
|
2924
|
+
ModuleKey key = {arch, use_ptx};
|
|
2925
|
+
auto it = g_conditional_modules.find(key);
|
|
2926
|
+
if (it != g_conditional_modules.end())
|
|
2849
2927
|
return it->second;
|
|
2850
2928
|
|
|
2851
2929
|
nvrtcProgram prog;
|
|
@@ -2853,11 +2931,23 @@ static void* compile_conditional_module(int arch)
|
|
|
2853
2931
|
return NULL;
|
|
2854
2932
|
|
|
2855
2933
|
char arch_opt[128];
|
|
2856
|
-
|
|
2934
|
+
if (use_ptx)
|
|
2935
|
+
snprintf(arch_opt, sizeof(arch_opt), "--gpu-architecture=compute_%d", arch);
|
|
2936
|
+
else
|
|
2937
|
+
snprintf(arch_opt, sizeof(arch_opt), "--gpu-architecture=sm_%d", arch);
|
|
2857
2938
|
|
|
2858
2939
|
std::vector<const char*> opts;
|
|
2859
2940
|
opts.push_back(arch_opt);
|
|
2860
2941
|
|
|
2942
|
+
const bool print_debug = (std::getenv("WARP_DEBUG") != nullptr);
|
|
2943
|
+
if (print_debug)
|
|
2944
|
+
{
|
|
2945
|
+
printf("NVRTC options (conditional module, arch=%d, use_ptx=%s):\n", arch, use_ptx ? "true" : "false");
|
|
2946
|
+
for(auto o: opts) {
|
|
2947
|
+
printf("%s\n", o);
|
|
2948
|
+
}
|
|
2949
|
+
}
|
|
2950
|
+
|
|
2861
2951
|
if (!check_nvrtc(nvrtcCompileProgram(prog, int(opts.size()), opts.data())))
|
|
2862
2952
|
{
|
|
2863
2953
|
size_t log_size;
|
|
@@ -2874,23 +2964,37 @@ static void* compile_conditional_module(int arch)
|
|
|
2874
2964
|
// get output
|
|
2875
2965
|
char* output = NULL;
|
|
2876
2966
|
size_t output_size = 0;
|
|
2877
|
-
|
|
2878
|
-
if (
|
|
2967
|
+
|
|
2968
|
+
if (use_ptx)
|
|
2969
|
+
{
|
|
2970
|
+
check_nvrtc(nvrtcGetPTXSize(prog, &output_size));
|
|
2971
|
+
if (output_size > 0)
|
|
2972
|
+
{
|
|
2973
|
+
output = new char[output_size];
|
|
2974
|
+
if (check_nvrtc(nvrtcGetPTX(prog, output)))
|
|
2975
|
+
g_conditional_modules[key] = output;
|
|
2976
|
+
}
|
|
2977
|
+
}
|
|
2978
|
+
else
|
|
2879
2979
|
{
|
|
2880
|
-
|
|
2881
|
-
if (
|
|
2882
|
-
|
|
2980
|
+
check_nvrtc(nvrtcGetCUBINSize(prog, &output_size));
|
|
2981
|
+
if (output_size > 0)
|
|
2982
|
+
{
|
|
2983
|
+
output = new char[output_size];
|
|
2984
|
+
if (check_nvrtc(nvrtcGetCUBIN(prog, output)))
|
|
2985
|
+
g_conditional_modules[key] = output;
|
|
2986
|
+
}
|
|
2883
2987
|
}
|
|
2884
2988
|
|
|
2885
2989
|
nvrtcDestroyProgram(&prog);
|
|
2886
2990
|
|
|
2887
|
-
// return CUBIN data
|
|
2991
|
+
// return CUBIN or PTX data
|
|
2888
2992
|
return output;
|
|
2889
2993
|
}
|
|
2890
2994
|
|
|
2891
2995
|
|
|
2892
2996
|
// Load module with conditional helper kernels
|
|
2893
|
-
static CUmodule load_conditional_module(void* context)
|
|
2997
|
+
static CUmodule load_conditional_module(void* context, int arch, bool use_ptx)
|
|
2894
2998
|
{
|
|
2895
2999
|
ContextInfo* context_info = get_context_info(context);
|
|
2896
3000
|
if (!context_info)
|
|
@@ -2900,17 +3004,15 @@ static CUmodule load_conditional_module(void* context)
|
|
|
2900
3004
|
if (context_info->conditional_module)
|
|
2901
3005
|
return context_info->conditional_module;
|
|
2902
3006
|
|
|
2903
|
-
int arch = context_info->device_info->arch;
|
|
2904
|
-
|
|
2905
3007
|
// compile if needed
|
|
2906
|
-
void* compiled_module = compile_conditional_module(arch);
|
|
3008
|
+
void* compiled_module = compile_conditional_module(arch, use_ptx);
|
|
2907
3009
|
if (!compiled_module)
|
|
2908
3010
|
{
|
|
2909
3011
|
fprintf(stderr, "Warp error: Failed to compile conditional kernels\n");
|
|
2910
3012
|
return NULL;
|
|
2911
3013
|
}
|
|
2912
3014
|
|
|
2913
|
-
// load module
|
|
3015
|
+
// load module (handles both PTX and CUBIN data automatically)
|
|
2914
3016
|
CUmodule module = NULL;
|
|
2915
3017
|
if (!check_cu(cuModuleLoadDataEx_f(&module, compiled_module, 0, NULL, NULL)))
|
|
2916
3018
|
{
|
|
@@ -2923,10 +3025,10 @@ static CUmodule load_conditional_module(void* context)
|
|
|
2923
3025
|
return module;
|
|
2924
3026
|
}
|
|
2925
3027
|
|
|
2926
|
-
static CUfunction get_conditional_kernel(void* context, const char* name)
|
|
3028
|
+
static CUfunction get_conditional_kernel(void* context, int arch, bool use_ptx, const char* name)
|
|
2927
3029
|
{
|
|
2928
3030
|
// load module if needed
|
|
2929
|
-
CUmodule module = load_conditional_module(context);
|
|
3031
|
+
CUmodule module = load_conditional_module(context, arch, use_ptx);
|
|
2930
3032
|
if (!module)
|
|
2931
3033
|
return NULL;
|
|
2932
3034
|
|
|
@@ -2966,7 +3068,7 @@ bool wp_cuda_graph_resume_capture(void* context, void* stream, void* graph)
|
|
|
2966
3068
|
leaf_nodes.data(),
|
|
2967
3069
|
nullptr,
|
|
2968
3070
|
leaf_nodes.size(),
|
|
2969
|
-
|
|
3071
|
+
cudaStreamCaptureModeThreadLocal)))
|
|
2970
3072
|
return false;
|
|
2971
3073
|
|
|
2972
3074
|
return true;
|
|
@@ -2976,7 +3078,7 @@ bool wp_cuda_graph_resume_capture(void* context, void* stream, void* graph)
|
|
|
2976
3078
|
// https://developer.nvidia.com/blog/dynamic-control-flow-in-cuda-graphs-with-conditional-nodes/
|
|
2977
3079
|
// condition is a gpu pointer
|
|
2978
3080
|
// if_graph_ret and else_graph_ret should be NULL if not needed
|
|
2979
|
-
bool wp_cuda_graph_insert_if_else(void* context, void* stream, int* condition, void** if_graph_ret, void** else_graph_ret)
|
|
3081
|
+
bool wp_cuda_graph_insert_if_else(void* context, void* stream, int arch, bool use_ptx, int* condition, void** if_graph_ret, void** else_graph_ret)
|
|
2980
3082
|
{
|
|
2981
3083
|
bool has_if = if_graph_ret != NULL;
|
|
2982
3084
|
bool has_else = else_graph_ret != NULL;
|
|
@@ -3019,9 +3121,9 @@ bool wp_cuda_graph_insert_if_else(void* context, void* stream, int* condition, v
|
|
|
3019
3121
|
// (need to negate the condition if only the else branch is used)
|
|
3020
3122
|
CUfunction kernel;
|
|
3021
3123
|
if (has_if)
|
|
3022
|
-
kernel = get_conditional_kernel(context, "set_conditional_if_handle_kernel");
|
|
3124
|
+
kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_if_handle_kernel");
|
|
3023
3125
|
else
|
|
3024
|
-
kernel = get_conditional_kernel(context, "set_conditional_else_handle_kernel");
|
|
3126
|
+
kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_else_handle_kernel");
|
|
3025
3127
|
|
|
3026
3128
|
if (!kernel)
|
|
3027
3129
|
{
|
|
@@ -3072,7 +3174,7 @@ bool wp_cuda_graph_insert_if_else(void* context, void* stream, int* condition, v
|
|
|
3072
3174
|
check_cuda(cudaGraphConditionalHandleCreate(&if_handle, cuda_graph));
|
|
3073
3175
|
check_cuda(cudaGraphConditionalHandleCreate(&else_handle, cuda_graph));
|
|
3074
3176
|
|
|
3075
|
-
CUfunction kernel = get_conditional_kernel(context, "set_conditional_if_else_handles_kernel");
|
|
3177
|
+
CUfunction kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_if_else_handles_kernel");
|
|
3076
3178
|
if (!kernel)
|
|
3077
3179
|
{
|
|
3078
3180
|
wp::set_error_string("Failed to get built-in conditional kernel");
|
|
@@ -3273,7 +3375,7 @@ bool wp_cuda_graph_insert_child_graph(void* context, void* stream, void* child_g
|
|
|
3273
3375
|
return true;
|
|
3274
3376
|
}
|
|
3275
3377
|
|
|
3276
|
-
bool wp_cuda_graph_insert_while(void* context, void* stream, int* condition, void** body_graph_ret, uint64_t* handle_ret)
|
|
3378
|
+
bool wp_cuda_graph_insert_while(void* context, void* stream, int arch, bool use_ptx, int* condition, void** body_graph_ret, uint64_t* handle_ret)
|
|
3277
3379
|
{
|
|
3278
3380
|
// if there's no body, it's a no-op
|
|
3279
3381
|
if (!body_graph_ret)
|
|
@@ -3303,7 +3405,7 @@ bool wp_cuda_graph_insert_while(void* context, void* stream, int* condition, voi
|
|
|
3303
3405
|
return false;
|
|
3304
3406
|
|
|
3305
3407
|
// launch a kernel to set the condition handle from condition pointer
|
|
3306
|
-
CUfunction kernel = get_conditional_kernel(context, "set_conditional_if_handle_kernel");
|
|
3408
|
+
CUfunction kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_if_handle_kernel");
|
|
3307
3409
|
if (!kernel)
|
|
3308
3410
|
{
|
|
3309
3411
|
wp::set_error_string("Failed to get built-in conditional kernel");
|
|
@@ -3339,14 +3441,14 @@ bool wp_cuda_graph_insert_while(void* context, void* stream, int* condition, voi
|
|
|
3339
3441
|
return true;
|
|
3340
3442
|
}
|
|
3341
3443
|
|
|
3342
|
-
bool wp_cuda_graph_set_condition(void* context, void* stream, int* condition, uint64_t handle)
|
|
3444
|
+
bool wp_cuda_graph_set_condition(void* context, void* stream, int arch, bool use_ptx, int* condition, uint64_t handle)
|
|
3343
3445
|
{
|
|
3344
3446
|
ContextGuard guard(context);
|
|
3345
3447
|
|
|
3346
3448
|
CUstream cuda_stream = static_cast<CUstream>(stream);
|
|
3347
3449
|
|
|
3348
3450
|
// launch a kernel to set the condition handle from condition pointer
|
|
3349
|
-
CUfunction kernel = get_conditional_kernel(context, "set_conditional_if_handle_kernel");
|
|
3451
|
+
CUfunction kernel = get_conditional_kernel(context, arch, use_ptx, "set_conditional_if_handle_kernel");
|
|
3350
3452
|
if (!kernel)
|
|
3351
3453
|
{
|
|
3352
3454
|
wp::set_error_string("Failed to get built-in conditional kernel");
|
|
@@ -3378,19 +3480,19 @@ bool wp_cuda_graph_resume_capture(void* context, void* stream, void* graph)
|
|
|
3378
3480
|
return false;
|
|
3379
3481
|
}
|
|
3380
3482
|
|
|
3381
|
-
bool wp_cuda_graph_insert_if_else(void* context, void* stream, int* condition, void** if_graph_ret, void** else_graph_ret)
|
|
3483
|
+
bool wp_cuda_graph_insert_if_else(void* context, void* stream, int arch, bool use_ptx, int* condition, void** if_graph_ret, void** else_graph_ret)
|
|
3382
3484
|
{
|
|
3383
3485
|
wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
|
|
3384
3486
|
return false;
|
|
3385
3487
|
}
|
|
3386
3488
|
|
|
3387
|
-
bool wp_cuda_graph_insert_while(void* context, void* stream, int* condition, void** body_graph_ret, uint64_t* handle_ret)
|
|
3489
|
+
bool wp_cuda_graph_insert_while(void* context, void* stream, int arch, bool use_ptx, int* condition, void** body_graph_ret, uint64_t* handle_ret)
|
|
3388
3490
|
{
|
|
3389
3491
|
wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
|
|
3390
3492
|
return false;
|
|
3391
3493
|
}
|
|
3392
3494
|
|
|
3393
|
-
bool wp_cuda_graph_set_condition(void* context, void* stream, int* condition, uint64_t handle)
|
|
3495
|
+
bool wp_cuda_graph_set_condition(void* context, void* stream, int arch, bool use_ptx, int* condition, uint64_t handle)
|
|
3394
3496
|
{
|
|
3395
3497
|
wp::set_error_string("Warp error: Warp must be built with CUDA Toolkit 12.4+ to enable conditional graph nodes");
|
|
3396
3498
|
return false;
|
|
@@ -3425,16 +3527,38 @@ bool wp_cuda_graph_launch(void* graph_exec, void* stream)
|
|
|
3425
3527
|
|
|
3426
3528
|
bool wp_cuda_graph_destroy(void* context, void* graph)
|
|
3427
3529
|
{
|
|
3428
|
-
|
|
3429
|
-
|
|
3430
|
-
|
|
3530
|
+
// ensure there are no graph captures in progress
|
|
3531
|
+
if (g_captures.empty())
|
|
3532
|
+
{
|
|
3533
|
+
ContextGuard guard(context);
|
|
3534
|
+
return check_cuda(cudaGraphDestroy((cudaGraph_t)graph));
|
|
3535
|
+
}
|
|
3536
|
+
else
|
|
3537
|
+
{
|
|
3538
|
+
GraphDestroyInfo info;
|
|
3539
|
+
info.context = context ? context : get_current_context();
|
|
3540
|
+
info.graph = graph;
|
|
3541
|
+
g_deferred_graph_list.push_back(info);
|
|
3542
|
+
return true;
|
|
3543
|
+
}
|
|
3431
3544
|
}
|
|
3432
3545
|
|
|
3433
3546
|
bool wp_cuda_graph_exec_destroy(void* context, void* graph_exec)
|
|
3434
3547
|
{
|
|
3435
|
-
|
|
3436
|
-
|
|
3437
|
-
|
|
3548
|
+
// ensure there are no graph captures in progress
|
|
3549
|
+
if (g_captures.empty())
|
|
3550
|
+
{
|
|
3551
|
+
ContextGuard guard(context);
|
|
3552
|
+
return check_cuda(cudaGraphExecDestroy((cudaGraphExec_t)graph_exec));
|
|
3553
|
+
}
|
|
3554
|
+
else
|
|
3555
|
+
{
|
|
3556
|
+
GraphDestroyInfo info;
|
|
3557
|
+
info.context = context ? context : get_current_context();
|
|
3558
|
+
info.graph_exec = graph_exec;
|
|
3559
|
+
g_deferred_graph_list.push_back(info);
|
|
3560
|
+
return true;
|
|
3561
|
+
}
|
|
3438
3562
|
}
|
|
3439
3563
|
|
|
3440
3564
|
bool write_file(const char* data, size_t size, std::string filename, const char* mode)
|
|
@@ -4287,17 +4411,5 @@ void wp_cuda_timing_end(timing_result_t* results, int size)
|
|
|
4287
4411
|
g_cuda_timing_state = parent_state;
|
|
4288
4412
|
}
|
|
4289
4413
|
|
|
4290
|
-
// impl. files
|
|
4291
|
-
#include "bvh.cu"
|
|
4292
|
-
#include "mesh.cu"
|
|
4293
|
-
#include "sort.cu"
|
|
4294
|
-
#include "hashgrid.cu"
|
|
4295
|
-
#include "reduce.cu"
|
|
4296
|
-
#include "runlength_encode.cu"
|
|
4297
|
-
#include "scan.cu"
|
|
4298
|
-
#include "sparse.cu"
|
|
4299
|
-
#include "volume.cu"
|
|
4300
|
-
#include "volume_builder.cu"
|
|
4301
|
-
|
|
4302
4414
|
//#include "spline.inl"
|
|
4303
4415
|
//#include "volume.inl"
|