warp-lang 1.9.1__py3-none-win_amd64.whl → 1.10.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +301 -287
- warp/__init__.pyi +882 -305
- warp/_src/__init__.py +14 -0
- warp/_src/autograd.py +1077 -0
- warp/_src/build.py +620 -0
- warp/_src/build_dll.py +642 -0
- warp/{builtins.py → _src/builtins.py} +1435 -379
- warp/_src/codegen.py +4361 -0
- warp/{config.py → _src/config.py} +178 -169
- warp/_src/constants.py +59 -0
- warp/_src/context.py +8352 -0
- warp/_src/dlpack.py +464 -0
- warp/_src/fabric.py +362 -0
- warp/_src/fem/__init__.py +14 -0
- warp/_src/fem/adaptivity.py +510 -0
- warp/_src/fem/cache.py +689 -0
- warp/_src/fem/dirichlet.py +190 -0
- warp/{fem → _src/fem}/domain.py +42 -30
- warp/_src/fem/field/__init__.py +131 -0
- warp/_src/fem/field/field.py +703 -0
- warp/{fem → _src/fem}/field/nodal_field.py +32 -15
- warp/{fem → _src/fem}/field/restriction.py +3 -1
- warp/{fem → _src/fem}/field/virtual.py +55 -27
- warp/_src/fem/geometry/__init__.py +32 -0
- warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +79 -163
- warp/_src/fem/geometry/closest_point.py +99 -0
- warp/{fem → _src/fem}/geometry/deformed_geometry.py +16 -22
- warp/{fem → _src/fem}/geometry/element.py +34 -10
- warp/{fem → _src/fem}/geometry/geometry.py +50 -20
- warp/{fem → _src/fem}/geometry/grid_2d.py +14 -23
- warp/{fem → _src/fem}/geometry/grid_3d.py +14 -23
- warp/{fem → _src/fem}/geometry/hexmesh.py +42 -63
- warp/{fem → _src/fem}/geometry/nanogrid.py +256 -247
- warp/{fem → _src/fem}/geometry/partition.py +123 -63
- warp/{fem → _src/fem}/geometry/quadmesh.py +28 -45
- warp/{fem → _src/fem}/geometry/tetmesh.py +42 -63
- warp/{fem → _src/fem}/geometry/trimesh.py +28 -45
- warp/{fem → _src/fem}/integrate.py +166 -158
- warp/_src/fem/linalg.py +385 -0
- warp/_src/fem/operator.py +398 -0
- warp/_src/fem/polynomial.py +231 -0
- warp/{fem → _src/fem}/quadrature/pic_quadrature.py +17 -20
- warp/{fem → _src/fem}/quadrature/quadrature.py +97 -47
- warp/_src/fem/space/__init__.py +248 -0
- warp/{fem → _src/fem}/space/basis_function_space.py +22 -11
- warp/_src/fem/space/basis_space.py +681 -0
- warp/{fem → _src/fem}/space/dof_mapper.py +5 -3
- warp/{fem → _src/fem}/space/function_space.py +16 -13
- warp/{fem → _src/fem}/space/grid_2d_function_space.py +6 -7
- warp/{fem → _src/fem}/space/grid_3d_function_space.py +6 -4
- warp/{fem → _src/fem}/space/hexmesh_function_space.py +6 -10
- warp/{fem → _src/fem}/space/nanogrid_function_space.py +5 -9
- warp/{fem → _src/fem}/space/partition.py +119 -60
- warp/{fem → _src/fem}/space/quadmesh_function_space.py +6 -10
- warp/{fem → _src/fem}/space/restriction.py +68 -33
- warp/_src/fem/space/shape/__init__.py +152 -0
- warp/{fem → _src/fem}/space/shape/cube_shape_function.py +11 -9
- warp/{fem → _src/fem}/space/shape/shape_function.py +10 -9
- warp/{fem → _src/fem}/space/shape/square_shape_function.py +8 -6
- warp/{fem → _src/fem}/space/shape/tet_shape_function.py +5 -3
- warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +5 -3
- warp/{fem → _src/fem}/space/tetmesh_function_space.py +5 -9
- warp/_src/fem/space/topology.py +461 -0
- warp/{fem → _src/fem}/space/trimesh_function_space.py +5 -9
- warp/_src/fem/types.py +114 -0
- warp/_src/fem/utils.py +488 -0
- warp/_src/jax.py +188 -0
- warp/_src/jax_experimental/__init__.py +14 -0
- warp/_src/jax_experimental/custom_call.py +389 -0
- warp/_src/jax_experimental/ffi.py +1286 -0
- warp/_src/jax_experimental/xla_ffi.py +658 -0
- warp/_src/marching_cubes.py +710 -0
- warp/_src/math.py +416 -0
- warp/_src/optim/__init__.py +14 -0
- warp/_src/optim/adam.py +165 -0
- warp/_src/optim/linear.py +1608 -0
- warp/_src/optim/sgd.py +114 -0
- warp/_src/paddle.py +408 -0
- warp/_src/render/__init__.py +14 -0
- warp/_src/render/imgui_manager.py +291 -0
- warp/_src/render/render_opengl.py +3638 -0
- warp/_src/render/render_usd.py +939 -0
- warp/_src/render/utils.py +162 -0
- warp/_src/sparse.py +2718 -0
- warp/_src/tape.py +1208 -0
- warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
- warp/_src/torch.py +393 -0
- warp/_src/types.py +5888 -0
- warp/_src/utils.py +1695 -0
- warp/autograd.py +12 -1054
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +8 -588
- warp/build_dll.py +6 -721
- warp/codegen.py +6 -4251
- warp/constants.py +6 -39
- warp/context.py +12 -8062
- warp/dlpack.py +6 -444
- warp/examples/distributed/example_jacobi_mpi.py +4 -5
- warp/examples/fem/example_adaptive_grid.py +1 -1
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +8 -8
- warp/examples/fem/example_diffusion.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_mixed_elasticity.py +2 -2
- warp/examples/fem/example_navier_stokes.py +1 -1
- warp/examples/fem/example_nonconforming_contact.py +7 -7
- warp/examples/fem/example_stokes.py +1 -1
- warp/examples/fem/example_stokes_transfer.py +1 -1
- warp/examples/fem/utils.py +2 -2
- warp/examples/interop/example_jax_callable.py +1 -1
- warp/examples/interop/example_jax_ffi_callback.py +1 -1
- warp/examples/interop/example_jax_kernel.py +1 -1
- warp/examples/tile/example_tile_mcgp.py +191 -0
- warp/fabric.py +6 -337
- warp/fem/__init__.py +159 -97
- warp/fem/adaptivity.py +7 -489
- warp/fem/cache.py +9 -648
- warp/fem/dirichlet.py +6 -184
- warp/fem/field/__init__.py +8 -109
- warp/fem/field/field.py +7 -652
- warp/fem/geometry/__init__.py +7 -18
- warp/fem/geometry/closest_point.py +11 -77
- warp/fem/linalg.py +18 -366
- warp/fem/operator.py +11 -369
- warp/fem/polynomial.py +9 -209
- warp/fem/space/__init__.py +5 -211
- warp/fem/space/basis_space.py +6 -662
- warp/fem/space/shape/__init__.py +41 -118
- warp/fem/space/topology.py +6 -437
- warp/fem/types.py +6 -81
- warp/fem/utils.py +11 -444
- warp/jax.py +8 -165
- warp/jax_experimental/__init__.py +14 -1
- warp/jax_experimental/custom_call.py +8 -365
- warp/jax_experimental/ffi.py +17 -873
- warp/jax_experimental/xla_ffi.py +5 -605
- warp/marching_cubes.py +5 -689
- warp/math.py +16 -393
- warp/native/array.h +385 -37
- warp/native/builtin.h +314 -37
- warp/native/bvh.cpp +43 -9
- warp/native/bvh.cu +62 -27
- warp/native/bvh.h +310 -309
- warp/native/clang/clang.cpp +102 -97
- warp/native/coloring.cpp +0 -1
- warp/native/crt.h +208 -0
- warp/native/exports.h +156 -0
- warp/native/hashgrid.cu +2 -0
- warp/native/intersect.h +24 -1
- warp/native/intersect_tri.h +44 -35
- warp/native/mat.h +1456 -276
- warp/native/mesh.cpp +4 -4
- warp/native/mesh.cu +4 -2
- warp/native/mesh.h +176 -61
- warp/native/quat.h +0 -52
- warp/native/scan.cu +2 -0
- warp/native/sparse.cu +7 -3
- warp/native/spatial.h +12 -0
- warp/native/tile.h +681 -89
- warp/native/tile_radix_sort.h +3 -3
- warp/native/tile_reduce.h +394 -46
- warp/native/tile_scan.h +4 -4
- warp/native/vec.h +469 -0
- warp/native/version.h +23 -0
- warp/native/volume.cpp +1 -1
- warp/native/volume.cu +1 -0
- warp/native/volume.h +1 -1
- warp/native/volume_builder.cu +2 -0
- warp/native/warp.cpp +57 -29
- warp/native/warp.cu +521 -250
- warp/native/warp.h +11 -8
- warp/optim/__init__.py +6 -3
- warp/optim/adam.py +6 -145
- warp/optim/linear.py +14 -1585
- warp/optim/sgd.py +6 -94
- warp/paddle.py +6 -388
- warp/render/__init__.py +8 -4
- warp/render/imgui_manager.py +7 -267
- warp/render/render_opengl.py +6 -3618
- warp/render/render_usd.py +6 -919
- warp/render/utils.py +6 -142
- warp/sparse.py +37 -2563
- warp/tape.py +6 -1188
- warp/tests/__main__.py +1 -1
- warp/tests/cuda/test_async.py +4 -4
- warp/tests/cuda/test_conditional_captures.py +1 -1
- warp/tests/cuda/test_multigpu.py +1 -1
- warp/tests/cuda/test_streams.py +58 -1
- warp/tests/geometry/test_bvh.py +157 -22
- warp/tests/geometry/test_marching_cubes.py +0 -1
- warp/tests/geometry/test_mesh.py +5 -3
- warp/tests/geometry/test_mesh_query_aabb.py +5 -12
- warp/tests/geometry/test_mesh_query_point.py +5 -2
- warp/tests/geometry/test_mesh_query_ray.py +15 -3
- warp/tests/geometry/test_volume_write.py +5 -5
- warp/tests/interop/test_dlpack.py +18 -17
- warp/tests/interop/test_jax.py +772 -49
- warp/tests/interop/test_paddle.py +1 -1
- warp/tests/test_adam.py +0 -1
- warp/tests/test_arithmetic.py +9 -9
- warp/tests/test_array.py +578 -100
- warp/tests/test_array_reduce.py +3 -3
- warp/tests/test_atomic.py +12 -8
- warp/tests/test_atomic_bitwise.py +209 -0
- warp/tests/test_atomic_cas.py +4 -4
- warp/tests/test_bool.py +2 -2
- warp/tests/test_builtins_resolution.py +5 -571
- warp/tests/test_codegen.py +33 -14
- warp/tests/test_conditional.py +1 -1
- warp/tests/test_context.py +6 -6
- warp/tests/test_copy.py +242 -161
- warp/tests/test_ctypes.py +3 -3
- warp/tests/test_devices.py +24 -2
- warp/tests/test_examples.py +16 -84
- warp/tests/test_fabricarray.py +35 -35
- warp/tests/test_fast_math.py +0 -2
- warp/tests/test_fem.py +56 -10
- warp/tests/test_fixedarray.py +3 -3
- warp/tests/test_func.py +8 -5
- warp/tests/test_generics.py +1 -1
- warp/tests/test_indexedarray.py +24 -24
- warp/tests/test_intersect.py +39 -9
- warp/tests/test_large.py +1 -1
- warp/tests/test_lerp.py +3 -1
- warp/tests/test_linear_solvers.py +1 -1
- warp/tests/test_map.py +35 -4
- warp/tests/test_mat.py +52 -62
- warp/tests/test_mat_constructors.py +4 -5
- warp/tests/test_mat_lite.py +1 -1
- warp/tests/test_mat_scalar_ops.py +121 -121
- warp/tests/test_math.py +34 -0
- warp/tests/test_module_aot.py +4 -4
- warp/tests/test_modules_lite.py +28 -2
- warp/tests/test_print.py +11 -11
- warp/tests/test_quat.py +93 -58
- warp/tests/test_runlength_encode.py +1 -1
- warp/tests/test_scalar_ops.py +38 -10
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +126 -15
- warp/tests/test_spatial.py +105 -87
- warp/tests/test_special_values.py +6 -6
- warp/tests/test_static.py +7 -7
- warp/tests/test_struct.py +13 -2
- warp/tests/test_triangle_closest_point.py +48 -1
- warp/tests/test_types.py +27 -15
- warp/tests/test_utils.py +52 -52
- warp/tests/test_vec.py +29 -29
- warp/tests/test_vec_constructors.py +5 -5
- warp/tests/test_vec_scalar_ops.py +97 -97
- warp/tests/test_version.py +75 -0
- warp/tests/tile/test_tile.py +178 -0
- warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
- warp/tests/tile/test_tile_cholesky.py +7 -4
- warp/tests/tile/test_tile_load.py +26 -2
- warp/tests/tile/test_tile_mathdx.py +3 -3
- warp/tests/tile/test_tile_matmul.py +1 -1
- warp/tests/tile/test_tile_mlp.py +2 -4
- warp/tests/tile/test_tile_reduce.py +214 -13
- warp/tests/unittest_suites.py +6 -14
- warp/tests/unittest_utils.py +10 -9
- warp/tests/walkthrough_debug.py +3 -1
- warp/torch.py +6 -373
- warp/types.py +29 -5764
- warp/utils.py +10 -1659
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/METADATA +46 -99
- warp_lang-1.10.0.dist-info/RECORD +468 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
- warp_lang-1.10.0.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
- warp/examples/assets/cartpole.urdf +0 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/nv_ant.xml +0 -92
- warp/examples/assets/nv_humanoid.xml +0 -183
- warp/examples/assets/quadruped.urdf +0 -268
- warp/examples/optim/example_bounce.py +0 -266
- warp/examples/optim/example_cloth_throw.py +0 -228
- warp/examples/optim/example_drone.py +0 -870
- warp/examples/optim/example_inverse_kinematics.py +0 -182
- warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
- warp/examples/optim/example_softbody_properties.py +0 -400
- warp/examples/optim/example_spring_cage.py +0 -245
- warp/examples/optim/example_trajectory.py +0 -227
- warp/examples/sim/example_cartpole.py +0 -143
- warp/examples/sim/example_cloth.py +0 -225
- warp/examples/sim/example_cloth_self_contact.py +0 -316
- warp/examples/sim/example_granular.py +0 -130
- warp/examples/sim/example_granular_collision_sdf.py +0 -202
- warp/examples/sim/example_jacobian_ik.py +0 -244
- warp/examples/sim/example_particle_chain.py +0 -124
- warp/examples/sim/example_quadruped.py +0 -203
- warp/examples/sim/example_rigid_chain.py +0 -203
- warp/examples/sim/example_rigid_contact.py +0 -195
- warp/examples/sim/example_rigid_force.py +0 -133
- warp/examples/sim/example_rigid_gyroscopic.py +0 -115
- warp/examples/sim/example_rigid_soft_contact.py +0 -140
- warp/examples/sim/example_soft_body.py +0 -196
- warp/examples/tile/example_tile_walker.py +0 -327
- warp/sim/__init__.py +0 -74
- warp/sim/articulation.py +0 -793
- warp/sim/collide.py +0 -2570
- warp/sim/graph_coloring.py +0 -307
- warp/sim/import_mjcf.py +0 -791
- warp/sim/import_snu.py +0 -227
- warp/sim/import_urdf.py +0 -579
- warp/sim/import_usd.py +0 -898
- warp/sim/inertia.py +0 -357
- warp/sim/integrator.py +0 -245
- warp/sim/integrator_euler.py +0 -2000
- warp/sim/integrator_featherstone.py +0 -2101
- warp/sim/integrator_vbd.py +0 -2487
- warp/sim/integrator_xpbd.py +0 -3295
- warp/sim/model.py +0 -4821
- warp/sim/particles.py +0 -121
- warp/sim/render.py +0 -431
- warp/sim/utils.py +0 -431
- warp/tests/sim/disabled_kinematics.py +0 -244
- warp/tests/sim/test_cloth.py +0 -863
- warp/tests/sim/test_collision.py +0 -743
- warp/tests/sim/test_coloring.py +0 -347
- warp/tests/sim/test_inertia.py +0 -161
- warp/tests/sim/test_model.py +0 -226
- warp/tests/sim/test_sim_grad.py +0 -287
- warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
- warp/tests/sim/test_sim_kinematics.py +0 -98
- warp/thirdparty/__init__.py +0 -0
- warp_lang-1.9.1.dist-info/RECORD +0 -456
- /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
- /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0.dist-info}/top_level.txt +0 -0
warp/native/warp.cu
CHANGED
|
@@ -38,6 +38,7 @@
|
|
|
38
38
|
#include <iterator>
|
|
39
39
|
#include <list>
|
|
40
40
|
#include <map>
|
|
41
|
+
#include <mutex>
|
|
41
42
|
#include <string>
|
|
42
43
|
#include <unordered_map>
|
|
43
44
|
#include <unordered_set>
|
|
@@ -176,11 +177,20 @@ struct ContextInfo
|
|
|
176
177
|
CUmodule conditional_module = NULL;
|
|
177
178
|
};
|
|
178
179
|
|
|
180
|
+
// Information used for freeing allocations.
|
|
181
|
+
struct FreeInfo
|
|
182
|
+
{
|
|
183
|
+
void* context = NULL;
|
|
184
|
+
void* ptr = NULL;
|
|
185
|
+
bool is_async = false;
|
|
186
|
+
};
|
|
187
|
+
|
|
179
188
|
struct CaptureInfo
|
|
180
189
|
{
|
|
181
190
|
CUstream stream = NULL; // the main stream where capture begins and ends
|
|
182
191
|
uint64_t id = 0; // unique capture id from CUDA
|
|
183
192
|
bool external = false; // whether this is an external capture
|
|
193
|
+
std::vector<FreeInfo> tmp_allocs; // temporary allocations owned by the graph (e.g., staged array fill values)
|
|
184
194
|
};
|
|
185
195
|
|
|
186
196
|
struct StreamInfo
|
|
@@ -189,9 +199,13 @@ struct StreamInfo
|
|
|
189
199
|
CaptureInfo* capture = NULL; // capture info (only if started on this stream)
|
|
190
200
|
};
|
|
191
201
|
|
|
192
|
-
|
|
202
|
+
// Extra resources tied to a graph, freed after the graph is released by CUDA.
|
|
203
|
+
// Used with the on_graph_destroy() callback.
|
|
204
|
+
struct GraphDestroyCallbackInfo
|
|
193
205
|
{
|
|
194
|
-
|
|
206
|
+
void* context = NULL; // graph CUDA context
|
|
207
|
+
std::vector<void*> unfreed_allocs; // graph allocations not freed by the graph
|
|
208
|
+
std::vector<FreeInfo> tmp_allocs; // temporary allocations owned by the graph (e.g., staged array fill values)
|
|
195
209
|
};
|
|
196
210
|
|
|
197
211
|
// Information for graph allocations that are not freed by the graph.
|
|
@@ -207,19 +221,19 @@ struct GraphAllocInfo
|
|
|
207
221
|
bool graph_destroyed = false; // whether graph instance was destroyed
|
|
208
222
|
};
|
|
209
223
|
|
|
210
|
-
// Information used when deferring
|
|
211
|
-
struct
|
|
224
|
+
// Information used when deferring module unloading.
|
|
225
|
+
struct ModuleInfo
|
|
212
226
|
{
|
|
213
227
|
void* context = NULL;
|
|
214
|
-
void*
|
|
215
|
-
bool is_async = false;
|
|
228
|
+
void* module = NULL;
|
|
216
229
|
};
|
|
217
230
|
|
|
218
|
-
// Information used when deferring
|
|
219
|
-
struct
|
|
231
|
+
// Information used when deferring graph destruction.
|
|
232
|
+
struct GraphDestroyInfo
|
|
220
233
|
{
|
|
221
234
|
void* context = NULL;
|
|
222
|
-
void*
|
|
235
|
+
void* graph = NULL;
|
|
236
|
+
void* graph_exec = NULL;
|
|
223
237
|
};
|
|
224
238
|
|
|
225
239
|
static std::unordered_map<CUfunction, std::string> g_kernel_names;
|
|
@@ -253,6 +267,15 @@ static std::vector<FreeInfo> g_deferred_free_list;
|
|
|
253
267
|
// Call unload_deferred_modules() to release.
|
|
254
268
|
static std::vector<ModuleInfo> g_deferred_module_list;
|
|
255
269
|
|
|
270
|
+
// Graphs that cannot be destroyed immediately get queued here.
|
|
271
|
+
// Call destroy_deferred_graphs() to release.
|
|
272
|
+
static std::vector<GraphDestroyInfo> g_deferred_graph_list;
|
|
273
|
+
|
|
274
|
+
// Data from on_graph_destroy() callbacks that run on a different thread.
|
|
275
|
+
static std::vector<GraphDestroyCallbackInfo*> g_deferred_graph_destroy_list;
|
|
276
|
+
static std::mutex g_graph_destroy_mutex;
|
|
277
|
+
|
|
278
|
+
|
|
256
279
|
void wp_cuda_set_context_restore_policy(bool always_restore)
|
|
257
280
|
{
|
|
258
281
|
ContextGuard::always_restore = always_restore;
|
|
@@ -338,7 +361,7 @@ int cuda_init()
|
|
|
338
361
|
}
|
|
339
362
|
|
|
340
363
|
|
|
341
|
-
|
|
364
|
+
CUcontext get_current_context()
|
|
342
365
|
{
|
|
343
366
|
CUcontext ctx;
|
|
344
367
|
if (check_cu(cuCtxGetCurrent_f(&ctx)))
|
|
@@ -408,6 +431,114 @@ static inline StreamInfo* get_stream_info(CUstream stream)
|
|
|
408
431
|
return NULL;
|
|
409
432
|
}
|
|
410
433
|
|
|
434
|
+
static inline CaptureInfo* get_capture_info(CUstream stream)
|
|
435
|
+
{
|
|
436
|
+
if (!g_captures.empty() && wp_cuda_stream_is_capturing(stream))
|
|
437
|
+
{
|
|
438
|
+
uint64_t capture_id = get_capture_id(stream);
|
|
439
|
+
auto capture_iter = g_captures.find(capture_id);
|
|
440
|
+
if (capture_iter != g_captures.end())
|
|
441
|
+
return capture_iter->second;
|
|
442
|
+
}
|
|
443
|
+
return NULL;
|
|
444
|
+
}
|
|
445
|
+
|
|
446
|
+
// helper function to copy a value to device memory in a graph-friendly way
|
|
447
|
+
static bool capturable_tmp_alloc(void* context, const void* data, size_t size, void** devptr_ret, bool* free_devptr_ret)
|
|
448
|
+
{
|
|
449
|
+
ContextGuard guard(context);
|
|
450
|
+
|
|
451
|
+
CUstream stream = get_current_stream();
|
|
452
|
+
CaptureInfo* capture_info = get_capture_info(stream);
|
|
453
|
+
int device_ordinal = wp_cuda_context_get_device_ordinal(context);
|
|
454
|
+
void* devptr = NULL;
|
|
455
|
+
bool free_devptr = true;
|
|
456
|
+
|
|
457
|
+
if (capture_info)
|
|
458
|
+
{
|
|
459
|
+
// ongoing graph capture - need to stage the fill value so that it persists with the graph
|
|
460
|
+
if (CUDA_VERSION >= 12040 && wp_cuda_driver_version() >= 12040)
|
|
461
|
+
{
|
|
462
|
+
// pause the capture so that the alloc/memcpy won't be captured
|
|
463
|
+
void* graph = NULL;
|
|
464
|
+
if (!wp_cuda_graph_pause_capture(WP_CURRENT_CONTEXT, stream, &graph))
|
|
465
|
+
return false;
|
|
466
|
+
|
|
467
|
+
// copy value to device memory
|
|
468
|
+
devptr = wp_alloc_device(WP_CURRENT_CONTEXT, size);
|
|
469
|
+
if (!devptr)
|
|
470
|
+
{
|
|
471
|
+
fprintf(stderr, "Warp error: Failed to allocate %llu bytes on device 'cuda:%d' (in function %s)\n", (unsigned long long)size, device_ordinal, __FUNCTION__);
|
|
472
|
+
return false;
|
|
473
|
+
}
|
|
474
|
+
if (!check_cuda(cudaMemcpyAsync(devptr, data, size, cudaMemcpyHostToDevice, stream)))
|
|
475
|
+
return false;
|
|
476
|
+
|
|
477
|
+
// graph takes ownership of the value storage
|
|
478
|
+
FreeInfo free_info;
|
|
479
|
+
free_info.context = context ? context : get_current_context();
|
|
480
|
+
free_info.ptr = devptr;
|
|
481
|
+
free_info.is_async = wp_cuda_device_is_mempool_supported(device_ordinal);
|
|
482
|
+
|
|
483
|
+
// allocation will be freed when graph is destroyed
|
|
484
|
+
capture_info->tmp_allocs.push_back(free_info);
|
|
485
|
+
|
|
486
|
+
// resume the capture
|
|
487
|
+
if (!wp_cuda_graph_resume_capture(WP_CURRENT_CONTEXT, stream, graph))
|
|
488
|
+
return false;
|
|
489
|
+
|
|
490
|
+
free_devptr = false; // memory is owned by the graph, doesn't need to be freed
|
|
491
|
+
}
|
|
492
|
+
else
|
|
493
|
+
{
|
|
494
|
+
// older CUDA can't pause/resume the capture, so stage in CPU memory
|
|
495
|
+
void* hostptr = wp_alloc_host(size);
|
|
496
|
+
if (!hostptr)
|
|
497
|
+
{
|
|
498
|
+
fprintf(stderr, "Warp error: Failed to allocate %llu bytes on device 'cpu' (in function %s)\n", (unsigned long long)size, __FUNCTION__);
|
|
499
|
+
return false;
|
|
500
|
+
}
|
|
501
|
+
memcpy(hostptr, data, size);
|
|
502
|
+
|
|
503
|
+
// the device allocation and h2d copy will be captured in the graph
|
|
504
|
+
devptr = wp_alloc_device(WP_CURRENT_CONTEXT, size);
|
|
505
|
+
if (!devptr)
|
|
506
|
+
{
|
|
507
|
+
fprintf(stderr, "Warp error: Failed to allocate %llu bytes on device 'cuda:%d' (in function %s)\n", (unsigned long long)size, device_ordinal, __FUNCTION__);
|
|
508
|
+
return false;
|
|
509
|
+
}
|
|
510
|
+
if (!check_cuda(cudaMemcpyAsync(devptr, hostptr, size, cudaMemcpyHostToDevice, stream)))
|
|
511
|
+
return false;
|
|
512
|
+
|
|
513
|
+
// graph takes ownership of the value storage
|
|
514
|
+
FreeInfo free_info;
|
|
515
|
+
free_info.context = NULL;
|
|
516
|
+
free_info.ptr = hostptr;
|
|
517
|
+
free_info.is_async = false;
|
|
518
|
+
|
|
519
|
+
// allocation will be freed when graph is destroyed
|
|
520
|
+
capture_info->tmp_allocs.push_back(free_info);
|
|
521
|
+
}
|
|
522
|
+
}
|
|
523
|
+
else
|
|
524
|
+
{
|
|
525
|
+
// not capturing, copy the value to device memory
|
|
526
|
+
devptr = wp_alloc_device(WP_CURRENT_CONTEXT, size);
|
|
527
|
+
if (!devptr)
|
|
528
|
+
{
|
|
529
|
+
fprintf(stderr, "Warp error: Failed to allocate %llu bytes on device 'cuda:%d' (in function %s)\n", (unsigned long long)size, device_ordinal, __FUNCTION__);
|
|
530
|
+
return false;
|
|
531
|
+
}
|
|
532
|
+
if (!check_cuda(cudaMemcpyAsync(devptr, data, size, cudaMemcpyHostToDevice, stream)))
|
|
533
|
+
return false;
|
|
534
|
+
}
|
|
535
|
+
|
|
536
|
+
*devptr_ret = devptr;
|
|
537
|
+
*free_devptr_ret = free_devptr;
|
|
538
|
+
|
|
539
|
+
return true;
|
|
540
|
+
}
|
|
541
|
+
|
|
411
542
|
static void deferred_free(void* ptr, void* context, bool is_async)
|
|
412
543
|
{
|
|
413
544
|
FreeInfo free_info;
|
|
@@ -495,34 +626,124 @@ static int unload_deferred_modules(void* context = NULL)
|
|
|
495
626
|
return num_unloaded_modules;
|
|
496
627
|
}
|
|
497
628
|
|
|
498
|
-
static
|
|
629
|
+
static int destroy_deferred_graphs(void* context = NULL)
|
|
499
630
|
{
|
|
500
|
-
if (!
|
|
501
|
-
return;
|
|
631
|
+
if (g_deferred_graph_list.empty() || !g_captures.empty())
|
|
632
|
+
return 0;
|
|
502
633
|
|
|
503
|
-
|
|
634
|
+
int num_destroyed_graphs = 0;
|
|
635
|
+
for (auto it = g_deferred_graph_list.begin(); it != g_deferred_graph_list.end(); /*noop*/)
|
|
636
|
+
{
|
|
637
|
+
// destroy the graph if it matches the given context or if the context is unspecified
|
|
638
|
+
const GraphDestroyInfo& graph_info = *it;
|
|
639
|
+
if (graph_info.context == context || !context)
|
|
640
|
+
{
|
|
641
|
+
if (graph_info.graph)
|
|
642
|
+
{
|
|
643
|
+
check_cuda(cudaGraphDestroy((cudaGraph_t)graph_info.graph));
|
|
644
|
+
}
|
|
645
|
+
if (graph_info.graph_exec)
|
|
646
|
+
{
|
|
647
|
+
check_cuda(cudaGraphExecDestroy((cudaGraphExec_t)graph_info.graph_exec));
|
|
648
|
+
}
|
|
649
|
+
++num_destroyed_graphs;
|
|
650
|
+
it = g_deferred_graph_list.erase(it);
|
|
651
|
+
}
|
|
652
|
+
else
|
|
653
|
+
{
|
|
654
|
+
++it;
|
|
655
|
+
}
|
|
656
|
+
}
|
|
504
657
|
|
|
505
|
-
|
|
658
|
+
return num_destroyed_graphs;
|
|
659
|
+
}
|
|
660
|
+
|
|
661
|
+
static int process_deferred_graph_destroy_callbacks(void* context = NULL)
|
|
662
|
+
{
|
|
663
|
+
int num_freed = 0;
|
|
664
|
+
|
|
665
|
+
std::lock_guard<std::mutex> lock(g_graph_destroy_mutex);
|
|
666
|
+
|
|
667
|
+
for (auto it = g_deferred_graph_destroy_list.begin(); it != g_deferred_graph_destroy_list.end(); /*noop*/)
|
|
506
668
|
{
|
|
507
|
-
|
|
508
|
-
if (
|
|
669
|
+
GraphDestroyCallbackInfo* graph_info = *it;
|
|
670
|
+
if (graph_info->context == context || !context)
|
|
509
671
|
{
|
|
510
|
-
|
|
511
|
-
|
|
672
|
+
// handle unfreed graph allocations (may have outstanding user references)
|
|
673
|
+
for (void* ptr : graph_info->unfreed_allocs)
|
|
512
674
|
{
|
|
513
|
-
|
|
514
|
-
|
|
675
|
+
auto alloc_iter = g_graph_allocs.find(ptr);
|
|
676
|
+
if (alloc_iter != g_graph_allocs.end())
|
|
677
|
+
{
|
|
678
|
+
GraphAllocInfo& alloc_info = alloc_iter->second;
|
|
679
|
+
if (alloc_info.ref_exists)
|
|
680
|
+
{
|
|
681
|
+
// unreference from graph so the pointer will be deallocated when the user reference goes away
|
|
682
|
+
alloc_info.graph_destroyed = true;
|
|
683
|
+
}
|
|
684
|
+
else
|
|
685
|
+
{
|
|
686
|
+
// the pointer can be freed, no references remain
|
|
687
|
+
wp_free_device_async(alloc_info.context, ptr);
|
|
688
|
+
g_graph_allocs.erase(alloc_iter);
|
|
689
|
+
}
|
|
690
|
+
}
|
|
515
691
|
}
|
|
516
|
-
|
|
692
|
+
|
|
693
|
+
// handle temporary allocations owned by the graph (no user references)
|
|
694
|
+
for (const FreeInfo& tmp_info : graph_info->tmp_allocs)
|
|
517
695
|
{
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
696
|
+
if (tmp_info.context)
|
|
697
|
+
{
|
|
698
|
+
// GPU alloc
|
|
699
|
+
if (tmp_info.is_async)
|
|
700
|
+
{
|
|
701
|
+
wp_free_device_async(tmp_info.context, tmp_info.ptr);
|
|
702
|
+
}
|
|
703
|
+
else
|
|
704
|
+
{
|
|
705
|
+
wp_free_device_default(tmp_info.context, tmp_info.ptr);
|
|
706
|
+
}
|
|
707
|
+
}
|
|
708
|
+
else
|
|
709
|
+
{
|
|
710
|
+
// CPU alloc
|
|
711
|
+
wp_free_host(tmp_info.ptr);
|
|
712
|
+
}
|
|
521
713
|
}
|
|
714
|
+
|
|
715
|
+
++num_freed;
|
|
716
|
+
delete graph_info;
|
|
717
|
+
it = g_deferred_graph_destroy_list.erase(it);
|
|
718
|
+
}
|
|
719
|
+
else
|
|
720
|
+
{
|
|
721
|
+
++it;
|
|
522
722
|
}
|
|
523
723
|
}
|
|
524
724
|
|
|
525
|
-
|
|
725
|
+
return num_freed;
|
|
726
|
+
}
|
|
727
|
+
|
|
728
|
+
static int run_deferred_actions(void* context = NULL)
|
|
729
|
+
{
|
|
730
|
+
int num_actions = 0;
|
|
731
|
+
num_actions += free_deferred_allocs(context);
|
|
732
|
+
num_actions += unload_deferred_modules(context);
|
|
733
|
+
num_actions += destroy_deferred_graphs(context);
|
|
734
|
+
num_actions += process_deferred_graph_destroy_callbacks(context);
|
|
735
|
+
return num_actions;
|
|
736
|
+
}
|
|
737
|
+
|
|
738
|
+
// Callback used when a graph is destroyed.
|
|
739
|
+
// NOTE: this runs on an internal CUDA thread and requires synchronization.
|
|
740
|
+
static void CUDART_CB on_graph_destroy(void* user_data)
|
|
741
|
+
{
|
|
742
|
+
if (user_data)
|
|
743
|
+
{
|
|
744
|
+
std::lock_guard<std::mutex> lock(g_graph_destroy_mutex);
|
|
745
|
+
g_deferred_graph_destroy_list.push_back(static_cast<GraphDestroyCallbackInfo*>(user_data));
|
|
746
|
+
}
|
|
526
747
|
}
|
|
527
748
|
|
|
528
749
|
static inline const char* get_cuda_kernel_name(void* kernel)
|
|
@@ -974,30 +1195,36 @@ void wp_memtile_device(void* context, void* dst, const void* src, size_t srcsize
|
|
|
974
1195
|
else
|
|
975
1196
|
{
|
|
976
1197
|
// generic version
|
|
1198
|
+
void* value_devptr = NULL; // fill value in device memory
|
|
1199
|
+
bool free_devptr = true; // whether we need to free the memory
|
|
977
1200
|
|
|
978
|
-
//
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
|
|
1201
|
+
// prepare the fill value in a graph-friendly way
|
|
1202
|
+
if (!capturable_tmp_alloc(WP_CURRENT_CONTEXT, src, srcsize, &value_devptr, &free_devptr))
|
|
1203
|
+
{
|
|
1204
|
+
fprintf(stderr, "Warp fill error: failed to copy value to device memory\n");
|
|
1205
|
+
return;
|
|
1206
|
+
}
|
|
984
1207
|
|
|
985
|
-
|
|
1208
|
+
wp_launch_device(WP_CURRENT_CONTEXT, memtile_kernel, n, (dst, value_devptr, srcsize, n));
|
|
986
1209
|
|
|
1210
|
+
if (free_devptr)
|
|
1211
|
+
{
|
|
1212
|
+
wp_free_device(WP_CURRENT_CONTEXT, value_devptr);
|
|
1213
|
+
}
|
|
987
1214
|
}
|
|
988
1215
|
}
|
|
989
1216
|
|
|
990
1217
|
|
|
991
1218
|
static __global__ void array_copy_1d_kernel(void* dst, const void* src,
|
|
992
|
-
|
|
1219
|
+
size_t dst_stride, size_t src_stride,
|
|
993
1220
|
const int* dst_indices, const int* src_indices,
|
|
994
|
-
|
|
1221
|
+
size_t n, size_t elem_size)
|
|
995
1222
|
{
|
|
996
|
-
|
|
1223
|
+
size_t i = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
997
1224
|
if (i < n)
|
|
998
1225
|
{
|
|
999
|
-
|
|
1000
|
-
|
|
1226
|
+
size_t src_idx = src_indices ? src_indices[i] : i;
|
|
1227
|
+
size_t dst_idx = dst_indices ? dst_indices[i] : i;
|
|
1001
1228
|
const char* p = (const char*)src + src_idx * src_stride;
|
|
1002
1229
|
char* q = (char*)dst + dst_idx * dst_stride;
|
|
1003
1230
|
memcpy(q, p, elem_size);
|
|
@@ -1005,20 +1232,20 @@ static __global__ void array_copy_1d_kernel(void* dst, const void* src,
|
|
|
1005
1232
|
}
|
|
1006
1233
|
|
|
1007
1234
|
static __global__ void array_copy_2d_kernel(void* dst, const void* src,
|
|
1008
|
-
wp::vec_t<2,
|
|
1235
|
+
wp::vec_t<2, size_t> dst_strides, wp::vec_t<2, size_t> src_strides,
|
|
1009
1236
|
wp::vec_t<2, const int*> dst_indices, wp::vec_t<2, const int*> src_indices,
|
|
1010
|
-
wp::vec_t<2,
|
|
1237
|
+
wp::vec_t<2, size_t> shape, size_t elem_size)
|
|
1011
1238
|
{
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1239
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1240
|
+
size_t n = shape[1];
|
|
1241
|
+
size_t i = tid / n;
|
|
1242
|
+
size_t j = tid % n;
|
|
1016
1243
|
if (i < shape[0] /*&& j < shape[1]*/)
|
|
1017
1244
|
{
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1245
|
+
size_t src_idx0 = src_indices[0] ? src_indices[0][i] : i;
|
|
1246
|
+
size_t dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
|
|
1247
|
+
size_t src_idx1 = src_indices[1] ? src_indices[1][j] : j;
|
|
1248
|
+
size_t dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
|
|
1022
1249
|
const char* p = (const char*)src + src_idx0 * src_strides[0] + src_idx1 * src_strides[1];
|
|
1023
1250
|
char* q = (char*)dst + dst_idx0 * dst_strides[0] + dst_idx1 * dst_strides[1];
|
|
1024
1251
|
memcpy(q, p, elem_size);
|
|
@@ -1026,24 +1253,24 @@ static __global__ void array_copy_2d_kernel(void* dst, const void* src,
|
|
|
1026
1253
|
}
|
|
1027
1254
|
|
|
1028
1255
|
static __global__ void array_copy_3d_kernel(void* dst, const void* src,
|
|
1029
|
-
wp::vec_t<3,
|
|
1256
|
+
wp::vec_t<3, size_t> dst_strides, wp::vec_t<3, size_t> src_strides,
|
|
1030
1257
|
wp::vec_t<3, const int*> dst_indices, wp::vec_t<3, const int*> src_indices,
|
|
1031
|
-
wp::vec_t<3,
|
|
1032
|
-
{
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
1038
|
-
|
|
1258
|
+
wp::vec_t<3, size_t> shape, size_t elem_size)
|
|
1259
|
+
{
|
|
1260
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1261
|
+
size_t n = shape[1];
|
|
1262
|
+
size_t o = shape[2];
|
|
1263
|
+
size_t i = tid / (n * o);
|
|
1264
|
+
size_t j = tid % (n * o) / o;
|
|
1265
|
+
size_t k = tid % o;
|
|
1039
1266
|
if (i < shape[0] && j < shape[1] /*&& k < shape[2]*/)
|
|
1040
1267
|
{
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1268
|
+
size_t src_idx0 = src_indices[0] ? src_indices[0][i] : i;
|
|
1269
|
+
size_t dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
|
|
1270
|
+
size_t src_idx1 = src_indices[1] ? src_indices[1][j] : j;
|
|
1271
|
+
size_t dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
|
|
1272
|
+
size_t src_idx2 = src_indices[2] ? src_indices[2][k] : k;
|
|
1273
|
+
size_t dst_idx2 = dst_indices[2] ? dst_indices[2][k] : k;
|
|
1047
1274
|
const char* p = (const char*)src + src_idx0 * src_strides[0]
|
|
1048
1275
|
+ src_idx1 * src_strides[1]
|
|
1049
1276
|
+ src_idx2 * src_strides[2];
|
|
@@ -1055,28 +1282,28 @@ static __global__ void array_copy_3d_kernel(void* dst, const void* src,
|
|
|
1055
1282
|
}
|
|
1056
1283
|
|
|
1057
1284
|
static __global__ void array_copy_4d_kernel(void* dst, const void* src,
|
|
1058
|
-
wp::vec_t<4,
|
|
1285
|
+
wp::vec_t<4, size_t> dst_strides, wp::vec_t<4, size_t> src_strides,
|
|
1059
1286
|
wp::vec_t<4, const int*> dst_indices, wp::vec_t<4, const int*> src_indices,
|
|
1060
|
-
wp::vec_t<4,
|
|
1061
|
-
{
|
|
1062
|
-
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
|
|
1067
|
-
|
|
1068
|
-
|
|
1069
|
-
|
|
1287
|
+
wp::vec_t<4, size_t> shape, size_t elem_size)
|
|
1288
|
+
{
|
|
1289
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1290
|
+
size_t n = shape[1];
|
|
1291
|
+
size_t o = shape[2];
|
|
1292
|
+
size_t p = shape[3];
|
|
1293
|
+
size_t i = tid / (n * o * p);
|
|
1294
|
+
size_t j = tid % (n * o * p) / (o * p);
|
|
1295
|
+
size_t k = tid % (o * p) / p;
|
|
1296
|
+
size_t l = tid % p;
|
|
1070
1297
|
if (i < shape[0] && j < shape[1] && k < shape[2] /*&& l < shape[3]*/)
|
|
1071
1298
|
{
|
|
1072
|
-
|
|
1073
|
-
|
|
1074
|
-
|
|
1075
|
-
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1299
|
+
size_t src_idx0 = src_indices[0] ? src_indices[0][i] : i;
|
|
1300
|
+
size_t dst_idx0 = dst_indices[0] ? dst_indices[0][i] : i;
|
|
1301
|
+
size_t src_idx1 = src_indices[1] ? src_indices[1][j] : j;
|
|
1302
|
+
size_t dst_idx1 = dst_indices[1] ? dst_indices[1][j] : j;
|
|
1303
|
+
size_t src_idx2 = src_indices[2] ? src_indices[2][k] : k;
|
|
1304
|
+
size_t dst_idx2 = dst_indices[2] ? dst_indices[2][k] : k;
|
|
1305
|
+
size_t src_idx3 = src_indices[3] ? src_indices[3][l] : l;
|
|
1306
|
+
size_t dst_idx3 = dst_indices[3] ? dst_indices[3][l] : l;
|
|
1080
1307
|
const char* p = (const char*)src + src_idx0 * src_strides[0]
|
|
1081
1308
|
+ src_idx1 * src_strides[1]
|
|
1082
1309
|
+ src_idx2 * src_strides[2]
|
|
@@ -1091,14 +1318,14 @@ static __global__ void array_copy_4d_kernel(void* dst, const void* src,
|
|
|
1091
1318
|
|
|
1092
1319
|
|
|
1093
1320
|
static __global__ void array_copy_from_fabric_kernel(wp::fabricarray_t<void> src,
|
|
1094
|
-
void* dst_data,
|
|
1095
|
-
|
|
1321
|
+
void* dst_data, size_t dst_stride, const int* dst_indices,
|
|
1322
|
+
size_t elem_size)
|
|
1096
1323
|
{
|
|
1097
|
-
|
|
1324
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1098
1325
|
|
|
1099
1326
|
if (tid < src.size)
|
|
1100
1327
|
{
|
|
1101
|
-
|
|
1328
|
+
size_t dst_idx = dst_indices ? dst_indices[tid] : tid;
|
|
1102
1329
|
void* dst_ptr = (char*)dst_data + dst_idx * dst_stride;
|
|
1103
1330
|
const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
|
|
1104
1331
|
memcpy(dst_ptr, src_ptr, elem_size);
|
|
@@ -1106,15 +1333,15 @@ static __global__ void array_copy_from_fabric_kernel(wp::fabricarray_t<void> src
|
|
|
1106
1333
|
}
|
|
1107
1334
|
|
|
1108
1335
|
static __global__ void array_copy_from_fabric_indexed_kernel(wp::indexedfabricarray_t<void> src,
|
|
1109
|
-
void* dst_data,
|
|
1110
|
-
|
|
1336
|
+
void* dst_data, size_t dst_stride, const int* dst_indices,
|
|
1337
|
+
size_t elem_size)
|
|
1111
1338
|
{
|
|
1112
|
-
|
|
1339
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1113
1340
|
|
|
1114
1341
|
if (tid < src.size)
|
|
1115
1342
|
{
|
|
1116
|
-
|
|
1117
|
-
|
|
1343
|
+
size_t src_index = src.indices[tid];
|
|
1344
|
+
size_t dst_idx = dst_indices ? dst_indices[tid] : tid;
|
|
1118
1345
|
void* dst_ptr = (char*)dst_data + dst_idx * dst_stride;
|
|
1119
1346
|
const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
|
|
1120
1347
|
memcpy(dst_ptr, src_ptr, elem_size);
|
|
@@ -1122,14 +1349,14 @@ static __global__ void array_copy_from_fabric_indexed_kernel(wp::indexedfabricar
|
|
|
1122
1349
|
}
|
|
1123
1350
|
|
|
1124
1351
|
static __global__ void array_copy_to_fabric_kernel(wp::fabricarray_t<void> dst,
|
|
1125
|
-
const void* src_data,
|
|
1126
|
-
|
|
1352
|
+
const void* src_data, size_t src_stride, const int* src_indices,
|
|
1353
|
+
size_t elem_size)
|
|
1127
1354
|
{
|
|
1128
|
-
|
|
1355
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1129
1356
|
|
|
1130
1357
|
if (tid < dst.size)
|
|
1131
1358
|
{
|
|
1132
|
-
|
|
1359
|
+
size_t src_idx = src_indices ? src_indices[tid] : tid;
|
|
1133
1360
|
const void* src_ptr = (const char*)src_data + src_idx * src_stride;
|
|
1134
1361
|
void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
|
|
1135
1362
|
memcpy(dst_ptr, src_ptr, elem_size);
|
|
@@ -1137,25 +1364,25 @@ static __global__ void array_copy_to_fabric_kernel(wp::fabricarray_t<void> dst,
|
|
|
1137
1364
|
}
|
|
1138
1365
|
|
|
1139
1366
|
static __global__ void array_copy_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst,
|
|
1140
|
-
const void* src_data,
|
|
1141
|
-
|
|
1367
|
+
const void* src_data, size_t src_stride, const int* src_indices,
|
|
1368
|
+
size_t elem_size)
|
|
1142
1369
|
{
|
|
1143
|
-
|
|
1370
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1144
1371
|
|
|
1145
1372
|
if (tid < dst.size)
|
|
1146
1373
|
{
|
|
1147
|
-
|
|
1374
|
+
size_t src_idx = src_indices ? src_indices[tid] : tid;
|
|
1148
1375
|
const void* src_ptr = (const char*)src_data + src_idx * src_stride;
|
|
1149
|
-
|
|
1376
|
+
size_t dst_idx = dst.indices[tid];
|
|
1150
1377
|
void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_idx, elem_size);
|
|
1151
1378
|
memcpy(dst_ptr, src_ptr, elem_size);
|
|
1152
1379
|
}
|
|
1153
1380
|
}
|
|
1154
1381
|
|
|
1155
1382
|
|
|
1156
|
-
static __global__ void array_copy_fabric_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::fabricarray_t<void> src,
|
|
1383
|
+
static __global__ void array_copy_fabric_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::fabricarray_t<void> src, size_t elem_size)
|
|
1157
1384
|
{
|
|
1158
|
-
|
|
1385
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1159
1386
|
|
|
1160
1387
|
if (tid < dst.size)
|
|
1161
1388
|
{
|
|
@@ -1166,27 +1393,27 @@ static __global__ void array_copy_fabric_to_fabric_kernel(wp::fabricarray_t<void
|
|
|
1166
1393
|
}
|
|
1167
1394
|
|
|
1168
1395
|
|
|
1169
|
-
static __global__ void array_copy_fabric_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::fabricarray_t<void> src,
|
|
1396
|
+
static __global__ void array_copy_fabric_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::fabricarray_t<void> src, size_t elem_size)
|
|
1170
1397
|
{
|
|
1171
|
-
|
|
1398
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1172
1399
|
|
|
1173
1400
|
if (tid < dst.size)
|
|
1174
1401
|
{
|
|
1175
1402
|
const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
|
|
1176
|
-
|
|
1403
|
+
size_t dst_index = dst.indices[tid];
|
|
1177
1404
|
void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_index, elem_size);
|
|
1178
1405
|
memcpy(dst_ptr, src_ptr, elem_size);
|
|
1179
1406
|
}
|
|
1180
1407
|
}
|
|
1181
1408
|
|
|
1182
1409
|
|
|
1183
|
-
static __global__ void array_copy_fabric_indexed_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::indexedfabricarray_t<void> src,
|
|
1410
|
+
static __global__ void array_copy_fabric_indexed_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, size_t elem_size)
|
|
1184
1411
|
{
|
|
1185
|
-
|
|
1412
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1186
1413
|
|
|
1187
1414
|
if (tid < dst.size)
|
|
1188
1415
|
{
|
|
1189
|
-
|
|
1416
|
+
size_t src_index = src.indices[tid];
|
|
1190
1417
|
const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
|
|
1191
1418
|
void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
|
|
1192
1419
|
memcpy(dst_ptr, src_ptr, elem_size);
|
|
@@ -1194,14 +1421,14 @@ static __global__ void array_copy_fabric_indexed_to_fabric_kernel(wp::fabricarra
|
|
|
1194
1421
|
}
|
|
1195
1422
|
|
|
1196
1423
|
|
|
1197
|
-
static __global__ void array_copy_fabric_indexed_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::indexedfabricarray_t<void> src,
|
|
1424
|
+
static __global__ void array_copy_fabric_indexed_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, size_t elem_size)
|
|
1198
1425
|
{
|
|
1199
|
-
|
|
1426
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1200
1427
|
|
|
1201
1428
|
if (tid < dst.size)
|
|
1202
1429
|
{
|
|
1203
|
-
|
|
1204
|
-
|
|
1430
|
+
size_t src_index = src.indices[tid];
|
|
1431
|
+
size_t dst_index = dst.indices[tid];
|
|
1205
1432
|
const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
|
|
1206
1433
|
void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_index, elem_size);
|
|
1207
1434
|
memcpy(dst_ptr, src_ptr, elem_size);
|
|
@@ -1440,9 +1667,9 @@ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_ty
|
|
|
1440
1667
|
}
|
|
1441
1668
|
case 2:
|
|
1442
1669
|
{
|
|
1443
|
-
wp::vec_t<2,
|
|
1444
|
-
wp::vec_t<2,
|
|
1445
|
-
wp::vec_t<2,
|
|
1670
|
+
wp::vec_t<2, size_t> shape_v(src_shape[0], src_shape[1]);
|
|
1671
|
+
wp::vec_t<2, size_t> src_strides_v(src_strides[0], src_strides[1]);
|
|
1672
|
+
wp::vec_t<2, size_t> dst_strides_v(dst_strides[0], dst_strides[1]);
|
|
1446
1673
|
wp::vec_t<2, const int*> src_indices_v(src_indices[0], src_indices[1]);
|
|
1447
1674
|
wp::vec_t<2, const int*> dst_indices_v(dst_indices[0], dst_indices[1]);
|
|
1448
1675
|
|
|
@@ -1454,9 +1681,9 @@ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_ty
|
|
|
1454
1681
|
}
|
|
1455
1682
|
case 3:
|
|
1456
1683
|
{
|
|
1457
|
-
wp::vec_t<3,
|
|
1458
|
-
wp::vec_t<3,
|
|
1459
|
-
wp::vec_t<3,
|
|
1684
|
+
wp::vec_t<3, size_t> shape_v(src_shape[0], src_shape[1], src_shape[2]);
|
|
1685
|
+
wp::vec_t<3, size_t> src_strides_v(src_strides[0], src_strides[1], src_strides[2]);
|
|
1686
|
+
wp::vec_t<3, size_t> dst_strides_v(dst_strides[0], dst_strides[1], dst_strides[2]);
|
|
1460
1687
|
wp::vec_t<3, const int*> src_indices_v(src_indices[0], src_indices[1], src_indices[2]);
|
|
1461
1688
|
wp::vec_t<3, const int*> dst_indices_v(dst_indices[0], dst_indices[1], dst_indices[2]);
|
|
1462
1689
|
|
|
@@ -1468,9 +1695,9 @@ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_ty
|
|
|
1468
1695
|
}
|
|
1469
1696
|
case 4:
|
|
1470
1697
|
{
|
|
1471
|
-
wp::vec_t<4,
|
|
1472
|
-
wp::vec_t<4,
|
|
1473
|
-
wp::vec_t<4,
|
|
1698
|
+
wp::vec_t<4, size_t> shape_v(src_shape[0], src_shape[1], src_shape[2], src_shape[3]);
|
|
1699
|
+
wp::vec_t<4, size_t> src_strides_v(src_strides[0], src_strides[1], src_strides[2], src_strides[3]);
|
|
1700
|
+
wp::vec_t<4, size_t> dst_strides_v(dst_strides[0], dst_strides[1], dst_strides[2], dst_strides[3]);
|
|
1474
1701
|
wp::vec_t<4, const int*> src_indices_v(src_indices[0], src_indices[1], src_indices[2], src_indices[3]);
|
|
1475
1702
|
wp::vec_t<4, const int*> dst_indices_v(dst_indices[0], dst_indices[1], dst_indices[2], dst_indices[3]);
|
|
1476
1703
|
|
|
@@ -1490,94 +1717,94 @@ WP_API bool wp_array_copy_device(void* context, void* dst, void* src, int dst_ty
|
|
|
1490
1717
|
|
|
1491
1718
|
|
|
1492
1719
|
static __global__ void array_fill_1d_kernel(void* data,
|
|
1493
|
-
|
|
1494
|
-
|
|
1720
|
+
size_t n,
|
|
1721
|
+
size_t stride,
|
|
1495
1722
|
const int* indices,
|
|
1496
1723
|
const void* value,
|
|
1497
|
-
|
|
1724
|
+
size_t value_size)
|
|
1498
1725
|
{
|
|
1499
|
-
|
|
1726
|
+
size_t i = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1500
1727
|
if (i < n)
|
|
1501
1728
|
{
|
|
1502
|
-
|
|
1729
|
+
size_t idx = indices ? indices[i] : i;
|
|
1503
1730
|
char* p = (char*)data + idx * stride;
|
|
1504
1731
|
memcpy(p, value, value_size);
|
|
1505
1732
|
}
|
|
1506
1733
|
}
|
|
1507
1734
|
|
|
1508
1735
|
static __global__ void array_fill_2d_kernel(void* data,
|
|
1509
|
-
wp::vec_t<2,
|
|
1510
|
-
wp::vec_t<2,
|
|
1736
|
+
wp::vec_t<2, size_t> shape,
|
|
1737
|
+
wp::vec_t<2, size_t> strides,
|
|
1511
1738
|
wp::vec_t<2, const int*> indices,
|
|
1512
1739
|
const void* value,
|
|
1513
|
-
|
|
1740
|
+
size_t value_size)
|
|
1514
1741
|
{
|
|
1515
|
-
|
|
1516
|
-
|
|
1517
|
-
|
|
1518
|
-
|
|
1742
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1743
|
+
size_t n = shape[1];
|
|
1744
|
+
size_t i = tid / n;
|
|
1745
|
+
size_t j = tid % n;
|
|
1519
1746
|
if (i < shape[0] /*&& j < shape[1]*/)
|
|
1520
1747
|
{
|
|
1521
|
-
|
|
1522
|
-
|
|
1748
|
+
size_t idx0 = indices[0] ? indices[0][i] : i;
|
|
1749
|
+
size_t idx1 = indices[1] ? indices[1][j] : j;
|
|
1523
1750
|
char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1];
|
|
1524
1751
|
memcpy(p, value, value_size);
|
|
1525
1752
|
}
|
|
1526
1753
|
}
|
|
1527
1754
|
|
|
1528
1755
|
static __global__ void array_fill_3d_kernel(void* data,
|
|
1529
|
-
wp::vec_t<3,
|
|
1530
|
-
wp::vec_t<3,
|
|
1756
|
+
wp::vec_t<3, size_t> shape,
|
|
1757
|
+
wp::vec_t<3, size_t> strides,
|
|
1531
1758
|
wp::vec_t<3, const int*> indices,
|
|
1532
1759
|
const void* value,
|
|
1533
|
-
|
|
1534
|
-
{
|
|
1535
|
-
|
|
1536
|
-
|
|
1537
|
-
|
|
1538
|
-
|
|
1539
|
-
|
|
1540
|
-
|
|
1760
|
+
size_t value_size)
|
|
1761
|
+
{
|
|
1762
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1763
|
+
size_t n = shape[1];
|
|
1764
|
+
size_t o = shape[2];
|
|
1765
|
+
size_t i = tid / (n * o);
|
|
1766
|
+
size_t j = tid % (n * o) / o;
|
|
1767
|
+
size_t k = tid % o;
|
|
1541
1768
|
if (i < shape[0] && j < shape[1] /*&& k < shape[2]*/)
|
|
1542
1769
|
{
|
|
1543
|
-
|
|
1544
|
-
|
|
1545
|
-
|
|
1770
|
+
size_t idx0 = indices[0] ? indices[0][i] : i;
|
|
1771
|
+
size_t idx1 = indices[1] ? indices[1][j] : j;
|
|
1772
|
+
size_t idx2 = indices[2] ? indices[2][k] : k;
|
|
1546
1773
|
char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1] + idx2 * strides[2];
|
|
1547
1774
|
memcpy(p, value, value_size);
|
|
1548
1775
|
}
|
|
1549
1776
|
}
|
|
1550
1777
|
|
|
1551
1778
|
static __global__ void array_fill_4d_kernel(void* data,
|
|
1552
|
-
wp::vec_t<4,
|
|
1553
|
-
wp::vec_t<4,
|
|
1779
|
+
wp::vec_t<4, size_t> shape,
|
|
1780
|
+
wp::vec_t<4, size_t> strides,
|
|
1554
1781
|
wp::vec_t<4, const int*> indices,
|
|
1555
1782
|
const void* value,
|
|
1556
|
-
|
|
1557
|
-
{
|
|
1558
|
-
|
|
1559
|
-
|
|
1560
|
-
|
|
1561
|
-
|
|
1562
|
-
|
|
1563
|
-
|
|
1564
|
-
|
|
1565
|
-
|
|
1783
|
+
size_t value_size)
|
|
1784
|
+
{
|
|
1785
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1786
|
+
size_t n = shape[1];
|
|
1787
|
+
size_t o = shape[2];
|
|
1788
|
+
size_t p = shape[3];
|
|
1789
|
+
size_t i = tid / (n * o * p);
|
|
1790
|
+
size_t j = tid % (n * o * p) / (o * p);
|
|
1791
|
+
size_t k = tid % (o * p) / p;
|
|
1792
|
+
size_t l = tid % p;
|
|
1566
1793
|
if (i < shape[0] && j < shape[1] && k < shape[2] /*&& l < shape[3]*/)
|
|
1567
1794
|
{
|
|
1568
|
-
|
|
1569
|
-
|
|
1570
|
-
|
|
1571
|
-
|
|
1795
|
+
size_t idx0 = indices[0] ? indices[0][i] : i;
|
|
1796
|
+
size_t idx1 = indices[1] ? indices[1][j] : j;
|
|
1797
|
+
size_t idx2 = indices[2] ? indices[2][k] : k;
|
|
1798
|
+
size_t idx3 = indices[3] ? indices[3][l] : l;
|
|
1572
1799
|
char* p = (char*)data + idx0 * strides[0] + idx1 * strides[1] + idx2 * strides[2] + idx3 * strides[3];
|
|
1573
1800
|
memcpy(p, value, value_size);
|
|
1574
1801
|
}
|
|
1575
1802
|
}
|
|
1576
1803
|
|
|
1577
1804
|
|
|
1578
|
-
static __global__ void array_fill_fabric_kernel(wp::fabricarray_t<void> fa, const void* value,
|
|
1805
|
+
static __global__ void array_fill_fabric_kernel(wp::fabricarray_t<void> fa, const void* value, size_t value_size)
|
|
1579
1806
|
{
|
|
1580
|
-
|
|
1807
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1581
1808
|
if (tid < fa.size)
|
|
1582
1809
|
{
|
|
1583
1810
|
void* dst_ptr = fabricarray_element_ptr(fa, tid, value_size);
|
|
@@ -1586,9 +1813,9 @@ static __global__ void array_fill_fabric_kernel(wp::fabricarray_t<void> fa, cons
|
|
|
1586
1813
|
}
|
|
1587
1814
|
|
|
1588
1815
|
|
|
1589
|
-
static __global__ void array_fill_fabric_indexed_kernel(wp::indexedfabricarray_t<void> ifa, const void* value,
|
|
1816
|
+
static __global__ void array_fill_fabric_indexed_kernel(wp::indexedfabricarray_t<void> ifa, const void* value, size_t value_size)
|
|
1590
1817
|
{
|
|
1591
|
-
|
|
1818
|
+
size_t tid = size_t(blockIdx.x) * size_t(blockDim.x) + size_t(threadIdx.x);
|
|
1592
1819
|
if (tid < ifa.size)
|
|
1593
1820
|
{
|
|
1594
1821
|
size_t idx = size_t(ifa.indices[tid]);
|
|
@@ -1655,67 +1882,76 @@ WP_API void wp_array_fill_device(void* context, void* arr_ptr, int arr_type, con
|
|
|
1655
1882
|
|
|
1656
1883
|
ContextGuard guard(context);
|
|
1657
1884
|
|
|
1658
|
-
//
|
|
1659
|
-
|
|
1660
|
-
|
|
1661
|
-
|
|
1885
|
+
void* value_devptr = NULL; // fill value in device memory
|
|
1886
|
+
bool free_devptr = true; // whether we need to free the memory
|
|
1887
|
+
|
|
1888
|
+
// prepare the fill value in a graph-friendly way
|
|
1889
|
+
if (!capturable_tmp_alloc(WP_CURRENT_CONTEXT, value_ptr, value_size, &value_devptr, &free_devptr))
|
|
1890
|
+
{
|
|
1891
|
+
fprintf(stderr, "Warp fill error: failed to copy value to device memory\n");
|
|
1892
|
+
return;
|
|
1893
|
+
}
|
|
1662
1894
|
|
|
1663
|
-
// handle fabric arrays
|
|
1664
1895
|
if (fa)
|
|
1665
1896
|
{
|
|
1897
|
+
// handle fabric arrays
|
|
1666
1898
|
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_fabric_kernel, n,
|
|
1667
1899
|
(*fa, value_devptr, value_size));
|
|
1668
|
-
return;
|
|
1669
1900
|
}
|
|
1670
1901
|
else if (ifa)
|
|
1671
1902
|
{
|
|
1903
|
+
// handle indexed fabric arrays
|
|
1672
1904
|
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_fabric_indexed_kernel, n,
|
|
1673
1905
|
(*ifa, value_devptr, value_size));
|
|
1674
|
-
return;
|
|
1675
1906
|
}
|
|
1676
|
-
|
|
1677
|
-
// handle regular or indexed arrays
|
|
1678
|
-
switch (ndim)
|
|
1679
|
-
{
|
|
1680
|
-
case 1:
|
|
1681
|
-
{
|
|
1682
|
-
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_1d_kernel, n,
|
|
1683
|
-
(data, shape[0], strides[0], indices[0], value_devptr, value_size));
|
|
1684
|
-
break;
|
|
1685
|
-
}
|
|
1686
|
-
case 2:
|
|
1687
|
-
{
|
|
1688
|
-
wp::vec_t<2, int> shape_v(shape[0], shape[1]);
|
|
1689
|
-
wp::vec_t<2, int> strides_v(strides[0], strides[1]);
|
|
1690
|
-
wp::vec_t<2, const int*> indices_v(indices[0], indices[1]);
|
|
1691
|
-
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_2d_kernel, n,
|
|
1692
|
-
(data, shape_v, strides_v, indices_v, value_devptr, value_size));
|
|
1693
|
-
break;
|
|
1694
|
-
}
|
|
1695
|
-
case 3:
|
|
1907
|
+
else
|
|
1696
1908
|
{
|
|
1697
|
-
|
|
1698
|
-
|
|
1699
|
-
|
|
1700
|
-
|
|
1701
|
-
|
|
1702
|
-
|
|
1909
|
+
// handle regular or indexed arrays
|
|
1910
|
+
switch (ndim)
|
|
1911
|
+
{
|
|
1912
|
+
case 1:
|
|
1913
|
+
{
|
|
1914
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_1d_kernel, n,
|
|
1915
|
+
(data, shape[0], strides[0], indices[0], value_devptr, value_size));
|
|
1916
|
+
break;
|
|
1917
|
+
}
|
|
1918
|
+
case 2:
|
|
1919
|
+
{
|
|
1920
|
+
wp::vec_t<2, size_t> shape_v(shape[0], shape[1]);
|
|
1921
|
+
wp::vec_t<2, size_t> strides_v(strides[0], strides[1]);
|
|
1922
|
+
wp::vec_t<2, const int*> indices_v(indices[0], indices[1]);
|
|
1923
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_2d_kernel, n,
|
|
1924
|
+
(data, shape_v, strides_v, indices_v, value_devptr, value_size));
|
|
1925
|
+
break;
|
|
1926
|
+
}
|
|
1927
|
+
case 3:
|
|
1928
|
+
{
|
|
1929
|
+
wp::vec_t<3, size_t> shape_v(shape[0], shape[1], shape[2]);
|
|
1930
|
+
wp::vec_t<3, size_t> strides_v(strides[0], strides[1], strides[2]);
|
|
1931
|
+
wp::vec_t<3, const int*> indices_v(indices[0], indices[1], indices[2]);
|
|
1932
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_3d_kernel, n,
|
|
1933
|
+
(data, shape_v, strides_v, indices_v, value_devptr, value_size));
|
|
1934
|
+
break;
|
|
1935
|
+
}
|
|
1936
|
+
case 4:
|
|
1937
|
+
{
|
|
1938
|
+
wp::vec_t<4, size_t> shape_v(shape[0], shape[1], shape[2], shape[3]);
|
|
1939
|
+
wp::vec_t<4, size_t> strides_v(strides[0], strides[1], strides[2], strides[3]);
|
|
1940
|
+
wp::vec_t<4, const int*> indices_v(indices[0], indices[1], indices[2], indices[3]);
|
|
1941
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_4d_kernel, n,
|
|
1942
|
+
(data, shape_v, strides_v, indices_v, value_devptr, value_size));
|
|
1943
|
+
break;
|
|
1944
|
+
}
|
|
1945
|
+
default:
|
|
1946
|
+
fprintf(stderr, "Warp fill error: invalid array dimensionality (%d)\n", ndim);
|
|
1947
|
+
break;
|
|
1948
|
+
}
|
|
1703
1949
|
}
|
|
1704
|
-
|
|
1950
|
+
|
|
1951
|
+
if (free_devptr)
|
|
1705
1952
|
{
|
|
1706
|
-
|
|
1707
|
-
wp::vec_t<4, int> strides_v(strides[0], strides[1], strides[2], strides[3]);
|
|
1708
|
-
wp::vec_t<4, const int*> indices_v(indices[0], indices[1], indices[2], indices[3]);
|
|
1709
|
-
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_4d_kernel, n,
|
|
1710
|
-
(data, shape_v, strides_v, indices_v, value_devptr, value_size));
|
|
1711
|
-
break;
|
|
1712
|
-
}
|
|
1713
|
-
default:
|
|
1714
|
-
fprintf(stderr, "Warp fill error: invalid array dimensionality (%d)\n", ndim);
|
|
1715
|
-
return;
|
|
1953
|
+
wp_free_device(WP_CURRENT_CONTEXT, value_devptr);
|
|
1716
1954
|
}
|
|
1717
|
-
|
|
1718
|
-
wp_free_device(WP_CURRENT_CONTEXT, value_devptr);
|
|
1719
1955
|
}
|
|
1720
1956
|
|
|
1721
1957
|
void wp_array_scan_int_device(uint64_t in, uint64_t out, int len, bool inclusive)
|
|
@@ -2072,14 +2308,15 @@ void wp_cuda_context_synchronize(void* context)
|
|
|
2072
2308
|
|
|
2073
2309
|
check_cu(cuCtxSynchronize_f());
|
|
2074
2310
|
|
|
2075
|
-
if (
|
|
2311
|
+
if (!context)
|
|
2312
|
+
context = get_current_context();
|
|
2313
|
+
|
|
2314
|
+
if (run_deferred_actions(context) > 0)
|
|
2076
2315
|
{
|
|
2077
|
-
// ensure deferred asynchronous
|
|
2316
|
+
// ensure deferred asynchronous operations complete
|
|
2078
2317
|
check_cu(cuCtxSynchronize_f());
|
|
2079
2318
|
}
|
|
2080
2319
|
|
|
2081
|
-
unload_deferred_modules(context);
|
|
2082
|
-
|
|
2083
2320
|
// check_cuda(cudaDeviceGraphMemTrim(wp_cuda_context_get_device_ordinal(context)));
|
|
2084
2321
|
}
|
|
2085
2322
|
|
|
@@ -2514,15 +2751,36 @@ void wp_cuda_stream_synchronize(void* stream)
|
|
|
2514
2751
|
check_cu(cuStreamSynchronize_f(static_cast<CUstream>(stream)));
|
|
2515
2752
|
}
|
|
2516
2753
|
|
|
2517
|
-
void wp_cuda_stream_wait_event(void* stream, void* event)
|
|
2754
|
+
void wp_cuda_stream_wait_event(void* stream, void* event, bool external)
|
|
2518
2755
|
{
|
|
2519
|
-
|
|
2756
|
+
// the external flag can only be used during graph capture
|
|
2757
|
+
if (external && !g_captures.empty() && wp_cuda_stream_is_capturing(stream))
|
|
2758
|
+
{
|
|
2759
|
+
// wait for an external event during graph capture
|
|
2760
|
+
check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), CU_EVENT_WAIT_EXTERNAL));
|
|
2761
|
+
}
|
|
2762
|
+
else
|
|
2763
|
+
{
|
|
2764
|
+
check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), CU_EVENT_WAIT_DEFAULT));
|
|
2765
|
+
}
|
|
2520
2766
|
}
|
|
2521
2767
|
|
|
2522
|
-
void wp_cuda_stream_wait_stream(void* stream, void* other_stream, void* event)
|
|
2768
|
+
void wp_cuda_stream_wait_stream(void* stream, void* other_stream, void* event, bool external)
|
|
2523
2769
|
{
|
|
2524
|
-
|
|
2525
|
-
|
|
2770
|
+
unsigned record_flags = CU_EVENT_RECORD_DEFAULT;
|
|
2771
|
+
unsigned wait_flags = CU_EVENT_WAIT_DEFAULT;
|
|
2772
|
+
|
|
2773
|
+
// the external flag can only be used during graph capture
|
|
2774
|
+
if (external && !g_captures.empty())
|
|
2775
|
+
{
|
|
2776
|
+
if (wp_cuda_stream_is_capturing(other_stream))
|
|
2777
|
+
record_flags = CU_EVENT_RECORD_EXTERNAL;
|
|
2778
|
+
if (wp_cuda_stream_is_capturing(stream))
|
|
2779
|
+
wait_flags = CU_EVENT_WAIT_EXTERNAL;
|
|
2780
|
+
}
|
|
2781
|
+
|
|
2782
|
+
check_cu(cuEventRecordWithFlags_f(static_cast<CUevent>(event), static_cast<CUstream>(other_stream), record_flags));
|
|
2783
|
+
check_cu(cuStreamWaitEvent_f(static_cast<CUstream>(stream), static_cast<CUevent>(event), wait_flags));
|
|
2526
2784
|
}
|
|
2527
2785
|
|
|
2528
2786
|
int wp_cuda_stream_is_capturing(void* stream)
|
|
@@ -2575,11 +2833,12 @@ int wp_cuda_event_query(void* event)
|
|
|
2575
2833
|
return res;
|
|
2576
2834
|
}
|
|
2577
2835
|
|
|
2578
|
-
void wp_cuda_event_record(void* event, void* stream, bool
|
|
2836
|
+
void wp_cuda_event_record(void* event, void* stream, bool external)
|
|
2579
2837
|
{
|
|
2580
|
-
|
|
2838
|
+
// the external flag can only be used during graph capture
|
|
2839
|
+
if (external && !g_captures.empty() && wp_cuda_stream_is_capturing(stream))
|
|
2581
2840
|
{
|
|
2582
|
-
// record
|
|
2841
|
+
// record external event during graph capture (e.g., for timing or when explicitly specified by the user)
|
|
2583
2842
|
check_cu(cuEventRecordWithFlags_f(static_cast<CUevent>(event), static_cast<CUstream>(stream), CU_EVENT_RECORD_EXTERNAL));
|
|
2584
2843
|
}
|
|
2585
2844
|
else
|
|
@@ -2629,7 +2888,7 @@ bool wp_cuda_graph_begin_capture(void* context, void* stream, int external)
|
|
|
2629
2888
|
else
|
|
2630
2889
|
{
|
|
2631
2890
|
// start the capture
|
|
2632
|
-
if (!check_cuda(cudaStreamBeginCapture(cuda_stream,
|
|
2891
|
+
if (!check_cuda(cudaStreamBeginCapture(cuda_stream, cudaStreamCaptureModeThreadLocal)))
|
|
2633
2892
|
return false;
|
|
2634
2893
|
}
|
|
2635
2894
|
|
|
@@ -2673,6 +2932,7 @@ bool wp_cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
|
|
|
2673
2932
|
// get capture info
|
|
2674
2933
|
bool external = capture->external;
|
|
2675
2934
|
uint64_t capture_id = capture->id;
|
|
2935
|
+
std::vector<FreeInfo> tmp_allocs = capture->tmp_allocs;
|
|
2676
2936
|
|
|
2677
2937
|
// clear capture info
|
|
2678
2938
|
stream_info->capture = NULL;
|
|
@@ -2742,15 +3002,17 @@ bool wp_cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
|
|
|
2742
3002
|
unfreed_allocs.push_back(it->first);
|
|
2743
3003
|
}
|
|
2744
3004
|
|
|
2745
|
-
if (!unfreed_allocs.empty())
|
|
3005
|
+
if (!unfreed_allocs.empty() || !tmp_allocs.empty())
|
|
2746
3006
|
{
|
|
2747
3007
|
// Create a user object that will notify us when the instantiated graph is destroyed.
|
|
2748
3008
|
// This works for external captures also, since we wouldn't otherwise know when
|
|
2749
3009
|
// the externally-created graph instance gets deleted.
|
|
2750
3010
|
// This callback is guaranteed to arrive after the graph has finished executing on the device,
|
|
2751
3011
|
// not necessarily when cudaGraphExecDestroy() is called.
|
|
2752
|
-
|
|
3012
|
+
GraphDestroyCallbackInfo* graph_info = new GraphDestroyCallbackInfo;
|
|
3013
|
+
graph_info->context = context ? context : get_current_context();
|
|
2753
3014
|
graph_info->unfreed_allocs = unfreed_allocs;
|
|
3015
|
+
graph_info->tmp_allocs = tmp_allocs;
|
|
2754
3016
|
cudaUserObject_t user_object;
|
|
2755
3017
|
check_cuda(cudaUserObjectCreate(&user_object, graph_info, on_graph_destroy, 1, cudaUserObjectNoDestructorSync));
|
|
2756
3018
|
check_cuda(cudaGraphRetainUserObject(graph, user_object, 1, cudaGraphUserObjectMove));
|
|
@@ -2774,8 +3036,7 @@ bool wp_cuda_graph_end_capture(void* context, void* stream, void** graph_ret)
|
|
|
2774
3036
|
// process deferred free list if no more captures are ongoing
|
|
2775
3037
|
if (g_captures.empty())
|
|
2776
3038
|
{
|
|
2777
|
-
|
|
2778
|
-
unload_deferred_modules();
|
|
3039
|
+
run_deferred_actions();
|
|
2779
3040
|
}
|
|
2780
3041
|
|
|
2781
3042
|
if (graph_ret)
|
|
@@ -2996,7 +3257,7 @@ bool wp_cuda_graph_resume_capture(void* context, void* stream, void* graph)
|
|
|
2996
3257
|
leaf_nodes.data(),
|
|
2997
3258
|
nullptr,
|
|
2998
3259
|
leaf_nodes.size(),
|
|
2999
|
-
|
|
3260
|
+
cudaStreamCaptureModeThreadLocal)))
|
|
3000
3261
|
return false;
|
|
3001
3262
|
|
|
3002
3263
|
return true;
|
|
@@ -3455,16 +3716,38 @@ bool wp_cuda_graph_launch(void* graph_exec, void* stream)
|
|
|
3455
3716
|
|
|
3456
3717
|
bool wp_cuda_graph_destroy(void* context, void* graph)
|
|
3457
3718
|
{
|
|
3458
|
-
|
|
3459
|
-
|
|
3460
|
-
|
|
3719
|
+
// ensure there are no graph captures in progress
|
|
3720
|
+
if (g_captures.empty())
|
|
3721
|
+
{
|
|
3722
|
+
ContextGuard guard(context);
|
|
3723
|
+
return check_cuda(cudaGraphDestroy((cudaGraph_t)graph));
|
|
3724
|
+
}
|
|
3725
|
+
else
|
|
3726
|
+
{
|
|
3727
|
+
GraphDestroyInfo info;
|
|
3728
|
+
info.context = context ? context : get_current_context();
|
|
3729
|
+
info.graph = graph;
|
|
3730
|
+
g_deferred_graph_list.push_back(info);
|
|
3731
|
+
return true;
|
|
3732
|
+
}
|
|
3461
3733
|
}
|
|
3462
3734
|
|
|
3463
3735
|
bool wp_cuda_graph_exec_destroy(void* context, void* graph_exec)
|
|
3464
3736
|
{
|
|
3465
|
-
|
|
3466
|
-
|
|
3467
|
-
|
|
3737
|
+
// ensure there are no graph captures in progress
|
|
3738
|
+
if (g_captures.empty())
|
|
3739
|
+
{
|
|
3740
|
+
ContextGuard guard(context);
|
|
3741
|
+
return check_cuda(cudaGraphExecDestroy((cudaGraphExec_t)graph_exec));
|
|
3742
|
+
}
|
|
3743
|
+
else
|
|
3744
|
+
{
|
|
3745
|
+
GraphDestroyInfo info;
|
|
3746
|
+
info.context = context ? context : get_current_context();
|
|
3747
|
+
info.graph_exec = graph_exec;
|
|
3748
|
+
g_deferred_graph_list.push_back(info);
|
|
3749
|
+
return true;
|
|
3750
|
+
}
|
|
3468
3751
|
}
|
|
3469
3752
|
|
|
3470
3753
|
bool write_file(const char* data, size_t size, std::string filename, const char* mode)
|
|
@@ -4317,17 +4600,5 @@ void wp_cuda_timing_end(timing_result_t* results, int size)
|
|
|
4317
4600
|
g_cuda_timing_state = parent_state;
|
|
4318
4601
|
}
|
|
4319
4602
|
|
|
4320
|
-
// impl. files
|
|
4321
|
-
#include "bvh.cu"
|
|
4322
|
-
#include "mesh.cu"
|
|
4323
|
-
#include "sort.cu"
|
|
4324
|
-
#include "hashgrid.cu"
|
|
4325
|
-
#include "reduce.cu"
|
|
4326
|
-
#include "runlength_encode.cu"
|
|
4327
|
-
#include "scan.cu"
|
|
4328
|
-
#include "sparse.cu"
|
|
4329
|
-
#include "volume.cu"
|
|
4330
|
-
#include "volume_builder.cu"
|
|
4331
|
-
|
|
4332
4603
|
//#include "spline.inl"
|
|
4333
4604
|
//#include "volume.inl"
|