warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +10 -4
- warp/__init__.pyi +1 -0
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +5 -3
- warp/build_dll.py +29 -9
- warp/builtins.py +868 -507
- warp/codegen.py +1074 -638
- warp/config.py +3 -3
- warp/constants.py +6 -0
- warp/context.py +715 -222
- warp/fabric.py +326 -0
- warp/fem/__init__.py +27 -0
- warp/fem/cache.py +389 -0
- warp/fem/dirichlet.py +181 -0
- warp/fem/domain.py +263 -0
- warp/fem/field/__init__.py +101 -0
- warp/fem/field/field.py +149 -0
- warp/fem/field/nodal_field.py +299 -0
- warp/fem/field/restriction.py +21 -0
- warp/fem/field/test.py +181 -0
- warp/fem/field/trial.py +183 -0
- warp/fem/geometry/__init__.py +19 -0
- warp/fem/geometry/closest_point.py +70 -0
- warp/fem/geometry/deformed_geometry.py +271 -0
- warp/fem/geometry/element.py +744 -0
- warp/fem/geometry/geometry.py +186 -0
- warp/fem/geometry/grid_2d.py +373 -0
- warp/fem/geometry/grid_3d.py +435 -0
- warp/fem/geometry/hexmesh.py +953 -0
- warp/fem/geometry/partition.py +376 -0
- warp/fem/geometry/quadmesh_2d.py +532 -0
- warp/fem/geometry/tetmesh.py +840 -0
- warp/fem/geometry/trimesh_2d.py +577 -0
- warp/fem/integrate.py +1616 -0
- warp/fem/operator.py +191 -0
- warp/fem/polynomial.py +213 -0
- warp/fem/quadrature/__init__.py +2 -0
- warp/fem/quadrature/pic_quadrature.py +245 -0
- warp/fem/quadrature/quadrature.py +294 -0
- warp/fem/space/__init__.py +292 -0
- warp/fem/space/basis_space.py +489 -0
- warp/fem/space/collocated_function_space.py +105 -0
- warp/fem/space/dof_mapper.py +236 -0
- warp/fem/space/function_space.py +145 -0
- warp/fem/space/grid_2d_function_space.py +267 -0
- warp/fem/space/grid_3d_function_space.py +306 -0
- warp/fem/space/hexmesh_function_space.py +352 -0
- warp/fem/space/partition.py +350 -0
- warp/fem/space/quadmesh_2d_function_space.py +369 -0
- warp/fem/space/restriction.py +160 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +738 -0
- warp/fem/space/shape/shape_function.py +103 -0
- warp/fem/space/shape/square_shape_function.py +611 -0
- warp/fem/space/shape/tet_shape_function.py +567 -0
- warp/fem/space/shape/triangle_shape_function.py +429 -0
- warp/fem/space/tetmesh_function_space.py +292 -0
- warp/fem/space/topology.py +295 -0
- warp/fem/space/trimesh_2d_function_space.py +221 -0
- warp/fem/types.py +77 -0
- warp/fem/utils.py +495 -0
- warp/native/array.h +147 -44
- warp/native/builtin.h +122 -149
- warp/native/bvh.cpp +73 -325
- warp/native/bvh.cu +406 -23
- warp/native/bvh.h +34 -43
- warp/native/clang/clang.cpp +13 -8
- warp/native/crt.h +2 -0
- warp/native/cuda_crt.h +5 -0
- warp/native/cuda_util.cpp +15 -3
- warp/native/cuda_util.h +3 -1
- warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
- warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
- warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
- warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
- warp/native/cutlass/tools/library/scripts/library.py +799 -0
- warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
- warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
- warp/native/cutlass/tools/library/scripts/rt.py +796 -0
- warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
- warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
- warp/native/cutlass_gemm.cu +5 -3
- warp/native/exports.h +1240 -952
- warp/native/fabric.h +228 -0
- warp/native/hashgrid.cpp +4 -4
- warp/native/hashgrid.h +22 -2
- warp/native/intersect.h +22 -7
- warp/native/intersect_adj.h +8 -8
- warp/native/intersect_tri.h +1 -1
- warp/native/marching.cu +157 -161
- warp/native/mat.h +80 -19
- warp/native/matnn.h +2 -2
- warp/native/mesh.cpp +33 -108
- warp/native/mesh.cu +114 -23
- warp/native/mesh.h +446 -46
- warp/native/noise.h +272 -329
- warp/native/quat.h +51 -8
- warp/native/rand.h +45 -35
- warp/native/range.h +6 -2
- warp/native/reduce.cpp +1 -1
- warp/native/reduce.cu +10 -12
- warp/native/runlength_encode.cu +6 -10
- warp/native/scan.cu +8 -11
- warp/native/sparse.cpp +4 -4
- warp/native/sparse.cu +164 -154
- warp/native/spatial.h +2 -2
- warp/native/temp_buffer.h +14 -30
- warp/native/vec.h +107 -23
- warp/native/volume.h +120 -0
- warp/native/warp.cpp +560 -30
- warp/native/warp.cu +431 -44
- warp/native/warp.h +13 -4
- warp/optim/__init__.py +1 -0
- warp/optim/linear.py +922 -0
- warp/optim/sgd.py +92 -0
- warp/render/render_opengl.py +335 -119
- warp/render/render_usd.py +11 -11
- warp/sim/__init__.py +2 -2
- warp/sim/articulation.py +385 -185
- warp/sim/collide.py +8 -0
- warp/sim/import_mjcf.py +297 -106
- warp/sim/import_urdf.py +389 -210
- warp/sim/import_usd.py +198 -97
- warp/sim/inertia.py +17 -18
- warp/sim/integrator_euler.py +14 -8
- warp/sim/integrator_xpbd.py +158 -16
- warp/sim/model.py +795 -291
- warp/sim/render.py +3 -3
- warp/sim/utils.py +3 -0
- warp/sparse.py +640 -150
- warp/stubs.py +606 -267
- warp/tape.py +61 -10
- warp/tests/__main__.py +3 -6
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/disabled_kinematics.py +239 -0
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +103 -106
- warp/tests/test_arithmetic.py +128 -74
- warp/tests/test_array.py +212 -97
- warp/tests/test_array_reduce.py +57 -23
- warp/tests/test_atomic.py +64 -28
- warp/tests/test_bool.py +99 -0
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +42 -18
- warp/tests/test_closest_point_edge_edge.py +54 -57
- warp/tests/test_codegen.py +208 -130
- warp/tests/test_compile_consts.py +28 -20
- warp/tests/test_conditional.py +108 -24
- warp/tests/test_copy.py +10 -12
- warp/tests/test_ctypes.py +112 -88
- warp/tests/test_dense.py +21 -14
- warp/tests/test_devices.py +98 -0
- warp/tests/test_dlpack.py +75 -75
- warp/tests/test_examples.py +277 -0
- warp/tests/test_fabricarray.py +955 -0
- warp/tests/test_fast_math.py +15 -11
- warp/tests/test_fem.py +1271 -0
- warp/tests/test_fp16.py +53 -19
- warp/tests/test_func.py +187 -86
- warp/tests/test_generics.py +194 -49
- warp/tests/test_grad.py +178 -109
- warp/tests/test_grad_customs.py +176 -0
- warp/tests/test_hash_grid.py +52 -37
- warp/tests/test_import.py +10 -23
- warp/tests/test_indexedarray.py +32 -31
- warp/tests/test_intersect.py +18 -9
- warp/tests/test_large.py +141 -0
- warp/tests/test_launch.py +14 -41
- warp/tests/test_lerp.py +64 -65
- warp/tests/test_linear_solvers.py +154 -0
- warp/tests/test_lvalue.py +493 -0
- warp/tests/test_marching_cubes.py +12 -13
- warp/tests/test_mat.py +517 -2898
- warp/tests/test_mat_lite.py +115 -0
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +103 -9
- warp/tests/test_matmul.py +305 -69
- warp/tests/test_matmul_lite.py +410 -0
- warp/tests/test_mesh.py +71 -14
- warp/tests/test_mesh_query_aabb.py +41 -25
- warp/tests/test_mesh_query_point.py +140 -22
- warp/tests/test_mesh_query_ray.py +39 -22
- warp/tests/test_mlp.py +30 -22
- warp/tests/test_model.py +92 -89
- warp/tests/test_modules_lite.py +39 -0
- warp/tests/test_multigpu.py +88 -114
- warp/tests/test_noise.py +12 -11
- warp/tests/test_operators.py +16 -20
- warp/tests/test_options.py +11 -11
- warp/tests/test_pinned.py +17 -18
- warp/tests/test_print.py +32 -11
- warp/tests/test_quat.py +275 -129
- warp/tests/test_rand.py +18 -16
- warp/tests/test_reload.py +38 -34
- warp/tests/test_rounding.py +50 -43
- warp/tests/test_runlength_encode.py +168 -20
- warp/tests/test_smoothstep.py +9 -11
- warp/tests/test_snippet.py +143 -0
- warp/tests/test_sparse.py +261 -63
- warp/tests/test_spatial.py +276 -243
- warp/tests/test_streams.py +110 -85
- warp/tests/test_struct.py +268 -63
- warp/tests/test_tape.py +39 -21
- warp/tests/test_torch.py +118 -89
- warp/tests/test_transient_module.py +12 -13
- warp/tests/test_types.py +614 -0
- warp/tests/test_utils.py +494 -0
- warp/tests/test_vec.py +354 -2050
- warp/tests/test_vec_lite.py +73 -0
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +457 -293
- warp/tests/test_volume_write.py +124 -134
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +341 -0
- warp/tests/unittest_utils.py +568 -0
- warp/tests/unused_test_misc.py +71 -0
- warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
- warp/thirdparty/appdirs.py +36 -45
- warp/thirdparty/unittest_parallel.py +549 -0
- warp/torch.py +9 -6
- warp/types.py +1089 -366
- warp/utils.py +93 -387
- warp_lang-0.11.0.dist-info/METADATA +238 -0
- warp_lang-0.11.0.dist-info/RECORD +332 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
- warp/tests/test_all.py +0 -219
- warp/tests/test_array_scan.py +0 -60
- warp/tests/test_base.py +0 -208
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- warp_lang-0.10.1.dist-info/METADATA +0 -21
- warp_lang-0.10.1.dist-info/RECORD +0 -188
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/native/warp.cu
CHANGED
|
@@ -73,10 +73,15 @@ struct DeviceInfo
|
|
|
73
73
|
static constexpr int kNameLen = 128;
|
|
74
74
|
|
|
75
75
|
CUdevice device = -1;
|
|
76
|
+
CUuuid uuid = {0};
|
|
76
77
|
int ordinal = -1;
|
|
78
|
+
int pci_domain_id = -1;
|
|
79
|
+
int pci_bus_id = -1;
|
|
80
|
+
int pci_device_id = -1;
|
|
77
81
|
char name[kNameLen] = "";
|
|
78
82
|
int arch = 0;
|
|
79
83
|
int is_uva = 0;
|
|
84
|
+
int is_memory_pool_supported = 0;
|
|
80
85
|
};
|
|
81
86
|
|
|
82
87
|
struct ContextInfo
|
|
@@ -125,7 +130,12 @@ int cuda_init()
|
|
|
125
130
|
g_devices[i].device = device;
|
|
126
131
|
g_devices[i].ordinal = i;
|
|
127
132
|
check_cu(cuDeviceGetName_f(g_devices[i].name, DeviceInfo::kNameLen, device));
|
|
133
|
+
check_cu(cuDeviceGetUuid_f(&g_devices[i].uuid, device));
|
|
134
|
+
check_cu(cuDeviceGetAttribute_f(&g_devices[i].pci_domain_id, CU_DEVICE_ATTRIBUTE_PCI_DOMAIN_ID, device));
|
|
135
|
+
check_cu(cuDeviceGetAttribute_f(&g_devices[i].pci_bus_id, CU_DEVICE_ATTRIBUTE_PCI_BUS_ID, device));
|
|
136
|
+
check_cu(cuDeviceGetAttribute_f(&g_devices[i].pci_device_id, CU_DEVICE_ATTRIBUTE_PCI_DEVICE_ID, device));
|
|
128
137
|
check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_uva, CU_DEVICE_ATTRIBUTE_UNIFIED_ADDRESSING, device));
|
|
138
|
+
check_cu(cuDeviceGetAttribute_f(&g_devices[i].is_memory_pool_supported, CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, device));
|
|
129
139
|
int major = 0;
|
|
130
140
|
int minor = 0;
|
|
131
141
|
check_cu(cuDeviceGetAttribute_f(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device));
|
|
@@ -216,6 +226,26 @@ void* alloc_device(void* context, size_t s)
|
|
|
216
226
|
return ptr;
|
|
217
227
|
}
|
|
218
228
|
|
|
229
|
+
void* alloc_temp_device(void* context, size_t s)
|
|
230
|
+
{
|
|
231
|
+
// "cudaMallocAsync ignores the current device/context when determining where the allocation will reside. Instead,
|
|
232
|
+
// cudaMallocAsync determines the resident device based on the specified memory pool or the supplied stream."
|
|
233
|
+
ContextGuard guard(context);
|
|
234
|
+
|
|
235
|
+
void* ptr;
|
|
236
|
+
|
|
237
|
+
if (cuda_context_is_memory_pool_supported(context))
|
|
238
|
+
{
|
|
239
|
+
check_cuda(cudaMallocAsync(&ptr, s, get_current_stream()));
|
|
240
|
+
}
|
|
241
|
+
else
|
|
242
|
+
{
|
|
243
|
+
check_cuda(cudaMalloc(&ptr, s));
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
return ptr;
|
|
247
|
+
}
|
|
248
|
+
|
|
219
249
|
void free_device(void* context, void* ptr)
|
|
220
250
|
{
|
|
221
251
|
ContextGuard guard(context);
|
|
@@ -223,6 +253,20 @@ void free_device(void* context, void* ptr)
|
|
|
223
253
|
check_cuda(cudaFree(ptr));
|
|
224
254
|
}
|
|
225
255
|
|
|
256
|
+
void free_temp_device(void* context, void* ptr)
|
|
257
|
+
{
|
|
258
|
+
ContextGuard guard(context);
|
|
259
|
+
|
|
260
|
+
if (cuda_context_is_memory_pool_supported(context))
|
|
261
|
+
{
|
|
262
|
+
check_cuda(cudaFreeAsync(ptr, get_current_stream()));
|
|
263
|
+
}
|
|
264
|
+
else
|
|
265
|
+
{
|
|
266
|
+
check_cuda(cudaFree(ptr));
|
|
267
|
+
}
|
|
268
|
+
}
|
|
269
|
+
|
|
226
270
|
void memcpy_h2d(void* context, void* dest, void* src, size_t n)
|
|
227
271
|
{
|
|
228
272
|
ContextGuard guard(context);
|
|
@@ -266,7 +310,7 @@ void memset_device(void* context, void* dest, int value, size_t n)
|
|
|
266
310
|
{
|
|
267
311
|
ContextGuard guard(context);
|
|
268
312
|
|
|
269
|
-
if ((n%4) > 0)
|
|
313
|
+
if (true)// ((n%4) > 0)
|
|
270
314
|
{
|
|
271
315
|
// for unaligned lengths fallback to CUDA memset
|
|
272
316
|
check_cuda(cudaMemsetAsync(dest, value, n, get_current_stream()));
|
|
@@ -448,6 +492,125 @@ static __global__ void array_copy_4d_kernel(void* dst, const void* src,
|
|
|
448
492
|
}
|
|
449
493
|
|
|
450
494
|
|
|
495
|
+
static __global__ void array_copy_from_fabric_kernel(wp::fabricarray_t<void> src,
|
|
496
|
+
void* dst_data, int dst_stride, const int* dst_indices,
|
|
497
|
+
int elem_size)
|
|
498
|
+
{
|
|
499
|
+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
500
|
+
|
|
501
|
+
if (tid < src.size)
|
|
502
|
+
{
|
|
503
|
+
int dst_idx = dst_indices ? dst_indices[tid] : tid;
|
|
504
|
+
void* dst_ptr = (char*)dst_data + dst_idx * dst_stride;
|
|
505
|
+
const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
|
|
506
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
507
|
+
}
|
|
508
|
+
}
|
|
509
|
+
|
|
510
|
+
static __global__ void array_copy_from_fabric_indexed_kernel(wp::indexedfabricarray_t<void> src,
|
|
511
|
+
void* dst_data, int dst_stride, const int* dst_indices,
|
|
512
|
+
int elem_size)
|
|
513
|
+
{
|
|
514
|
+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
515
|
+
|
|
516
|
+
if (tid < src.size)
|
|
517
|
+
{
|
|
518
|
+
int src_index = src.indices[tid];
|
|
519
|
+
int dst_idx = dst_indices ? dst_indices[tid] : tid;
|
|
520
|
+
void* dst_ptr = (char*)dst_data + dst_idx * dst_stride;
|
|
521
|
+
const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
|
|
522
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
523
|
+
}
|
|
524
|
+
}
|
|
525
|
+
|
|
526
|
+
static __global__ void array_copy_to_fabric_kernel(wp::fabricarray_t<void> dst,
|
|
527
|
+
const void* src_data, int src_stride, const int* src_indices,
|
|
528
|
+
int elem_size)
|
|
529
|
+
{
|
|
530
|
+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
531
|
+
|
|
532
|
+
if (tid < dst.size)
|
|
533
|
+
{
|
|
534
|
+
int src_idx = src_indices ? src_indices[tid] : tid;
|
|
535
|
+
const void* src_ptr = (const char*)src_data + src_idx * src_stride;
|
|
536
|
+
void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
|
|
537
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
538
|
+
}
|
|
539
|
+
}
|
|
540
|
+
|
|
541
|
+
static __global__ void array_copy_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst,
|
|
542
|
+
const void* src_data, int src_stride, const int* src_indices,
|
|
543
|
+
int elem_size)
|
|
544
|
+
{
|
|
545
|
+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
546
|
+
|
|
547
|
+
if (tid < dst.size)
|
|
548
|
+
{
|
|
549
|
+
int src_idx = src_indices ? src_indices[tid] : tid;
|
|
550
|
+
const void* src_ptr = (const char*)src_data + src_idx * src_stride;
|
|
551
|
+
int dst_idx = dst.indices[tid];
|
|
552
|
+
void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_idx, elem_size);
|
|
553
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
554
|
+
}
|
|
555
|
+
}
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
static __global__ void array_copy_fabric_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::fabricarray_t<void> src, int elem_size)
|
|
559
|
+
{
|
|
560
|
+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
561
|
+
|
|
562
|
+
if (tid < dst.size)
|
|
563
|
+
{
|
|
564
|
+
const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
|
|
565
|
+
void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
|
|
566
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
567
|
+
}
|
|
568
|
+
}
|
|
569
|
+
|
|
570
|
+
|
|
571
|
+
static __global__ void array_copy_fabric_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::fabricarray_t<void> src, int elem_size)
|
|
572
|
+
{
|
|
573
|
+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
574
|
+
|
|
575
|
+
if (tid < dst.size)
|
|
576
|
+
{
|
|
577
|
+
const void* src_ptr = fabricarray_element_ptr(src, tid, elem_size);
|
|
578
|
+
int dst_index = dst.indices[tid];
|
|
579
|
+
void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_index, elem_size);
|
|
580
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
581
|
+
}
|
|
582
|
+
}
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
static __global__ void array_copy_fabric_indexed_to_fabric_kernel(wp::fabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, int elem_size)
|
|
586
|
+
{
|
|
587
|
+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
588
|
+
|
|
589
|
+
if (tid < dst.size)
|
|
590
|
+
{
|
|
591
|
+
int src_index = src.indices[tid];
|
|
592
|
+
const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
|
|
593
|
+
void* dst_ptr = fabricarray_element_ptr(dst, tid, elem_size);
|
|
594
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
595
|
+
}
|
|
596
|
+
}
|
|
597
|
+
|
|
598
|
+
|
|
599
|
+
static __global__ void array_copy_fabric_indexed_to_fabric_indexed_kernel(wp::indexedfabricarray_t<void> dst, wp::indexedfabricarray_t<void> src, int elem_size)
|
|
600
|
+
{
|
|
601
|
+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
602
|
+
|
|
603
|
+
if (tid < dst.size)
|
|
604
|
+
{
|
|
605
|
+
int src_index = src.indices[tid];
|
|
606
|
+
int dst_index = dst.indices[tid];
|
|
607
|
+
const void* src_ptr = fabricarray_element_ptr(src.fa, src_index, elem_size);
|
|
608
|
+
void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_index, elem_size);
|
|
609
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
610
|
+
}
|
|
611
|
+
}
|
|
612
|
+
|
|
613
|
+
|
|
451
614
|
WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_type, int src_type, int elem_size)
|
|
452
615
|
{
|
|
453
616
|
if (!src || !dst)
|
|
@@ -466,6 +629,12 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
|
|
|
466
629
|
const int*const* src_indices = NULL;
|
|
467
630
|
const int*const* dst_indices = NULL;
|
|
468
631
|
|
|
632
|
+
const wp::fabricarray_t<void>* src_fabricarray = NULL;
|
|
633
|
+
wp::fabricarray_t<void>* dst_fabricarray = NULL;
|
|
634
|
+
|
|
635
|
+
const wp::indexedfabricarray_t<void>* src_indexedfabricarray = NULL;
|
|
636
|
+
wp::indexedfabricarray_t<void>* dst_indexedfabricarray = NULL;
|
|
637
|
+
|
|
469
638
|
const int* null_indices[wp::ARRAY_MAX_DIMS] = { NULL };
|
|
470
639
|
|
|
471
640
|
if (src_type == wp::ARRAY_TYPE_REGULAR)
|
|
@@ -487,9 +656,19 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
|
|
|
487
656
|
src_strides = src_arr.arr.strides;
|
|
488
657
|
src_indices = src_arr.indices;
|
|
489
658
|
}
|
|
659
|
+
else if (src_type == wp::ARRAY_TYPE_FABRIC)
|
|
660
|
+
{
|
|
661
|
+
src_fabricarray = static_cast<const wp::fabricarray_t<void>*>(src);
|
|
662
|
+
src_ndim = 1;
|
|
663
|
+
}
|
|
664
|
+
else if (src_type == wp::ARRAY_TYPE_FABRIC_INDEXED)
|
|
665
|
+
{
|
|
666
|
+
src_indexedfabricarray = static_cast<const wp::indexedfabricarray_t<void>*>(src);
|
|
667
|
+
src_ndim = 1;
|
|
668
|
+
}
|
|
490
669
|
else
|
|
491
670
|
{
|
|
492
|
-
fprintf(stderr, "Warp error: Invalid array type (%d)\n", src_type);
|
|
671
|
+
fprintf(stderr, "Warp copy error: Invalid array type (%d)\n", src_type);
|
|
493
672
|
return 0;
|
|
494
673
|
}
|
|
495
674
|
|
|
@@ -512,33 +691,149 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
|
|
|
512
691
|
dst_strides = dst_arr.arr.strides;
|
|
513
692
|
dst_indices = dst_arr.indices;
|
|
514
693
|
}
|
|
694
|
+
else if (dst_type == wp::ARRAY_TYPE_FABRIC)
|
|
695
|
+
{
|
|
696
|
+
dst_fabricarray = static_cast<wp::fabricarray_t<void>*>(dst);
|
|
697
|
+
dst_ndim = 1;
|
|
698
|
+
}
|
|
699
|
+
else if (dst_type == wp::ARRAY_TYPE_FABRIC_INDEXED)
|
|
700
|
+
{
|
|
701
|
+
dst_indexedfabricarray = static_cast<wp::indexedfabricarray_t<void>*>(dst);
|
|
702
|
+
dst_ndim = 1;
|
|
703
|
+
}
|
|
515
704
|
else
|
|
516
705
|
{
|
|
517
|
-
fprintf(stderr, "Warp error: Invalid array type (%d)\n", dst_type);
|
|
706
|
+
fprintf(stderr, "Warp copy error: Invalid array type (%d)\n", dst_type);
|
|
518
707
|
return 0;
|
|
519
708
|
}
|
|
520
709
|
|
|
521
710
|
if (src_ndim != dst_ndim)
|
|
522
711
|
{
|
|
523
|
-
fprintf(stderr, "Warp error: Incompatible array dimensionalities (%d and %d)\n", src_ndim, dst_ndim);
|
|
712
|
+
fprintf(stderr, "Warp copy error: Incompatible array dimensionalities (%d and %d)\n", src_ndim, dst_ndim);
|
|
524
713
|
return 0;
|
|
525
714
|
}
|
|
526
715
|
|
|
527
|
-
|
|
528
|
-
|
|
716
|
+
ContextGuard guard(context);
|
|
717
|
+
|
|
718
|
+
// handle fabric arrays
|
|
719
|
+
if (dst_fabricarray)
|
|
720
|
+
{
|
|
721
|
+
size_t n = dst_fabricarray->size;
|
|
722
|
+
if (src_fabricarray)
|
|
723
|
+
{
|
|
724
|
+
// copy from fabric to fabric
|
|
725
|
+
if (src_fabricarray->size != n)
|
|
726
|
+
{
|
|
727
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
728
|
+
return 0;
|
|
729
|
+
}
|
|
730
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_to_fabric_kernel, n,
|
|
731
|
+
(*dst_fabricarray, *src_fabricarray, elem_size));
|
|
732
|
+
return n;
|
|
733
|
+
}
|
|
734
|
+
else if (src_indexedfabricarray)
|
|
735
|
+
{
|
|
736
|
+
// copy from fabric indexed to fabric
|
|
737
|
+
if (src_indexedfabricarray->size != n)
|
|
738
|
+
{
|
|
739
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
740
|
+
return 0;
|
|
741
|
+
}
|
|
742
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_indexed_to_fabric_kernel, n,
|
|
743
|
+
(*dst_fabricarray, *src_indexedfabricarray, elem_size));
|
|
744
|
+
return n;
|
|
745
|
+
}
|
|
746
|
+
else
|
|
747
|
+
{
|
|
748
|
+
// copy to fabric
|
|
749
|
+
if (size_t(src_shape[0]) != n)
|
|
750
|
+
{
|
|
751
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
752
|
+
return 0;
|
|
753
|
+
}
|
|
754
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_to_fabric_kernel, n,
|
|
755
|
+
(*dst_fabricarray, src_data, src_strides[0], src_indices[0], elem_size));
|
|
756
|
+
return n;
|
|
757
|
+
}
|
|
758
|
+
}
|
|
759
|
+
if (dst_indexedfabricarray)
|
|
760
|
+
{
|
|
761
|
+
size_t n = dst_indexedfabricarray->size;
|
|
762
|
+
if (src_fabricarray)
|
|
763
|
+
{
|
|
764
|
+
// copy from fabric to fabric indexed
|
|
765
|
+
if (src_fabricarray->size != n)
|
|
766
|
+
{
|
|
767
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
768
|
+
return 0;
|
|
769
|
+
}
|
|
770
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_to_fabric_indexed_kernel, n,
|
|
771
|
+
(*dst_indexedfabricarray, *src_fabricarray, elem_size));
|
|
772
|
+
return n;
|
|
773
|
+
}
|
|
774
|
+
else if (src_indexedfabricarray)
|
|
775
|
+
{
|
|
776
|
+
// copy from fabric indexed to fabric indexed
|
|
777
|
+
if (src_indexedfabricarray->size != n)
|
|
778
|
+
{
|
|
779
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
780
|
+
return 0;
|
|
781
|
+
}
|
|
782
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_fabric_indexed_to_fabric_indexed_kernel, n,
|
|
783
|
+
(*dst_indexedfabricarray, *src_indexedfabricarray, elem_size));
|
|
784
|
+
return n;
|
|
785
|
+
}
|
|
786
|
+
else
|
|
787
|
+
{
|
|
788
|
+
// copy to fabric indexed
|
|
789
|
+
if (size_t(src_shape[0]) != n)
|
|
790
|
+
{
|
|
791
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
792
|
+
return 0;
|
|
793
|
+
}
|
|
794
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_to_fabric_indexed_kernel, n,
|
|
795
|
+
(*dst_indexedfabricarray, src_data, src_strides[0], src_indices[0], elem_size));
|
|
796
|
+
return n;
|
|
797
|
+
}
|
|
798
|
+
}
|
|
799
|
+
else if (src_fabricarray)
|
|
800
|
+
{
|
|
801
|
+
// copy from fabric
|
|
802
|
+
size_t n = src_fabricarray->size;
|
|
803
|
+
if (size_t(dst_shape[0]) != n)
|
|
804
|
+
{
|
|
805
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
806
|
+
return 0;
|
|
807
|
+
}
|
|
808
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_from_fabric_kernel, n,
|
|
809
|
+
(*src_fabricarray, dst_data, dst_strides[0], dst_indices[0], elem_size));
|
|
810
|
+
return n;
|
|
811
|
+
}
|
|
812
|
+
else if (src_indexedfabricarray)
|
|
813
|
+
{
|
|
814
|
+
// copy from fabric indexed
|
|
815
|
+
size_t n = src_indexedfabricarray->size;
|
|
816
|
+
if (size_t(dst_shape[0]) != n)
|
|
817
|
+
{
|
|
818
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
819
|
+
return 0;
|
|
820
|
+
}
|
|
821
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_from_fabric_indexed_kernel, n,
|
|
822
|
+
(*src_indexedfabricarray, dst_data, dst_strides[0], dst_indices[0], elem_size));
|
|
823
|
+
return n;
|
|
824
|
+
}
|
|
529
825
|
|
|
826
|
+
size_t n = 1;
|
|
530
827
|
for (int i = 0; i < src_ndim; i++)
|
|
531
828
|
{
|
|
532
829
|
if (src_shape[i] != dst_shape[i])
|
|
533
830
|
{
|
|
534
|
-
fprintf(stderr, "Warp error: Incompatible array shapes\n");
|
|
831
|
+
fprintf(stderr, "Warp copy error: Incompatible array shapes\n");
|
|
535
832
|
return 0;
|
|
536
833
|
}
|
|
537
834
|
n *= src_shape[i];
|
|
538
835
|
}
|
|
539
836
|
|
|
540
|
-
ContextGuard guard(context);
|
|
541
|
-
|
|
542
837
|
switch (src_ndim)
|
|
543
838
|
{
|
|
544
839
|
case 1:
|
|
@@ -547,13 +842,6 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
|
|
|
547
842
|
dst_strides[0], src_strides[0],
|
|
548
843
|
dst_indices[0], src_indices[0],
|
|
549
844
|
src_shape[0], elem_size));
|
|
550
|
-
if (has_grad)
|
|
551
|
-
{
|
|
552
|
-
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_1d_kernel, n, (dst_grad, src_grad,
|
|
553
|
-
dst_strides[0], src_strides[0],
|
|
554
|
-
dst_indices[0], src_indices[0],
|
|
555
|
-
src_shape[0], elem_size));
|
|
556
|
-
}
|
|
557
845
|
break;
|
|
558
846
|
}
|
|
559
847
|
case 2:
|
|
@@ -568,13 +856,6 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
|
|
|
568
856
|
dst_strides_v, src_strides_v,
|
|
569
857
|
dst_indices_v, src_indices_v,
|
|
570
858
|
shape_v, elem_size));
|
|
571
|
-
if (has_grad)
|
|
572
|
-
{
|
|
573
|
-
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_2d_kernel, n, (dst_grad, src_grad,
|
|
574
|
-
dst_strides_v, src_strides_v,
|
|
575
|
-
dst_indices_v, src_indices_v,
|
|
576
|
-
shape_v, elem_size));
|
|
577
|
-
}
|
|
578
859
|
break;
|
|
579
860
|
}
|
|
580
861
|
case 3:
|
|
@@ -589,13 +870,6 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
|
|
|
589
870
|
dst_strides_v, src_strides_v,
|
|
590
871
|
dst_indices_v, src_indices_v,
|
|
591
872
|
shape_v, elem_size));
|
|
592
|
-
if (has_grad)
|
|
593
|
-
{
|
|
594
|
-
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_3d_kernel, n, (dst_grad, src_grad,
|
|
595
|
-
dst_strides_v, src_strides_v,
|
|
596
|
-
dst_indices_v, src_indices_v,
|
|
597
|
-
shape_v, elem_size));
|
|
598
|
-
}
|
|
599
873
|
break;
|
|
600
874
|
}
|
|
601
875
|
case 4:
|
|
@@ -610,17 +884,10 @@ WP_API size_t array_copy_device(void* context, void* dst, void* src, int dst_typ
|
|
|
610
884
|
dst_strides_v, src_strides_v,
|
|
611
885
|
dst_indices_v, src_indices_v,
|
|
612
886
|
shape_v, elem_size));
|
|
613
|
-
if (has_grad)
|
|
614
|
-
{
|
|
615
|
-
wp_launch_device(WP_CURRENT_CONTEXT, array_copy_4d_kernel, n, (dst_grad, src_grad,
|
|
616
|
-
dst_strides_v, src_strides_v,
|
|
617
|
-
dst_indices_v, src_indices_v,
|
|
618
|
-
shape_v, elem_size));
|
|
619
|
-
}
|
|
620
887
|
break;
|
|
621
888
|
}
|
|
622
889
|
default:
|
|
623
|
-
fprintf(stderr, "Warp error: invalid array dimensionality (%d)\n", src_ndim);
|
|
890
|
+
fprintf(stderr, "Warp copy error: invalid array dimensionality (%d)\n", src_ndim);
|
|
624
891
|
return 0;
|
|
625
892
|
}
|
|
626
893
|
|
|
@@ -717,6 +984,32 @@ static __global__ void array_fill_4d_kernel(void* data,
|
|
|
717
984
|
}
|
|
718
985
|
|
|
719
986
|
|
|
987
|
+
static __global__ void array_fill_fabric_kernel(wp::fabricarray_t<void> fa, const void* value, int value_size)
|
|
988
|
+
{
|
|
989
|
+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
990
|
+
if (tid < fa.size)
|
|
991
|
+
{
|
|
992
|
+
void* dst_ptr = fabricarray_element_ptr(fa, tid, value_size);
|
|
993
|
+
memcpy(dst_ptr, value, value_size);
|
|
994
|
+
}
|
|
995
|
+
}
|
|
996
|
+
|
|
997
|
+
|
|
998
|
+
static __global__ void array_fill_fabric_indexed_kernel(wp::indexedfabricarray_t<void> ifa, const void* value, int value_size)
|
|
999
|
+
{
|
|
1000
|
+
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
|
1001
|
+
if (tid < ifa.size)
|
|
1002
|
+
{
|
|
1003
|
+
size_t idx = size_t(ifa.indices[tid]);
|
|
1004
|
+
if (idx < ifa.fa.size)
|
|
1005
|
+
{
|
|
1006
|
+
void* dst_ptr = fabricarray_element_ptr(ifa.fa, idx, value_size);
|
|
1007
|
+
memcpy(dst_ptr, value, value_size);
|
|
1008
|
+
}
|
|
1009
|
+
}
|
|
1010
|
+
}
|
|
1011
|
+
|
|
1012
|
+
|
|
720
1013
|
WP_API void array_fill_device(void* context, void* arr_ptr, int arr_type, const void* value_ptr, int value_size)
|
|
721
1014
|
{
|
|
722
1015
|
if (!arr_ptr || !value_ptr)
|
|
@@ -728,6 +1021,9 @@ WP_API void array_fill_device(void* context, void* arr_ptr, int arr_type, const
|
|
|
728
1021
|
const int* strides = NULL;
|
|
729
1022
|
const int*const* indices = NULL;
|
|
730
1023
|
|
|
1024
|
+
wp::fabricarray_t<void>* fa = NULL;
|
|
1025
|
+
wp::indexedfabricarray_t<void>* ifa = NULL;
|
|
1026
|
+
|
|
731
1027
|
const int* null_indices[wp::ARRAY_MAX_DIMS] = { NULL };
|
|
732
1028
|
|
|
733
1029
|
if (arr_type == wp::ARRAY_TYPE_REGULAR)
|
|
@@ -748,9 +1044,17 @@ WP_API void array_fill_device(void* context, void* arr_ptr, int arr_type, const
|
|
|
748
1044
|
strides = ia.arr.strides;
|
|
749
1045
|
indices = ia.indices;
|
|
750
1046
|
}
|
|
1047
|
+
else if (arr_type == wp::ARRAY_TYPE_FABRIC)
|
|
1048
|
+
{
|
|
1049
|
+
fa = static_cast<wp::fabricarray_t<void>*>(arr_ptr);
|
|
1050
|
+
}
|
|
1051
|
+
else if (arr_type == wp::ARRAY_TYPE_FABRIC_INDEXED)
|
|
1052
|
+
{
|
|
1053
|
+
ifa = static_cast<wp::indexedfabricarray_t<void>*>(arr_ptr);
|
|
1054
|
+
}
|
|
751
1055
|
else
|
|
752
1056
|
{
|
|
753
|
-
fprintf(stderr, "Warp error: Invalid array type id %d\n", arr_type);
|
|
1057
|
+
fprintf(stderr, "Warp fill error: Invalid array type id %d\n", arr_type);
|
|
754
1058
|
return;
|
|
755
1059
|
}
|
|
756
1060
|
|
|
@@ -765,6 +1069,21 @@ WP_API void array_fill_device(void* context, void* arr_ptr, int arr_type, const
|
|
|
765
1069
|
check_cuda(cudaMalloc(&value_devptr, value_size));
|
|
766
1070
|
check_cuda(cudaMemcpyAsync(value_devptr, value_ptr, value_size, cudaMemcpyHostToDevice, get_current_stream()));
|
|
767
1071
|
|
|
1072
|
+
// handle fabric arrays
|
|
1073
|
+
if (fa)
|
|
1074
|
+
{
|
|
1075
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_fabric_kernel, n,
|
|
1076
|
+
(*fa, value_devptr, value_size));
|
|
1077
|
+
return;
|
|
1078
|
+
}
|
|
1079
|
+
else if (ifa)
|
|
1080
|
+
{
|
|
1081
|
+
wp_launch_device(WP_CURRENT_CONTEXT, array_fill_fabric_indexed_kernel, n,
|
|
1082
|
+
(*ifa, value_devptr, value_size));
|
|
1083
|
+
return;
|
|
1084
|
+
}
|
|
1085
|
+
|
|
1086
|
+
// handle regular or indexed arrays
|
|
768
1087
|
switch (ndim)
|
|
769
1088
|
{
|
|
770
1089
|
case 1:
|
|
@@ -801,7 +1120,7 @@ WP_API void array_fill_device(void* context, void* arr_ptr, int arr_type, const
|
|
|
801
1120
|
break;
|
|
802
1121
|
}
|
|
803
1122
|
default:
|
|
804
|
-
fprintf(stderr, "Warp error: invalid array dimensionality (%d)\n", ndim);
|
|
1123
|
+
fprintf(stderr, "Warp fill error: invalid array dimensionality (%d)\n", ndim);
|
|
805
1124
|
return;
|
|
806
1125
|
}
|
|
807
1126
|
}
|
|
@@ -830,6 +1149,11 @@ int cuda_toolkit_version()
|
|
|
830
1149
|
return CUDA_VERSION;
|
|
831
1150
|
}
|
|
832
1151
|
|
|
1152
|
+
bool cuda_driver_is_initialized()
|
|
1153
|
+
{
|
|
1154
|
+
return is_cuda_driver_initialized();
|
|
1155
|
+
}
|
|
1156
|
+
|
|
833
1157
|
int nvrtc_supported_arch_count()
|
|
834
1158
|
{
|
|
835
1159
|
int count;
|
|
@@ -884,6 +1208,32 @@ int cuda_device_get_arch(int ordinal)
|
|
|
884
1208
|
return 0;
|
|
885
1209
|
}
|
|
886
1210
|
|
|
1211
|
+
void cuda_device_get_uuid(int ordinal, char uuid[16])
|
|
1212
|
+
{
|
|
1213
|
+
memcpy(uuid, g_devices[ordinal].uuid.bytes, sizeof(char)*16);
|
|
1214
|
+
}
|
|
1215
|
+
|
|
1216
|
+
int cuda_device_get_pci_domain_id(int ordinal)
|
|
1217
|
+
{
|
|
1218
|
+
if (ordinal >= 0 && ordinal < int(g_devices.size()))
|
|
1219
|
+
return g_devices[ordinal].pci_domain_id;
|
|
1220
|
+
return -1;
|
|
1221
|
+
}
|
|
1222
|
+
|
|
1223
|
+
int cuda_device_get_pci_bus_id(int ordinal)
|
|
1224
|
+
{
|
|
1225
|
+
if (ordinal >= 0 && ordinal < int(g_devices.size()))
|
|
1226
|
+
return g_devices[ordinal].pci_bus_id;
|
|
1227
|
+
return -1;
|
|
1228
|
+
}
|
|
1229
|
+
|
|
1230
|
+
int cuda_device_get_pci_device_id(int ordinal)
|
|
1231
|
+
{
|
|
1232
|
+
if (ordinal >= 0 && ordinal < int(g_devices.size()))
|
|
1233
|
+
return g_devices[ordinal].pci_device_id;
|
|
1234
|
+
return -1;
|
|
1235
|
+
}
|
|
1236
|
+
|
|
887
1237
|
int cuda_device_is_uva(int ordinal)
|
|
888
1238
|
{
|
|
889
1239
|
if (ordinal >= 0 && ordinal < int(g_devices.size()))
|
|
@@ -891,6 +1241,13 @@ int cuda_device_is_uva(int ordinal)
|
|
|
891
1241
|
return 0;
|
|
892
1242
|
}
|
|
893
1243
|
|
|
1244
|
+
int cuda_device_is_memory_pool_supported(int ordinal)
|
|
1245
|
+
{
|
|
1246
|
+
if (ordinal >= 0 && ordinal < int(g_devices.size()))
|
|
1247
|
+
return g_devices[ordinal].is_memory_pool_supported;
|
|
1248
|
+
return false;
|
|
1249
|
+
}
|
|
1250
|
+
|
|
894
1251
|
void* cuda_context_get_current()
|
|
895
1252
|
{
|
|
896
1253
|
return get_current_context();
|
|
@@ -999,6 +1356,16 @@ int cuda_context_is_primary(void* context)
|
|
|
999
1356
|
return 0;
|
|
1000
1357
|
}
|
|
1001
1358
|
|
|
1359
|
+
int cuda_context_is_memory_pool_supported(void* context)
|
|
1360
|
+
{
|
|
1361
|
+
int ordinal = cuda_context_get_device_ordinal(context);
|
|
1362
|
+
if (ordinal != -1)
|
|
1363
|
+
{
|
|
1364
|
+
return cuda_device_is_memory_pool_supported(ordinal);
|
|
1365
|
+
}
|
|
1366
|
+
return 0;
|
|
1367
|
+
}
|
|
1368
|
+
|
|
1002
1369
|
void* cuda_context_get_stream(void* context)
|
|
1003
1370
|
{
|
|
1004
1371
|
ContextInfo* info = get_context_info(static_cast<CUcontext>(context));
|
|
@@ -1208,10 +1575,10 @@ void* cuda_graph_end_capture(void* context)
|
|
|
1208
1575
|
//cudaGraphDebugDotPrint(graph, "graph.dot", cudaGraphDebugDotFlagsVerbose);
|
|
1209
1576
|
|
|
1210
1577
|
cudaGraphExec_t graph_exec = NULL;
|
|
1211
|
-
check_cuda(cudaGraphInstantiate(&graph_exec, graph, NULL, NULL, 0));
|
|
1578
|
+
//check_cuda(cudaGraphInstantiate(&graph_exec, graph, NULL, NULL, 0));
|
|
1212
1579
|
|
|
1213
1580
|
// can use after CUDA 11.4 to permit graphs to capture cudaMallocAsync() operations
|
|
1214
|
-
|
|
1581
|
+
check_cuda(cudaGraphInstantiateWithFlags(&graph_exec, graph, cudaGraphInstantiateFlagAutoFreeOnLaunch));
|
|
1215
1582
|
|
|
1216
1583
|
// free source graph
|
|
1217
1584
|
check_cuda(cudaGraphDestroy(graph));
|
|
@@ -1513,14 +1880,34 @@ void* cuda_get_kernel(void* context, void* module, const char* name)
|
|
|
1513
1880
|
return kernel;
|
|
1514
1881
|
}
|
|
1515
1882
|
|
|
1516
|
-
size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, void** args)
|
|
1883
|
+
size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_blocks, void** args)
|
|
1517
1884
|
{
|
|
1518
1885
|
ContextGuard guard(context);
|
|
1519
1886
|
|
|
1520
1887
|
const int block_dim = 256;
|
|
1521
1888
|
// CUDA specs up to compute capability 9.0 says the max x-dim grid is 2**31-1, so
|
|
1522
1889
|
// grid_dim is fine as an int for the near future
|
|
1523
|
-
|
|
1890
|
+
int grid_dim = (dim + block_dim - 1)/block_dim;
|
|
1891
|
+
|
|
1892
|
+
if (max_blocks <= 0) {
|
|
1893
|
+
max_blocks = 2147483647;
|
|
1894
|
+
}
|
|
1895
|
+
|
|
1896
|
+
if (grid_dim < 0)
|
|
1897
|
+
{
|
|
1898
|
+
#if defined(_DEBUG)
|
|
1899
|
+
fprintf(stderr, "Warp warning: Overflow in grid dimensions detected for %zu total elements and 256 threads "
|
|
1900
|
+
"per block.\n Setting block count to %d.\n", dim, max_blocks);
|
|
1901
|
+
#endif
|
|
1902
|
+
grid_dim = max_blocks;
|
|
1903
|
+
}
|
|
1904
|
+
else
|
|
1905
|
+
{
|
|
1906
|
+
if (grid_dim > max_blocks)
|
|
1907
|
+
{
|
|
1908
|
+
grid_dim = max_blocks;
|
|
1909
|
+
}
|
|
1910
|
+
}
|
|
1524
1911
|
|
|
1525
1912
|
CUresult res = cuLaunchKernel_f(
|
|
1526
1913
|
(CUfunction)kernel,
|