warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +10 -4
- warp/__init__.pyi +1 -0
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +5 -3
- warp/build_dll.py +29 -9
- warp/builtins.py +868 -507
- warp/codegen.py +1074 -638
- warp/config.py +3 -3
- warp/constants.py +6 -0
- warp/context.py +715 -222
- warp/fabric.py +326 -0
- warp/fem/__init__.py +27 -0
- warp/fem/cache.py +389 -0
- warp/fem/dirichlet.py +181 -0
- warp/fem/domain.py +263 -0
- warp/fem/field/__init__.py +101 -0
- warp/fem/field/field.py +149 -0
- warp/fem/field/nodal_field.py +299 -0
- warp/fem/field/restriction.py +21 -0
- warp/fem/field/test.py +181 -0
- warp/fem/field/trial.py +183 -0
- warp/fem/geometry/__init__.py +19 -0
- warp/fem/geometry/closest_point.py +70 -0
- warp/fem/geometry/deformed_geometry.py +271 -0
- warp/fem/geometry/element.py +744 -0
- warp/fem/geometry/geometry.py +186 -0
- warp/fem/geometry/grid_2d.py +373 -0
- warp/fem/geometry/grid_3d.py +435 -0
- warp/fem/geometry/hexmesh.py +953 -0
- warp/fem/geometry/partition.py +376 -0
- warp/fem/geometry/quadmesh_2d.py +532 -0
- warp/fem/geometry/tetmesh.py +840 -0
- warp/fem/geometry/trimesh_2d.py +577 -0
- warp/fem/integrate.py +1616 -0
- warp/fem/operator.py +191 -0
- warp/fem/polynomial.py +213 -0
- warp/fem/quadrature/__init__.py +2 -0
- warp/fem/quadrature/pic_quadrature.py +245 -0
- warp/fem/quadrature/quadrature.py +294 -0
- warp/fem/space/__init__.py +292 -0
- warp/fem/space/basis_space.py +489 -0
- warp/fem/space/collocated_function_space.py +105 -0
- warp/fem/space/dof_mapper.py +236 -0
- warp/fem/space/function_space.py +145 -0
- warp/fem/space/grid_2d_function_space.py +267 -0
- warp/fem/space/grid_3d_function_space.py +306 -0
- warp/fem/space/hexmesh_function_space.py +352 -0
- warp/fem/space/partition.py +350 -0
- warp/fem/space/quadmesh_2d_function_space.py +369 -0
- warp/fem/space/restriction.py +160 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +738 -0
- warp/fem/space/shape/shape_function.py +103 -0
- warp/fem/space/shape/square_shape_function.py +611 -0
- warp/fem/space/shape/tet_shape_function.py +567 -0
- warp/fem/space/shape/triangle_shape_function.py +429 -0
- warp/fem/space/tetmesh_function_space.py +292 -0
- warp/fem/space/topology.py +295 -0
- warp/fem/space/trimesh_2d_function_space.py +221 -0
- warp/fem/types.py +77 -0
- warp/fem/utils.py +495 -0
- warp/native/array.h +147 -44
- warp/native/builtin.h +122 -149
- warp/native/bvh.cpp +73 -325
- warp/native/bvh.cu +406 -23
- warp/native/bvh.h +34 -43
- warp/native/clang/clang.cpp +13 -8
- warp/native/crt.h +2 -0
- warp/native/cuda_crt.h +5 -0
- warp/native/cuda_util.cpp +15 -3
- warp/native/cuda_util.h +3 -1
- warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
- warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
- warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
- warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
- warp/native/cutlass/tools/library/scripts/library.py +799 -0
- warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
- warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
- warp/native/cutlass/tools/library/scripts/rt.py +796 -0
- warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
- warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
- warp/native/cutlass_gemm.cu +5 -3
- warp/native/exports.h +1240 -952
- warp/native/fabric.h +228 -0
- warp/native/hashgrid.cpp +4 -4
- warp/native/hashgrid.h +22 -2
- warp/native/intersect.h +22 -7
- warp/native/intersect_adj.h +8 -8
- warp/native/intersect_tri.h +1 -1
- warp/native/marching.cu +157 -161
- warp/native/mat.h +80 -19
- warp/native/matnn.h +2 -2
- warp/native/mesh.cpp +33 -108
- warp/native/mesh.cu +114 -23
- warp/native/mesh.h +446 -46
- warp/native/noise.h +272 -329
- warp/native/quat.h +51 -8
- warp/native/rand.h +45 -35
- warp/native/range.h +6 -2
- warp/native/reduce.cpp +1 -1
- warp/native/reduce.cu +10 -12
- warp/native/runlength_encode.cu +6 -10
- warp/native/scan.cu +8 -11
- warp/native/sparse.cpp +4 -4
- warp/native/sparse.cu +164 -154
- warp/native/spatial.h +2 -2
- warp/native/temp_buffer.h +14 -30
- warp/native/vec.h +107 -23
- warp/native/volume.h +120 -0
- warp/native/warp.cpp +560 -30
- warp/native/warp.cu +431 -44
- warp/native/warp.h +13 -4
- warp/optim/__init__.py +1 -0
- warp/optim/linear.py +922 -0
- warp/optim/sgd.py +92 -0
- warp/render/render_opengl.py +335 -119
- warp/render/render_usd.py +11 -11
- warp/sim/__init__.py +2 -2
- warp/sim/articulation.py +385 -185
- warp/sim/collide.py +8 -0
- warp/sim/import_mjcf.py +297 -106
- warp/sim/import_urdf.py +389 -210
- warp/sim/import_usd.py +198 -97
- warp/sim/inertia.py +17 -18
- warp/sim/integrator_euler.py +14 -8
- warp/sim/integrator_xpbd.py +158 -16
- warp/sim/model.py +795 -291
- warp/sim/render.py +3 -3
- warp/sim/utils.py +3 -0
- warp/sparse.py +640 -150
- warp/stubs.py +606 -267
- warp/tape.py +61 -10
- warp/tests/__main__.py +3 -6
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/disabled_kinematics.py +239 -0
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +103 -106
- warp/tests/test_arithmetic.py +128 -74
- warp/tests/test_array.py +212 -97
- warp/tests/test_array_reduce.py +57 -23
- warp/tests/test_atomic.py +64 -28
- warp/tests/test_bool.py +99 -0
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +42 -18
- warp/tests/test_closest_point_edge_edge.py +54 -57
- warp/tests/test_codegen.py +208 -130
- warp/tests/test_compile_consts.py +28 -20
- warp/tests/test_conditional.py +108 -24
- warp/tests/test_copy.py +10 -12
- warp/tests/test_ctypes.py +112 -88
- warp/tests/test_dense.py +21 -14
- warp/tests/test_devices.py +98 -0
- warp/tests/test_dlpack.py +75 -75
- warp/tests/test_examples.py +277 -0
- warp/tests/test_fabricarray.py +955 -0
- warp/tests/test_fast_math.py +15 -11
- warp/tests/test_fem.py +1271 -0
- warp/tests/test_fp16.py +53 -19
- warp/tests/test_func.py +187 -86
- warp/tests/test_generics.py +194 -49
- warp/tests/test_grad.py +178 -109
- warp/tests/test_grad_customs.py +176 -0
- warp/tests/test_hash_grid.py +52 -37
- warp/tests/test_import.py +10 -23
- warp/tests/test_indexedarray.py +32 -31
- warp/tests/test_intersect.py +18 -9
- warp/tests/test_large.py +141 -0
- warp/tests/test_launch.py +14 -41
- warp/tests/test_lerp.py +64 -65
- warp/tests/test_linear_solvers.py +154 -0
- warp/tests/test_lvalue.py +493 -0
- warp/tests/test_marching_cubes.py +12 -13
- warp/tests/test_mat.py +517 -2898
- warp/tests/test_mat_lite.py +115 -0
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +103 -9
- warp/tests/test_matmul.py +305 -69
- warp/tests/test_matmul_lite.py +410 -0
- warp/tests/test_mesh.py +71 -14
- warp/tests/test_mesh_query_aabb.py +41 -25
- warp/tests/test_mesh_query_point.py +140 -22
- warp/tests/test_mesh_query_ray.py +39 -22
- warp/tests/test_mlp.py +30 -22
- warp/tests/test_model.py +92 -89
- warp/tests/test_modules_lite.py +39 -0
- warp/tests/test_multigpu.py +88 -114
- warp/tests/test_noise.py +12 -11
- warp/tests/test_operators.py +16 -20
- warp/tests/test_options.py +11 -11
- warp/tests/test_pinned.py +17 -18
- warp/tests/test_print.py +32 -11
- warp/tests/test_quat.py +275 -129
- warp/tests/test_rand.py +18 -16
- warp/tests/test_reload.py +38 -34
- warp/tests/test_rounding.py +50 -43
- warp/tests/test_runlength_encode.py +168 -20
- warp/tests/test_smoothstep.py +9 -11
- warp/tests/test_snippet.py +143 -0
- warp/tests/test_sparse.py +261 -63
- warp/tests/test_spatial.py +276 -243
- warp/tests/test_streams.py +110 -85
- warp/tests/test_struct.py +268 -63
- warp/tests/test_tape.py +39 -21
- warp/tests/test_torch.py +118 -89
- warp/tests/test_transient_module.py +12 -13
- warp/tests/test_types.py +614 -0
- warp/tests/test_utils.py +494 -0
- warp/tests/test_vec.py +354 -2050
- warp/tests/test_vec_lite.py +73 -0
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +457 -293
- warp/tests/test_volume_write.py +124 -134
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +341 -0
- warp/tests/unittest_utils.py +568 -0
- warp/tests/unused_test_misc.py +71 -0
- warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
- warp/thirdparty/appdirs.py +36 -45
- warp/thirdparty/unittest_parallel.py +549 -0
- warp/torch.py +9 -6
- warp/types.py +1089 -366
- warp/utils.py +93 -387
- warp_lang-0.11.0.dist-info/METADATA +238 -0
- warp_lang-0.11.0.dist-info/RECORD +332 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
- warp/tests/test_all.py +0 -219
- warp/tests/test_array_scan.py +0 -60
- warp/tests/test_base.py +0 -208
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- warp_lang-0.10.1.dist-info/METADATA +0 -21
- warp_lang-0.10.1.dist-info/RECORD +0 -188
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/native/warp.cpp
CHANGED
|
@@ -10,33 +10,99 @@
|
|
|
10
10
|
#include "scan.h"
|
|
11
11
|
#include "array.h"
|
|
12
12
|
|
|
13
|
+
#include "exports.h"
|
|
14
|
+
|
|
13
15
|
#include "stdlib.h"
|
|
14
16
|
#include "string.h"
|
|
15
17
|
|
|
18
|
+
int cuda_init();
|
|
16
19
|
|
|
17
|
-
namespace wp
|
|
18
|
-
{
|
|
19
20
|
|
|
20
|
-
|
|
21
|
+
uint16_t float_to_half_bits(float x)
|
|
21
22
|
{
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
23
|
+
// adapted from Fabien Giesen's post: https://gist.github.com/rygorous/2156668
|
|
24
|
+
union fp32
|
|
25
|
+
{
|
|
26
|
+
uint32_t u;
|
|
27
|
+
float f;
|
|
26
28
|
|
|
27
|
-
|
|
29
|
+
struct
|
|
30
|
+
{
|
|
31
|
+
unsigned int mantissa : 23;
|
|
32
|
+
unsigned int exponent : 8;
|
|
33
|
+
unsigned int sign : 1;
|
|
34
|
+
};
|
|
35
|
+
};
|
|
36
|
+
|
|
37
|
+
fp32 f;
|
|
38
|
+
f.f = x;
|
|
39
|
+
|
|
40
|
+
fp32 f32infty = { 255 << 23 };
|
|
41
|
+
fp32 f16infty = { 31 << 23 };
|
|
42
|
+
fp32 magic = { 15 << 23 };
|
|
43
|
+
uint32_t sign_mask = 0x80000000u;
|
|
44
|
+
uint32_t round_mask = ~0xfffu;
|
|
45
|
+
uint16_t u;
|
|
46
|
+
|
|
47
|
+
uint32_t sign = f.u & sign_mask;
|
|
48
|
+
f.u ^= sign;
|
|
49
|
+
|
|
50
|
+
// NOTE all the integer compares in this function can be safely
|
|
51
|
+
// compiled into signed compares since all operands are below
|
|
52
|
+
// 0x80000000. Important if you want fast straight SSE2 code
|
|
53
|
+
// (since there's no unsigned PCMPGTD).
|
|
54
|
+
|
|
55
|
+
if (f.u >= f32infty.u) // Inf or NaN (all exponent bits set)
|
|
56
|
+
u = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf
|
|
57
|
+
else // (De)normalized number or zero
|
|
58
|
+
{
|
|
59
|
+
f.u &= round_mask;
|
|
60
|
+
f.f *= magic.f;
|
|
61
|
+
f.u -= round_mask;
|
|
62
|
+
if (f.u > f16infty.u) f.u = f16infty.u; // Clamp to signed infinity if overflowed
|
|
28
63
|
|
|
64
|
+
u = f.u >> 13; // Take the bits!
|
|
65
|
+
}
|
|
29
66
|
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
return wp::half(x).u;
|
|
67
|
+
u |= sign >> 16;
|
|
68
|
+
return u;
|
|
33
69
|
}
|
|
34
70
|
|
|
35
71
|
float half_bits_to_float(uint16_t u)
|
|
36
72
|
{
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
73
|
+
// adapted from Fabien Giesen's post: https://gist.github.com/rygorous/2156668
|
|
74
|
+
union fp32
|
|
75
|
+
{
|
|
76
|
+
uint32_t u;
|
|
77
|
+
float f;
|
|
78
|
+
|
|
79
|
+
struct
|
|
80
|
+
{
|
|
81
|
+
unsigned int mantissa : 23;
|
|
82
|
+
unsigned int exponent : 8;
|
|
83
|
+
unsigned int sign : 1;
|
|
84
|
+
};
|
|
85
|
+
};
|
|
86
|
+
|
|
87
|
+
static const fp32 magic = { 113 << 23 };
|
|
88
|
+
static const uint32_t shifted_exp = 0x7c00 << 13; // exponent mask after shift
|
|
89
|
+
fp32 o;
|
|
90
|
+
|
|
91
|
+
o.u = (u & 0x7fff) << 13; // exponent/mantissa bits
|
|
92
|
+
uint32_t exp = shifted_exp & o.u; // just the exponent
|
|
93
|
+
o.u += (127 - 15) << 23; // exponent adjust
|
|
94
|
+
|
|
95
|
+
// handle exponent special cases
|
|
96
|
+
if (exp == shifted_exp) // Inf/NaN?
|
|
97
|
+
o.u += (128 - 16) << 23; // extra exp adjust
|
|
98
|
+
else if (exp == 0) // Zero/Denormal?
|
|
99
|
+
{
|
|
100
|
+
o.u += 1 << 23; // extra exp adjust
|
|
101
|
+
o.f -= magic.f; // renormalize
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
o.u |= (u & 0x8000) << 16; // sign bit
|
|
105
|
+
return o.f;
|
|
40
106
|
}
|
|
41
107
|
|
|
42
108
|
int init()
|
|
@@ -179,6 +245,312 @@ static void array_copy_nd(void* dst, const void* src,
|
|
|
179
245
|
}
|
|
180
246
|
|
|
181
247
|
|
|
248
|
+
static void array_copy_to_fabric(wp::fabricarray_t<void>& dst, const void* src_data,
|
|
249
|
+
int src_stride, const int* src_indices, int elem_size)
|
|
250
|
+
{
|
|
251
|
+
const int8_t* src_ptr = static_cast<const int8_t*>(src_data);
|
|
252
|
+
|
|
253
|
+
if (src_indices)
|
|
254
|
+
{
|
|
255
|
+
// copy from indexed array
|
|
256
|
+
for (size_t i = 0; i < dst.nbuckets; i++)
|
|
257
|
+
{
|
|
258
|
+
const wp::fabricbucket_t& bucket = dst.buckets[i];
|
|
259
|
+
int8_t* dst_ptr = static_cast<int8_t*>(bucket.ptr);
|
|
260
|
+
size_t bucket_size = bucket.index_end - bucket.index_start;
|
|
261
|
+
for (size_t j = 0; j < bucket_size; j++)
|
|
262
|
+
{
|
|
263
|
+
int idx = *src_indices;
|
|
264
|
+
memcpy(dst_ptr, src_ptr + idx * elem_size, elem_size);
|
|
265
|
+
dst_ptr += elem_size;
|
|
266
|
+
++src_indices;
|
|
267
|
+
}
|
|
268
|
+
}
|
|
269
|
+
}
|
|
270
|
+
else
|
|
271
|
+
{
|
|
272
|
+
if (src_stride == elem_size)
|
|
273
|
+
{
|
|
274
|
+
// copy from contiguous array
|
|
275
|
+
for (size_t i = 0; i < dst.nbuckets; i++)
|
|
276
|
+
{
|
|
277
|
+
const wp::fabricbucket_t& bucket = dst.buckets[i];
|
|
278
|
+
size_t num_bytes = (bucket.index_end - bucket.index_start) * elem_size;
|
|
279
|
+
memcpy(bucket.ptr, src_ptr, num_bytes);
|
|
280
|
+
src_ptr += num_bytes;
|
|
281
|
+
}
|
|
282
|
+
}
|
|
283
|
+
else
|
|
284
|
+
{
|
|
285
|
+
// copy from strided array
|
|
286
|
+
for (size_t i = 0; i < dst.nbuckets; i++)
|
|
287
|
+
{
|
|
288
|
+
const wp::fabricbucket_t& bucket = dst.buckets[i];
|
|
289
|
+
int8_t* dst_ptr = static_cast<int8_t*>(bucket.ptr);
|
|
290
|
+
size_t bucket_size = bucket.index_end - bucket.index_start;
|
|
291
|
+
for (size_t j = 0; j < bucket_size; j++)
|
|
292
|
+
{
|
|
293
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
294
|
+
src_ptr += src_stride;
|
|
295
|
+
dst_ptr += elem_size;
|
|
296
|
+
}
|
|
297
|
+
}
|
|
298
|
+
}
|
|
299
|
+
}
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
static void array_copy_from_fabric(const wp::fabricarray_t<void>& src, void* dst_data,
|
|
303
|
+
int dst_stride, const int* dst_indices, int elem_size)
|
|
304
|
+
{
|
|
305
|
+
int8_t* dst_ptr = static_cast<int8_t*>(dst_data);
|
|
306
|
+
|
|
307
|
+
if (dst_indices)
|
|
308
|
+
{
|
|
309
|
+
// copy to indexed array
|
|
310
|
+
for (size_t i = 0; i < src.nbuckets; i++)
|
|
311
|
+
{
|
|
312
|
+
const wp::fabricbucket_t& bucket = src.buckets[i];
|
|
313
|
+
const int8_t* src_ptr = static_cast<const int8_t*>(bucket.ptr);
|
|
314
|
+
size_t bucket_size = bucket.index_end - bucket.index_start;
|
|
315
|
+
for (size_t j = 0; j < bucket_size; j++)
|
|
316
|
+
{
|
|
317
|
+
int idx = *dst_indices;
|
|
318
|
+
memcpy(dst_ptr + idx * elem_size, src_ptr, elem_size);
|
|
319
|
+
src_ptr += elem_size;
|
|
320
|
+
++dst_indices;
|
|
321
|
+
}
|
|
322
|
+
}
|
|
323
|
+
}
|
|
324
|
+
else
|
|
325
|
+
{
|
|
326
|
+
if (dst_stride == elem_size)
|
|
327
|
+
{
|
|
328
|
+
// copy to contiguous array
|
|
329
|
+
for (size_t i = 0; i < src.nbuckets; i++)
|
|
330
|
+
{
|
|
331
|
+
const wp::fabricbucket_t& bucket = src.buckets[i];
|
|
332
|
+
size_t num_bytes = (bucket.index_end - bucket.index_start) * elem_size;
|
|
333
|
+
memcpy(dst_ptr, bucket.ptr, num_bytes);
|
|
334
|
+
dst_ptr += num_bytes;
|
|
335
|
+
}
|
|
336
|
+
}
|
|
337
|
+
else
|
|
338
|
+
{
|
|
339
|
+
// copy to strided array
|
|
340
|
+
for (size_t i = 0; i < src.nbuckets; i++)
|
|
341
|
+
{
|
|
342
|
+
const wp::fabricbucket_t& bucket = src.buckets[i];
|
|
343
|
+
const int8_t* src_ptr = static_cast<const int8_t*>(bucket.ptr);
|
|
344
|
+
size_t bucket_size = bucket.index_end - bucket.index_start;
|
|
345
|
+
for (size_t j = 0; j < bucket_size; j++)
|
|
346
|
+
{
|
|
347
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
348
|
+
dst_ptr += dst_stride;
|
|
349
|
+
src_ptr += elem_size;
|
|
350
|
+
}
|
|
351
|
+
}
|
|
352
|
+
}
|
|
353
|
+
}
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
static void array_copy_fabric_to_fabric(wp::fabricarray_t<void>& dst, const wp::fabricarray_t<void>& src, int elem_size)
|
|
357
|
+
{
|
|
358
|
+
wp::fabricbucket_t* dst_bucket = dst.buckets;
|
|
359
|
+
const wp::fabricbucket_t* src_bucket = src.buckets;
|
|
360
|
+
int8_t* dst_ptr = static_cast<int8_t*>(dst_bucket->ptr);
|
|
361
|
+
const int8_t* src_ptr = static_cast<const int8_t*>(src_bucket->ptr);
|
|
362
|
+
size_t dst_remaining = dst_bucket->index_end - dst_bucket->index_start;
|
|
363
|
+
size_t src_remaining = src_bucket->index_end - src_bucket->index_start;
|
|
364
|
+
size_t total_copied = 0;
|
|
365
|
+
|
|
366
|
+
while (total_copied < dst.size)
|
|
367
|
+
{
|
|
368
|
+
if (dst_remaining <= src_remaining)
|
|
369
|
+
{
|
|
370
|
+
// copy to destination bucket
|
|
371
|
+
size_t num_elems = dst_remaining;
|
|
372
|
+
size_t num_bytes = num_elems * elem_size;
|
|
373
|
+
memcpy(dst_ptr, src_ptr, num_bytes);
|
|
374
|
+
|
|
375
|
+
// advance to next destination bucket
|
|
376
|
+
++dst_bucket;
|
|
377
|
+
dst_ptr = static_cast<int8_t*>(dst_bucket->ptr);
|
|
378
|
+
dst_remaining = dst_bucket->index_end - dst_bucket->index_start;
|
|
379
|
+
|
|
380
|
+
// advance source offset
|
|
381
|
+
src_ptr += num_bytes;
|
|
382
|
+
src_remaining -= num_elems;
|
|
383
|
+
|
|
384
|
+
total_copied += num_elems;
|
|
385
|
+
}
|
|
386
|
+
else
|
|
387
|
+
{
|
|
388
|
+
// copy to destination bucket
|
|
389
|
+
size_t num_elems = src_remaining;
|
|
390
|
+
size_t num_bytes = num_elems * elem_size;
|
|
391
|
+
memcpy(dst_ptr, src_ptr, num_bytes);
|
|
392
|
+
|
|
393
|
+
// advance to next source bucket
|
|
394
|
+
++src_bucket;
|
|
395
|
+
src_ptr = static_cast<const int8_t*>(src_bucket->ptr);
|
|
396
|
+
src_remaining = src_bucket->index_end - src_bucket->index_start;
|
|
397
|
+
|
|
398
|
+
// advance destination offset
|
|
399
|
+
dst_ptr += num_bytes;
|
|
400
|
+
dst_remaining -= num_elems;
|
|
401
|
+
|
|
402
|
+
total_copied += num_elems;
|
|
403
|
+
}
|
|
404
|
+
}
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
static void array_copy_to_fabric_indexed(wp::indexedfabricarray_t<void>& dst, const void* src_data,
|
|
409
|
+
int src_stride, const int* src_indices, int elem_size)
|
|
410
|
+
{
|
|
411
|
+
const int8_t* src_ptr = static_cast<const int8_t*>(src_data);
|
|
412
|
+
|
|
413
|
+
if (src_indices)
|
|
414
|
+
{
|
|
415
|
+
// copy from indexed array
|
|
416
|
+
for (size_t i = 0; i < dst.size; i++)
|
|
417
|
+
{
|
|
418
|
+
size_t src_idx = src_indices[i];
|
|
419
|
+
size_t dst_idx = dst.indices[i];
|
|
420
|
+
void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_idx, elem_size);
|
|
421
|
+
memcpy(dst_ptr, src_ptr + dst_idx * elem_size, elem_size);
|
|
422
|
+
}
|
|
423
|
+
}
|
|
424
|
+
else
|
|
425
|
+
{
|
|
426
|
+
// copy from contiguous/strided array
|
|
427
|
+
for (size_t i = 0; i < dst.size; i++)
|
|
428
|
+
{
|
|
429
|
+
size_t dst_idx = dst.indices[i];
|
|
430
|
+
void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_idx, elem_size);
|
|
431
|
+
if (dst_ptr)
|
|
432
|
+
{
|
|
433
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
434
|
+
src_ptr += src_stride;
|
|
435
|
+
}
|
|
436
|
+
}
|
|
437
|
+
}
|
|
438
|
+
}
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
static void array_copy_fabric_indexed_to_fabric(wp::fabricarray_t<void>& dst, const wp::indexedfabricarray_t<void>& src, int elem_size)
|
|
442
|
+
{
|
|
443
|
+
wp::fabricbucket_t* dst_bucket = dst.buckets;
|
|
444
|
+
int8_t* dst_ptr = static_cast<int8_t*>(dst_bucket->ptr);
|
|
445
|
+
int8_t* dst_end = dst_ptr + elem_size * (dst_bucket->index_end - dst_bucket->index_start);
|
|
446
|
+
|
|
447
|
+
for (size_t i = 0; i < src.size; i++)
|
|
448
|
+
{
|
|
449
|
+
size_t src_idx = src.indices[i];
|
|
450
|
+
const void* src_ptr = fabricarray_element_ptr(src.fa, src_idx, elem_size);
|
|
451
|
+
|
|
452
|
+
if (dst_ptr >= dst_end)
|
|
453
|
+
{
|
|
454
|
+
// advance to next destination bucket
|
|
455
|
+
++dst_bucket;
|
|
456
|
+
dst_ptr = static_cast<int8_t*>(dst_bucket->ptr);
|
|
457
|
+
dst_end = dst_ptr + elem_size * (dst_bucket->index_end - dst_bucket->index_start);
|
|
458
|
+
}
|
|
459
|
+
|
|
460
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
461
|
+
|
|
462
|
+
dst_ptr += elem_size;
|
|
463
|
+
}
|
|
464
|
+
}
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
static void array_copy_fabric_indexed_to_fabric_indexed(wp::indexedfabricarray_t<void>& dst, const wp::indexedfabricarray_t<void>& src, int elem_size)
|
|
468
|
+
{
|
|
469
|
+
for (size_t i = 0; i < src.size; i++)
|
|
470
|
+
{
|
|
471
|
+
size_t src_idx = src.indices[i];
|
|
472
|
+
size_t dst_idx = dst.indices[i];
|
|
473
|
+
|
|
474
|
+
const void* src_ptr = fabricarray_element_ptr(src.fa, src_idx, elem_size);
|
|
475
|
+
void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_idx, elem_size);
|
|
476
|
+
|
|
477
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
478
|
+
}
|
|
479
|
+
}
|
|
480
|
+
|
|
481
|
+
|
|
482
|
+
static void array_copy_fabric_to_fabric_indexed(wp::indexedfabricarray_t<void>& dst, const wp::fabricarray_t<void>& src, int elem_size)
|
|
483
|
+
{
|
|
484
|
+
wp::fabricbucket_t* src_bucket = src.buckets;
|
|
485
|
+
const int8_t* src_ptr = static_cast<const int8_t*>(src_bucket->ptr);
|
|
486
|
+
const int8_t* src_end = src_ptr + elem_size * (src_bucket->index_end - src_bucket->index_start);
|
|
487
|
+
|
|
488
|
+
for (size_t i = 0; i < dst.size; i++)
|
|
489
|
+
{
|
|
490
|
+
size_t dst_idx = dst.indices[i];
|
|
491
|
+
void* dst_ptr = fabricarray_element_ptr(dst.fa, dst_idx, elem_size);
|
|
492
|
+
|
|
493
|
+
if (src_ptr >= src_end)
|
|
494
|
+
{
|
|
495
|
+
// advance to next source bucket
|
|
496
|
+
++src_bucket;
|
|
497
|
+
src_ptr = static_cast<int8_t*>(src_bucket->ptr);
|
|
498
|
+
src_end = src_ptr + elem_size * (src_bucket->index_end - src_bucket->index_start);
|
|
499
|
+
}
|
|
500
|
+
|
|
501
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
502
|
+
|
|
503
|
+
src_ptr += elem_size;
|
|
504
|
+
}
|
|
505
|
+
}
|
|
506
|
+
|
|
507
|
+
|
|
508
|
+
static void array_copy_from_fabric_indexed(const wp::indexedfabricarray_t<void>& src, void* dst_data,
|
|
509
|
+
int dst_stride, const int* dst_indices, int elem_size)
|
|
510
|
+
{
|
|
511
|
+
int8_t* dst_ptr = static_cast<int8_t*>(dst_data);
|
|
512
|
+
|
|
513
|
+
if (dst_indices)
|
|
514
|
+
{
|
|
515
|
+
// copy to indexed array
|
|
516
|
+
for (size_t i = 0; i < src.size; i++)
|
|
517
|
+
{
|
|
518
|
+
size_t idx = src.indices[i];
|
|
519
|
+
if (idx < src.fa.size)
|
|
520
|
+
{
|
|
521
|
+
const void* src_ptr = fabricarray_element_ptr(src.fa, idx, elem_size);
|
|
522
|
+
int dst_idx = dst_indices[i];
|
|
523
|
+
memcpy(dst_ptr + dst_idx * elem_size, src_ptr, elem_size);
|
|
524
|
+
}
|
|
525
|
+
else
|
|
526
|
+
{
|
|
527
|
+
fprintf(stderr, "Warp copy error: Source index %llu is out of bounds for fabric array of size %llu",
|
|
528
|
+
(unsigned long long)idx, (unsigned long long)src.fa.size);
|
|
529
|
+
}
|
|
530
|
+
}
|
|
531
|
+
}
|
|
532
|
+
else
|
|
533
|
+
{
|
|
534
|
+
// copy to contiguous/strided array
|
|
535
|
+
for (size_t i = 0; i < src.size; i++)
|
|
536
|
+
{
|
|
537
|
+
size_t idx = src.indices[i];
|
|
538
|
+
if (idx < src.fa.size)
|
|
539
|
+
{
|
|
540
|
+
const void* src_ptr = fabricarray_element_ptr(src.fa, idx, elem_size);
|
|
541
|
+
memcpy(dst_ptr, src_ptr, elem_size);
|
|
542
|
+
dst_ptr += dst_stride;
|
|
543
|
+
}
|
|
544
|
+
else
|
|
545
|
+
{
|
|
546
|
+
fprintf(stderr, "Warp copy error: Source index %llu is out of bounds for fabric array of size %llu",
|
|
547
|
+
(unsigned long long)idx, (unsigned long long)src.fa.size);
|
|
548
|
+
}
|
|
549
|
+
}
|
|
550
|
+
}
|
|
551
|
+
}
|
|
552
|
+
|
|
553
|
+
|
|
182
554
|
WP_API size_t array_copy_host(void* dst, void* src, int dst_type, int src_type, int elem_size)
|
|
183
555
|
{
|
|
184
556
|
if (!src || !dst)
|
|
@@ -197,6 +569,12 @@ WP_API size_t array_copy_host(void* dst, void* src, int dst_type, int src_type,
|
|
|
197
569
|
const int*const* src_indices = NULL;
|
|
198
570
|
const int*const* dst_indices = NULL;
|
|
199
571
|
|
|
572
|
+
const wp::fabricarray_t<void>* src_fabricarray = NULL;
|
|
573
|
+
wp::fabricarray_t<void>* dst_fabricarray = NULL;
|
|
574
|
+
|
|
575
|
+
const wp::indexedfabricarray_t<void>* src_indexedfabricarray = NULL;
|
|
576
|
+
wp::indexedfabricarray_t<void>* dst_indexedfabricarray = NULL;
|
|
577
|
+
|
|
200
578
|
const int* null_indices[wp::ARRAY_MAX_DIMS] = { NULL };
|
|
201
579
|
|
|
202
580
|
if (src_type == wp::ARRAY_TYPE_REGULAR)
|
|
@@ -218,9 +596,19 @@ WP_API size_t array_copy_host(void* dst, void* src, int dst_type, int src_type,
|
|
|
218
596
|
src_strides = src_arr.arr.strides;
|
|
219
597
|
src_indices = src_arr.indices;
|
|
220
598
|
}
|
|
599
|
+
else if (src_type == wp::ARRAY_TYPE_FABRIC)
|
|
600
|
+
{
|
|
601
|
+
src_fabricarray = static_cast<const wp::fabricarray_t<void>*>(src);
|
|
602
|
+
src_ndim = 1;
|
|
603
|
+
}
|
|
604
|
+
else if (src_type == wp::ARRAY_TYPE_FABRIC_INDEXED)
|
|
605
|
+
{
|
|
606
|
+
src_indexedfabricarray = static_cast<const wp::indexedfabricarray_t<void>*>(src);
|
|
607
|
+
src_ndim = 1;
|
|
608
|
+
}
|
|
221
609
|
else
|
|
222
610
|
{
|
|
223
|
-
fprintf(stderr, "Warp error: Invalid array type (%d)\n", src_type);
|
|
611
|
+
fprintf(stderr, "Warp copy error: Invalid source array type (%d)\n", src_type);
|
|
224
612
|
return 0;
|
|
225
613
|
}
|
|
226
614
|
|
|
@@ -243,26 +631,134 @@ WP_API size_t array_copy_host(void* dst, void* src, int dst_type, int src_type,
|
|
|
243
631
|
dst_strides = dst_arr.arr.strides;
|
|
244
632
|
dst_indices = dst_arr.indices;
|
|
245
633
|
}
|
|
634
|
+
else if (dst_type == wp::ARRAY_TYPE_FABRIC)
|
|
635
|
+
{
|
|
636
|
+
dst_fabricarray = static_cast<wp::fabricarray_t<void>*>(dst);
|
|
637
|
+
dst_ndim = 1;
|
|
638
|
+
}
|
|
639
|
+
else if (dst_type == wp::ARRAY_TYPE_FABRIC_INDEXED)
|
|
640
|
+
{
|
|
641
|
+
dst_indexedfabricarray = static_cast<wp::indexedfabricarray_t<void>*>(dst);
|
|
642
|
+
dst_ndim = 1;
|
|
643
|
+
}
|
|
246
644
|
else
|
|
247
645
|
{
|
|
248
|
-
fprintf(stderr, "Warp error: Invalid array type (%d)\n", dst_type);
|
|
646
|
+
fprintf(stderr, "Warp copy error: Invalid destination array type (%d)\n", dst_type);
|
|
249
647
|
return 0;
|
|
250
648
|
}
|
|
251
649
|
|
|
252
650
|
if (src_ndim != dst_ndim)
|
|
253
651
|
{
|
|
254
|
-
fprintf(stderr, "Warp error: Incompatible array dimensionalities (%d and %d)\n", src_ndim, dst_ndim);
|
|
652
|
+
fprintf(stderr, "Warp copy error: Incompatible array dimensionalities (%d and %d)\n", src_ndim, dst_ndim);
|
|
255
653
|
return 0;
|
|
256
654
|
}
|
|
257
655
|
|
|
258
|
-
|
|
259
|
-
|
|
656
|
+
// handle fabric arrays
|
|
657
|
+
if (dst_fabricarray)
|
|
658
|
+
{
|
|
659
|
+
size_t n = dst_fabricarray->size;
|
|
660
|
+
if (src_fabricarray)
|
|
661
|
+
{
|
|
662
|
+
// copy from fabric to fabric
|
|
663
|
+
if (src_fabricarray->size != n)
|
|
664
|
+
{
|
|
665
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
666
|
+
return 0;
|
|
667
|
+
}
|
|
668
|
+
array_copy_fabric_to_fabric(*dst_fabricarray, *src_fabricarray, elem_size);
|
|
669
|
+
return n;
|
|
670
|
+
}
|
|
671
|
+
else if (src_indexedfabricarray)
|
|
672
|
+
{
|
|
673
|
+
// copy from fabric indexed to fabric
|
|
674
|
+
if (src_indexedfabricarray->size != n)
|
|
675
|
+
{
|
|
676
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
677
|
+
return 0;
|
|
678
|
+
}
|
|
679
|
+
array_copy_fabric_indexed_to_fabric(*dst_fabricarray, *src_indexedfabricarray, elem_size);
|
|
680
|
+
return n;
|
|
681
|
+
}
|
|
682
|
+
else
|
|
683
|
+
{
|
|
684
|
+
// copy to fabric
|
|
685
|
+
if (size_t(src_shape[0]) != n)
|
|
686
|
+
{
|
|
687
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
688
|
+
return 0;
|
|
689
|
+
}
|
|
690
|
+
array_copy_to_fabric(*dst_fabricarray, src_data, src_strides[0], src_indices[0], elem_size);
|
|
691
|
+
return n;
|
|
692
|
+
}
|
|
693
|
+
}
|
|
694
|
+
else if (dst_indexedfabricarray)
|
|
695
|
+
{
|
|
696
|
+
size_t n = dst_indexedfabricarray->size;
|
|
697
|
+
if (src_fabricarray)
|
|
698
|
+
{
|
|
699
|
+
// copy from fabric to fabric indexed
|
|
700
|
+
if (src_fabricarray->size != n)
|
|
701
|
+
{
|
|
702
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
703
|
+
return 0;
|
|
704
|
+
}
|
|
705
|
+
array_copy_fabric_to_fabric_indexed(*dst_indexedfabricarray, *src_fabricarray, elem_size);
|
|
706
|
+
return n;
|
|
707
|
+
}
|
|
708
|
+
else if (src_indexedfabricarray)
|
|
709
|
+
{
|
|
710
|
+
// copy from fabric indexed to fabric indexed
|
|
711
|
+
if (src_indexedfabricarray->size != n)
|
|
712
|
+
{
|
|
713
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
714
|
+
return 0;
|
|
715
|
+
}
|
|
716
|
+
array_copy_fabric_indexed_to_fabric_indexed(*dst_indexedfabricarray, *src_indexedfabricarray, elem_size);
|
|
717
|
+
return n;
|
|
718
|
+
}
|
|
719
|
+
else
|
|
720
|
+
{
|
|
721
|
+
// copy to fabric indexed
|
|
722
|
+
if (size_t(src_shape[0]) != n)
|
|
723
|
+
{
|
|
724
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
725
|
+
return 0;
|
|
726
|
+
}
|
|
727
|
+
array_copy_to_fabric_indexed(*dst_indexedfabricarray, src_data, src_strides[0], src_indices[0], elem_size);
|
|
728
|
+
return n;
|
|
729
|
+
}
|
|
730
|
+
}
|
|
731
|
+
else if (src_fabricarray)
|
|
732
|
+
{
|
|
733
|
+
// copy from fabric
|
|
734
|
+
size_t n = src_fabricarray->size;
|
|
735
|
+
if (size_t(dst_shape[0]) != n)
|
|
736
|
+
{
|
|
737
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
738
|
+
return 0;
|
|
739
|
+
}
|
|
740
|
+
array_copy_from_fabric(*src_fabricarray, dst_data, dst_strides[0], dst_indices[0], elem_size);
|
|
741
|
+
return n;
|
|
742
|
+
}
|
|
743
|
+
else if (src_indexedfabricarray)
|
|
744
|
+
{
|
|
745
|
+
// copy from fabric indexed
|
|
746
|
+
size_t n = src_indexedfabricarray->size;
|
|
747
|
+
if (size_t(dst_shape[0]) != n)
|
|
748
|
+
{
|
|
749
|
+
fprintf(stderr, "Warp copy error: Incompatible array sizes\n");
|
|
750
|
+
return 0;
|
|
751
|
+
}
|
|
752
|
+
array_copy_from_fabric_indexed(*src_indexedfabricarray, dst_data, dst_strides[0], dst_indices[0], elem_size);
|
|
753
|
+
return n;
|
|
754
|
+
}
|
|
260
755
|
|
|
756
|
+
size_t n = 1;
|
|
261
757
|
for (int i = 0; i < src_ndim; i++)
|
|
262
758
|
{
|
|
263
759
|
if (src_shape[i] != dst_shape[i])
|
|
264
760
|
{
|
|
265
|
-
fprintf(stderr, "Warp error: Incompatible array shapes\n");
|
|
761
|
+
fprintf(stderr, "Warp copy error: Incompatible array shapes\n");
|
|
266
762
|
return 0;
|
|
267
763
|
}
|
|
268
764
|
n *= src_shape[i];
|
|
@@ -273,14 +769,6 @@ WP_API size_t array_copy_host(void* dst, void* src, int dst_type, int src_type,
|
|
|
273
769
|
dst_indices, src_indices,
|
|
274
770
|
src_shape, src_ndim, elem_size);
|
|
275
771
|
|
|
276
|
-
if (has_grad)
|
|
277
|
-
{
|
|
278
|
-
array_copy_nd(dst_grad, src_grad,
|
|
279
|
-
dst_strides, src_strides,
|
|
280
|
-
dst_indices, src_indices,
|
|
281
|
-
src_shape, src_ndim, elem_size);
|
|
282
|
-
}
|
|
283
|
-
|
|
284
772
|
return n;
|
|
285
773
|
}
|
|
286
774
|
|
|
@@ -332,6 +820,31 @@ static void array_fill_indexed(void* data, const int* shape, const int* strides,
|
|
|
332
820
|
}
|
|
333
821
|
|
|
334
822
|
|
|
823
|
+
static void array_fill_fabric(wp::fabricarray_t<void>& fa, const void* value_ptr, int value_size)
|
|
824
|
+
{
|
|
825
|
+
for (size_t i = 0; i < fa.nbuckets; i++)
|
|
826
|
+
{
|
|
827
|
+
const wp::fabricbucket_t& bucket = fa.buckets[i];
|
|
828
|
+
size_t bucket_size = bucket.index_end - bucket.index_start;
|
|
829
|
+
memtile_host(bucket.ptr, value_ptr, value_size, bucket_size);
|
|
830
|
+
}
|
|
831
|
+
}
|
|
832
|
+
|
|
833
|
+
|
|
834
|
+
static void array_fill_fabric_indexed(wp::indexedfabricarray_t<void>& ifa, const void* value_ptr, int value_size)
|
|
835
|
+
{
|
|
836
|
+
for (size_t i = 0; i < ifa.size; i++)
|
|
837
|
+
{
|
|
838
|
+
size_t idx = size_t(ifa.indices[i]);
|
|
839
|
+
if (idx < ifa.fa.size)
|
|
840
|
+
{
|
|
841
|
+
void* p = fabricarray_element_ptr(ifa.fa, idx, value_size);
|
|
842
|
+
memcpy(p, value_ptr, value_size);
|
|
843
|
+
}
|
|
844
|
+
}
|
|
845
|
+
}
|
|
846
|
+
|
|
847
|
+
|
|
335
848
|
WP_API void array_fill_host(void* arr_ptr, int arr_type, const void* value_ptr, int value_size)
|
|
336
849
|
{
|
|
337
850
|
if (!arr_ptr || !value_ptr)
|
|
@@ -347,9 +860,19 @@ WP_API void array_fill_host(void* arr_ptr, int arr_type, const void* value_ptr,
|
|
|
347
860
|
wp::indexedarray_t<void>& ia = *static_cast<wp::indexedarray_t<void>*>(arr_ptr);
|
|
348
861
|
array_fill_indexed(ia.arr.data, ia.shape.dims, ia.arr.strides, ia.indices, ia.arr.ndim, value_ptr, value_size);
|
|
349
862
|
}
|
|
863
|
+
else if (arr_type == wp::ARRAY_TYPE_FABRIC)
|
|
864
|
+
{
|
|
865
|
+
wp::fabricarray_t<void>& fa = *static_cast<wp::fabricarray_t<void>*>(arr_ptr);
|
|
866
|
+
array_fill_fabric(fa, value_ptr, value_size);
|
|
867
|
+
}
|
|
868
|
+
else if (arr_type == wp::ARRAY_TYPE_FABRIC_INDEXED)
|
|
869
|
+
{
|
|
870
|
+
wp::indexedfabricarray_t<void>& ifa = *static_cast<wp::indexedfabricarray_t<void>*>(arr_ptr);
|
|
871
|
+
array_fill_fabric_indexed(ifa, value_ptr, value_size);
|
|
872
|
+
}
|
|
350
873
|
else
|
|
351
874
|
{
|
|
352
|
-
fprintf(stderr, "Warp error: Invalid array type id %d\n", arr_type);
|
|
875
|
+
fprintf(stderr, "Warp fill error: Invalid array type id %d\n", arr_type);
|
|
353
876
|
}
|
|
354
877
|
}
|
|
355
878
|
|
|
@@ -422,6 +945,7 @@ void array_fill_device(void* context, void* arr, int arr_type, const void* value
|
|
|
422
945
|
|
|
423
946
|
WP_API int cuda_driver_version() { return 0; }
|
|
424
947
|
WP_API int cuda_toolkit_version() { return 0; }
|
|
948
|
+
WP_API bool cuda_driver_is_initialized() { return false; }
|
|
425
949
|
|
|
426
950
|
WP_API int nvrtc_supported_arch_count() { return 0; }
|
|
427
951
|
WP_API void nvrtc_supported_archs(int* archs) {}
|
|
@@ -431,7 +955,12 @@ WP_API void* cuda_device_primary_context_retain(int ordinal) { return NULL; }
|
|
|
431
955
|
WP_API void cuda_device_primary_context_release(int ordinal) {}
|
|
432
956
|
WP_API const char* cuda_device_get_name(int ordinal) { return NULL; }
|
|
433
957
|
WP_API int cuda_device_get_arch(int ordinal) { return 0; }
|
|
958
|
+
WP_API void cuda_device_get_uuid(int ordinal, char uuid[16]) {}
|
|
959
|
+
WP_API int cuda_device_get_pci_domain_id(int ordinal) { return -1; }
|
|
960
|
+
WP_API int cuda_device_get_pci_bus_id(int ordinal) { return -1; }
|
|
961
|
+
WP_API int cuda_device_get_pci_device_id(int ordinal) { return -1; }
|
|
434
962
|
WP_API int cuda_device_is_uva(int ordinal) { return 0; }
|
|
963
|
+
WP_API int cuda_device_is_memory_pool_supported() { return 0; }
|
|
435
964
|
|
|
436
965
|
WP_API void* cuda_context_get_current() { return NULL; }
|
|
437
966
|
WP_API void cuda_context_set_current(void* ctx) {}
|
|
@@ -443,6 +972,7 @@ WP_API void cuda_context_synchronize(void* context) {}
|
|
|
443
972
|
WP_API uint64_t cuda_context_check(void* context) { return 0; }
|
|
444
973
|
WP_API int cuda_context_get_device_ordinal(void* context) { return -1; }
|
|
445
974
|
WP_API int cuda_context_is_primary(void* context) { return 0; }
|
|
975
|
+
WP_API int cuda_context_is_memory_pool_supported(void* context) { return 0; }
|
|
446
976
|
WP_API void* cuda_context_get_stream(void* context) { return NULL; }
|
|
447
977
|
WP_API void cuda_context_set_stream(void* context, void* stream) {}
|
|
448
978
|
WP_API int cuda_context_can_access_peer(void* context, void* peer_context) { return 0; }
|
|
@@ -469,7 +999,7 @@ WP_API size_t cuda_compile_program(const char* cuda_src, int arch, const char* i
|
|
|
469
999
|
WP_API void* cuda_load_module(void* context, const char* ptx) { return NULL; }
|
|
470
1000
|
WP_API void cuda_unload_module(void* context, void* module) {}
|
|
471
1001
|
WP_API void* cuda_get_kernel(void* context, void* module, const char* name) { return NULL; }
|
|
472
|
-
WP_API size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, void** args) { return 0;}
|
|
1002
|
+
WP_API size_t cuda_launch_kernel(void* context, void* kernel, size_t dim, int max_blocks, void** args) { return 0;}
|
|
473
1003
|
|
|
474
1004
|
WP_API void cuda_set_context_restore_policy(bool always_restore) {}
|
|
475
1005
|
WP_API int cuda_get_context_restore_policy() { return false; }
|