warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +15 -7
- warp/__init__.pyi +1 -0
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +22 -443
- warp/build_dll.py +384 -0
- warp/builtins.py +998 -488
- warp/codegen.py +1307 -739
- warp/config.py +5 -3
- warp/constants.py +6 -0
- warp/context.py +1291 -548
- warp/dlpack.py +31 -31
- warp/fabric.py +326 -0
- warp/fem/__init__.py +27 -0
- warp/fem/cache.py +389 -0
- warp/fem/dirichlet.py +181 -0
- warp/fem/domain.py +263 -0
- warp/fem/field/__init__.py +101 -0
- warp/fem/field/field.py +149 -0
- warp/fem/field/nodal_field.py +299 -0
- warp/fem/field/restriction.py +21 -0
- warp/fem/field/test.py +181 -0
- warp/fem/field/trial.py +183 -0
- warp/fem/geometry/__init__.py +19 -0
- warp/fem/geometry/closest_point.py +70 -0
- warp/fem/geometry/deformed_geometry.py +271 -0
- warp/fem/geometry/element.py +744 -0
- warp/fem/geometry/geometry.py +186 -0
- warp/fem/geometry/grid_2d.py +373 -0
- warp/fem/geometry/grid_3d.py +435 -0
- warp/fem/geometry/hexmesh.py +953 -0
- warp/fem/geometry/partition.py +376 -0
- warp/fem/geometry/quadmesh_2d.py +532 -0
- warp/fem/geometry/tetmesh.py +840 -0
- warp/fem/geometry/trimesh_2d.py +577 -0
- warp/fem/integrate.py +1616 -0
- warp/fem/operator.py +191 -0
- warp/fem/polynomial.py +213 -0
- warp/fem/quadrature/__init__.py +2 -0
- warp/fem/quadrature/pic_quadrature.py +245 -0
- warp/fem/quadrature/quadrature.py +294 -0
- warp/fem/space/__init__.py +292 -0
- warp/fem/space/basis_space.py +489 -0
- warp/fem/space/collocated_function_space.py +105 -0
- warp/fem/space/dof_mapper.py +236 -0
- warp/fem/space/function_space.py +145 -0
- warp/fem/space/grid_2d_function_space.py +267 -0
- warp/fem/space/grid_3d_function_space.py +306 -0
- warp/fem/space/hexmesh_function_space.py +352 -0
- warp/fem/space/partition.py +350 -0
- warp/fem/space/quadmesh_2d_function_space.py +369 -0
- warp/fem/space/restriction.py +160 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +738 -0
- warp/fem/space/shape/shape_function.py +103 -0
- warp/fem/space/shape/square_shape_function.py +611 -0
- warp/fem/space/shape/tet_shape_function.py +567 -0
- warp/fem/space/shape/triangle_shape_function.py +429 -0
- warp/fem/space/tetmesh_function_space.py +292 -0
- warp/fem/space/topology.py +295 -0
- warp/fem/space/trimesh_2d_function_space.py +221 -0
- warp/fem/types.py +77 -0
- warp/fem/utils.py +495 -0
- warp/native/array.h +164 -55
- warp/native/builtin.h +150 -174
- warp/native/bvh.cpp +75 -328
- warp/native/bvh.cu +406 -23
- warp/native/bvh.h +37 -45
- warp/native/clang/clang.cpp +136 -24
- warp/native/crt.cpp +1 -76
- warp/native/crt.h +111 -104
- warp/native/cuda_crt.h +1049 -0
- warp/native/cuda_util.cpp +15 -3
- warp/native/cuda_util.h +3 -1
- warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
- warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
- warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
- warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
- warp/native/cutlass/tools/library/scripts/library.py +799 -0
- warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
- warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
- warp/native/cutlass/tools/library/scripts/rt.py +796 -0
- warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
- warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
- warp/native/cutlass_gemm.cu +5 -3
- warp/native/exports.h +1240 -949
- warp/native/fabric.h +228 -0
- warp/native/hashgrid.cpp +4 -4
- warp/native/hashgrid.h +22 -2
- warp/native/initializer_array.h +2 -2
- warp/native/intersect.h +22 -7
- warp/native/intersect_adj.h +8 -8
- warp/native/intersect_tri.h +13 -16
- warp/native/marching.cu +157 -161
- warp/native/mat.h +119 -19
- warp/native/matnn.h +2 -2
- warp/native/mesh.cpp +108 -83
- warp/native/mesh.cu +243 -6
- warp/native/mesh.h +1547 -458
- warp/native/nanovdb/NanoVDB.h +1 -1
- warp/native/noise.h +272 -329
- warp/native/quat.h +51 -8
- warp/native/rand.h +45 -35
- warp/native/range.h +6 -2
- warp/native/reduce.cpp +157 -0
- warp/native/reduce.cu +348 -0
- warp/native/runlength_encode.cpp +62 -0
- warp/native/runlength_encode.cu +46 -0
- warp/native/scan.cu +11 -13
- warp/native/scan.h +1 -0
- warp/native/solid_angle.h +442 -0
- warp/native/sort.cpp +13 -0
- warp/native/sort.cu +9 -1
- warp/native/sparse.cpp +338 -0
- warp/native/sparse.cu +545 -0
- warp/native/spatial.h +2 -2
- warp/native/temp_buffer.h +30 -0
- warp/native/vec.h +126 -24
- warp/native/volume.h +120 -0
- warp/native/warp.cpp +658 -53
- warp/native/warp.cu +660 -68
- warp/native/warp.h +112 -12
- warp/optim/__init__.py +1 -0
- warp/optim/linear.py +922 -0
- warp/optim/sgd.py +92 -0
- warp/render/render_opengl.py +392 -152
- warp/render/render_usd.py +11 -11
- warp/sim/__init__.py +2 -2
- warp/sim/articulation.py +385 -185
- warp/sim/collide.py +21 -8
- warp/sim/import_mjcf.py +297 -106
- warp/sim/import_urdf.py +389 -210
- warp/sim/import_usd.py +198 -97
- warp/sim/inertia.py +17 -18
- warp/sim/integrator_euler.py +14 -8
- warp/sim/integrator_xpbd.py +161 -19
- warp/sim/model.py +795 -291
- warp/sim/optimizer.py +2 -6
- warp/sim/render.py +65 -3
- warp/sim/utils.py +3 -0
- warp/sparse.py +1227 -0
- warp/stubs.py +665 -223
- warp/tape.py +66 -15
- warp/tests/__main__.py +3 -6
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/assets/torus.usda +105 -105
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/disabled_kinematics.py +239 -0
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +103 -106
- warp/tests/test_arithmetic.py +128 -74
- warp/tests/test_array.py +1497 -211
- warp/tests/test_array_reduce.py +150 -0
- warp/tests/test_atomic.py +64 -28
- warp/tests/test_bool.py +99 -0
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +75 -43
- warp/tests/test_closest_point_edge_edge.py +54 -57
- warp/tests/test_codegen.py +233 -128
- warp/tests/test_compile_consts.py +28 -20
- warp/tests/test_conditional.py +108 -24
- warp/tests/test_copy.py +10 -12
- warp/tests/test_ctypes.py +112 -88
- warp/tests/test_dense.py +21 -14
- warp/tests/test_devices.py +98 -0
- warp/tests/test_dlpack.py +136 -108
- warp/tests/test_examples.py +277 -0
- warp/tests/test_fabricarray.py +955 -0
- warp/tests/test_fast_math.py +15 -11
- warp/tests/test_fem.py +1271 -0
- warp/tests/test_fp16.py +53 -19
- warp/tests/test_func.py +187 -74
- warp/tests/test_generics.py +194 -49
- warp/tests/test_grad.py +180 -116
- warp/tests/test_grad_customs.py +176 -0
- warp/tests/test_hash_grid.py +52 -37
- warp/tests/test_import.py +10 -23
- warp/tests/test_indexedarray.py +577 -24
- warp/tests/test_intersect.py +18 -9
- warp/tests/test_large.py +141 -0
- warp/tests/test_launch.py +251 -15
- warp/tests/test_lerp.py +64 -65
- warp/tests/test_linear_solvers.py +154 -0
- warp/tests/test_lvalue.py +493 -0
- warp/tests/test_marching_cubes.py +12 -13
- warp/tests/test_mat.py +508 -2778
- warp/tests/test_mat_lite.py +115 -0
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +103 -9
- warp/tests/test_matmul.py +305 -69
- warp/tests/test_matmul_lite.py +410 -0
- warp/tests/test_mesh.py +71 -14
- warp/tests/test_mesh_query_aabb.py +41 -25
- warp/tests/test_mesh_query_point.py +325 -34
- warp/tests/test_mesh_query_ray.py +39 -22
- warp/tests/test_mlp.py +30 -22
- warp/tests/test_model.py +92 -89
- warp/tests/test_modules_lite.py +39 -0
- warp/tests/test_multigpu.py +88 -114
- warp/tests/test_noise.py +12 -11
- warp/tests/test_operators.py +16 -20
- warp/tests/test_options.py +11 -11
- warp/tests/test_pinned.py +17 -18
- warp/tests/test_print.py +32 -11
- warp/tests/test_quat.py +275 -129
- warp/tests/test_rand.py +18 -16
- warp/tests/test_reload.py +38 -34
- warp/tests/test_rounding.py +50 -43
- warp/tests/test_runlength_encode.py +190 -0
- warp/tests/test_smoothstep.py +9 -11
- warp/tests/test_snippet.py +143 -0
- warp/tests/test_sparse.py +460 -0
- warp/tests/test_spatial.py +276 -243
- warp/tests/test_streams.py +110 -85
- warp/tests/test_struct.py +331 -85
- warp/tests/test_tape.py +39 -21
- warp/tests/test_torch.py +118 -89
- warp/tests/test_transient_module.py +12 -13
- warp/tests/test_types.py +614 -0
- warp/tests/test_utils.py +494 -0
- warp/tests/test_vec.py +354 -1987
- warp/tests/test_vec_lite.py +73 -0
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +457 -293
- warp/tests/test_volume_write.py +124 -134
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +341 -0
- warp/tests/unittest_utils.py +568 -0
- warp/tests/unused_test_misc.py +71 -0
- warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
- warp/thirdparty/appdirs.py +36 -45
- warp/thirdparty/unittest_parallel.py +549 -0
- warp/torch.py +72 -30
- warp/types.py +1744 -713
- warp/utils.py +360 -350
- warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
- warp_lang-0.11.0.dist-info/METADATA +238 -0
- warp_lang-0.11.0.dist-info/RECORD +332 -0
- {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
- warp/bin/warp-clang.exp +0 -0
- warp/bin/warp-clang.lib +0 -0
- warp/bin/warp.exp +0 -0
- warp/bin/warp.lib +0 -0
- warp/tests/test_all.py +0 -215
- warp/tests/test_array_scan.py +0 -60
- warp/tests/test_base.py +0 -208
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- warp_lang-0.9.0.dist-info/METADATA +0 -20
- warp_lang-0.9.0.dist-info/RECORD +0 -177
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/native/warp.cpp
CHANGED
|
@@ -10,33 +10,99 @@
|
|
|
10
10
|
#include "scan.h"
|
|
11
11
|
#include "array.h"
|
|
12
12
|
|
|
13
|
+
#include "exports.h"
|
|
14
|
+
|
|
13
15
|
#include "stdlib.h"
|
|
14
16
|
#include "string.h"
|
|
15
17
|
|
|
18
|
+
int cuda_init();
|
|
16
19
|
|
|
17
|
-
namespace wp
|
|
18
|
-
{
|
|
19
20
|
|
|
20
|
-
|
|
21
|
+
uint16_t float_to_half_bits(float x)
|
|
21
22
|
{
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
23
|
+
// adapted from Fabien Giesen's post: https://gist.github.com/rygorous/2156668
|
|
24
|
+
union fp32
|
|
25
|
+
{
|
|
26
|
+
uint32_t u;
|
|
27
|
+
float f;
|
|
26
28
|
|
|
27
|
-
|
|
29
|
+
struct
|
|
30
|
+
{
|
|
31
|
+
unsigned int mantissa : 23;
|
|
32
|
+
unsigned int exponent : 8;
|
|
33
|
+
unsigned int sign : 1;
|
|
34
|
+
};
|
|
35
|
+
};
|
|
36
|
+
|
|
37
|
+
fp32 f;
|
|
38
|
+
f.f = x;
|
|
39
|
+
|
|
40
|
+
fp32 f32infty = { 255 << 23 };
|
|
41
|
+
fp32 f16infty = { 31 << 23 };
|
|
42
|
+
fp32 magic = { 15 << 23 };
|
|
43
|
+
uint32_t sign_mask = 0x80000000u;
|
|
44
|
+
uint32_t round_mask = ~0xfffu;
|
|
45
|
+
uint16_t u;
|
|
46
|
+
|
|
47
|
+
uint32_t sign = f.u & sign_mask;
|
|
48
|
+
f.u ^= sign;
|
|
49
|
+
|
|
50
|
+
// NOTE all the integer compares in this function can be safely
|
|
51
|
+
// compiled into signed compares since all operands are below
|
|
52
|
+
// 0x80000000. Important if you want fast straight SSE2 code
|
|
53
|
+
// (since there's no unsigned PCMPGTD).
|
|
54
|
+
|
|
55
|
+
if (f.u >= f32infty.u) // Inf or NaN (all exponent bits set)
|
|
56
|
+
u = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf
|
|
57
|
+
else // (De)normalized number or zero
|
|
58
|
+
{
|
|
59
|
+
f.u &= round_mask;
|
|
60
|
+
f.f *= magic.f;
|
|
61
|
+
f.u -= round_mask;
|
|
62
|
+
if (f.u > f16infty.u) f.u = f16infty.u; // Clamp to signed infinity if overflowed
|
|
28
63
|
|
|
64
|
+
u = f.u >> 13; // Take the bits!
|
|
65
|
+
}
|
|
29
66
|
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
return wp::half(x).u;
|
|
67
|
+
u |= sign >> 16;
|
|
68
|
+
return u;
|
|
33
69
|
}
|
|
34
70
|
|
|
35
71
|
float half_bits_to_float(uint16_t u)
|
|
36
72
|
{
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
73
|
+
// adapted from Fabien Giesen's post: https://gist.github.com/rygorous/2156668
|
|
74
|
+
union fp32
|
|
75
|
+
{
|
|
76
|
+
uint32_t u;
|
|
77
|
+
float f;
|
|
78
|
+
|
|
79
|
+
struct
|
|
80
|
+
{
|
|
81
|
+
unsigned int mantissa : 23;
|
|
82
|
+
unsigned int exponent : 8;
|
|
83
|
+
unsigned int sign : 1;
|
|
84
|
+
};
|
|
85
|
+
};
|
|
86
|
+
|
|
87
|
+
static const fp32 magic = { 113 << 23 };
|
|
88
|
+
static const uint32_t shifted_exp = 0x7c00 << 13; // exponent mask after shift
|
|
89
|
+
fp32 o;
|
|
90
|
+
|
|
91
|
+
o.u = (u & 0x7fff) << 13; // exponent/mantissa bits
|
|
92
|
+
uint32_t exp = shifted_exp & o.u; // just the exponent
|
|
93
|
+
o.u += (127 - 15) << 23; // exponent adjust
|
|
94
|
+
|
|
95
|
+
// handle exponent special cases
|
|
96
|
+
if (exp == shifted_exp) // Inf/NaN?
|
|
97
|
+
o.u += (128 - 16) << 23; // extra exp adjust
|
|
98
|
+
else if (exp == 0) // Zero/Denormal?
|
|
99
|
+
{
|
|
100
|
+
o.u += 1 << 23; // extra exp adjust
|
|
101
|
+
o.f -= magic.f; // renormalize
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
o.u |= (u & 0x8000) << 16; // sign bit
|
|
105
|
+
return o.f;
|
|
40
106
|
}
|
|
41
107
|
|
|
42
108
|
int init()
|
|
@@ -102,34 +168,38 @@ void memset_host(void* dest, int value, size_t n)
|
|
|
102
168
|
}
|
|
103
169
|
}
|
|
104
170
|
|
|
105
|
-
|
|
171
|
+
// fill memory buffer with a value: this is a faster memtile variant
|
|
172
|
+
// for types bigger than one byte, but requires proper alignment of dst
|
|
173
|
+
template <typename T>
|
|
174
|
+
void memtile_value_host(T* dst, T value, size_t n)
|
|
106
175
|
{
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
memcpy(dest,src,srcsize);
|
|
110
|
-
dest = (char*)dest + srcsize;
|
|
111
|
-
}
|
|
176
|
+
while (n--)
|
|
177
|
+
*dst++ = value;
|
|
112
178
|
}
|
|
113
179
|
|
|
114
|
-
void
|
|
180
|
+
void memtile_host(void* dst, const void* src, size_t srcsize, size_t n)
|
|
115
181
|
{
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
float* ptr_out = (float*)(out);
|
|
119
|
-
|
|
120
|
-
*ptr_out = 0.0f;
|
|
121
|
-
for (int i=0; i < len; ++i)
|
|
122
|
-
*ptr_out += ptr_a[i]*ptr_b[i];
|
|
123
|
-
}
|
|
182
|
+
size_t dst_addr = reinterpret_cast<size_t>(dst);
|
|
183
|
+
size_t src_addr = reinterpret_cast<size_t>(src);
|
|
124
184
|
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
185
|
+
// try memtile_value first because it should be faster, but we need to ensure proper alignment
|
|
186
|
+
if (srcsize == 8 && (dst_addr & 7) == 0 && (src_addr & 7) == 0)
|
|
187
|
+
memtile_value_host(reinterpret_cast<int64_t*>(dst), *reinterpret_cast<const int64_t*>(src), n);
|
|
188
|
+
else if (srcsize == 4 && (dst_addr & 3) == 0 && (src_addr & 3) == 0)
|
|
189
|
+
memtile_value_host(reinterpret_cast<int32_t*>(dst), *reinterpret_cast<const int32_t*>(src), n);
|
|
190
|
+
else if (srcsize == 2 && (dst_addr & 1) == 0 && (src_addr & 1) == 0)
|
|
191
|
+
memtile_value_host(reinterpret_cast<int16_t*>(dst), *reinterpret_cast<const int16_t*>(src), n);
|
|
192
|
+
else if (srcsize == 1)
|
|
193
|
+
memset(dst, *reinterpret_cast<const int8_t*>(src), n);
|
|
194
|
+
else
|
|
195
|
+
{
|
|
196
|
+
// generic version
|
|
197
|
+
while (n--)
|
|
198
|
+
{
|
|
199
|
+
memcpy(dst, src, srcsize);
|
|
200
|
+
dst = (int8_t*)dst + srcsize;
|
|
201
|
+
}
|
|
202
|
+
}
|
|
133
203
|
}
|
|
134
204
|
|
|
135
205
|
void array_scan_int_host(uint64_t in, uint64_t out, int len, bool inclusive)
|
|
@@ -175,6 +245,312 @@ static void array_copy_nd(void* dst, const void* src,
|
|
|
175
245
|
}
|
|
176
246
|
|
|
177
247
|
|
|
248
|
+
static void array_copy_to_fabric(wp::fabricarray_t<void>& dst, const void* src_data,
|
|
249
|
+
int src_stride, const int* src_indices, int elem_size)
|
|
250
|
+
{
|
|
251
|
+
const int8_t* src_ptr = static_cast<const int8_t*>(src_data);
|
|
252
|
+
|
|
253
|
+
if (src_indices)
|
|
254
|
+
{
|
|
255
|
+
// copy from indexed array
|
|
256
|
+
for (size_t i = 0; i < dst.nbuckets; i++)
|
|
257
|
+
{
|
|
258
|
+
const wp::fabricbucket_t& bucket = dst.buckets[i];
|
|
259
|
+
int8_t* dst_ptr = static_cast<int8_t*>(bucket.ptr);
|
|
260
|
+
size_t bucket_size = bucket.index_end - bucket.index_start;
|
|
261
|
+
for (size_t j = 0; j < bucket_size; j++)
|
|
262
|
+
{
|
|
263
|
+
int idx = *src_indices;
|
|
264
|
+
memcpy(dst_ptr, src_ptr + idx * elem_size, elem_size);
|
|
265
|
+
dst_ptr += elem_size;
|
|
266
|
+
++src_indices;
|
|
267
|
+
}
|
|
268
|
+
}
|
|
269
|
+
}
|
|
270
|
+
else
|
|
271
|
+
{
|
|
272
|
+
if (src_stride == elem_size)
|
|
273
|
+
{
|
|
274
|
+
// copy from contiguous array
|
|
275
|
+
for (size_t i = 0; i < dst.nbuckets; i++)
|
|
276
|
+
{
|
|
277
|
+
const wp::fabricbucket_t& bucket = dst.buckets[i];
|
|
278
|
+
size_t num_bytes = (bucket.index_end - bucket.index_start) * elem_size;
|
|
279
|
+
memcpy(bucket.ptr, src_ptr, num_bytes);
|
|
280
|
+
src_ptr += num_bytes;
|
|
281
|
+
}
|
|
282
|
+
}
|
|
283
|
+
else
|
|
284
|
+
{
|
|
285
|
+
// copy from strided array
|
|
286
|
+
for (size_t i = 0; i < dst.nbuckets; i++)
|
|
287
|
+
{
|
|
288
|
+
const wp::fabricbucket_t& bucket = dst.buckets[i];
|
|
289
|
+
int8_t* dst_ptr = static_cast<int8_t*>(bucket.ptr);
|
|
290
|
+
size_t bucket_size = bucket.index_end - bucket.index_start;
|
|
291
|
+
for (size_t j = 0; j < bucket_size; j++)
|
|
292
|
+
{
|
|
293
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
294
|
+
src_ptr += src_stride;
|
|
295
|
+
dst_ptr += elem_size;
|
|
296
|
+
}
|
|
297
|
+
}
|
|
298
|
+
}
|
|
299
|
+
}
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
static void array_copy_from_fabric(const wp::fabricarray_t<void>& src, void* dst_data,
|
|
303
|
+
int dst_stride, const int* dst_indices, int elem_size)
|
|
304
|
+
{
|
|
305
|
+
int8_t* dst_ptr = static_cast<int8_t*>(dst_data);
|
|
306
|
+
|
|
307
|
+
if (dst_indices)
|
|
308
|
+
{
|
|
309
|
+
// copy to indexed array
|
|
310
|
+
for (size_t i = 0; i < src.nbuckets; i++)
|
|
311
|
+
{
|
|
312
|
+
const wp::fabricbucket_t& bucket = src.buckets[i];
|
|
313
|
+
const int8_t* src_ptr = static_cast<const int8_t*>(bucket.ptr);
|
|
314
|
+
size_t bucket_size = bucket.index_end - bucket.index_start;
|
|
315
|
+
for (size_t j = 0; j < bucket_size; j++)
|
|
316
|
+
{
|
|
317
|
+
int idx = *dst_indices;
|
|
318
|
+
memcpy(dst_ptr + idx * elem_size, src_ptr, elem_size);
|
|
319
|
+
src_ptr += elem_size;
|
|
320
|
+
++dst_indices;
|
|
321
|
+
}
|
|
322
|
+
}
|
|
323
|
+
}
|
|
324
|
+
else
|
|
325
|
+
{
|
|
326
|
+
if (dst_stride == elem_size)
|
|
327
|
+
{
|
|
328
|
+
// copy to contiguous array
|
|
329
|
+
for (size_t i = 0; i < src.nbuckets; i++)
|
|
330
|
+
{
|
|
331
|
+
const wp::fabricbucket_t& bucket = src.buckets[i];
|
|
332
|
+
size_t num_bytes = (bucket.index_end - bucket.index_start) * elem_size;
|
|
333
|
+
memcpy(dst_ptr, bucket.ptr, num_bytes);
|
|
334
|
+
dst_ptr += num_bytes;
|
|
335
|
+
}
|
|
336
|
+
}
|
|
337
|
+
else
|
|
338
|
+
{
|
|
339
|
+
// copy to strided array
|
|
340
|
+
for (size_t i = 0; i < src.nbuckets; i++)
|
|
341
|
+
{
|
|
342
|
+
const wp::fabricbucket_t& bucket = src.buckets[i];
|
|
343
|
+
const int8_t* src_ptr = static_cast<const int8_t*>(bucket.ptr);
|
|
344
|
+
size_t bucket_size = bucket.index_end - bucket.index_start;
|
|
345
|
+
for (size_t j = 0; j < bucket_size; j++)
|
|
346
|
+
{
|
|
347
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
348
|
+
dst_ptr += dst_stride;
|
|
349
|
+
src_ptr += elem_size;
|
|
350
|
+
}
|
|
351
|
+
}
|
|
352
|
+
}
|
|
353
|
+
}
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
static void array_copy_fabric_to_fabric(wp::fabricarray_t<void>& dst, const wp::fabricarray_t<void>& src, int elem_size)
|
|
357
|
+
{
|
|
358
|
+
wp::fabricbucket_t* dst_bucket = dst.buckets;
|
|
359
|
+
const wp::fabricbucket_t* src_bucket = src.buckets;
|
|
360
|
+
int8_t* dst_ptr = static_cast<int8_t*>(dst_bucket->ptr);
|
|
361
|
+
const int8_t* src_ptr = static_cast<const int8_t*>(src_bucket->ptr);
|
|
362
|
+
size_t dst_remaining = dst_bucket->index_end - dst_bucket->index_start;
|
|
363
|
+
size_t src_remaining = src_bucket->index_end - src_bucket->index_start;
|
|
364
|
+
size_t total_copied = 0;
|
|
365
|
+
|
|
366
|
+
while (total_copied < dst.size)
|
|
367
|
+
{
|
|
368
|
+
if (dst_remaining <= src_remaining)
|
|
369
|
+
{
|
|
370
|
+
// copy to destination bucket
|
|
371
|
+
size_t num_elems = dst_remaining;
|
|
372
|
+
size_t num_bytes = num_elems * elem_size;
|
|
373
|
+
memcpy(dst_ptr, src_ptr, num_bytes);
|
|
374
|
+
|
|
375
|
+
// advance to next destination bucket
|
|
376
|
+
++dst_bucket;
|
|
377
|
+
dst_ptr = static_cast<int8_t*>(dst_bucket->ptr);
|
|
378
|
+
dst_remaining = dst_bucket->index_end - dst_bucket->index_start;
|
|
379
|
+
|
|
380
|
+
// advance source offset
|
|
381
|
+
src_ptr += num_bytes;
|
|
382
|
+
src_remaining -= num_elems;
|
|
383
|
+
|
|
384
|
+
total_copied += num_elems;
|
|
385
|
+
}
|
|
386
|
+
else
|
|
387
|
+
{
|
|
388
|
+
// copy to destination bucket
|
|
389
|
+
size_t num_elems = src_remaining;
|
|
390
|
+
size_t num_bytes = num_elems * elem_size;
|
|
391
|
+
memcpy(dst_ptr, src_ptr, num_bytes);
|
|
392
|
+
|
|
393
|
+
// advance to next source bucket
|
|
394
|
+
++src_bucket;
|
|
395
|
+
src_ptr = static_cast<const int8_t*>(src_bucket->ptr);
|
|
396
|
+
src_remaining = src_bucket->index_end - src_bucket->index_start;
|
|
397
|
+
|
|
398
|
+
// advance destination offset
|
|
399
|
+
dst_ptr += num_bytes;
|
|
400
|
+
dst_remaining -= num_elems;
|
|
401
|
+
|
|
402
|
+
total_copied += num_elems;
|
|
403
|
+
}
|
|
404
|
+
}
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
static void array_copy_to_fabric_indexed(wp::indexedfabricarray_t<void>& dst, const void* src_data,
|
|
409
|
+
int src_stride, const int* src_indices, int elem_size)
|
|
410
|
+
{
|
|
411
|
+
const int8_t* src_ptr = static_cast<const int8_t*>(src_data);
|
|
412
|
+
|
|
413
|
+
if (src_indices)
|
|
414
|
+
{
|
|
415
|
+
// copy from indexed array
|
|
416
|
+
for (size_t i = 0; i < dst.size; i++)
|
|
417
|
+
{
|
|
418
|
+
size_t src_idx = src_indices[i];
|
|
419
|
+
size_t dst_idx = dst.indices[i];
|
|
420
|
+
void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_idx, elem_size);
|
|
421
|
+
memcpy(dst_ptr, src_ptr + dst_idx * elem_size, elem_size);
|
|
422
|
+
}
|
|
423
|
+
}
|
|
424
|
+
else
|
|
425
|
+
{
|
|
426
|
+
// copy from contiguous/strided array
|
|
427
|
+
for (size_t i = 0; i < dst.size; i++)
|
|
428
|
+
{
|
|
429
|
+
size_t dst_idx = dst.indices[i];
|
|
430
|
+
void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_idx, elem_size);
|
|
431
|
+
if (dst_ptr)
|
|
432
|
+
{
|
|
433
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
434
|
+
src_ptr += src_stride;
|
|
435
|
+
}
|
|
436
|
+
}
|
|
437
|
+
}
|
|
438
|
+
}
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
static void array_copy_fabric_indexed_to_fabric(wp::fabricarray_t<void>& dst, const wp::indexedfabricarray_t<void>& src, int elem_size)
|
|
442
|
+
{
|
|
443
|
+
wp::fabricbucket_t* dst_bucket = dst.buckets;
|
|
444
|
+
int8_t* dst_ptr = static_cast<int8_t*>(dst_bucket->ptr);
|
|
445
|
+
int8_t* dst_end = dst_ptr + elem_size * (dst_bucket->index_end - dst_bucket->index_start);
|
|
446
|
+
|
|
447
|
+
for (size_t i = 0; i < src.size; i++)
|
|
448
|
+
{
|
|
449
|
+
size_t src_idx = src.indices[i];
|
|
450
|
+
const void* src_ptr = fabricarray_element_ptr(src.fa, src_idx, elem_size);
|
|
451
|
+
|
|
452
|
+
if (dst_ptr >= dst_end)
|
|
453
|
+
{
|
|
454
|
+
// advance to next destination bucket
|
|
455
|
+
++dst_bucket;
|
|
456
|
+
dst_ptr = static_cast<int8_t*>(dst_bucket->ptr);
|
|
457
|
+
dst_end = dst_ptr + elem_size * (dst_bucket->index_end - dst_bucket->index_start);
|
|
458
|
+
}
|
|
459
|
+
|
|
460
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
461
|
+
|
|
462
|
+
dst_ptr += elem_size;
|
|
463
|
+
}
|
|
464
|
+
}
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
static void array_copy_fabric_indexed_to_fabric_indexed(wp::indexedfabricarray_t<void>& dst, const wp::indexedfabricarray_t<void>& src, int elem_size)
|
|
468
|
+
{
|
|
469
|
+
for (size_t i = 0; i < src.size; i++)
|
|
470
|
+
{
|
|
471
|
+
size_t src_idx = src.indices[i];
|
|
472
|
+
size_t dst_idx = dst.indices[i];
|
|
473
|
+
|
|
474
|
+
const void* src_ptr = fabricarray_element_ptr(src.fa, src_idx, elem_size);
|
|
475
|
+
void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_idx, elem_size);
|
|
476
|
+
|
|
477
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
478
|
+
}
|
|
479
|
+
}
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
static void array_copy_fabric_to_fabric_indexed(wp::indexedfabricarray_t<void>& dst, const wp::fabricarray_t<void>& src, int elem_size)
|
|
483
|
+
{
|
|
484
|
+
wp::fabricbucket_t* src_bucket = src.buckets;
|
|
485
|
+
const int8_t* src_ptr = static_cast<const int8_t*>(src_bucket->ptr);
|
|
486
|
+
const int8_t* src_end = src_ptr + elem_size * (src_bucket->index_end - src_bucket->index_start);
|
|
487
|
+
|
|
488
|
+
for (size_t i = 0; i < dst.size; i++)
|
|
489
|
+
{
|
|
490
|
+
size_t dst_idx = dst.indices[i];
|
|
491
|
+
void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_idx, elem_size);
|
|
492
|
+
|
|
493
|
+
if (src_ptr >= src_end)
|
|
494
|
+
{
|
|
495
|
+
// advance to next source bucket
|
|
496
|
+
++src_bucket;
|
|
497
|
+
src_ptr = static_cast<int8_t*>(src_bucket->ptr);
|
|
498
|
+
src_end = src_ptr + elem_size * (src_bucket->index_end - src_bucket->index_start);
|
|
499
|
+
}
|
|
500
|
+
|
|
501
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
502
|
+
|
|
503
|
+
src_ptr += elem_size;
|
|
504
|
+
}
|
|
505
|
+
}
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
static void array_copy_from_fabric_indexed(const wp::indexedfabricarray_t<void>& src, void* dst_data,
|
|
509
|
+
int dst_stride, const int* dst_indices, int elem_size)
|
|
510
|
+
{
|
|
511
|
+
int8_t* dst_ptr = static_cast<int8_t*>(dst_data);
|
|
512
|
+
|
|
513
|
+
if (dst_indices)
|
|
514
|
+
{
|
|
515
|
+
// copy to indexed array
|
|
516
|
+
for (size_t i = 0; i < src.size; i++)
|
|
517
|
+
{
|
|
518
|
+
size_t idx = src.indices[i];
|
|
519
|
+
if (idx < src.fa.size)
|
|
520
|
+
{
|
|
521
|
+
const void* src_ptr = fabricarray_element_ptr(src.fa, idx, elem_size);
|
|
522
|
+
int dst_idx = dst_indices[i];
|
|
523
|
+
memcpy(dst_ptr + dst_idx * elem_size, src_ptr, elem_size);
|
|
524
|
+
}
|
|
525
|
+
else
|
|
526
|
+
{
|
|
527
|
+
fprintf(stderr, "Warp copy error: Source index %llu is out of bounds for fabric array of size %llu",
|
|
528
|
+
(unsigned long long)idx, (unsigned long long)src.fa.size);
|
|
529
|
+
}
|
|
530
|
+
}
|
|
531
|
+
}
|
|
532
|
+
else
|
|
533
|
+
{
|
|
534
|
+
// copy to contiguous/strided array
|
|
535
|
+
for (size_t i = 0; i < src.size; i++)
|
|
536
|
+
{
|
|
537
|
+
size_t idx = src.indices[i];
|
|
538
|
+
if (idx < src.fa.size)
|
|
539
|
+
{
|
|
540
|
+
const void* src_ptr = fabricarray_element_ptr(src.fa, idx, elem_size);
|
|
541
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
542
|
+
dst_ptr += dst_stride;
|
|
543
|
+
}
|
|
544
|
+
else
|
|
545
|
+
{
|
|
546
|
+
fprintf(stderr, "Warp copy error: Source index %llu is out of bounds for fabric array of size %llu",
|
|
547
|
+
(unsigned long long)idx, (unsigned long long)src.fa.size);
|
|
548
|
+
}
|
|
549
|
+
}
|
|
550
|
+
}
|
|
551
|
+
}
|
|
552
|
+
|
|
553
|
+
|
|
178
554
|
WP_API size_t array_copy_host(void* dst, void* src, int dst_type, int src_type, int elem_size)
|
|
179
555
|
{
|
|
180
556
|
if (!src || !dst)
|
|
@@ -193,6 +569,12 @@ WP_API size_t array_copy_host(void* dst, void* src, int dst_type, int src_type,
|
|
|
193
569
|
const int*const* src_indices = NULL;
|
|
194
570
|
const int*const* dst_indices = NULL;
|
|
195
571
|
|
|
572
|
+
const wp::fabricarray_t<void>* src_fabricarray = NULL;
|
|
573
|
+
wp::fabricarray_t<void>* dst_fabricarray = NULL;
|
|
574
|
+
|
|
575
|
+
const wp::indexedfabricarray_t<void>* src_indexedfabricarray = NULL;
|
|
576
|
+
wp::indexedfabricarray_t<void>* dst_indexedfabricarray = NULL;
|
|
577
|
+
|
|
196
578
|
const int* null_indices[wp::ARRAY_MAX_DIMS] = { NULL };
|
|
197
579
|
|
|
198
580
|
if (src_type == wp::ARRAY_TYPE_REGULAR)
|
|
@@ -214,9 +596,19 @@ WP_API size_t array_copy_host(void* dst, void* src, int dst_type, int src_type,
|
|
|
214
596
|
src_strides = src_arr.arr.strides;
|
|
215
597
|
src_indices = src_arr.indices;
|
|
216
598
|
}
|
|
599
|
+
else if (src_type == wp::ARRAY_TYPE_FABRIC)
|
|
600
|
+
{
|
|
601
|
+
src_fabricarray = static_cast<const wp::fabricarray_t<void>*>(src);
|
|
602
|
+
src_ndim = 1;
|
|
603
|
+
}
|
|
604
|
+
else if (src_type == wp::ARRAY_TYPE_FABRIC_INDEXED)
|
|
605
|
+
{
|
|
606
|
+
src_indexedfabricarray = static_cast<const wp::indexedfabricarray_t<void>*>(src);
|
|
607
|
+
src_ndim = 1;
|
|
608
|
+
}
|
|
217
609
|
else
|
|
218
610
|
{
|
|
219
|
-
fprintf(stderr, "Warp error: Invalid array type (%d)\n", src_type);
|
|
611
|
+
fprintf(stderr, "Warp copy error: Invalid source array type (%d)\n", src_type);
|
|
220
612
|
return 0;
|
|
221
613
|
}
|
|
222
614
|
|
|
@@ -239,26 +631,134 @@ WP_API size_t array_copy_host(void* dst, void* src, int dst_type, int src_type,
|
|
|
239
631
|
dst_strides = dst_arr.arr.strides;
|
|
240
632
|
dst_indices = dst_arr.indices;
|
|
241
633
|
}
|
|
634
|
+
else if (dst_type == wp::ARRAY_TYPE_FABRIC)
|
|
635
|
+
{
|
|
636
|
+
dst_fabricarray = static_cast<wp::fabricarray_t<void>*>(dst);
|
|
637
|
+
dst_ndim = 1;
|
|
638
|
+
}
|
|
639
|
+
else if (dst_type == wp::ARRAY_TYPE_FABRIC_INDEXED)
|
|
640
|
+
{
|
|
641
|
+
dst_indexedfabricarray = static_cast<wp::indexedfabricarray_t<void>*>(dst);
|
|
642
|
+
dst_ndim = 1;
|
|
643
|
+
}
|
|
242
644
|
else
|
|
243
645
|
{
|
|
244
|
-
fprintf(stderr, "Warp error: Invalid array type (%d)\n", dst_type);
|
|
646
|
+
fprintf(stderr, "Warp copy error: Invalid destination array type (%d)\n", dst_type);
|
|
245
647
|
return 0;
|
|
246
648
|
}
|
|
247
649
|
|
|
248
650
|
if (src_ndim != dst_ndim)
|
|
249
651
|
{
|
|
250
|
-
fprintf(stderr, "Warp error: Incompatible array dimensionalities (%d and %d)\n", src_ndim, dst_ndim);
|
|
652
|
+
fprintf(stderr, "Warp copy error: Incompatible array dimensionalities (%d and %d)\n", src_ndim, dst_ndim);
|
|
251
653
|
return 0;
|
|
252
654
|
}
|
|
253
655
|
|
|
254
|
-
|
|
255
|
-
|
|
656
|
+
// handle fabric arrays
|
|
657
|
+
if (dst_fabricarray)
|
|
658
|
+
{
|
|
659
|
+
size_t n = dst_fabricarray->size;
|
|
660
|
+
if (src_fabricarray)
|
|
661
|
+
{
|
|
662
|
+
// copy from fabric to fabric
|
|
663
|
+
if (src_fabricarray->size != n)
|
|
664
|
+
{
|
|
665
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
666
|
+
return 0;
|
|
667
|
+
}
|
|
668
|
+
array_copy_fabric_to_fabric(*dst_fabricarray, *src_fabricarray, elem_size);
|
|
669
|
+
return n;
|
|
670
|
+
}
|
|
671
|
+
else if (src_indexedfabricarray)
|
|
672
|
+
{
|
|
673
|
+
// copy from fabric indexed to fabric
|
|
674
|
+
if (src_indexedfabricarray->size != n)
|
|
675
|
+
{
|
|
676
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
677
|
+
return 0;
|
|
678
|
+
}
|
|
679
|
+
array_copy_fabric_indexed_to_fabric(*dst_fabricarray, *src_indexedfabricarray, elem_size);
|
|
680
|
+
return n;
|
|
681
|
+
}
|
|
682
|
+
else
|
|
683
|
+
{
|
|
684
|
+
// copy to fabric
|
|
685
|
+
if (size_t(src_shape[0]) != n)
|
|
686
|
+
{
|
|
687
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
688
|
+
return 0;
|
|
689
|
+
}
|
|
690
|
+
array_copy_to_fabric(*dst_fabricarray, src_data, src_strides[0], src_indices[0], elem_size);
|
|
691
|
+
return n;
|
|
692
|
+
}
|
|
693
|
+
}
|
|
694
|
+
else if (dst_indexedfabricarray)
|
|
695
|
+
{
|
|
696
|
+
size_t n = dst_indexedfabricarray->size;
|
|
697
|
+
if (src_fabricarray)
|
|
698
|
+
{
|
|
699
|
+
// copy from fabric to fabric indexed
|
|
700
|
+
if (src_fabricarray->size != n)
|
|
701
|
+
{
|
|
702
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
703
|
+
return 0;
|
|
704
|
+
}
|
|
705
|
+
array_copy_fabric_to_fabric_indexed(*dst_indexedfabricarray, *src_fabricarray, elem_size);
|
|
706
|
+
return n;
|
|
707
|
+
}
|
|
708
|
+
else if (src_indexedfabricarray)
|
|
709
|
+
{
|
|
710
|
+
// copy from fabric indexed to fabric indexed
|
|
711
|
+
if (src_indexedfabricarray->size != n)
|
|
712
|
+
{
|
|
713
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
714
|
+
return 0;
|
|
715
|
+
}
|
|
716
|
+
array_copy_fabric_indexed_to_fabric_indexed(*dst_indexedfabricarray, *src_indexedfabricarray, elem_size);
|
|
717
|
+
return n;
|
|
718
|
+
}
|
|
719
|
+
else
|
|
720
|
+
{
|
|
721
|
+
// copy to fabric indexed
|
|
722
|
+
if (size_t(src_shape[0]) != n)
|
|
723
|
+
{
|
|
724
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
725
|
+
return 0;
|
|
726
|
+
}
|
|
727
|
+
array_copy_to_fabric_indexed(*dst_indexedfabricarray, src_data, src_strides[0], src_indices[0], elem_size);
|
|
728
|
+
return n;
|
|
729
|
+
}
|
|
730
|
+
}
|
|
731
|
+
else if (src_fabricarray)
|
|
732
|
+
{
|
|
733
|
+
// copy from fabric
|
|
734
|
+
size_t n = src_fabricarray->size;
|
|
735
|
+
if (size_t(dst_shape[0]) != n)
|
|
736
|
+
{
|
|
737
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
738
|
+
return 0;
|
|
739
|
+
}
|
|
740
|
+
array_copy_from_fabric(*src_fabricarray, dst_data, dst_strides[0], dst_indices[0], elem_size);
|
|
741
|
+
return n;
|
|
742
|
+
}
|
|
743
|
+
else if (src_indexedfabricarray)
|
|
744
|
+
{
|
|
745
|
+
// copy from fabric indexed
|
|
746
|
+
size_t n = src_indexedfabricarray->size;
|
|
747
|
+
if (size_t(dst_shape[0]) != n)
|
|
748
|
+
{
|
|
749
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
750
|
+
return 0;
|
|
751
|
+
}
|
|
752
|
+
array_copy_from_fabric_indexed(*src_indexedfabricarray, dst_data, dst_strides[0], dst_indices[0], elem_size);
|
|
753
|
+
return n;
|
|
754
|
+
}
|
|
256
755
|
|
|
756
|
+
size_t n = 1;
|
|
257
757
|
for (int i = 0; i < src_ndim; i++)
|
|
258
758
|
{
|
|
259
759
|
if (src_shape[i] != dst_shape[i])
|
|
260
760
|
{
|
|
261
|
-
fprintf(stderr, "Warp error: Incompatible array shapes\n");
|
|
761
|
+
fprintf(stderr, "Warp copy error: Incompatible array shapes\n");
|
|
262
762
|
return 0;
|
|
263
763
|
}
|
|
264
764
|
n *= src_shape[i];
|
|
@@ -269,15 +769,111 @@ WP_API size_t array_copy_host(void* dst, void* src, int dst_type, int src_type,
|
|
|
269
769
|
dst_indices, src_indices,
|
|
270
770
|
src_shape, src_ndim, elem_size);
|
|
271
771
|
|
|
272
|
-
|
|
772
|
+
return n;
|
|
773
|
+
}
|
|
774
|
+
|
|
775
|
+
|
|
776
|
+
static void array_fill_strided(void* data, const int* shape, const int* strides, int ndim, const void* value, int value_size)
|
|
777
|
+
{
|
|
778
|
+
if (ndim == 1)
|
|
779
|
+
{
|
|
780
|
+
char* p = (char*)data;
|
|
781
|
+
for (int i = 0; i < shape[0]; i++)
|
|
782
|
+
{
|
|
783
|
+
memcpy(p, value, value_size);
|
|
784
|
+
p += strides[0];
|
|
785
|
+
}
|
|
786
|
+
}
|
|
787
|
+
else
|
|
788
|
+
{
|
|
789
|
+
for (int i = 0; i < shape[0]; i++)
|
|
790
|
+
{
|
|
791
|
+
char* p = (char*)data + i * strides[0];
|
|
792
|
+
// recurse on next inner dimension
|
|
793
|
+
array_fill_strided(p, shape + 1, strides + 1, ndim - 1, value, value_size);
|
|
794
|
+
}
|
|
795
|
+
}
|
|
796
|
+
}
|
|
797
|
+
|
|
798
|
+
|
|
799
|
+
static void array_fill_indexed(void* data, const int* shape, const int* strides, const int*const* indices, int ndim, const void* value, int value_size)
|
|
800
|
+
{
|
|
801
|
+
if (ndim == 1)
|
|
273
802
|
{
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
803
|
+
for (int i = 0; i < shape[0]; i++)
|
|
804
|
+
{
|
|
805
|
+
int idx = indices[0] ? indices[0][i] : i;
|
|
806
|
+
char* p = (char*)data + idx * strides[0];
|
|
807
|
+
memcpy(p, value, value_size);
|
|
808
|
+
}
|
|
278
809
|
}
|
|
810
|
+
else
|
|
811
|
+
{
|
|
812
|
+
for (int i = 0; i < shape[0]; i++)
|
|
813
|
+
{
|
|
814
|
+
int idx = indices[0] ? indices[0][i] : i;
|
|
815
|
+
char* p = (char*)data + idx * strides[0];
|
|
816
|
+
// recurse on next inner dimension
|
|
817
|
+
array_fill_indexed(p, shape + 1, strides + 1, indices + 1, ndim - 1, value, value_size);
|
|
818
|
+
}
|
|
819
|
+
}
|
|
820
|
+
}
|
|
279
821
|
|
|
280
|
-
|
|
822
|
+
|
|
823
|
+
static void array_fill_fabric(wp::fabricarray_t<void>& fa, const void* value_ptr, int value_size)
|
|
824
|
+
{
|
|
825
|
+
for (size_t i = 0; i < fa.nbuckets; i++)
|
|
826
|
+
{
|
|
827
|
+
const wp::fabricbucket_t& bucket = fa.buckets[i];
|
|
828
|
+
size_t bucket_size = bucket.index_end - bucket.index_start;
|
|
829
|
+
memtile_host(bucket.ptr, value_ptr, value_size, bucket_size);
|
|
830
|
+
}
|
|
831
|
+
}
|
|
832
|
+
|
|
833
|
+
|
|
834
|
+
static void array_fill_fabric_indexed(wp::indexedfabricarray_t<void>& ifa, const void* value_ptr, int value_size)
|
|
835
|
+
{
|
|
836
|
+
for (size_t i = 0; i < ifa.size; i++)
|
|
837
|
+
{
|
|
838
|
+
size_t idx = size_t(ifa.indices[i]);
|
|
839
|
+
if (idx < ifa.fa.size)
|
|
840
|
+
{
|
|
841
|
+
void* p = fabricarray_element_ptr(ifa.fa, idx, value_size);
|
|
842
|
+
memcpy(p, value_ptr, value_size);
|
|
843
|
+
}
|
|
844
|
+
}
|
|
845
|
+
}
|
|
846
|
+
|
|
847
|
+
|
|
848
|
+
WP_API void array_fill_host(void* arr_ptr, int arr_type, const void* value_ptr, int value_size)
|
|
849
|
+
{
|
|
850
|
+
if (!arr_ptr || !value_ptr)
|
|
851
|
+
return;
|
|
852
|
+
|
|
853
|
+
if (arr_type == wp::ARRAY_TYPE_REGULAR)
|
|
854
|
+
{
|
|
855
|
+
wp::array_t<void>& arr = *static_cast<wp::array_t<void>*>(arr_ptr);
|
|
856
|
+
array_fill_strided(arr.data, arr.shape.dims, arr.strides, arr.ndim, value_ptr, value_size);
|
|
857
|
+
}
|
|
858
|
+
else if (arr_type == wp::ARRAY_TYPE_INDEXED)
|
|
859
|
+
{
|
|
860
|
+
wp::indexedarray_t<void>& ia = *static_cast<wp::indexedarray_t<void>*>(arr_ptr);
|
|
861
|
+
array_fill_indexed(ia.arr.data, ia.shape.dims, ia.arr.strides, ia.indices, ia.arr.ndim, value_ptr, value_size);
|
|
862
|
+
}
|
|
863
|
+
else if (arr_type == wp::ARRAY_TYPE_FABRIC)
|
|
864
|
+
{
|
|
865
|
+
wp::fabricarray_t<void>& fa = *static_cast<wp::fabricarray_t<void>*>(arr_ptr);
|
|
866
|
+
array_fill_fabric(fa, value_ptr, value_size);
|
|
867
|
+
}
|
|
868
|
+
else if (arr_type == wp::ARRAY_TYPE_FABRIC_INDEXED)
|
|
869
|
+
{
|
|
870
|
+
wp::indexedfabricarray_t<void>& ifa = *static_cast<wp::indexedfabricarray_t<void>*>(arr_ptr);
|
|
871
|
+
array_fill_fabric_indexed(ifa, value_ptr, value_size);
|
|
872
|
+
}
|
|
873
|
+
else
|
|
874
|
+
{
|
|
875
|
+
fprintf(stderr, "Warp fill error: Invalid array type id %d\n", arr_type);
|
|
876
|
+
}
|
|
281
877
|
}
|
|
282
878
|
|
|
283
879
|
|
|
@@ -334,7 +930,7 @@ void memset_device(void* context, void* dest, int value, size_t n)
|
|
|
334
930
|
{
|
|
335
931
|
}
|
|
336
932
|
|
|
337
|
-
void memtile_device(void* context, void* dest, void
|
|
933
|
+
void memtile_device(void* context, void* dest, const void* src, size_t srcsize, size_t n)
|
|
338
934
|
{
|
|
339
935
|
}
|
|
340
936
|
|
|
@@ -343,8 +939,13 @@ size_t array_copy_device(void* context, void* dst, void* src, int dst_type, int
|
|
|
343
939
|
return 0;
|
|
344
940
|
}
|
|
345
941
|
|
|
942
|
+
void array_fill_device(void* context, void* arr, int arr_type, const void* value, int value_size)
|
|
943
|
+
{
|
|
944
|
+
}
|
|
945
|
+
|
|
346
946
|
WP_API int cuda_driver_version() { return 0; }
|
|
347
947
|
WP_API int cuda_toolkit_version() { return 0; }
|
|
948
|
+
WP_API bool cuda_driver_is_initialized() { return false; }
|
|
348
949
|
|
|
349
950
|
WP_API int nvrtc_supported_arch_count() { return 0; }
|
|
350
951
|
WP_API void nvrtc_supported_archs(int* archs) {}
|
|
@@ -354,7 +955,12 @@ WP_API void* cuda_device_primary_context_retain(int ordinal) { return NULL; }
|
|
|
354
955
|
WP_API void cuda_device_primary_context_release(int ordinal) {}
|
|
355
956
|
WP_API const char* cuda_device_get_name(int ordinal) { return NULL; }
|
|
356
957
|
WP_API int cuda_device_get_arch(int ordinal) { return 0; }
|
|
958
|
+
WP_API void cuda_device_get_uuid(int ordinal, char uuid[16]) {}
|
|
959
|
+
WP_API int cuda_device_get_pci_domain_id(int ordinal) { return -1; }
|
|
960
|
+
WP_API int cuda_device_get_pci_bus_id(int ordinal) { return -1; }
|
|
961
|
+
WP_API int cuda_device_get_pci_device_id(int ordinal) { return -1; }
|
|
357
962
|
WP_API int cuda_device_is_uva(int ordinal) { return 0; }
|
|
963
|
+
WP_API int cuda_device_is_memory_pool_supported() { return 0; }
|
|
358
964
|
|
|
359
965
|
WP_API void* cuda_context_get_current() { return NULL; }
|
|
360
966
|
WP_API void cuda_context_set_current(void* ctx) {}
|
|
@@ -366,6 +972,7 @@ WP_API void cuda_context_synchronize(void* context) {}
|
|
|
366
972
|
WP_API uint64_t cuda_context_check(void* context) { return 0; }
|
|
367
973
|
WP_API int cuda_context_get_device_ordinal(void* context) { return -1; }
|
|
368
974
|
WP_API int cuda_context_is_primary(void* context) { return 0; }
|
|
975
|
+
WP_API int cuda_context_is_memory_pool_supported(void* context) { return 0; }
|
|
369
976
|
WP_API void* cuda_context_get_stream(void* context) { return NULL; }
|
|
370
977
|
WP_API void cuda_context_set_stream(void* context, void* stream) {}
|
|
371
978
|
WP_API int cuda_context_can_access_peer(void* context, void* peer_context) { return 0; }
|
|
@@ -392,13 +999,11 @@ WP_API size_t cuda_compile_program(const char* cuda_src, int arch, const char* i
|
|
|
392
999
|
WP_API void* cuda_load_module(void* context, const char* ptx) { return NULL; }
|
|
393
1000
|
WP_API void cuda_unload_module(void* context, void* module) {}
|
|
394
1001
|
WP_API void* cuda_get_kernel(void* context, void* module, const char* name) { return NULL; }
|
|
395
|
-
WP_API size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, void** args) { return 0;}
|
|
1002
|
+
WP_API size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_blocks, void** args) { return 0;}
|
|
396
1003
|
|
|
397
1004
|
WP_API void cuda_set_context_restore_policy(bool always_restore) {}
|
|
398
1005
|
WP_API int cuda_get_context_restore_policy() { return false; }
|
|
399
1006
|
|
|
400
|
-
WP_API void array_inner_device(uint64_t a, uint64_t b, uint64_t out, int len) {}
|
|
401
|
-
WP_API void array_sum_device(uint64_t a, uint64_t out, int len) {}
|
|
402
1007
|
WP_API void array_scan_int_device(uint64_t in, uint64_t out, int len, bool inclusive) {}
|
|
403
1008
|
WP_API void array_scan_float_device(uint64_t in, uint64_t out, int len, bool inclusive) {}
|
|
404
1009
|
|