warp-lang 1.9.1__py3-none-win_amd64.whl → 1.10.0rc2__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +301 -287
- warp/__init__.pyi +794 -305
- warp/_src/__init__.py +14 -0
- warp/_src/autograd.py +1075 -0
- warp/_src/build.py +618 -0
- warp/_src/build_dll.py +640 -0
- warp/{builtins.py → _src/builtins.py} +1382 -377
- warp/_src/codegen.py +4359 -0
- warp/{config.py → _src/config.py} +178 -169
- warp/_src/constants.py +57 -0
- warp/_src/context.py +8294 -0
- warp/_src/dlpack.py +462 -0
- warp/_src/fabric.py +355 -0
- warp/_src/fem/__init__.py +14 -0
- warp/_src/fem/adaptivity.py +508 -0
- warp/_src/fem/cache.py +687 -0
- warp/_src/fem/dirichlet.py +188 -0
- warp/{fem → _src/fem}/domain.py +40 -30
- warp/_src/fem/field/__init__.py +131 -0
- warp/_src/fem/field/field.py +701 -0
- warp/{fem → _src/fem}/field/nodal_field.py +30 -15
- warp/{fem → _src/fem}/field/restriction.py +1 -1
- warp/{fem → _src/fem}/field/virtual.py +53 -27
- warp/_src/fem/geometry/__init__.py +32 -0
- warp/{fem → _src/fem}/geometry/adaptive_nanogrid.py +77 -163
- warp/_src/fem/geometry/closest_point.py +97 -0
- warp/{fem → _src/fem}/geometry/deformed_geometry.py +14 -22
- warp/{fem → _src/fem}/geometry/element.py +32 -10
- warp/{fem → _src/fem}/geometry/geometry.py +48 -20
- warp/{fem → _src/fem}/geometry/grid_2d.py +12 -23
- warp/{fem → _src/fem}/geometry/grid_3d.py +12 -23
- warp/{fem → _src/fem}/geometry/hexmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/nanogrid.py +255 -248
- warp/{fem → _src/fem}/geometry/partition.py +121 -63
- warp/{fem → _src/fem}/geometry/quadmesh.py +26 -45
- warp/{fem → _src/fem}/geometry/tetmesh.py +40 -63
- warp/{fem → _src/fem}/geometry/trimesh.py +26 -45
- warp/{fem → _src/fem}/integrate.py +164 -158
- warp/_src/fem/linalg.py +383 -0
- warp/_src/fem/operator.py +396 -0
- warp/_src/fem/polynomial.py +229 -0
- warp/{fem → _src/fem}/quadrature/pic_quadrature.py +15 -20
- warp/{fem → _src/fem}/quadrature/quadrature.py +95 -47
- warp/_src/fem/space/__init__.py +248 -0
- warp/{fem → _src/fem}/space/basis_function_space.py +20 -11
- warp/_src/fem/space/basis_space.py +679 -0
- warp/{fem → _src/fem}/space/dof_mapper.py +3 -3
- warp/{fem → _src/fem}/space/function_space.py +14 -13
- warp/{fem → _src/fem}/space/grid_2d_function_space.py +4 -7
- warp/{fem → _src/fem}/space/grid_3d_function_space.py +4 -4
- warp/{fem → _src/fem}/space/hexmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/nanogrid_function_space.py +3 -9
- warp/{fem → _src/fem}/space/partition.py +117 -60
- warp/{fem → _src/fem}/space/quadmesh_function_space.py +4 -10
- warp/{fem → _src/fem}/space/restriction.py +66 -33
- warp/_src/fem/space/shape/__init__.py +152 -0
- warp/{fem → _src/fem}/space/shape/cube_shape_function.py +9 -9
- warp/{fem → _src/fem}/space/shape/shape_function.py +8 -9
- warp/{fem → _src/fem}/space/shape/square_shape_function.py +6 -6
- warp/{fem → _src/fem}/space/shape/tet_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/shape/triangle_shape_function.py +3 -3
- warp/{fem → _src/fem}/space/tetmesh_function_space.py +3 -9
- warp/_src/fem/space/topology.py +459 -0
- warp/{fem → _src/fem}/space/trimesh_function_space.py +3 -9
- warp/_src/fem/types.py +112 -0
- warp/_src/fem/utils.py +486 -0
- warp/_src/jax.py +186 -0
- warp/_src/jax_experimental/__init__.py +14 -0
- warp/_src/jax_experimental/custom_call.py +387 -0
- warp/_src/jax_experimental/ffi.py +1284 -0
- warp/_src/jax_experimental/xla_ffi.py +656 -0
- warp/_src/marching_cubes.py +708 -0
- warp/_src/math.py +414 -0
- warp/_src/optim/__init__.py +14 -0
- warp/_src/optim/adam.py +163 -0
- warp/_src/optim/linear.py +1606 -0
- warp/_src/optim/sgd.py +112 -0
- warp/_src/paddle.py +406 -0
- warp/_src/render/__init__.py +14 -0
- warp/_src/render/imgui_manager.py +289 -0
- warp/_src/render/render_opengl.py +3636 -0
- warp/_src/render/render_usd.py +937 -0
- warp/_src/render/utils.py +160 -0
- warp/_src/sparse.py +2716 -0
- warp/_src/tape.py +1206 -0
- warp/{thirdparty → _src/thirdparty}/unittest_parallel.py +9 -2
- warp/_src/torch.py +391 -0
- warp/_src/types.py +5870 -0
- warp/_src/utils.py +1693 -0
- warp/autograd.py +12 -1054
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +8 -588
- warp/build_dll.py +6 -721
- warp/codegen.py +6 -4251
- warp/constants.py +6 -39
- warp/context.py +12 -8062
- warp/dlpack.py +6 -444
- warp/examples/distributed/example_jacobi_mpi.py +4 -5
- warp/examples/fem/example_adaptive_grid.py +1 -1
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +8 -8
- warp/examples/fem/example_diffusion.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_mixed_elasticity.py +2 -2
- warp/examples/fem/example_navier_stokes.py +1 -1
- warp/examples/fem/example_nonconforming_contact.py +7 -7
- warp/examples/fem/example_stokes.py +1 -1
- warp/examples/fem/example_stokes_transfer.py +1 -1
- warp/examples/fem/utils.py +2 -2
- warp/examples/interop/example_jax_callable.py +1 -1
- warp/examples/interop/example_jax_ffi_callback.py +1 -1
- warp/examples/interop/example_jax_kernel.py +1 -1
- warp/examples/tile/example_tile_mcgp.py +191 -0
- warp/fabric.py +6 -337
- warp/fem/__init__.py +159 -97
- warp/fem/adaptivity.py +7 -489
- warp/fem/cache.py +9 -648
- warp/fem/dirichlet.py +6 -184
- warp/fem/field/__init__.py +8 -109
- warp/fem/field/field.py +7 -652
- warp/fem/geometry/__init__.py +7 -18
- warp/fem/geometry/closest_point.py +11 -77
- warp/fem/linalg.py +18 -366
- warp/fem/operator.py +11 -369
- warp/fem/polynomial.py +9 -209
- warp/fem/space/__init__.py +5 -211
- warp/fem/space/basis_space.py +6 -662
- warp/fem/space/shape/__init__.py +41 -118
- warp/fem/space/topology.py +6 -437
- warp/fem/types.py +6 -81
- warp/fem/utils.py +11 -444
- warp/jax.py +8 -165
- warp/jax_experimental/__init__.py +14 -1
- warp/jax_experimental/custom_call.py +8 -365
- warp/jax_experimental/ffi.py +17 -873
- warp/jax_experimental/xla_ffi.py +5 -605
- warp/marching_cubes.py +5 -689
- warp/math.py +16 -393
- warp/native/array.h +385 -37
- warp/native/builtin.h +314 -37
- warp/native/bvh.cpp +43 -9
- warp/native/bvh.cu +62 -27
- warp/native/bvh.h +310 -309
- warp/native/clang/clang.cpp +102 -97
- warp/native/coloring.cpp +0 -1
- warp/native/crt.h +208 -0
- warp/native/exports.h +156 -0
- warp/native/hashgrid.cu +2 -0
- warp/native/intersect.h +24 -1
- warp/native/intersect_tri.h +44 -35
- warp/native/mat.h +1456 -276
- warp/native/mesh.cpp +4 -4
- warp/native/mesh.cu +4 -2
- warp/native/mesh.h +176 -61
- warp/native/quat.h +0 -52
- warp/native/scan.cu +2 -0
- warp/native/sparse.cu +7 -3
- warp/native/spatial.h +12 -0
- warp/native/tile.h +681 -89
- warp/native/tile_radix_sort.h +1 -1
- warp/native/tile_reduce.h +394 -46
- warp/native/tile_scan.h +4 -4
- warp/native/vec.h +469 -0
- warp/native/version.h +23 -0
- warp/native/volume.cpp +1 -1
- warp/native/volume.cu +1 -0
- warp/native/volume.h +1 -1
- warp/native/volume_builder.cu +2 -0
- warp/native/warp.cpp +57 -29
- warp/native/warp.cu +253 -171
- warp/native/warp.h +11 -8
- warp/optim/__init__.py +6 -3
- warp/optim/adam.py +6 -145
- warp/optim/linear.py +14 -1585
- warp/optim/sgd.py +6 -94
- warp/paddle.py +6 -388
- warp/render/__init__.py +8 -4
- warp/render/imgui_manager.py +7 -267
- warp/render/render_opengl.py +6 -3618
- warp/render/render_usd.py +6 -919
- warp/render/utils.py +6 -142
- warp/sparse.py +37 -2563
- warp/tape.py +6 -1188
- warp/tests/__main__.py +1 -1
- warp/tests/cuda/test_async.py +4 -4
- warp/tests/cuda/test_conditional_captures.py +1 -1
- warp/tests/cuda/test_multigpu.py +1 -1
- warp/tests/cuda/test_streams.py +58 -1
- warp/tests/geometry/test_bvh.py +157 -22
- warp/tests/geometry/test_marching_cubes.py +0 -1
- warp/tests/geometry/test_mesh.py +5 -3
- warp/tests/geometry/test_mesh_query_aabb.py +5 -12
- warp/tests/geometry/test_mesh_query_point.py +5 -2
- warp/tests/geometry/test_mesh_query_ray.py +15 -3
- warp/tests/geometry/test_volume_write.py +5 -5
- warp/tests/interop/test_dlpack.py +14 -14
- warp/tests/interop/test_jax.py +772 -49
- warp/tests/interop/test_paddle.py +1 -1
- warp/tests/test_adam.py +0 -1
- warp/tests/test_arithmetic.py +9 -9
- warp/tests/test_array.py +527 -100
- warp/tests/test_array_reduce.py +3 -3
- warp/tests/test_atomic.py +12 -8
- warp/tests/test_atomic_bitwise.py +209 -0
- warp/tests/test_atomic_cas.py +4 -4
- warp/tests/test_bool.py +2 -2
- warp/tests/test_builtins_resolution.py +5 -571
- warp/tests/test_codegen.py +33 -14
- warp/tests/test_conditional.py +1 -1
- warp/tests/test_context.py +6 -6
- warp/tests/test_copy.py +242 -161
- warp/tests/test_ctypes.py +3 -3
- warp/tests/test_devices.py +24 -2
- warp/tests/test_examples.py +16 -84
- warp/tests/test_fabricarray.py +35 -35
- warp/tests/test_fast_math.py +0 -2
- warp/tests/test_fem.py +56 -10
- warp/tests/test_fixedarray.py +3 -3
- warp/tests/test_func.py +8 -5
- warp/tests/test_generics.py +1 -1
- warp/tests/test_indexedarray.py +24 -24
- warp/tests/test_intersect.py +39 -9
- warp/tests/test_large.py +1 -1
- warp/tests/test_lerp.py +3 -1
- warp/tests/test_linear_solvers.py +1 -1
- warp/tests/test_map.py +35 -4
- warp/tests/test_mat.py +52 -62
- warp/tests/test_mat_constructors.py +4 -5
- warp/tests/test_mat_lite.py +1 -1
- warp/tests/test_mat_scalar_ops.py +121 -121
- warp/tests/test_math.py +34 -0
- warp/tests/test_module_aot.py +4 -4
- warp/tests/test_modules_lite.py +28 -2
- warp/tests/test_print.py +11 -11
- warp/tests/test_quat.py +93 -58
- warp/tests/test_runlength_encode.py +1 -1
- warp/tests/test_scalar_ops.py +38 -10
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +126 -15
- warp/tests/test_spatial.py +105 -87
- warp/tests/test_special_values.py +6 -6
- warp/tests/test_static.py +7 -7
- warp/tests/test_struct.py +13 -2
- warp/tests/test_triangle_closest_point.py +48 -1
- warp/tests/test_types.py +27 -15
- warp/tests/test_utils.py +52 -52
- warp/tests/test_vec.py +29 -29
- warp/tests/test_vec_constructors.py +5 -5
- warp/tests/test_vec_scalar_ops.py +97 -97
- warp/tests/test_version.py +75 -0
- warp/tests/tile/test_tile.py +178 -0
- warp/tests/tile/test_tile_atomic_bitwise.py +403 -0
- warp/tests/tile/test_tile_cholesky.py +7 -4
- warp/tests/tile/test_tile_load.py +26 -2
- warp/tests/tile/test_tile_mathdx.py +3 -3
- warp/tests/tile/test_tile_matmul.py +1 -1
- warp/tests/tile/test_tile_mlp.py +2 -4
- warp/tests/tile/test_tile_reduce.py +214 -13
- warp/tests/unittest_suites.py +6 -14
- warp/tests/unittest_utils.py +10 -9
- warp/tests/walkthrough_debug.py +3 -1
- warp/torch.py +6 -373
- warp/types.py +29 -5764
- warp/utils.py +10 -1659
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/METADATA +46 -99
- warp_lang-1.10.0rc2.dist-info/RECORD +468 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/Gaia-LICENSE.txt +6 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/appdirs-LICENSE.txt +22 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/asset_pixel_jpg-LICENSE.txt +3 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/cuda-LICENSE.txt +1582 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/dlpack-LICENSE.txt +201 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/fp16-LICENSE.txt +28 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/libmathdx-LICENSE.txt +220 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/llvm-LICENSE.txt +279 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/moller-LICENSE.txt +16 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nanovdb-LICENSE.txt +2 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/nvrtc-LICENSE.txt +1592 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/svd-LICENSE.txt +23 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/unittest_parallel-LICENSE.txt +21 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/usd-LICENSE.txt +213 -0
- warp_lang-1.10.0rc2.dist-info/licenses/licenses/windingnumber-LICENSE.txt +21 -0
- warp/examples/assets/cartpole.urdf +0 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/nv_ant.xml +0 -92
- warp/examples/assets/nv_humanoid.xml +0 -183
- warp/examples/assets/quadruped.urdf +0 -268
- warp/examples/optim/example_bounce.py +0 -266
- warp/examples/optim/example_cloth_throw.py +0 -228
- warp/examples/optim/example_drone.py +0 -870
- warp/examples/optim/example_inverse_kinematics.py +0 -182
- warp/examples/optim/example_inverse_kinematics_torch.py +0 -191
- warp/examples/optim/example_softbody_properties.py +0 -400
- warp/examples/optim/example_spring_cage.py +0 -245
- warp/examples/optim/example_trajectory.py +0 -227
- warp/examples/sim/example_cartpole.py +0 -143
- warp/examples/sim/example_cloth.py +0 -225
- warp/examples/sim/example_cloth_self_contact.py +0 -316
- warp/examples/sim/example_granular.py +0 -130
- warp/examples/sim/example_granular_collision_sdf.py +0 -202
- warp/examples/sim/example_jacobian_ik.py +0 -244
- warp/examples/sim/example_particle_chain.py +0 -124
- warp/examples/sim/example_quadruped.py +0 -203
- warp/examples/sim/example_rigid_chain.py +0 -203
- warp/examples/sim/example_rigid_contact.py +0 -195
- warp/examples/sim/example_rigid_force.py +0 -133
- warp/examples/sim/example_rigid_gyroscopic.py +0 -115
- warp/examples/sim/example_rigid_soft_contact.py +0 -140
- warp/examples/sim/example_soft_body.py +0 -196
- warp/examples/tile/example_tile_walker.py +0 -327
- warp/sim/__init__.py +0 -74
- warp/sim/articulation.py +0 -793
- warp/sim/collide.py +0 -2570
- warp/sim/graph_coloring.py +0 -307
- warp/sim/import_mjcf.py +0 -791
- warp/sim/import_snu.py +0 -227
- warp/sim/import_urdf.py +0 -579
- warp/sim/import_usd.py +0 -898
- warp/sim/inertia.py +0 -357
- warp/sim/integrator.py +0 -245
- warp/sim/integrator_euler.py +0 -2000
- warp/sim/integrator_featherstone.py +0 -2101
- warp/sim/integrator_vbd.py +0 -2487
- warp/sim/integrator_xpbd.py +0 -3295
- warp/sim/model.py +0 -4821
- warp/sim/particles.py +0 -121
- warp/sim/render.py +0 -431
- warp/sim/utils.py +0 -431
- warp/tests/sim/disabled_kinematics.py +0 -244
- warp/tests/sim/test_cloth.py +0 -863
- warp/tests/sim/test_collision.py +0 -743
- warp/tests/sim/test_coloring.py +0 -347
- warp/tests/sim/test_inertia.py +0 -161
- warp/tests/sim/test_model.py +0 -226
- warp/tests/sim/test_sim_grad.py +0 -287
- warp/tests/sim/test_sim_grad_bounce_linear.py +0 -212
- warp/tests/sim/test_sim_kinematics.py +0 -98
- warp/thirdparty/__init__.py +0 -0
- warp_lang-1.9.1.dist-info/RECORD +0 -456
- /warp/{fem → _src/fem}/quadrature/__init__.py +0 -0
- /warp/{tests/sim → _src/thirdparty}/__init__.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/appdirs.py +0 -0
- /warp/{thirdparty → _src/thirdparty}/dlpack.py +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/WHEEL +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.9.1.dist-info → warp_lang-1.10.0rc2.dist-info}/top_level.txt +0 -0
warp/native/array.h
CHANGED
|
@@ -118,6 +118,23 @@ namespace wp
|
|
|
118
118
|
|
|
119
119
|
#endif // WP_FP_CHECK
|
|
120
120
|
|
|
121
|
+
|
|
122
|
+
template<size_t... Is>
|
|
123
|
+
struct index_sequence {};
|
|
124
|
+
|
|
125
|
+
template<size_t N, size_t... Is>
|
|
126
|
+
struct make_index_sequence_impl : make_index_sequence_impl<N-1, N-1, Is...> {};
|
|
127
|
+
|
|
128
|
+
template<size_t... Is>
|
|
129
|
+
struct make_index_sequence_impl<0, Is...>
|
|
130
|
+
{
|
|
131
|
+
using type = index_sequence<Is...>;
|
|
132
|
+
};
|
|
133
|
+
|
|
134
|
+
template<size_t N>
|
|
135
|
+
using make_index_sequence = typename make_index_sequence_impl<N>::type;
|
|
136
|
+
|
|
137
|
+
|
|
121
138
|
const int ARRAY_MAX_DIMS = 4; // must match constant in types.py
|
|
122
139
|
|
|
123
140
|
// must match constants in types.py
|
|
@@ -423,6 +440,13 @@ template <typename T>
|
|
|
423
440
|
CUDA_CALLABLE inline T& index(const array_t<T>& arr, int i)
|
|
424
441
|
{
|
|
425
442
|
assert(arr.ndim == 1);
|
|
443
|
+
assert(i >= -arr.shape[0] && i < arr.shape[0]);
|
|
444
|
+
|
|
445
|
+
if (i < 0)
|
|
446
|
+
{
|
|
447
|
+
i += arr.shape[0];
|
|
448
|
+
}
|
|
449
|
+
|
|
426
450
|
T& result = *data_at_byte_offset(arr, byte_offset(arr, i));
|
|
427
451
|
FP_VERIFY_FWD_1(result)
|
|
428
452
|
|
|
@@ -433,6 +457,18 @@ template <typename T>
|
|
|
433
457
|
CUDA_CALLABLE inline T& index(const array_t<T>& arr, int i, int j)
|
|
434
458
|
{
|
|
435
459
|
assert(arr.ndim == 2);
|
|
460
|
+
assert(i >= -arr.shape[0] && i < arr.shape[0]);
|
|
461
|
+
assert(j >= -arr.shape[1] && j < arr.shape[1]);
|
|
462
|
+
|
|
463
|
+
if (i < 0)
|
|
464
|
+
{
|
|
465
|
+
i += arr.shape[0];
|
|
466
|
+
}
|
|
467
|
+
if (j < 0)
|
|
468
|
+
{
|
|
469
|
+
j += arr.shape[1];
|
|
470
|
+
}
|
|
471
|
+
|
|
436
472
|
T& result = *data_at_byte_offset(arr, byte_offset(arr, i, j));
|
|
437
473
|
FP_VERIFY_FWD_2(result)
|
|
438
474
|
|
|
@@ -443,6 +479,23 @@ template <typename T>
|
|
|
443
479
|
CUDA_CALLABLE inline T& index(const array_t<T>& arr, int i, int j, int k)
|
|
444
480
|
{
|
|
445
481
|
assert(arr.ndim == 3);
|
|
482
|
+
assert(i >= -arr.shape[0] && i < arr.shape[0]);
|
|
483
|
+
assert(j >= -arr.shape[1] && j < arr.shape[1]);
|
|
484
|
+
assert(k >= -arr.shape[2] && k < arr.shape[2]);
|
|
485
|
+
|
|
486
|
+
if (i < 0)
|
|
487
|
+
{
|
|
488
|
+
i += arr.shape[0];
|
|
489
|
+
}
|
|
490
|
+
if (j < 0)
|
|
491
|
+
{
|
|
492
|
+
j += arr.shape[1];
|
|
493
|
+
}
|
|
494
|
+
if (k < 0)
|
|
495
|
+
{
|
|
496
|
+
k += arr.shape[2];
|
|
497
|
+
}
|
|
498
|
+
|
|
446
499
|
T& result = *data_at_byte_offset(arr, byte_offset(arr, i, j, k));
|
|
447
500
|
FP_VERIFY_FWD_3(result)
|
|
448
501
|
|
|
@@ -453,6 +506,28 @@ template <typename T>
|
|
|
453
506
|
CUDA_CALLABLE inline T& index(const array_t<T>& arr, int i, int j, int k, int l)
|
|
454
507
|
{
|
|
455
508
|
assert(arr.ndim == 4);
|
|
509
|
+
assert(i >= -arr.shape[0] && i < arr.shape[0]);
|
|
510
|
+
assert(j >= -arr.shape[1] && j < arr.shape[1]);
|
|
511
|
+
assert(k >= -arr.shape[2] && k < arr.shape[2]);
|
|
512
|
+
assert(l >= -arr.shape[3] && l < arr.shape[3]);
|
|
513
|
+
|
|
514
|
+
if (i < 0)
|
|
515
|
+
{
|
|
516
|
+
i += arr.shape[0];
|
|
517
|
+
}
|
|
518
|
+
if (j < 0)
|
|
519
|
+
{
|
|
520
|
+
j += arr.shape[1];
|
|
521
|
+
}
|
|
522
|
+
if (k < 0)
|
|
523
|
+
{
|
|
524
|
+
k += arr.shape[2];
|
|
525
|
+
}
|
|
526
|
+
if (l < 0)
|
|
527
|
+
{
|
|
528
|
+
l += arr.shape[3];
|
|
529
|
+
}
|
|
530
|
+
|
|
456
531
|
T& result = *data_at_byte_offset(arr, byte_offset(arr, i, j, k, l));
|
|
457
532
|
FP_VERIFY_FWD_4(result)
|
|
458
533
|
|
|
@@ -462,6 +537,14 @@ CUDA_CALLABLE inline T& index(const array_t<T>& arr, int i, int j, int k, int l)
|
|
|
462
537
|
template <typename T>
|
|
463
538
|
CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i)
|
|
464
539
|
{
|
|
540
|
+
assert(arr.ndim == 1);
|
|
541
|
+
assert(i >= -arr.shape[0] && i < arr.shape[0]);
|
|
542
|
+
|
|
543
|
+
if (i < 0)
|
|
544
|
+
{
|
|
545
|
+
i += arr.shape[0];
|
|
546
|
+
}
|
|
547
|
+
|
|
465
548
|
T& result = *grad_at_byte_offset(arr, byte_offset(arr, i));
|
|
466
549
|
FP_VERIFY_FWD_1(result)
|
|
467
550
|
|
|
@@ -471,6 +554,19 @@ CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i)
|
|
|
471
554
|
template <typename T>
|
|
472
555
|
CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i, int j)
|
|
473
556
|
{
|
|
557
|
+
assert(arr.ndim == 2);
|
|
558
|
+
assert(i >= -arr.shape[0] && i < arr.shape[0]);
|
|
559
|
+
assert(j >= -arr.shape[1] && j < arr.shape[1]);
|
|
560
|
+
|
|
561
|
+
if (i < 0)
|
|
562
|
+
{
|
|
563
|
+
i += arr.shape[0];
|
|
564
|
+
}
|
|
565
|
+
if (j < 0)
|
|
566
|
+
{
|
|
567
|
+
j += arr.shape[1];
|
|
568
|
+
}
|
|
569
|
+
|
|
474
570
|
T& result = *grad_at_byte_offset(arr, byte_offset(arr, i, j));
|
|
475
571
|
FP_VERIFY_FWD_2(result)
|
|
476
572
|
|
|
@@ -480,6 +576,24 @@ CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i, int j)
|
|
|
480
576
|
template <typename T>
|
|
481
577
|
CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i, int j, int k)
|
|
482
578
|
{
|
|
579
|
+
assert(arr.ndim == 3);
|
|
580
|
+
assert(i >= -arr.shape[0] && i < arr.shape[0]);
|
|
581
|
+
assert(j >= -arr.shape[1] && j < arr.shape[1]);
|
|
582
|
+
assert(k >= -arr.shape[2] && k < arr.shape[2]);
|
|
583
|
+
|
|
584
|
+
if (i < 0)
|
|
585
|
+
{
|
|
586
|
+
i += arr.shape[0];
|
|
587
|
+
}
|
|
588
|
+
if (j < 0)
|
|
589
|
+
{
|
|
590
|
+
j += arr.shape[1];
|
|
591
|
+
}
|
|
592
|
+
if (k < 0)
|
|
593
|
+
{
|
|
594
|
+
k += arr.shape[2];
|
|
595
|
+
}
|
|
596
|
+
|
|
483
597
|
T& result = *grad_at_byte_offset(arr, byte_offset(arr, i, j, k));
|
|
484
598
|
FP_VERIFY_FWD_3(result)
|
|
485
599
|
|
|
@@ -489,6 +603,29 @@ CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i, int j, int k)
|
|
|
489
603
|
template <typename T>
|
|
490
604
|
CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i, int j, int k, int l)
|
|
491
605
|
{
|
|
606
|
+
assert(arr.ndim == 4);
|
|
607
|
+
assert(i >= -arr.shape[0] && i < arr.shape[0]);
|
|
608
|
+
assert(j >= -arr.shape[1] && j < arr.shape[1]);
|
|
609
|
+
assert(k >= -arr.shape[2] && k < arr.shape[2]);
|
|
610
|
+
assert(l >= -arr.shape[3] && l < arr.shape[3]);
|
|
611
|
+
|
|
612
|
+
if (i < 0)
|
|
613
|
+
{
|
|
614
|
+
i += arr.shape[0];
|
|
615
|
+
}
|
|
616
|
+
if (j < 0)
|
|
617
|
+
{
|
|
618
|
+
j += arr.shape[1];
|
|
619
|
+
}
|
|
620
|
+
if (k < 0)
|
|
621
|
+
{
|
|
622
|
+
k += arr.shape[2];
|
|
623
|
+
}
|
|
624
|
+
if (l < 0)
|
|
625
|
+
{
|
|
626
|
+
l += arr.shape[3];
|
|
627
|
+
}
|
|
628
|
+
|
|
492
629
|
T& result = *grad_at_byte_offset(arr, byte_offset(arr, i, j, k, l));
|
|
493
630
|
FP_VERIFY_FWD_4(result)
|
|
494
631
|
|
|
@@ -500,7 +637,12 @@ template <typename T>
|
|
|
500
637
|
CUDA_CALLABLE inline T& index(const indexedarray_t<T>& iarr, int i)
|
|
501
638
|
{
|
|
502
639
|
assert(iarr.arr.ndim == 1);
|
|
503
|
-
assert(i >= 0 && i < iarr.shape[0]);
|
|
640
|
+
assert(i >= -iarr.shape[0] && i < iarr.shape[0]);
|
|
641
|
+
|
|
642
|
+
if (i < 0)
|
|
643
|
+
{
|
|
644
|
+
i += iarr.shape[0];
|
|
645
|
+
}
|
|
504
646
|
|
|
505
647
|
if (iarr.indices[0])
|
|
506
648
|
{
|
|
@@ -518,8 +660,17 @@ template <typename T>
|
|
|
518
660
|
CUDA_CALLABLE inline T& index(const indexedarray_t<T>& iarr, int i, int j)
|
|
519
661
|
{
|
|
520
662
|
assert(iarr.arr.ndim == 2);
|
|
521
|
-
assert(i >= 0 && i < iarr.shape[0]);
|
|
522
|
-
assert(j >=
|
|
663
|
+
assert(i >= -iarr.shape[0] && i < iarr.shape[0]);
|
|
664
|
+
assert(j >= -iarr.shape[1] && j < iarr.shape[1]);
|
|
665
|
+
|
|
666
|
+
if (i < 0)
|
|
667
|
+
{
|
|
668
|
+
i += iarr.shape[0];
|
|
669
|
+
}
|
|
670
|
+
if (j < 0)
|
|
671
|
+
{
|
|
672
|
+
j += iarr.shape[1];
|
|
673
|
+
}
|
|
523
674
|
|
|
524
675
|
if (iarr.indices[0])
|
|
525
676
|
{
|
|
@@ -542,9 +693,22 @@ template <typename T>
|
|
|
542
693
|
CUDA_CALLABLE inline T& index(const indexedarray_t<T>& iarr, int i, int j, int k)
|
|
543
694
|
{
|
|
544
695
|
assert(iarr.arr.ndim == 3);
|
|
545
|
-
assert(i >= 0 && i < iarr.shape[0]);
|
|
546
|
-
assert(j >=
|
|
547
|
-
assert(k >=
|
|
696
|
+
assert(i >= -iarr.shape[0] && i < iarr.shape[0]);
|
|
697
|
+
assert(j >= -iarr.shape[1] && j < iarr.shape[1]);
|
|
698
|
+
assert(k >= -iarr.shape[2] && k < iarr.shape[2]);
|
|
699
|
+
|
|
700
|
+
if (i < 0)
|
|
701
|
+
{
|
|
702
|
+
i += iarr.shape[0];
|
|
703
|
+
}
|
|
704
|
+
if (j < 0)
|
|
705
|
+
{
|
|
706
|
+
j += iarr.shape[1];
|
|
707
|
+
}
|
|
708
|
+
if (k < 0)
|
|
709
|
+
{
|
|
710
|
+
k += iarr.shape[2];
|
|
711
|
+
}
|
|
548
712
|
|
|
549
713
|
if (iarr.indices[0])
|
|
550
714
|
{
|
|
@@ -572,10 +736,27 @@ template <typename T>
|
|
|
572
736
|
CUDA_CALLABLE inline T& index(const indexedarray_t<T>& iarr, int i, int j, int k, int l)
|
|
573
737
|
{
|
|
574
738
|
assert(iarr.arr.ndim == 4);
|
|
575
|
-
assert(i >= 0 && i < iarr.shape[0]);
|
|
576
|
-
assert(j >=
|
|
577
|
-
assert(k >=
|
|
578
|
-
assert(l >=
|
|
739
|
+
assert(i >= -iarr.shape[0] && i < iarr.shape[0]);
|
|
740
|
+
assert(j >= -iarr.shape[1] && j < iarr.shape[1]);
|
|
741
|
+
assert(k >= -iarr.shape[2] && k < iarr.shape[2]);
|
|
742
|
+
assert(l >= -iarr.shape[3] && l < iarr.shape[3]);
|
|
743
|
+
|
|
744
|
+
if (i < 0)
|
|
745
|
+
{
|
|
746
|
+
i += iarr.shape[0];
|
|
747
|
+
}
|
|
748
|
+
if (j < 0)
|
|
749
|
+
{
|
|
750
|
+
j += iarr.shape[1];
|
|
751
|
+
}
|
|
752
|
+
if (k < 0)
|
|
753
|
+
{
|
|
754
|
+
k += iarr.shape[2];
|
|
755
|
+
}
|
|
756
|
+
if (l < 0)
|
|
757
|
+
{
|
|
758
|
+
l += iarr.shape[3];
|
|
759
|
+
}
|
|
579
760
|
|
|
580
761
|
if (iarr.indices[0])
|
|
581
762
|
{
|
|
@@ -609,7 +790,12 @@ template <typename T>
|
|
|
609
790
|
CUDA_CALLABLE inline array_t<T> view(array_t<T>& src, int i)
|
|
610
791
|
{
|
|
611
792
|
assert(src.ndim > 1);
|
|
612
|
-
assert(i >= 0 && i < src.shape[0]);
|
|
793
|
+
assert(i >= -src.shape[0] && i < src.shape[0]);
|
|
794
|
+
|
|
795
|
+
if (i < 0)
|
|
796
|
+
{
|
|
797
|
+
i += src.shape[0];
|
|
798
|
+
}
|
|
613
799
|
|
|
614
800
|
array_t<T> a;
|
|
615
801
|
size_t offset = byte_offset(src, i);
|
|
@@ -631,8 +817,17 @@ template <typename T>
|
|
|
631
817
|
CUDA_CALLABLE inline array_t<T> view(array_t<T>& src, int i, int j)
|
|
632
818
|
{
|
|
633
819
|
assert(src.ndim > 2);
|
|
634
|
-
assert(i >= 0 && i < src.shape[0]);
|
|
635
|
-
assert(j >=
|
|
820
|
+
assert(i >= -src.shape[0] && i < src.shape[0]);
|
|
821
|
+
assert(j >= -src.shape[1] && j < src.shape[1]);
|
|
822
|
+
|
|
823
|
+
if (i < 0)
|
|
824
|
+
{
|
|
825
|
+
i += src.shape[0];
|
|
826
|
+
}
|
|
827
|
+
if (j < 0)
|
|
828
|
+
{
|
|
829
|
+
j += src.shape[1];
|
|
830
|
+
}
|
|
636
831
|
|
|
637
832
|
array_t<T> a;
|
|
638
833
|
size_t offset = byte_offset(src, i, j);
|
|
@@ -652,9 +847,22 @@ template <typename T>
|
|
|
652
847
|
CUDA_CALLABLE inline array_t<T> view(array_t<T>& src, int i, int j, int k)
|
|
653
848
|
{
|
|
654
849
|
assert(src.ndim > 3);
|
|
655
|
-
assert(i >= 0 && i < src.shape[0]);
|
|
656
|
-
assert(j >=
|
|
657
|
-
assert(k >=
|
|
850
|
+
assert(i >= -src.shape[0] && i < src.shape[0]);
|
|
851
|
+
assert(j >= -src.shape[1] && j < src.shape[1]);
|
|
852
|
+
assert(k >= -src.shape[2] && k < src.shape[2]);
|
|
853
|
+
|
|
854
|
+
if (i < 0)
|
|
855
|
+
{
|
|
856
|
+
i += src.shape[0];
|
|
857
|
+
}
|
|
858
|
+
if (j < 0)
|
|
859
|
+
{
|
|
860
|
+
j += src.shape[1];
|
|
861
|
+
}
|
|
862
|
+
if (k < 0)
|
|
863
|
+
{
|
|
864
|
+
k += src.shape[2];
|
|
865
|
+
}
|
|
658
866
|
|
|
659
867
|
array_t<T> a;
|
|
660
868
|
size_t offset = byte_offset(src, i, j, k);
|
|
@@ -669,6 +877,78 @@ CUDA_CALLABLE inline array_t<T> view(array_t<T>& src, int i, int j, int k)
|
|
|
669
877
|
}
|
|
670
878
|
|
|
671
879
|
|
|
880
|
+
template <typename T, size_t... Idxs>
|
|
881
|
+
size_t byte_offset_helper(
|
|
882
|
+
array_t<T>& src,
|
|
883
|
+
const slice_t (&slices)[sizeof...(Idxs)],
|
|
884
|
+
index_sequence<Idxs...>
|
|
885
|
+
)
|
|
886
|
+
{
|
|
887
|
+
return byte_offset(src, slices[Idxs].start...);
|
|
888
|
+
}
|
|
889
|
+
|
|
890
|
+
|
|
891
|
+
template <typename T, typename... Slices>
|
|
892
|
+
CUDA_CALLABLE inline array_t<T> view(array_t<T>& src, const Slices&... slice_args)
|
|
893
|
+
{
|
|
894
|
+
constexpr int N = sizeof...(Slices);
|
|
895
|
+
static_assert(N >= 1 && N <= 4, "view supports 1 to 4 slices");
|
|
896
|
+
assert(src.ndim >= N);
|
|
897
|
+
|
|
898
|
+
slice_t slices[N] = { slice_args... };
|
|
899
|
+
int slice_idxs[N];
|
|
900
|
+
int slice_count = 0;
|
|
901
|
+
|
|
902
|
+
for (int i = 0; i < N; ++i)
|
|
903
|
+
{
|
|
904
|
+
if (slices[i].step == 0)
|
|
905
|
+
{
|
|
906
|
+
// We have a slice representing an integer index.
|
|
907
|
+
if (slices[i].start < 0)
|
|
908
|
+
{
|
|
909
|
+
slices[i].start += src.shape[i];
|
|
910
|
+
}
|
|
911
|
+
}
|
|
912
|
+
else
|
|
913
|
+
{
|
|
914
|
+
slices[i] = slice_adjust_indices(slices[i], src.shape[i]);
|
|
915
|
+
slice_idxs[slice_count] = i;
|
|
916
|
+
++slice_count;
|
|
917
|
+
}
|
|
918
|
+
}
|
|
919
|
+
|
|
920
|
+
size_t offset = byte_offset_helper(src, slices, make_index_sequence<N>{});
|
|
921
|
+
|
|
922
|
+
array_t<T> out;
|
|
923
|
+
|
|
924
|
+
out.data = data_at_byte_offset(src, offset);
|
|
925
|
+
if (src.grad)
|
|
926
|
+
{
|
|
927
|
+
out.grad = grad_at_byte_offset(src, offset);
|
|
928
|
+
}
|
|
929
|
+
|
|
930
|
+
int dim = 0;
|
|
931
|
+
for (; dim < slice_count; ++dim)
|
|
932
|
+
{
|
|
933
|
+
int idx = slice_idxs[dim];
|
|
934
|
+
out.shape[dim] = slice_get_length(slices[idx]);
|
|
935
|
+
out.strides[dim] = src.strides[idx] * slices[idx].step;
|
|
936
|
+
}
|
|
937
|
+
for (; dim < slice_count + 4 - N; ++dim)
|
|
938
|
+
{
|
|
939
|
+
out.shape[dim] = src.shape[dim - slice_count + N];
|
|
940
|
+
out.strides[dim] = src.strides[dim - slice_count + N];
|
|
941
|
+
}
|
|
942
|
+
for (; dim < 4; ++dim)
|
|
943
|
+
{
|
|
944
|
+
out.shape[dim] = 0;
|
|
945
|
+
out.strides[dim] = 0;
|
|
946
|
+
}
|
|
947
|
+
|
|
948
|
+
out.ndim = src.ndim + slice_count - N;
|
|
949
|
+
return out;
|
|
950
|
+
}
|
|
951
|
+
|
|
672
952
|
template <typename T>
|
|
673
953
|
CUDA_CALLABLE inline indexedarray_t<T> view(indexedarray_t<T>& src, int i)
|
|
674
954
|
{
|
|
@@ -676,7 +956,11 @@ CUDA_CALLABLE inline indexedarray_t<T> view(indexedarray_t<T>& src, int i)
|
|
|
676
956
|
|
|
677
957
|
if (src.indices[0])
|
|
678
958
|
{
|
|
679
|
-
assert(i >= 0 && i < src.shape[0]);
|
|
959
|
+
assert(i >= -src.shape[0] && i < src.shape[0]);
|
|
960
|
+
if (i < 0)
|
|
961
|
+
{
|
|
962
|
+
i += src.shape[0];
|
|
963
|
+
}
|
|
680
964
|
i = src.indices[0][i];
|
|
681
965
|
}
|
|
682
966
|
|
|
@@ -699,12 +983,20 @@ CUDA_CALLABLE inline indexedarray_t<T> view(indexedarray_t<T>& src, int i, int j
|
|
|
699
983
|
|
|
700
984
|
if (src.indices[0])
|
|
701
985
|
{
|
|
702
|
-
assert(i >= 0 && i < src.shape[0]);
|
|
986
|
+
assert(i >= -src.shape[0] && i < src.shape[0]);
|
|
987
|
+
if (i < 0)
|
|
988
|
+
{
|
|
989
|
+
i += src.shape[0];
|
|
990
|
+
}
|
|
703
991
|
i = src.indices[0][i];
|
|
704
992
|
}
|
|
705
993
|
if (src.indices[1])
|
|
706
994
|
{
|
|
707
|
-
assert(j >=
|
|
995
|
+
assert(j >= -src.shape[1] && j < src.shape[1]);
|
|
996
|
+
if (j < 0)
|
|
997
|
+
{
|
|
998
|
+
j += src.shape[1];
|
|
999
|
+
}
|
|
708
1000
|
j = src.indices[1][j];
|
|
709
1001
|
}
|
|
710
1002
|
|
|
@@ -725,17 +1017,29 @@ CUDA_CALLABLE inline indexedarray_t<T> view(indexedarray_t<T>& src, int i, int j
|
|
|
725
1017
|
|
|
726
1018
|
if (src.indices[0])
|
|
727
1019
|
{
|
|
728
|
-
assert(i >= 0 && i < src.shape[0]);
|
|
1020
|
+
assert(i >= -src.shape[0] && i < src.shape[0]);
|
|
1021
|
+
if (i < 0)
|
|
1022
|
+
{
|
|
1023
|
+
i += src.shape[0];
|
|
1024
|
+
}
|
|
729
1025
|
i = src.indices[0][i];
|
|
730
1026
|
}
|
|
731
1027
|
if (src.indices[1])
|
|
732
1028
|
{
|
|
733
|
-
assert(j >=
|
|
1029
|
+
assert(j >= -src.shape[1] && j < src.shape[1]);
|
|
1030
|
+
if (j < 0)
|
|
1031
|
+
{
|
|
1032
|
+
j += src.shape[1];
|
|
1033
|
+
}
|
|
734
1034
|
j = src.indices[1][j];
|
|
735
1035
|
}
|
|
736
1036
|
if (src.indices[2])
|
|
737
1037
|
{
|
|
738
|
-
assert(k >=
|
|
1038
|
+
assert(k >= -src.shape[2] && k < src.shape[2]);
|
|
1039
|
+
if (k < 0)
|
|
1040
|
+
{
|
|
1041
|
+
k += src.shape[2];
|
|
1042
|
+
}
|
|
739
1043
|
k = src.indices[2][k];
|
|
740
1044
|
}
|
|
741
1045
|
|
|
@@ -754,6 +1058,9 @@ inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, int j, A2<T>& adj_src, int
|
|
|
754
1058
|
template<template<typename> class A1, template<typename> class A2, template<typename> class A3, typename T>
|
|
755
1059
|
inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, int j, int k, A2<T>& adj_src, int adj_i, int adj_j, int adj_k, A3<T>& adj_ret) {}
|
|
756
1060
|
|
|
1061
|
+
template <typename... Args>
|
|
1062
|
+
CUDA_CALLABLE inline void adj_view(Args&&...) { }
|
|
1063
|
+
|
|
757
1064
|
// TODO: lower_bound() for indexed arrays?
|
|
758
1065
|
|
|
759
1066
|
template <typename T>
|
|
@@ -844,6 +1151,33 @@ inline CUDA_CALLABLE T atomic_exch(const A<T>& buf, int i, int j, int k, T value
|
|
|
844
1151
|
template<template<typename> class A, typename T>
|
|
845
1152
|
inline CUDA_CALLABLE T atomic_exch(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_exch(&index(buf, i, j, k, l), value); }
|
|
846
1153
|
|
|
1154
|
+
template<template<typename> class A, typename T>
|
|
1155
|
+
inline CUDA_CALLABLE T atomic_and(const A<T>& buf, int i, T value) { return atomic_and(&index(buf, i), value); }
|
|
1156
|
+
template<template<typename> class A, typename T>
|
|
1157
|
+
inline CUDA_CALLABLE T atomic_and(const A<T>& buf, int i, int j, T value) { return atomic_and(&index(buf, i, j), value); }
|
|
1158
|
+
template<template<typename> class A, typename T>
|
|
1159
|
+
inline CUDA_CALLABLE T atomic_and(const A<T>& buf, int i, int j, int k, T value) { return atomic_and(&index(buf, i, j, k), value); }
|
|
1160
|
+
template<template<typename> class A, typename T>
|
|
1161
|
+
inline CUDA_CALLABLE T atomic_and(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_and(&index(buf, i, j, k, l), value); }
|
|
1162
|
+
|
|
1163
|
+
template<template<typename> class A, typename T>
|
|
1164
|
+
inline CUDA_CALLABLE T atomic_or(const A<T>& buf, int i, T value) { return atomic_or(&index(buf, i), value); }
|
|
1165
|
+
template<template<typename> class A, typename T>
|
|
1166
|
+
inline CUDA_CALLABLE T atomic_or(const A<T>& buf, int i, int j, T value) { return atomic_or(&index(buf, i, j), value); }
|
|
1167
|
+
template<template<typename> class A, typename T>
|
|
1168
|
+
inline CUDA_CALLABLE T atomic_or(const A<T>& buf, int i, int j, int k, T value) { return atomic_or(&index(buf, i, j, k), value); }
|
|
1169
|
+
template<template<typename> class A, typename T>
|
|
1170
|
+
inline CUDA_CALLABLE T atomic_or(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_or(&index(buf, i, j, k, l), value); }
|
|
1171
|
+
|
|
1172
|
+
template<template<typename> class A, typename T>
|
|
1173
|
+
inline CUDA_CALLABLE T atomic_xor(const A<T>& buf, int i, T value) { return atomic_xor(&index(buf, i), value); }
|
|
1174
|
+
template<template<typename> class A, typename T>
|
|
1175
|
+
inline CUDA_CALLABLE T atomic_xor(const A<T>& buf, int i, int j, T value) { return atomic_xor(&index(buf, i, j), value); }
|
|
1176
|
+
template<template<typename> class A, typename T>
|
|
1177
|
+
inline CUDA_CALLABLE T atomic_xor(const A<T>& buf, int i, int j, int k, T value) { return atomic_xor(&index(buf, i, j, k), value); }
|
|
1178
|
+
template<template<typename> class A, typename T>
|
|
1179
|
+
inline CUDA_CALLABLE T atomic_xor(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_xor(&index(buf, i, j, k, l), value); }
|
|
1180
|
+
|
|
847
1181
|
template<template<typename> class A, typename T>
|
|
848
1182
|
inline CUDA_CALLABLE T* address(const A<T>& buf, int i)
|
|
849
1183
|
{
|
|
@@ -911,20 +1245,7 @@ inline CUDA_CALLABLE T load(T* address)
|
|
|
911
1245
|
return value;
|
|
912
1246
|
}
|
|
913
1247
|
|
|
914
|
-
//
|
|
915
|
-
template <typename T1, typename T2>
|
|
916
|
-
CUDA_CALLABLE inline T2 select(const array_t<T1>& arr, const T2& a, const T2& b) { return arr.data?b:a; }
|
|
917
|
-
|
|
918
|
-
template <typename T1, typename T2>
|
|
919
|
-
CUDA_CALLABLE inline void adj_select(const array_t<T1>& arr, const T2& a, const T2& b, const array_t<T1>& adj_cond, T2& adj_a, T2& adj_b, const T2& adj_ret)
|
|
920
|
-
{
|
|
921
|
-
if (arr.data)
|
|
922
|
-
adj_b += adj_ret;
|
|
923
|
-
else
|
|
924
|
-
adj_a += adj_ret;
|
|
925
|
-
}
|
|
926
|
-
|
|
927
|
-
// where operator to check for array being null, opposite convention compared to select
|
|
1248
|
+
// where() overload for array condition - returns a if array.data is non-null, otherwise returns b
|
|
928
1249
|
template <typename T1, typename T2>
|
|
929
1250
|
CUDA_CALLABLE inline T2 where(const array_t<T1>& arr, const T2& a, const T2& b) { return arr.data?a:b; }
|
|
930
1251
|
|
|
@@ -1321,6 +1642,34 @@ inline CUDA_CALLABLE void adj_atomic_exch(const A1<T>& buf, int i, int j, int k,
|
|
|
1321
1642
|
FP_VERIFY_ADJ_4(value, adj_value)
|
|
1322
1643
|
}
|
|
1323
1644
|
|
|
1645
|
+
// for bitwise operations we do not accumulate gradients
|
|
1646
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
1647
|
+
inline CUDA_CALLABLE void adj_atomic_and(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int adj_i, T& adj_value, const T& adj_ret) {}
|
|
1648
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
1649
|
+
inline CUDA_CALLABLE void adj_atomic_and(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int adj_i, int adj_j, T& adj_value, const T& adj_ret) {}
|
|
1650
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
1651
|
+
inline CUDA_CALLABLE void adj_atomic_and(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_value, const T& adj_ret) {}
|
|
1652
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
1653
|
+
inline CUDA_CALLABLE void adj_atomic_and(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_value, const T& adj_ret) {}
|
|
1654
|
+
|
|
1655
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
1656
|
+
inline CUDA_CALLABLE void adj_atomic_or(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int adj_i, T& adj_value, const T& adj_ret) {}
|
|
1657
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
1658
|
+
inline CUDA_CALLABLE void adj_atomic_or(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int adj_i, int adj_j, T& adj_value, const T& adj_ret) {}
|
|
1659
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
1660
|
+
inline CUDA_CALLABLE void adj_atomic_or(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_value, const T& adj_ret) {}
|
|
1661
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
1662
|
+
inline CUDA_CALLABLE void adj_atomic_or(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_value, const T& adj_ret) {}
|
|
1663
|
+
|
|
1664
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
1665
|
+
inline CUDA_CALLABLE void adj_atomic_xor(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int adj_i, T& adj_value, const T& adj_ret) {}
|
|
1666
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
1667
|
+
inline CUDA_CALLABLE void adj_atomic_xor(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int adj_i, int adj_j, T& adj_value, const T& adj_ret) {}
|
|
1668
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
1669
|
+
inline CUDA_CALLABLE void adj_atomic_xor(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, T& adj_value, const T& adj_ret) {}
|
|
1670
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
1671
|
+
inline CUDA_CALLABLE void adj_atomic_xor(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int adj_i, int adj_j, int adj_k, int adj_l, T& adj_value, const T& adj_ret) {}
|
|
1672
|
+
|
|
1324
1673
|
|
|
1325
1674
|
template<template<typename> class A, typename T>
|
|
1326
1675
|
CUDA_CALLABLE inline int len(const A<T>& a)
|
|
@@ -1333,7 +1682,6 @@ CUDA_CALLABLE inline void adj_len(const A<T>& a, A<T>& adj_a, int& adj_ret)
|
|
|
1333
1682
|
{
|
|
1334
1683
|
}
|
|
1335
1684
|
|
|
1336
|
-
|
|
1337
1685
|
} // namespace wp
|
|
1338
1686
|
|
|
1339
1687
|
#include "fabric.h"
|