warp-lang 1.9.1__py3-none-manylinux_2_34_aarch64.whl → 1.10.0rc2__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +301 -287
- warp/__init__.pyi +794 -305
- warp/_src/__init__.py +14 -0
- warp/_src/autograd.py +1075 -0
- warp/_src/build.py +618 -0
- warp/_src/build_dll.py +640 -0
- warp/{builtins.py → _src/builtins.py} +1382 -377
- warp/_src/codegen.py +4359 -0
- warp/{config.py → _src/config.py} +178 -169
- warp/_src/constants.py +57 -0
- warp/_src/context.py +8294 -0
- warp/_src/dlpack.py +462 -0
- warp/_src/fabric.py +355 -0
- warp/_src/fem/__init__.py +14 -0
- warp/_src/fem/adaptivity.py +508 -0
- warp/_src/fem/cache.py +687 -0
- warp/_src/fem/dirichlet.py +188 -0
- warp/{fem → _src/fem}/domain.py +40 -30
- warp/_src/fem/field/__init__.py +131 -0
- warp/_src/fem/field/field.py +701 -0
- warp/{fem → _src/fem}/field/nodal_field.py +30 -15
- warp/{fem → _src/fem}/field/restriction.py +1 -1
- warp/{fem → _src/fem}/field/virtual.py +53 -27
- warp/_src/fem/geometry/__init__.py +32 -0
- warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
- warp/_src/fem/geometry/closest_point.py +97 -0
- warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
- warp/{fem → _src/fem}/geometry/element.py +32 -10
- warp/{fem → _src/fem}/geometry/geometry.py +48 -20
- warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
- warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
- warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
- warp/{fem → _src/fem}/geometry/partition.py +121 -63
- warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
- warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
- warp/{fem → _src/fem}/integrate.py +164 -158
- warp/_src/fem/linalg.py +383 -0
- warp/_src/fem/operator.py +396 -0
- warp/_src/fem/polynomial.py +229 -0
- warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
- warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
- warp/_src/fem/space/__init__.py +248 -0
- warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
- warp/_src/fem/space/basis_space.py +679 -0
- warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
- warp/{fem → _src/fem}/space/function_space.py +14 -13
- warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
- warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
- warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
- warp/{fem → _src/fem}/space/partition.py +117 -60
- warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/restriction.py +66 -33
- warp/_src/fem/space/shape/__init__.py +152 -0
- warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
- warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
- warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
- warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
- warp/_src/fem/space/topology.py +459 -0
- warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
- warp/_src/fem/types.py +112 -0
- warp/_src/fem/utils.py +486 -0
- warp/_src/jax.py +186 -0
- warp/_src/jax_experimental/__init__.py +14 -0
- warp/_src/jax_experimental/custom_call.py +387 -0
- warp/_src/jax_experimental/ffi.py +1284 -0
- warp/_src/jax_experimental/xla_ffi.py +656 -0
- warp/_src/marching_cubes.py +708 -0
- warp/_src/math.py +414 -0
- warp/_src/optim/__init__.py +14 -0
- warp/_src/optim/adam.py +163 -0
- warp/_src/optim/linear.py +1606 -0
- warp/_src/optim/sgd.py +112 -0
- warp/_src/paddle.py +406 -0
- warp/_src/render/__init__.py +14 -0
- warp/_src/render/imgui_manager.py +289 -0
- warp/_src/render/render_opengl.py +3636 -0
- warp/_src/render/render_usd.py +937 -0
- warp/_src/render/utils.py +160 -0
- warp/_src/sparse.py +2716 -0
- warp/_src/tape.py +1206 -0
- warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
- warp/_src/torch.py +391 -0
- warp/_src/types.py +5870 -0
- warp/_src/utils.py +1693 -0
- warp/autograd.py +12 -1054
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +8 -588
- warp/build_dll.py +6 -721
- warp/codegen.py +6 -4251
- warp/constants.py +6 -39
- warp/context.py +12 -8062
- warp/dlpack.py +6 -444
- warp/examples/distributed/example_jacobi_mpi.py +4 -5
- warp/examples/fem/example_adaptive_grid.py +1 -1
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +8 -8
- warp/examples/fem/example_diffusion.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_mixed_elasticity.py +2 -2
- warp/examples/fem/example_navier_stokes.py +1 -1
- warp/examples/fem/example_nonconforming_contact.py +7 -7
- warp/examples/fem/example_stokes.py +1 -1
- warp/examples/fem/example_stokes_transfer.py +1 -1
- warp/examples/fem/utils.py +2 -2
- warp/examples/interop/example_jax_callable.py +1 -1
- warp/examples/interop/example_jax_ffi_callback.py +1 -1
- warp/examples/interop/example_jax_kernel.py +1 -1
- warp/examples/tile/example_tile_mcgp.py +191 -0
- warp/fabric.py +6 -337
- warp/fem/__init__.py +159 -97
- warp/fem/adaptivity.py +7 -489
- warp/fem/cache.py +9 -648
- warp/fem/dirichlet.py +6 -184
- warp/fem/field/__init__.py +8 -109
- warp/fem/field/field.py +7 -652
- warp/fem/geometry/__init__.py +7 -18
- warp/fem/geometry/closest_point.py +11 -77
- warp/fem/linalg.py +18 -366
- warp/fem/operator.py +11 -369
- warp/fem/polynomial.py +9 -209
- warp/fem/space/__init__.py +5 -211
- warp/fem/space/basis_space.py +6 -662
- warp/fem/space/shape/__init__.py +41 -118
- warp/fem/space/topology.py +6 -437
- warp/fem/types.py +6 -81
- warp/fem/utils.py +11 -444
- warp/jax.py +8 -165
- warp/jax_experimental/__init__.py +14 -1
- warp/jax_experimental/custom_call.py +8 -365
- warp/jax_experimental/ffi.py +17 -873
- warp/jax_experimental/xla_ffi.py +5 -605
- warp/marching_cubes.py +5 -689
- warp/math.py +16 -393
- warp/native/array.h +385 -37
- warp/native/builtin.h +314 -37
- warp/native/bvh.cpp +43 -9
- warp/native/bvh.cu +62 -27
- warp/native/bvh.h +310 -309
- warp/native/clang/clang.cpp +102 -97
- warp/native/coloring.cpp +0 -1
- warp/native/crt.h +208 -0
- warp/native/exports.h +156 -0
- warp/native/hashgrid.cu +2 -0
- warp/native/intersect.h +24 -1
- warp/native/intersect_tri.h +44 -35
- warp/native/mat.h +1456 -276
- warp/native/mesh.cpp +4 -4
- warp/native/mesh.cu +4 -2
- warp/native/mesh.h +176 -61
- warp/native/quat.h +0 -52
- warp/native/scan.cu +2 -0
- warp/native/sparse.cu +7 -3
- warp/native/spatial.h +12 -0
- warp/native/tile.h +681 -89
- warp/native/tile_radix_sort.h +1 -1
- warp/native/tile_reduce.h +394 -46
- warp/native/tile_scan.h +4 -4
- warp/native/vec.h +469 -0
- warp/native/version.h +23 -0
- warp/native/volume.cpp +1 -1
- warp/native/volume.cu +1 -0
- warp/native/volume.h +1 -1
- warp/native/volume_builder.cu +2 -0
- warp/native/warp.cpp +57 -29
- warp/native/warp.cu +253 -171
- warp/native/warp.h +11 -8
- warp/optim/__init__.py +6 -3
- warp/optim/adam.py +6 -145
- warp/optim/linear.py +14 -1585
- warp/optim/sgd.py +6 -94
- warp/paddle.py +6 -388
- warp/render/__init__.py +8 -4
- warp/render/imgui_manager.py +7 -267
- warp/render/render_opengl.py +6 -3618
- warp/render/render_usd.py +6 -919
- warp/render/utils.py +6 -142
- warp/sparse.py +37 -2563
- warp/tape.py +6 -1188
- warp/tests/__main__.py +1 -1
- warp/tests/cuda/test_async.py +4 -4
- warp/tests/cuda/test_conditional_captures.py +1 -1
- warp/tests/cuda/test_multigpu.py +1 -1
- warp/tests/cuda/test_streams.py +58 -1
- warp/tests/geometry/test_bvh.py +157 -22
- warp/tests/geometry/test_marching_cubes.py +0 -1
- warp/tests/geometry/test_mesh.py +5 -3
- warp/tests/geometry/test_mesh_query_aabb.py +5 -12
- warp/tests/geometry/test_mesh_query_point.py +5 -2
- warp/tests/geometry/test_mesh_query_ray.py +15 -3
- warp/tests/geometry/test_volume_write.py +5 -5
- warp/tests/interop/test_dlpack.py +14 -14
- warp/tests/interop/test_jax.py +772 -49
- warp/tests/interop/test_paddle.py +1 -1
- warp/tests/test_adam.py +0 -1
- warp/tests/test_arithmetic.py +9 -9
- warp/tests/test_array.py +527 -100
- warp/tests/test_array_reduce.py +3 -3
- warp/tests/test_atomic.py +12 -8
- warp/tests/test_atomic_bitwise.py +209 -0
- warp/tests/test_atomic_cas.py +4 -4
- warp/tests/test_bool.py +2 -2
- warp/tests/test_builtins_resolution.py +5 -571
- warp/tests/test_codegen.py +33 -14
- warp/tests/test_conditional.py +1 -1
- warp/tests/test_context.py +6 -6
- warp/tests/test_copy.py +242 -161
- warp/tests/test_ctypes.py +3 -3
- warp/tests/test_devices.py +24 -2
- warp/tests/test_examples.py +16 -84
- warp/tests/test_fabricarray.py +35 -35
- warp/tests/test_fast_math.py +0 -2
- warp/tests/test_fem.py +56 -10
- warp/tests/test_fixedarray.py +3 -3
- warp/tests/test_func.py +8 -5
- warp/tests/test_generics.py +1 -1
- warp/tests/test_indexedarray.py +24 -24
- warp/tests/test_intersect.py +39 -9
- warp/tests/test_large.py +1 -1
- warp/tests/test_lerp.py +3 -1
- warp/tests/test_linear_solvers.py +1 -1
- warp/tests/test_map.py +35 -4
- warp/tests/test_mat.py +52 -62
- warp/tests/test_mat_constructors.py +4 -5
- warp/tests/test_mat_lite.py +1 -1
- warp/tests/test_mat_scalar_ops.py +121 -121
- warp/tests/test_math.py +34 -0
- warp/tests/test_module_aot.py +4 -4
- warp/tests/test_modules_lite.py +28 -2
- warp/tests/test_print.py +11 -11
- warp/tests/test_quat.py +93 -58
- warp/tests/test_runlength_encode.py +1 -1
- warp/tests/test_scalar_ops.py +38 -10
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +126 -15
- warp/tests/test_spatial.py +105 -87
- warp/tests/test_special_values.py +6 -6
- warp/tests/test_static.py +7 -7
- warp/tests/test_struct.py +13 -2
- warp/tests/test_triangle_closest_point.py +48 -1
- warp/tests/test_types.py +27 -15
- warp/tests/test_utils.py +52 -52
- warp/tests/test_vec.py +29 -29
- warp/tests/test_vec_constructors.py +5 -5
- warp/tests/test_vec_scalar_ops.py +97 -97
- warp/tests/test_version.py +75 -0
- warp/tests/tile/test_tile.py +178 -0
- warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
- warp/tests/tile/test_tile_cholesky.py +7 -4
- warp/tests/tile/test_tile_load.py +26 -2
- warp/tests/tile/test_tile_mathdx.py +3 -3
- warp/tests/tile/test_tile_matmul.py +1 -1
- warp/tests/tile/test_tile_mlp.py +2 -4
- warp/tests/tile/test_tile_reduce.py +214 -13
- warp/tests/unittest_suites.py +6 -14
- warp/tests/unittest_utils.py +10 -9
- warp/tests/walkthrough_debug.py +3 -1
- warp/torch.py +6 -373
- warp/types.py +29 -5764
- warp/utils.py +10 -1659
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +46 -99
- warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
- warp/examples/assets/cartpole.urdf +0 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/nv_ant.xml +0 -92
- warp/examples/assets/nv_humanoid.xml +0 -183
- warp/examples/assets/quadruped.urdf +0 -268
- warp/examples/optim/example_bounce.py +0 -266
- warp/examples/optim/example_cloth_throw.py +0 -228
- warp/examples/optim/example_drone.py +0 -870
- warp/examples/optim/example_inverse_kinematics.py +0 -182
- warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
- warp/examples/optim/example_softbody_properties.py +0 -400
- warp/examples/optim/example_spring_cage.py +0 -245
- warp/examples/optim/example_trajectory.py +0 -227
- warp/examples/sim/example_cartpole.py +0 -143
- warp/examples/sim/example_cloth.py +0 -225
- warp/examples/sim/example_cloth_self_contact.py +0 -316
- warp/examples/sim/example_granular.py +0 -130
- warp/examples/sim/example_granular_collision_sdf.py +0 -202
- warp/examples/sim/example_jacobian_ik.py +0 -244
- warp/examples/sim/example_particle_chain.py +0 -124
- warp/examples/sim/example_quadruped.py +0 -203
- warp/examples/sim/example_rigid_chain.py +0 -203
- warp/examples/sim/example_rigid_contact.py +0 -195
- warp/examples/sim/example_rigid_force.py +0 -133
- warp/examples/sim/example_rigid_gyroscopic.py +0 -115
- warp/examples/sim/example_rigid_soft_contact.py +0 -140
- warp/examples/sim/example_soft_body.py +0 -196
- warp/examples/tile/example_tile_walker.py +0 -327
- warp/sim/__init__.py +0 -74
- warp/sim/articulation.py +0 -793
- warp/sim/collide.py +0 -2570
- warp/sim/graph_coloring.py +0 -307
- warp/sim/import_mjcf.py +0 -791
- warp/sim/import_snu.py +0 -227
- warp/sim/import_urdf.py +0 -579
- warp/sim/import_usd.py +0 -898
- warp/sim/inertia.py +0 -357
- warp/sim/integrator.py +0 -245
- warp/sim/integrator_euler.py +0 -2000
- warp/sim/integrator_featherstone.py +0 -2101
- warp/sim/integrator_vbd.py +0 -2487
- warp/sim/integrator_xpbd.py +0 -3295
- warp/sim/model.py +0 -4821
- warp/sim/particles.py +0 -121
- warp/sim/render.py +0 -431
- warp/sim/utils.py +0 -431
- warp/tests/sim/disabled_kinematics.py +0 -244
- warp/tests/sim/test_cloth.py +0 -863
- warp/tests/sim/test_collision.py +0 -743
- warp/tests/sim/test_coloring.py +0 -347
- warp/tests/sim/test_inertia.py +0 -161
- warp/tests/sim/test_model.py +0 -226
- warp/tests/sim/test_sim_grad.py +0 -287
- warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
- warp/tests/sim/test_sim_kinematics.py +0 -98
- warp/thirdparty/__init__.py +0 -0
- warp_lang-1.9.1.dist-info/RECORD +0 -456
- /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
- /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
warp/native/warp.cu
CHANGED
|
@@ -222,6 +222,14 @@ struct ModuleInfo
|
|
|
222
222
|
void* module = NULL;
|
|
223
223
|
};
|
|
224
224
|
|
|
225
|
+
// Information used when deferring graph destruction.
|
|
226
|
+
struct GraphDestroyInfo
|
|
227
|
+
{
|
|
228
|
+
void* context = NULL;
|
|
229
|
+
void* graph = NULL;
|
|
230
|
+
void* graph_exec = NULL;
|
|
231
|
+
};
|
|
232
|
+
|
|
225
233
|
static std::unordered_map<CUfunction, std::string> g_kernel_names;
|
|
226
234
|
|
|
227
235
|
// cached info for all devices, indexed by ordinal
|
|
@@ -253,6 +261,11 @@ static std::vector<FreeInfo> g_deferred_free_list;
|
|
|
253
261
|
// Call unload_deferred_modules() to release.
|
|
254
262
|
static std::vector<ModuleInfo> g_deferred_module_list;
|
|
255
263
|
|
|
264
|
+
// Graphs that cannot be destroyed immediately get queued here.
|
|
265
|
+
// Call destroy_deferred_graphs() to release.
|
|
266
|
+
static std::vector<GraphDestroyInfo> g_deferred_graph_list;
|
|
267
|
+
|
|
268
|
+
|
|
256
269
|
void wp_cuda_set_context_restore_policy(bool always_restore)
|
|
257
270
|
{
|
|
258
271
|
ContextGuard::always_restore = always_restore;
|
|
@@ -338,7 +351,7 @@ int cuda_init()
|
|
|
338
351
|
}
|
|
339
352
|
|
|
340
353
|
|
|
341
|
-
|
|
354
|
+
CUcontext get_current_context()
|
|
342
355
|
{
|
|
343
356
|
CUcontext ctx;
|
|
344
357
|
if (check_cu(cuCtxGetCurrent_f(&ctx)))
|
|
@@ -495,6 +508,38 @@ static int unload_deferred_modules(void* context = NULL)
|
|
|
495
508
|
return num_unloaded_modules;
|
|
496
509
|
}
|
|
497
510
|
|
|
511
|
+
static int destroy_deferred_graphs(void* context = NULL)
|
|
512
|
+
{
|
|
513
|
+
if (g_deferred_graph_list.empty() || !g_captures.empty())
|
|
514
|
+
return 0;
|
|
515
|
+
|
|
516
|
+
int num_destroyed_graphs = 0;
|
|
517
|
+
for (auto it = g_deferred_graph_list.begin(); it != g_deferred_graph_list.end(); /*noop*/)
|
|
518
|
+
{
|
|
519
|
+
// destroy the graph if it matches the given context or if the context is unspecified
|
|
520
|
+
const GraphDestroyInfo& graph_info = *it;
|
|
521
|
+
if (graph_info.context == context || !context)
|
|
522
|
+
{
|
|
523
|
+
if (graph_info.graph)
|
|
524
|
+
{
|
|
525
|
+
check_cuda(cudaGraphDestroy((cudaGraph_t)graph_info.graph));
|
|
526
|
+
}
|
|
527
|
+
if (graph_info.graph_exec)
|
|
528
|
+
{
|
|
529
|
+
check_cuda(cudaGraphExecDestroy((cudaGraphExec_t)graph_info.graph_exec));
|
|
530
|
+
}
|
|
531
|
+
++num_destroyed_graphs;
|
|
532
|
+
it = g_deferred_graph_list.erase(it);
|
|
533
|
+
}
|
|
534
|
+
else
|
|
535
|
+
{
|
|
536
|
+
++it;
|
|
537
|
+
}
|
|
538
|
+
}
|
|
539
|
+
|
|
540
|
+
return num_destroyed_graphs;
|
|
541
|
+
}
|
|
542
|
+
|
|
498
543
|
static void CUDART_CB on_graph_destroy(void* user_data)
|
|
499
544
|
{
|
|
500
545
|
if (!user_data)
|
|
@@ -989,15 +1034,15 @@ void wp_memtile_device(void* context, void* dst, const void* src, size_t srcsize
|
|
|
989
1034
|
|
|
990
1035
|
|
|
991
1036
|
static __global__ void array_copy_1d_kernel(void* dst, const void* src,
|
|
992
|
-
|
|
1037
|
+
size_t dst_stride, size_t src_stride,
|
|
993
1038
|
const int* dst_indices, const int* src_indices,
|
|
994
|
-
|
|
1039
|
+
size_t n, size_t elem_size)
|
|
995
1040
|
{
|
|
996
|
-
|
|
1041
|
+
size_t i = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
997
1042
|
if (i < n)
|
|
998
1043
|
{
|
|
999
|
-
|
|
1000
|
-
|
|
1044
|
+
size_t src_idx = src_indices ? src_indices[i] : i;
|
|
1045
|
+
size_t dst_idx = dst_indices ? dst_indices[i] : i;
|
|
1001
1046
|
const char* p = (const char*)src + src_idx * src_stride;
|
|
1002
1047
|
char* q = (char*)dst + dst_idx * dst_stride;
|
|
1003
1048
|
memcpy(q, p, elem_size);
|
|
@@ -1005,20 +1050,20 @@ static __global__ void array_copy_1d_kernel(void* dst, const void* src,
|
|
|
1005
1050
|
}
|
|
1006
1051
|
|
|
1007
1052
|
static __global__ void array_copy_2d_kernel(void* dst, const void* src,
|
|
1008
|
-
wp::vec_t<2,
|
|
1053
|
+
wp::vec_t<2, size_t> dst_strides, wp::vec_t<2, size_t> src_strides,
|
|
1009
1054
|
wp::vec_t<2, const int*> dst_indices, wp::vec_t<2, const int*> src_indices,
|
|
1010
|
-
wp::vec_t<2,
|
|
1055
|
+
wp::vec_t<2, size_t> shape, size_t elem_size)
|
|
1011
1056
|
{
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1057
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1058
|
+
size_t n = shape[1];
|
|
1059
|
+
size_t i = tid / n;
|
|
1060
|
+
size_t j = tid % n;
|
|
1016
1061
|
if (i < shape[0] /*&& j < shape[1]*/)
|
|
1017
1062
|
{
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1063
|
+
size_t src_idx0 = src_indices[0] ? src_indices[0][i] : i;
|
|
1064
|
+
size_t dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
|
|
1065
|
+
size_t src_idx1 = src_indices[1] ? src_indices[1][j] : j;
|
|
1066
|
+
size_t dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
|
|
1022
1067
|
const char* p = (const char*)src + src_idx0 * src_strides[0] + src_idx1 * src_strides[1];
|
|
1023
1068
|
char* q = (char*)dst + dst_idx0 * dst_strides[0] + dst_idx1 * dst_strides[1];
|
|
1024
1069
|
memcpy(q, p, elem_size);
|
|
@@ -1026,24 +1071,24 @@ static __global__ void array_copy_2d_kernel(void* dst, const void* src,
|
|
|
1026
1071
|
}
|
|
1027
1072
|
|
|
1028
1073
|
static __global__ void array_copy_3d_kernel(void* dst, const void* src,
|
|
1029
|
-
wp::vec_t<3,
|
|
1074
|
+
wp::vec_t<3, size_t> dst_strides, wp::vec_t<3, size_t> src_strides,
|
|
1030
1075
|
wp::vec_t<3, const int*> dst_indices, wp::vec_t<3, const int*> src_indices,
|
|
1031
|
-
wp::vec_t<3,
|
|
1032
|
-
{
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1076
|
+
wp::vec_t<3, size_t> shape, size_t elem_size)
|
|
1077
|
+
{
|
|
1078
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1079
|
+
size_t n = shape[1];
|
|
1080
|
+
size_t o = shape[2];
|
|
1081
|
+
size_t i = tid / (n * o);
|
|
1082
|
+
size_t j = tid % (n * o) / o;
|
|
1083
|
+
size_t k = tid % o;
|
|
1039
1084
|
if (i < shape[0] && j < shape[1] /*&& k < shape[2]*/)
|
|
1040
1085
|
{
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1086
|
+
size_t src_idx0 = src_indices[0] ? src_indices[0][i] : i;
|
|
1087
|
+
size_t dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
|
|
1088
|
+
size_t src_idx1 = src_indices[1] ? src_indices[1][j] : j;
|
|
1089
|
+
size_t dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
|
|
1090
|
+
size_t src_idx2 = src_indices[2] ? src_indices[2][k] : k;
|
|
1091
|
+
size_t dst_idx2 = dst_indices[2] ? dst_indices[2][k] : k;
|
|
1047
1092
|
const char* p = (const char*)src + src_idx0 * src_strides[0]
|
|
1048
1093
|
+ src_idx1 * src_strides[1]
|
|
1049
1094
|
+ src_idx2 * src_strides[2];
|
|
@@ -1055,28 +1100,28 @@ static __global__ void array_copy_3d_kernel(void* dst, const void* src,
|
|
|
1055
1100
|
}
|
|
1056
1101
|
|
|
1057
1102
|
static __global__ void array_copy_4d_kernel(void* dst, const void* src,
|
|
1058
|
-
wp::vec_t<4,
|
|
1103
|
+
wp::vec_t<4, size_t> dst_strides, wp::vec_t<4, size_t> src_strides,
|
|
1059
1104
|
wp::vec_t<4, const int*> dst_indices, wp::vec_t<4, const int*> src_indices,
|
|
1060
|
-
wp::vec_t<4,
|
|
1061
|
-
{
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1105
|
+
wp::vec_t<4, size_t> shape, size_t elem_size)
|
|
1106
|
+
{
|
|
1107
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1108
|
+
size_t n = shape[1];
|
|
1109
|
+
size_t o = shape[2];
|
|
1110
|
+
size_t p = shape[3];
|
|
1111
|
+
size_t i = tid / (n * o * p);
|
|
1112
|
+
size_t j = tid % (n * o * p) / (o * p);
|
|
1113
|
+
size_t k = tid % (o * p) / p;
|
|
1114
|
+
size_t l = tid % p;
|
|
1070
1115
|
if (i < shape[0] && j < shape[1] && k < shape[2] /*&& l < shape[3]*/)
|
|
1071
1116
|
{
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1117
|
+
size_t src_idx0 = src_indices[0] ? src_indices[0][i] : i;
|
|
1118
|
+
size_t dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
|
|
1119
|
+
size_t src_idx1 = src_indices[1] ? src_indices[1][j] : j;
|
|
1120
|
+
size_t dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
|
|
1121
|
+
size_t src_idx2 = src_indices[2] ? src_indices[2][k] : k;
|
|
1122
|
+
size_t dst_idx2 = dst_indices[2] ? dst_indices[2][k] : k;
|
|
1123
|
+
size_t src_idx3 = src_indices[3] ? src_indices[3][l] : l;
|
|
1124
|
+
size_t dst_idx3 = dst_indices[3] ? dst_indices[3][l] : l;
|
|
1080
1125
|
const char* p = (const char*)src + src_idx0 * src_strides[0]
|
|
1081
1126
|
+ src_idx1 * src_strides[1]
|
|
1082
1127
|
+ src_idx2 * src_strides[2]
|
|
@@ -1091,14 +1136,14 @@ static __global__ void array_copy_4d_kernel(void* dst, const void* src,
|
|
|
1091
1136
|
|
|
1092
1137
|
|
|
1093
1138
|
static __global__ void array_copy_from_fabric_kernel(wp::fabricarray_t<void> src,
|
|
1094
|
-
void* dst_data,
|
|
1095
|
-
|
|
1139
|
+
void* dst_data, size_t dst_stride, const int* dst_indices,
|
|
1140
|
+
size_t elem_size)
|
|
1096
1141
|
{
|
|
1097
|
-
|
|
1142
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1098
1143
|
|
|
1099
1144
|
if (tid < src.size)
|
|
1100
1145
|
{
|
|
1101
|
-
|
|
1146
|
+
size_t dst_idx = dst_indices ? dst_indices[tid] : tid;
|
|
1102
1147
|
void* dst_ptr = (char*)dst_data + dst_idx * dst_stride;
|
|
1103
1148
|
const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
|
|
1104
1149
|
memcpy(dst_ptr, src_ptr, elem_size);
|
|
@@ -1106,15 +1151,15 @@ static __global__ void array_copy_from_fabric_kernel(wp::fabricarray_t<void> src
|
|
|
1106
1151
|
}
|
|
1107
1152
|
|
|
1108
1153
|
static __global__ void array_copy_from_fabric_indexed_kernel(wp::indexedfabricarray_t<void> src,
|
|
1109
|
-
void* dst_data,
|
|
1110
|
-
|
|
1154
|
+
void* dst_data, size_t dst_stride, const int* dst_indices,
|
|
1155
|
+
size_t elem_size)
|
|
1111
1156
|
{
|
|
1112
|
-
|
|
1157
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1113
1158
|
|
|
1114
1159
|
if (tid < src.size)
|
|
1115
1160
|
{
|
|
1116
|
-
|
|
1117
|
-
|
|
1161
|
+
size_t src_index = src.indices[tid];
|
|
1162
|
+
size_t dst_idx = dst_indices ? dst_indices[tid] : tid;
|
|
1118
1163
|
void* dst_ptr = (char*)dst_data + dst_idx * dst_stride;
|
|
1119
1164
|
const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
|
|
1120
1165
|
memcpy(dst_ptr, src_ptr, elem_size);
|
|
@@ -1122,14 +1167,14 @@ static __global__ void array_copy_from_fabric_indexed_kernel(wp::indexedfabricar
|
|
|
1122
1167
|
}
|
|
1123
1168
|
|
|
1124
1169
|
static __global__ void array_copy_to_fabric_kernel(wp::fabricarray_t<void> dst,
|
|
1125
|
-
const void* src_data,
|
|
1126
|
-
|
|
1170
|
+
const void* src_data, size_t src_stride, const int* src_indices,
|
|
1171
|
+
size_t elem_size)
|
|
1127
1172
|
{
|
|
1128
|
-
|
|
1173
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1129
1174
|
|
|
1130
1175
|
if (tid < dst.size)
|
|
1131
1176
|
{
|
|
1132
|
-
|
|
1177
|
+
size_t src_idx = src_indices ? src_indices[tid] : tid;
|
|
1133
1178
|
const void* src_ptr = (const char*)src_data + src_idx * src_stride;
|
|
1134
1179
|
void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
|
|
1135
1180
|
memcpy(dst_ptr, src_ptr, elem_size);
|
|
@@ -1137,25 +1182,25 @@ static __global__ void array_copy_to_fabric_kernel(wp::fabricarray_t<void> dst,
|
|
|
1137
1182
|
}
|
|
1138
1183
|
|
|
1139
1184
|
static __global__ void array_copy_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst,
|
|
1140
|
-
const void* src_data,
|
|
1141
|
-
|
|
1185
|
+
const void* src_data, size_t src_stride, const int* src_indices,
|
|
1186
|
+
size_t elem_size)
|
|
1142
1187
|
{
|
|
1143
|
-
|
|
1188
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1144
1189
|
|
|
1145
1190
|
if (tid < dst.size)
|
|
1146
1191
|
{
|
|
1147
|
-
|
|
1192
|
+
size_t src_idx = src_indices ? src_indices[tid] : tid;
|
|
1148
1193
|
const void* src_ptr = (const char*)src_data + src_idx * src_stride;
|
|
1149
|
-
|
|
1194
|
+
size_t dst_idx = dst.indices[tid];
|
|
1150
1195
|
void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_idx, elem_size);
|
|
1151
1196
|
memcpy(dst_ptr, src_ptr, elem_size);
|
|
1152
1197
|
}
|
|
1153
1198
|
}
|
|
1154
1199
|
|
|
1155
1200
|
|
|
1156
|
-
static __global__ void array_copy_fabric_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::fabricarray_t<void> src,
|
|
1201
|
+
static __global__ void array_copy_fabric_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::fabricarray_t<void> src, size_t elem_size)
|
|
1157
1202
|
{
|
|
1158
|
-
|
|
1203
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1159
1204
|
|
|
1160
1205
|
if (tid < dst.size)
|
|
1161
1206
|
{
|
|
@@ -1166,27 +1211,27 @@ static __global__ void array_copy_fabric_to_fabric_kernel(wp::fabricarray_t<void
|
|
|
1166
1211
|
}
|
|
1167
1212
|
|
|
1168
1213
|
|
|
1169
|
-
static __global__ void array_copy_fabric_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::fabricarray_t<void> src,
|
|
1214
|
+
static __global__ void array_copy_fabric_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::fabricarray_t<void> src, size_t elem_size)
|
|
1170
1215
|
{
|
|
1171
|
-
|
|
1216
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1172
1217
|
|
|
1173
1218
|
if (tid < dst.size)
|
|
1174
1219
|
{
|
|
1175
1220
|
const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
|
|
1176
|
-
|
|
1221
|
+
size_t dst_index = dst.indices[tid];
|
|
1177
1222
|
void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_index, elem_size);
|
|
1178
1223
|
memcpy(dst_ptr, src_ptr, elem_size);
|
|
1179
1224
|
}
|
|
1180
1225
|
}
|
|
1181
1226
|
|
|
1182
1227
|
|
|
1183
|
-
static __global__ void array_copy_fabric_indexed_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::indexedfabricarray_t<void> src,
|
|
1228
|
+
static __global__ void array_copy_fabric_indexed_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, size_t elem_size)
|
|
1184
1229
|
{
|
|
1185
|
-
|
|
1230
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1186
1231
|
|
|
1187
1232
|
if (tid < dst.size)
|
|
1188
1233
|
{
|
|
1189
|
-
|
|
1234
|
+
size_t src_index = src.indices[tid];
|
|
1190
1235
|
const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
|
|
1191
1236
|
void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
|
|
1192
1237
|
memcpy(dst_ptr, src_ptr, elem_size);
|
|
@@ -1194,14 +1239,14 @@ static __global__ void array_copy_fabric_indexed_to_fabric_kernel(wp::fabricarra
|
|
|
1194
1239
|
}
|
|
1195
1240
|
|
|
1196
1241
|
|
|
1197
|
-
static __global__ void array_copy_fabric_indexed_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::indexedfabricarray_t<void> src,
|
|
1242
|
+
static __global__ void array_copy_fabric_indexed_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, size_t elem_size)
|
|
1198
1243
|
{
|
|
1199
|
-
|
|
1244
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1200
1245
|
|
|
1201
1246
|
if (tid < dst.size)
|
|
1202
1247
|
{
|
|
1203
|
-
|
|
1204
|
-
|
|
1248
|
+
size_t src_index = src.indices[tid];
|
|
1249
|
+
size_t dst_index = dst.indices[tid];
|
|
1205
1250
|
const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
|
|
1206
1251
|
void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_index, elem_size);
|
|
1207
1252
|
memcpy(dst_ptr, src_ptr, elem_size);
|
|
@@ -1440,9 +1485,9 @@ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_ty
|
|
|
1440
1485
|
}
|
|
1441
1486
|
case 2:
|
|
1442
1487
|
{
|
|
1443
|
-
wp::vec_t<2,
|
|
1444
|
-
wp::vec_t<2,
|
|
1445
|
-
wp::vec_t<2,
|
|
1488
|
+
wp::vec_t<2, size_t> shape_v(src_shape[0], src_shape[1]);
|
|
1489
|
+
wp::vec_t<2, size_t> src_strides_v(src_strides[0], src_strides[1]);
|
|
1490
|
+
wp::vec_t<2, size_t> dst_strides_v(dst_strides[0], dst_strides[1]);
|
|
1446
1491
|
wp::vec_t<2, const int*> src_indices_v(src_indices[0], src_indices[1]);
|
|
1447
1492
|
wp::vec_t<2, const int*> dst_indices_v(dst_indices[0], dst_indices[1]);
|
|
1448
1493
|
|
|
@@ -1454,9 +1499,9 @@ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_ty
|
|
|
1454
1499
|
}
|
|
1455
1500
|
case 3:
|
|
1456
1501
|
{
|
|
1457
|
-
wp::vec_t<3,
|
|
1458
|
-
wp::vec_t<3,
|
|
1459
|
-
wp::vec_t<3,
|
|
1502
|
+
wp::vec_t<3, size_t> shape_v(src_shape[0], src_shape[1], src_shape[2]);
|
|
1503
|
+
wp::vec_t<3, size_t> src_strides_v(src_strides[0], src_strides[1], src_strides[2]);
|
|
1504
|
+
wp::vec_t<3, size_t> dst_strides_v(dst_strides[0], dst_strides[1], dst_strides[2]);
|
|
1460
1505
|
wp::vec_t<3, const int*> src_indices_v(src_indices[0], src_indices[1], src_indices[2]);
|
|
1461
1506
|
wp::vec_t<3, const int*> dst_indices_v(dst_indices[0], dst_indices[1], dst_indices[2]);
|
|
1462
1507
|
|
|
@@ -1468,9 +1513,9 @@ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_ty
|
|
|
1468
1513
|
}
|
|
1469
1514
|
case 4:
|
|
1470
1515
|
{
|
|
1471
|
-
wp::vec_t<4,
|
|
1472
|
-
wp::vec_t<4,
|
|
1473
|
-
wp::vec_t<4,
|
|
1516
|
+
wp::vec_t<4, size_t> shape_v(src_shape[0], src_shape[1], src_shape[2], src_shape[3]);
|
|
1517
|
+
wp::vec_t<4, size_t> src_strides_v(src_strides[0], src_strides[1], src_strides[2], src_strides[3]);
|
|
1518
|
+
wp::vec_t<4, size_t> dst_strides_v(dst_strides[0], dst_strides[1], dst_strides[2], dst_strides[3]);
|
|
1474
1519
|
wp::vec_t<4, const int*> src_indices_v(src_indices[0], src_indices[1], src_indices[2], src_indices[3]);
|
|
1475
1520
|
wp::vec_t<4, const int*> dst_indices_v(dst_indices[0], dst_indices[1], dst_indices[2], dst_indices[3]);
|
|
1476
1521
|
|
|
@@ -1490,94 +1535,94 @@ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_ty
|
|
|
1490
1535
|
|
|
1491
1536
|
|
|
1492
1537
|
static __global__ void array_fill_1d_kernel(void* data,
|
|
1493
|
-
|
|
1494
|
-
|
|
1538
|
+
size_t n,
|
|
1539
|
+
size_t stride,
|
|
1495
1540
|
const int* indices,
|
|
1496
1541
|
const void* value,
|
|
1497
|
-
|
|
1542
|
+
size_t value_size)
|
|
1498
1543
|
{
|
|
1499
|
-
|
|
1544
|
+
size_t i = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1500
1545
|
if (i < n)
|
|
1501
1546
|
{
|
|
1502
|
-
|
|
1547
|
+
size_t idx = indices ? indices[i] : i;
|
|
1503
1548
|
char* p = (char*)data + idx * stride;
|
|
1504
1549
|
memcpy(p, value, value_size);
|
|
1505
1550
|
}
|
|
1506
1551
|
}
|
|
1507
1552
|
|
|
1508
1553
|
static __global__ void array_fill_2d_kernel(void* data,
|
|
1509
|
-
wp::vec_t<2,
|
|
1510
|
-
wp::vec_t<2,
|
|
1554
|
+
wp::vec_t<2, size_t> shape,
|
|
1555
|
+
wp::vec_t<2, size_t> strides,
|
|
1511
1556
|
wp::vec_t<2, const int*> indices,
|
|
1512
1557
|
const void* value,
|
|
1513
|
-
|
|
1558
|
+
size_t value_size)
|
|
1514
1559
|
{
|
|
1515
|
-
|
|
1516
|
-
|
|
1517
|
-
|
|
1518
|
-
|
|
1560
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1561
|
+
size_t n = shape[1];
|
|
1562
|
+
size_t i = tid / n;
|
|
1563
|
+
size_t j = tid % n;
|
|
1519
1564
|
if (i < shape[0] /*&& j < shape[1]*/)
|
|
1520
1565
|
{
|
|
1521
|
-
|
|
1522
|
-
|
|
1566
|
+
size_t idx0 = indices[0] ? indices[0][i] : i;
|
|
1567
|
+
size_t idx1 = indices[1] ? indices[1][j] : j;
|
|
1523
1568
|
char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1];
|
|
1524
1569
|
memcpy(p, value, value_size);
|
|
1525
1570
|
}
|
|
1526
1571
|
}
|
|
1527
1572
|
|
|
1528
1573
|
static __global__ void array_fill_3d_kernel(void* data,
|
|
1529
|
-
wp::vec_t<3,
|
|
1530
|
-
wp::vec_t<3,
|
|
1574
|
+
wp::vec_t<3, size_t> shape,
|
|
1575
|
+
wp::vec_t<3, size_t> strides,
|
|
1531
1576
|
wp::vec_t<3, const int*> indices,
|
|
1532
1577
|
const void* value,
|
|
1533
|
-
|
|
1534
|
-
{
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
|
|
1539
|
-
|
|
1540
|
-
|
|
1578
|
+
size_t value_size)
|
|
1579
|
+
{
|
|
1580
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1581
|
+
size_t n = shape[1];
|
|
1582
|
+
size_t o = shape[2];
|
|
1583
|
+
size_t i = tid / (n * o);
|
|
1584
|
+
size_t j = tid % (n * o) / o;
|
|
1585
|
+
size_t k = tid % o;
|
|
1541
1586
|
if (i < shape[0] && j < shape[1] /*&& k < shape[2]*/)
|
|
1542
1587
|
{
|
|
1543
|
-
|
|
1544
|
-
|
|
1545
|
-
|
|
1588
|
+
size_t idx0 = indices[0] ? indices[0][i] : i;
|
|
1589
|
+
size_t idx1 = indices[1] ? indices[1][j] : j;
|
|
1590
|
+
size_t idx2 = indices[2] ? indices[2][k] : k;
|
|
1546
1591
|
char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1] + idx2 * strides[2];
|
|
1547
1592
|
memcpy(p, value, value_size);
|
|
1548
1593
|
}
|
|
1549
1594
|
}
|
|
1550
1595
|
|
|
1551
1596
|
static __global__ void array_fill_4d_kernel(void* data,
|
|
1552
|
-
wp::vec_t<4,
|
|
1553
|
-
wp::vec_t<4,
|
|
1597
|
+
wp::vec_t<4, size_t> shape,
|
|
1598
|
+
wp::vec_t<4, size_t> strides,
|
|
1554
1599
|
wp::vec_t<4, const int*> indices,
|
|
1555
1600
|
const void* value,
|
|
1556
|
-
|
|
1557
|
-
{
|
|
1558
|
-
|
|
1559
|
-
|
|
1560
|
-
|
|
1561
|
-
|
|
1562
|
-
|
|
1563
|
-
|
|
1564
|
-
|
|
1565
|
-
|
|
1601
|
+
size_t value_size)
|
|
1602
|
+
{
|
|
1603
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1604
|
+
size_t n = shape[1];
|
|
1605
|
+
size_t o = shape[2];
|
|
1606
|
+
size_t p = shape[3];
|
|
1607
|
+
size_t i = tid / (n * o * p);
|
|
1608
|
+
size_t j = tid % (n * o * p) / (o * p);
|
|
1609
|
+
size_t k = tid % (o * p) / p;
|
|
1610
|
+
size_t l = tid % p;
|
|
1566
1611
|
if (i < shape[0] && j < shape[1] && k < shape[2] /*&& l < shape[3]*/)
|
|
1567
1612
|
{
|
|
1568
|
-
|
|
1569
|
-
|
|
1570
|
-
|
|
1571
|
-
|
|
1613
|
+
size_t idx0 = indices[0] ? indices[0][i] : i;
|
|
1614
|
+
size_t idx1 = indices[1] ? indices[1][j] : j;
|
|
1615
|
+
size_t idx2 = indices[2] ? indices[2][k] : k;
|
|
1616
|
+
size_t idx3 = indices[3] ? indices[3][l] : l;
|
|
1572
1617
|
char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1] + idx2 * strides[2] + idx3 * strides[3];
|
|
1573
1618
|
memcpy(p, value, value_size);
|
|
1574
1619
|
}
|
|
1575
1620
|
}
|
|
1576
1621
|
|
|
1577
1622
|
|
|
1578
|
-
static __global__ void array_fill_fabric_kernel(wp::fabricarray_t<void> fa, const void* value,
|
|
1623
|
+
static __global__ void array_fill_fabric_kernel(wp::fabricarray_t<void> fa, const void* value, size_t value_size)
|
|
1579
1624
|
{
|
|
1580
|
-
|
|
1625
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1581
1626
|
if (tid < fa.size)
|
|
1582
1627
|
{
|
|
1583
1628
|
void* dst_ptr = fabricarray_element_ptr(fa, tid, value_size);
|
|
@@ -1586,9 +1631,9 @@ static __global__ void array_fill_fabric_kernel(wp::fabricarray_t<void> fa, cons
|
|
|
1586
1631
|
}
|
|
1587
1632
|
|
|
1588
1633
|
|
|
1589
|
-
static __global__ void array_fill_fabric_indexed_kernel(wp::indexedfabricarray_t<void> ifa, const void* value,
|
|
1634
|
+
static __global__ void array_fill_fabric_indexed_kernel(wp::indexedfabricarray_t<void> ifa, const void* value, size_t value_size)
|
|
1590
1635
|
{
|
|
1591
|
-
|
|
1636
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1592
1637
|
if (tid < ifa.size)
|
|
1593
1638
|
{
|
|
1594
1639
|
size_t idx = size_t(ifa.indices[tid]);
|
|
@@ -1685,8 +1730,8 @@ WP_API void wp_array_fill_device(void* context, void* arr_ptr, int arr_type, con
|
|
|
1685
1730
|
}
|
|
1686
1731
|
case 2:
|
|
1687
1732
|
{
|
|
1688
|
-
wp::vec_t<2,
|
|
1689
|
-
wp::vec_t<2,
|
|
1733
|
+
wp::vec_t<2, size_t> shape_v(shape[0], shape[1]);
|
|
1734
|
+
wp::vec_t<2, size_t> strides_v(strides[0], strides[1]);
|
|
1690
1735
|
wp::vec_t<2, const int*> indices_v(indices[0], indices[1]);
|
|
1691
1736
|
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_2d_kernel, n,
|
|
1692
1737
|
(data, shape_v, strides_v, indices_v, value_devptr, value_size));
|
|
@@ -1694,8 +1739,8 @@ WP_API void wp_array_fill_device(void* context, void* arr_ptr, int arr_type, con
|
|
|
1694
1739
|
}
|
|
1695
1740
|
case 3:
|
|
1696
1741
|
{
|
|
1697
|
-
wp::vec_t<3,
|
|
1698
|
-
wp::vec_t<3,
|
|
1742
|
+
wp::vec_t<3, size_t> shape_v(shape[0], shape[1], shape[2]);
|
|
1743
|
+
wp::vec_t<3, size_t> strides_v(strides[0], strides[1], strides[2]);
|
|
1699
1744
|
wp::vec_t<3, const int*> indices_v(indices[0], indices[1], indices[2]);
|
|
1700
1745
|
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_3d_kernel, n,
|
|
1701
1746
|
(data, shape_v, strides_v, indices_v, value_devptr, value_size));
|
|
@@ -1703,8 +1748,8 @@ WP_API void wp_array_fill_device(void* context, void* arr_ptr, int arr_type, con
|
|
|
1703
1748
|
}
|
|
1704
1749
|
case 4:
|
|
1705
1750
|
{
|
|
1706
|
-
wp::vec_t<4,
|
|
1707
|
-
wp::vec_t<4,
|
|
1751
|
+
wp::vec_t<4, size_t> shape_v(shape[0], shape[1], shape[2], shape[3]);
|
|
1752
|
+
wp::vec_t<4, size_t> strides_v(strides[0], strides[1], strides[2], strides[3]);
|
|
1708
1753
|
wp::vec_t<4, const int*> indices_v(indices[0], indices[1], indices[2], indices[3]);
|
|
1709
1754
|
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_4d_kernel, n,
|
|
1710
1755
|
(data, shape_v, strides_v, indices_v, value_devptr, value_size));
|
|
@@ -2072,13 +2117,17 @@ void wp_cuda_context_synchronize(void* context)
|
|
|
2072
2117
|
|
|
2073
2118
|
check_cu(cuCtxSynchronize_f());
|
|
2074
2119
|
|
|
2075
|
-
if (
|
|
2120
|
+
if (!context)
|
|
2121
|
+
context = get_current_context();
|
|
2122
|
+
|
|
2123
|
+
if (free_deferred_allocs(context) > 0)
|
|
2076
2124
|
{
|
|
2077
2125
|
// ensure deferred asynchronous deallocations complete
|
|
2078
2126
|
check_cu(cuCtxSynchronize_f());
|
|
2079
2127
|
}
|
|
2080
2128
|
|
|
2081
2129
|
unload_deferred_modules(context);
|
|
2130
|
+
destroy_deferred_graphs(context);
|
|
2082
2131
|
|
|
2083
2132
|
// check_cuda(cudaDeviceGraphMemTrim(wp_cuda_context_get_device_ordinal(context)));
|
|
2084
2133
|
}
|
|
@@ -2514,15 +2563,36 @@ void wp_cuda_stream_synchronize(void* stream)
|
|
|
2514
2563
|
check_cu(cuStreamSynchronize_f(static_cast<CUstream>(stream)));
|
|
2515
2564
|
}
|
|
2516
2565
|
|
|
2517
|
-
void wp_cuda_stream_wait_event(void* stream, void* event)
|
|
2566
|
+
void wp_cuda_stream_wait_event(void* stream, void* event, bool external)
|
|
2518
2567
|
{
|
|
2519
|
-
|
|
2568
|
+
// the external flag can only be used during graph capture
|
|
2569
|
+
if (external && !g_captures.empty() && wp_cuda_stream_is_capturing(stream))
|
|
2570
|
+
{
|
|
2571
|
+
// wait for an external event during graph capture
|
|
2572
|
+
check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), CU_EVENT_WAIT_EXTERNAL));
|
|
2573
|
+
}
|
|
2574
|
+
else
|
|
2575
|
+
{
|
|
2576
|
+
check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), CU_EVENT_WAIT_DEFAULT));
|
|
2577
|
+
}
|
|
2520
2578
|
}
|
|
2521
2579
|
|
|
2522
|
-
void wp_cuda_stream_wait_stream(void* stream, void* other_stream, void* event)
|
|
2580
|
+
void wp_cuda_stream_wait_stream(void* stream, void* other_stream, void* event, bool external)
|
|
2523
2581
|
{
|
|
2524
|
-
|
|
2525
|
-
|
|
2582
|
+
unsigned record_flags = CU_EVENT_RECORD_DEFAULT;
|
|
2583
|
+
unsigned wait_flags = CU_EVENT_WAIT_DEFAULT;
|
|
2584
|
+
|
|
2585
|
+
// the external flag can only be used during graph capture
|
|
2586
|
+
if (external && !g_captures.empty())
|
|
2587
|
+
{
|
|
2588
|
+
if (wp_cuda_stream_is_capturing(other_stream))
|
|
2589
|
+
record_flags = CU_EVENT_RECORD_EXTERNAL;
|
|
2590
|
+
if (wp_cuda_stream_is_capturing(stream))
|
|
2591
|
+
wait_flags = CU_EVENT_WAIT_EXTERNAL;
|
|
2592
|
+
}
|
|
2593
|
+
|
|
2594
|
+
check_cu(cuEventRecordWithFlags_f(static_cast<CUevent>(event), static_cast<CUstream>(other_stream), record_flags));
|
|
2595
|
+
check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), wait_flags));
|
|
2526
2596
|
}
|
|
2527
2597
|
|
|
2528
2598
|
int wp_cuda_stream_is_capturing(void* stream)
|
|
@@ -2575,11 +2645,12 @@ int wp_cuda_event_query(void* event)
|
|
|
2575
2645
|
return res;
|
|
2576
2646
|
}
|
|
2577
2647
|
|
|
2578
|
-
void wp_cuda_event_record(void* event, void* stream, bool
|
|
2648
|
+
void wp_cuda_event_record(void* event, void* stream, bool external)
|
|
2579
2649
|
{
|
|
2580
|
-
|
|
2650
|
+
// the external flag can only be used during graph capture
|
|
2651
|
+
if (external && !g_captures.empty() && wp_cuda_stream_is_capturing(stream))
|
|
2581
2652
|
{
|
|
2582
|
-
// record
|
|
2653
|
+
// record external event during graph capture (e.g., for timing or when explicitly specified by the user)
|
|
2583
2654
|
check_cu(cuEventRecordWithFlags_f(static_cast<CUevent>(event), static_cast<CUstream>(stream), CU_EVENT_RECORD_EXTERNAL));
|
|
2584
2655
|
}
|
|
2585
2656
|
else
|
|
@@ -2629,7 +2700,7 @@ bool wp_cuda_graph_begin_capture(void* context, void* stream, int external)
|
|
|
2629
2700
|
else
|
|
2630
2701
|
{
|
|
2631
2702
|
// start the capture
|
|
2632
|
-
if (!check_cuda(cudaStreamBeginCapture(cuda_stream,
|
|
2703
|
+
if (!check_cuda(cudaStreamBeginCapture(cuda_stream, cudaStreamCaptureModeThreadLocal)))
|
|
2633
2704
|
return false;
|
|
2634
2705
|
}
|
|
2635
2706
|
|
|
@@ -2776,6 +2847,7 @@ bool wp_cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
|
|
|
2776
2847
|
{
|
|
2777
2848
|
free_deferred_allocs();
|
|
2778
2849
|
unload_deferred_modules();
|
|
2850
|
+
destroy_deferred_graphs();
|
|
2779
2851
|
}
|
|
2780
2852
|
|
|
2781
2853
|
if (graph_ret)
|
|
@@ -2996,7 +3068,7 @@ bool wp_cuda_graph_resume_capture(void* context, void* stream, void* graph)
|
|
|
2996
3068
|
leaf_nodes.data(),
|
|
2997
3069
|
nullptr,
|
|
2998
3070
|
leaf_nodes.size(),
|
|
2999
|
-
|
|
3071
|
+
cudaStreamCaptureModeThreadLocal)))
|
|
3000
3072
|
return false;
|
|
3001
3073
|
|
|
3002
3074
|
return true;
|
|
@@ -3455,16 +3527,38 @@ bool wp_cuda_graph_launch(void* graph_exec, void* stream)
|
|
|
3455
3527
|
|
|
3456
3528
|
bool wp_cuda_graph_destroy(void* context, void* graph)
|
|
3457
3529
|
{
|
|
3458
|
-
|
|
3459
|
-
|
|
3460
|
-
|
|
3530
|
+
// ensure there are no graph captures in progress
|
|
3531
|
+
if (g_captures.empty())
|
|
3532
|
+
{
|
|
3533
|
+
ContextGuard guard(context);
|
|
3534
|
+
return check_cuda(cudaGraphDestroy((cudaGraph_t)graph));
|
|
3535
|
+
}
|
|
3536
|
+
else
|
|
3537
|
+
{
|
|
3538
|
+
GraphDestroyInfo info;
|
|
3539
|
+
info.context = context ? context : get_current_context();
|
|
3540
|
+
info.graph = graph;
|
|
3541
|
+
g_deferred_graph_list.push_back(info);
|
|
3542
|
+
return true;
|
|
3543
|
+
}
|
|
3461
3544
|
}
|
|
3462
3545
|
|
|
3463
3546
|
bool wp_cuda_graph_exec_destroy(void* context, void* graph_exec)
|
|
3464
3547
|
{
|
|
3465
|
-
|
|
3466
|
-
|
|
3467
|
-
|
|
3548
|
+
// ensure there are no graph captures in progress
|
|
3549
|
+
if (g_captures.empty())
|
|
3550
|
+
{
|
|
3551
|
+
ContextGuard guard(context);
|
|
3552
|
+
return check_cuda(cudaGraphExecDestroy((cudaGraphExec_t)graph_exec));
|
|
3553
|
+
}
|
|
3554
|
+
else
|
|
3555
|
+
{
|
|
3556
|
+
GraphDestroyInfo info;
|
|
3557
|
+
info.context = context ? context : get_current_context();
|
|
3558
|
+
info.graph_exec = graph_exec;
|
|
3559
|
+
g_deferred_graph_list.push_back(info);
|
|
3560
|
+
return true;
|
|
3561
|
+
}
|
|
3468
3562
|
}
|
|
3469
3563
|
|
|
3470
3564
|
bool write_file(const char* data, size_t size, std::string filename, const char* mode)
|
|
@@ -4317,17 +4411,5 @@ void wp_cuda_timing_end(timing_result_t* results, int size)
|
|
|
4317
4411
|
g_cuda_timing_state = parent_state;
|
|
4318
4412
|
}
|
|
4319
4413
|
|
|
4320
|
-
// impl. files
|
|
4321
|
-
#include "bvh.cu"
|
|
4322
|
-
#include "mesh.cu"
|
|
4323
|
-
#include "sort.cu"
|
|
4324
|
-
#include "hashgrid.cu"
|
|
4325
|
-
#include "reduce.cu"
|
|
4326
|
-
#include "runlength_encode.cu"
|
|
4327
|
-
#include "scan.cu"
|
|
4328
|
-
#include "sparse.cu"
|
|
4329
|
-
#include "volume.cu"
|
|
4330
|
-
#include "volume_builder.cu"
|
|
4331
|
-
|
|
4332
4414
|
//#include "spline.inl"
|
|
4333
4415
|
//#include "volume.inl"
|