warp-lang 1.0.2__py3-none-manylinux2014_aarch64.whl → 1.1.0__py3-none-manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +108 -97
- warp/__init__.pyi +1 -1
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +115 -113
- warp/build_dll.py +383 -375
- warp/builtins.py +3425 -3354
- warp/codegen.py +2878 -2792
- warp/config.py +40 -36
- warp/constants.py +45 -45
- warp/context.py +5194 -5102
- warp/dlpack.py +442 -442
- warp/examples/__init__.py +16 -16
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nv_ant.xml +92 -92
- warp/examples/assets/nv_humanoid.xml +183 -183
- warp/examples/assets/quadruped.urdf +267 -267
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +383 -383
- warp/examples/benchmarks/benchmark_cloth.py +278 -277
- warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -88
- warp/examples/benchmarks/benchmark_cloth_jax.py +97 -100
- warp/examples/benchmarks/benchmark_cloth_numba.py +146 -142
- warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -77
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -86
- warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -112
- warp/examples/benchmarks/benchmark_cloth_warp.py +146 -146
- warp/examples/benchmarks/benchmark_launches.py +295 -295
- warp/examples/browse.py +29 -29
- warp/examples/core/example_dem.py +234 -219
- warp/examples/core/example_fluid.py +293 -267
- warp/examples/core/example_graph_capture.py +144 -126
- warp/examples/core/example_marching_cubes.py +188 -174
- warp/examples/core/example_mesh.py +174 -155
- warp/examples/core/example_mesh_intersect.py +205 -193
- warp/examples/core/example_nvdb.py +176 -170
- warp/examples/core/example_raycast.py +105 -90
- warp/examples/core/example_raymarch.py +199 -178
- warp/examples/core/example_render_opengl.py +185 -141
- warp/examples/core/example_sph.py +405 -387
- warp/examples/core/example_torch.py +222 -181
- warp/examples/core/example_wave.py +263 -248
- warp/examples/fem/bsr_utils.py +378 -380
- warp/examples/fem/example_apic_fluid.py +407 -389
- warp/examples/fem/example_convection_diffusion.py +182 -168
- warp/examples/fem/example_convection_diffusion_dg.py +219 -209
- warp/examples/fem/example_convection_diffusion_dg0.py +204 -194
- warp/examples/fem/example_deformed_geometry.py +177 -159
- warp/examples/fem/example_diffusion.py +201 -173
- warp/examples/fem/example_diffusion_3d.py +177 -152
- warp/examples/fem/example_diffusion_mgpu.py +221 -214
- warp/examples/fem/example_mixed_elasticity.py +244 -222
- warp/examples/fem/example_navier_stokes.py +259 -243
- warp/examples/fem/example_stokes.py +220 -192
- warp/examples/fem/example_stokes_transfer.py +265 -249
- warp/examples/fem/mesh_utils.py +133 -109
- warp/examples/fem/plot_utils.py +292 -287
- warp/examples/optim/example_bounce.py +260 -246
- warp/examples/optim/example_cloth_throw.py +222 -209
- warp/examples/optim/example_diffray.py +566 -536
- warp/examples/optim/example_drone.py +864 -835
- warp/examples/optim/example_inverse_kinematics.py +176 -168
- warp/examples/optim/example_inverse_kinematics_torch.py +185 -169
- warp/examples/optim/example_spring_cage.py +239 -231
- warp/examples/optim/example_trajectory.py +223 -199
- warp/examples/optim/example_walker.py +306 -293
- warp/examples/sim/example_cartpole.py +139 -129
- warp/examples/sim/example_cloth.py +196 -186
- warp/examples/sim/example_granular.py +124 -111
- warp/examples/sim/example_granular_collision_sdf.py +197 -186
- warp/examples/sim/example_jacobian_ik.py +236 -214
- warp/examples/sim/example_particle_chain.py +118 -105
- warp/examples/sim/example_quadruped.py +193 -180
- warp/examples/sim/example_rigid_chain.py +197 -187
- warp/examples/sim/example_rigid_contact.py +189 -177
- warp/examples/sim/example_rigid_force.py +127 -125
- warp/examples/sim/example_rigid_gyroscopic.py +109 -95
- warp/examples/sim/example_rigid_soft_contact.py +134 -122
- warp/examples/sim/example_soft_body.py +190 -177
- warp/fabric.py +337 -335
- warp/fem/__init__.py +60 -27
- warp/fem/cache.py +401 -388
- warp/fem/dirichlet.py +178 -179
- warp/fem/domain.py +262 -263
- warp/fem/field/__init__.py +100 -101
- warp/fem/field/field.py +148 -149
- warp/fem/field/nodal_field.py +298 -299
- warp/fem/field/restriction.py +22 -21
- warp/fem/field/test.py +180 -181
- warp/fem/field/trial.py +183 -183
- warp/fem/geometry/__init__.py +15 -19
- warp/fem/geometry/closest_point.py +69 -70
- warp/fem/geometry/deformed_geometry.py +270 -271
- warp/fem/geometry/element.py +744 -744
- warp/fem/geometry/geometry.py +184 -186
- warp/fem/geometry/grid_2d.py +380 -373
- warp/fem/geometry/grid_3d.py +441 -435
- warp/fem/geometry/hexmesh.py +953 -953
- warp/fem/geometry/partition.py +374 -376
- warp/fem/geometry/quadmesh_2d.py +532 -532
- warp/fem/geometry/tetmesh.py +840 -840
- warp/fem/geometry/trimesh_2d.py +577 -577
- warp/fem/integrate.py +1630 -1615
- warp/fem/operator.py +190 -191
- warp/fem/polynomial.py +214 -213
- warp/fem/quadrature/__init__.py +2 -2
- warp/fem/quadrature/pic_quadrature.py +243 -245
- warp/fem/quadrature/quadrature.py +295 -294
- warp/fem/space/__init__.py +294 -292
- warp/fem/space/basis_space.py +488 -489
- warp/fem/space/collocated_function_space.py +100 -105
- warp/fem/space/dof_mapper.py +236 -236
- warp/fem/space/function_space.py +148 -145
- warp/fem/space/grid_2d_function_space.py +267 -267
- warp/fem/space/grid_3d_function_space.py +305 -306
- warp/fem/space/hexmesh_function_space.py +350 -352
- warp/fem/space/partition.py +350 -350
- warp/fem/space/quadmesh_2d_function_space.py +368 -369
- warp/fem/space/restriction.py +158 -160
- warp/fem/space/shape/__init__.py +13 -15
- warp/fem/space/shape/cube_shape_function.py +738 -738
- warp/fem/space/shape/shape_function.py +102 -103
- warp/fem/space/shape/square_shape_function.py +611 -611
- warp/fem/space/shape/tet_shape_function.py +565 -567
- warp/fem/space/shape/triangle_shape_function.py +429 -429
- warp/fem/space/tetmesh_function_space.py +294 -292
- warp/fem/space/topology.py +297 -295
- warp/fem/space/trimesh_2d_function_space.py +223 -221
- warp/fem/types.py +77 -77
- warp/fem/utils.py +495 -495
- warp/jax.py +166 -141
- warp/jax_experimental.py +341 -339
- warp/native/array.h +1072 -1025
- warp/native/builtin.h +1560 -1560
- warp/native/bvh.cpp +398 -398
- warp/native/bvh.cu +525 -525
- warp/native/bvh.h +429 -429
- warp/native/clang/clang.cpp +495 -464
- warp/native/crt.cpp +31 -31
- warp/native/crt.h +334 -334
- warp/native/cuda_crt.h +1049 -1049
- warp/native/cuda_util.cpp +549 -540
- warp/native/cuda_util.h +288 -203
- warp/native/cutlass_gemm.cpp +34 -34
- warp/native/cutlass_gemm.cu +372 -372
- warp/native/error.cpp +66 -66
- warp/native/error.h +27 -27
- warp/native/fabric.h +228 -228
- warp/native/hashgrid.cpp +301 -278
- warp/native/hashgrid.cu +78 -77
- warp/native/hashgrid.h +227 -227
- warp/native/initializer_array.h +32 -32
- warp/native/intersect.h +1204 -1204
- warp/native/intersect_adj.h +365 -365
- warp/native/intersect_tri.h +322 -322
- warp/native/marching.cpp +2 -2
- warp/native/marching.cu +497 -497
- warp/native/marching.h +2 -2
- warp/native/mat.h +1498 -1498
- warp/native/matnn.h +333 -333
- warp/native/mesh.cpp +203 -203
- warp/native/mesh.cu +293 -293
- warp/native/mesh.h +1887 -1887
- warp/native/nanovdb/NanoVDB.h +4782 -4782
- warp/native/nanovdb/PNanoVDB.h +2553 -2553
- warp/native/nanovdb/PNanoVDBWrite.h +294 -294
- warp/native/noise.h +850 -850
- warp/native/quat.h +1084 -1084
- warp/native/rand.h +299 -299
- warp/native/range.h +108 -108
- warp/native/reduce.cpp +156 -156
- warp/native/reduce.cu +348 -348
- warp/native/runlength_encode.cpp +61 -61
- warp/native/runlength_encode.cu +46 -46
- warp/native/scan.cpp +30 -30
- warp/native/scan.cu +36 -36
- warp/native/scan.h +7 -7
- warp/native/solid_angle.h +442 -442
- warp/native/sort.cpp +94 -94
- warp/native/sort.cu +97 -97
- warp/native/sort.h +14 -14
- warp/native/sparse.cpp +337 -337
- warp/native/sparse.cu +544 -544
- warp/native/spatial.h +630 -630
- warp/native/svd.h +562 -562
- warp/native/temp_buffer.h +30 -30
- warp/native/vec.h +1132 -1132
- warp/native/volume.cpp +297 -297
- warp/native/volume.cu +32 -32
- warp/native/volume.h +538 -538
- warp/native/volume_builder.cu +425 -425
- warp/native/volume_builder.h +19 -19
- warp/native/warp.cpp +1057 -1052
- warp/native/warp.cu +2943 -2828
- warp/native/warp.h +313 -305
- warp/optim/__init__.py +9 -9
- warp/optim/adam.py +120 -120
- warp/optim/linear.py +1104 -939
- warp/optim/sgd.py +104 -92
- warp/render/__init__.py +10 -10
- warp/render/render_opengl.py +3217 -3204
- warp/render/render_usd.py +768 -749
- warp/render/utils.py +152 -150
- warp/sim/__init__.py +52 -59
- warp/sim/articulation.py +685 -685
- warp/sim/collide.py +1594 -1590
- warp/sim/import_mjcf.py +489 -481
- warp/sim/import_snu.py +220 -221
- warp/sim/import_urdf.py +536 -516
- warp/sim/import_usd.py +887 -881
- warp/sim/inertia.py +316 -317
- warp/sim/integrator.py +234 -233
- warp/sim/integrator_euler.py +1956 -1956
- warp/sim/integrator_featherstone.py +1910 -1991
- warp/sim/integrator_xpbd.py +3294 -3312
- warp/sim/model.py +4473 -4314
- warp/sim/particles.py +113 -112
- warp/sim/render.py +417 -403
- warp/sim/utils.py +413 -410
- warp/sparse.py +1227 -1227
- warp/stubs.py +2109 -2469
- warp/tape.py +1162 -225
- warp/tests/__init__.py +1 -1
- warp/tests/__main__.py +4 -4
- warp/tests/assets/torus.usda +105 -105
- warp/tests/aux_test_class_kernel.py +26 -26
- warp/tests/aux_test_compile_consts_dummy.py +10 -10
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -21
- warp/tests/aux_test_dependent.py +22 -22
- warp/tests/aux_test_grad_customs.py +23 -23
- warp/tests/aux_test_reference.py +11 -11
- warp/tests/aux_test_reference_reference.py +10 -10
- warp/tests/aux_test_square.py +17 -17
- warp/tests/aux_test_unresolved_func.py +14 -14
- warp/tests/aux_test_unresolved_symbol.py +14 -14
- warp/tests/disabled_kinematics.py +239 -239
- warp/tests/run_coverage_serial.py +31 -31
- warp/tests/test_adam.py +157 -157
- warp/tests/test_arithmetic.py +1124 -1124
- warp/tests/test_array.py +2417 -2326
- warp/tests/test_array_reduce.py +150 -150
- warp/tests/test_async.py +668 -656
- warp/tests/test_atomic.py +141 -141
- warp/tests/test_bool.py +204 -149
- warp/tests/test_builtins_resolution.py +1292 -1292
- warp/tests/test_bvh.py +164 -171
- warp/tests/test_closest_point_edge_edge.py +228 -228
- warp/tests/test_codegen.py +566 -553
- warp/tests/test_compile_consts.py +97 -101
- warp/tests/test_conditional.py +246 -246
- warp/tests/test_copy.py +232 -215
- warp/tests/test_ctypes.py +632 -632
- warp/tests/test_dense.py +67 -67
- warp/tests/test_devices.py +91 -98
- warp/tests/test_dlpack.py +530 -529
- warp/tests/test_examples.py +400 -378
- warp/tests/test_fabricarray.py +955 -955
- warp/tests/test_fast_math.py +62 -54
- warp/tests/test_fem.py +1277 -1278
- warp/tests/test_fp16.py +130 -130
- warp/tests/test_func.py +338 -337
- warp/tests/test_generics.py +571 -571
- warp/tests/test_grad.py +746 -640
- warp/tests/test_grad_customs.py +333 -336
- warp/tests/test_hash_grid.py +210 -164
- warp/tests/test_import.py +39 -39
- warp/tests/test_indexedarray.py +1134 -1134
- warp/tests/test_intersect.py +67 -67
- warp/tests/test_jax.py +307 -307
- warp/tests/test_large.py +167 -164
- warp/tests/test_launch.py +354 -354
- warp/tests/test_lerp.py +261 -261
- warp/tests/test_linear_solvers.py +191 -171
- warp/tests/test_lvalue.py +421 -493
- warp/tests/test_marching_cubes.py +65 -65
- warp/tests/test_mat.py +1801 -1827
- warp/tests/test_mat_lite.py +115 -115
- warp/tests/test_mat_scalar_ops.py +2907 -2889
- warp/tests/test_math.py +126 -193
- warp/tests/test_matmul.py +500 -499
- warp/tests/test_matmul_lite.py +410 -410
- warp/tests/test_mempool.py +188 -190
- warp/tests/test_mesh.py +284 -324
- warp/tests/test_mesh_query_aabb.py +228 -241
- warp/tests/test_mesh_query_point.py +692 -702
- warp/tests/test_mesh_query_ray.py +292 -303
- warp/tests/test_mlp.py +276 -276
- warp/tests/test_model.py +110 -110
- warp/tests/test_modules_lite.py +39 -39
- warp/tests/test_multigpu.py +163 -163
- warp/tests/test_noise.py +248 -248
- warp/tests/test_operators.py +250 -250
- warp/tests/test_options.py +123 -125
- warp/tests/test_peer.py +133 -137
- warp/tests/test_pinned.py +78 -78
- warp/tests/test_print.py +54 -54
- warp/tests/test_quat.py +2086 -2086
- warp/tests/test_rand.py +288 -288
- warp/tests/test_reload.py +217 -217
- warp/tests/test_rounding.py +179 -179
- warp/tests/test_runlength_encode.py +190 -190
- warp/tests/test_sim_grad.py +243 -0
- warp/tests/test_sim_kinematics.py +91 -97
- warp/tests/test_smoothstep.py +168 -168
- warp/tests/test_snippet.py +305 -266
- warp/tests/test_sparse.py +468 -460
- warp/tests/test_spatial.py +2148 -2148
- warp/tests/test_streams.py +486 -473
- warp/tests/test_struct.py +710 -675
- warp/tests/test_tape.py +173 -148
- warp/tests/test_torch.py +743 -743
- warp/tests/test_transient_module.py +87 -87
- warp/tests/test_types.py +556 -659
- warp/tests/test_utils.py +490 -499
- warp/tests/test_vec.py +1264 -1268
- warp/tests/test_vec_lite.py +73 -73
- warp/tests/test_vec_scalar_ops.py +2099 -2099
- warp/tests/test_verify_fp.py +94 -94
- warp/tests/test_volume.py +737 -736
- warp/tests/test_volume_write.py +255 -265
- warp/tests/unittest_serial.py +37 -37
- warp/tests/unittest_suites.py +363 -359
- warp/tests/unittest_utils.py +603 -578
- warp/tests/unused_test_misc.py +71 -71
- warp/tests/walkthrough_debug.py +85 -85
- warp/thirdparty/appdirs.py +598 -598
- warp/thirdparty/dlpack.py +143 -143
- warp/thirdparty/unittest_parallel.py +566 -561
- warp/torch.py +321 -295
- warp/types.py +4504 -4450
- warp/utils.py +1008 -821
- {warp_lang-1.0.2.dist-info → warp_lang-1.1.0.dist-info}/LICENSE.md +126 -126
- {warp_lang-1.0.2.dist-info → warp_lang-1.1.0.dist-info}/METADATA +338 -400
- warp_lang-1.1.0.dist-info/RECORD +352 -0
- warp/examples/assets/cube.usda +0 -42
- warp/examples/assets/sphere.usda +0 -56
- warp/examples/assets/torus.usda +0 -105
- warp_lang-1.0.2.dist-info/RECORD +0 -352
- {warp_lang-1.0.2.dist-info → warp_lang-1.1.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.0.2.dist-info → warp_lang-1.1.0.dist-info}/top_level.txt +0 -0
warp/native/array.h
CHANGED
|
@@ -1,1025 +1,1072 @@
|
|
|
1
|
-
#pragma once
|
|
2
|
-
|
|
3
|
-
#include "builtin.h"
|
|
4
|
-
|
|
5
|
-
namespace wp
|
|
6
|
-
{
|
|
7
|
-
|
|
8
|
-
#if FP_CHECK
|
|
9
|
-
|
|
10
|
-
#define FP_ASSERT_FWD(value) \
|
|
11
|
-
print(value); \
|
|
12
|
-
printf(")\n"); \
|
|
13
|
-
assert(0); \
|
|
14
|
-
|
|
15
|
-
#define FP_ASSERT_ADJ(value, adj_value) \
|
|
16
|
-
print(value); \
|
|
17
|
-
printf(", "); \
|
|
18
|
-
print(adj_value); \
|
|
19
|
-
printf(")\n"); \
|
|
20
|
-
assert(0); \
|
|
21
|
-
|
|
22
|
-
#define FP_VERIFY_FWD(value) \
|
|
23
|
-
if (!isfinite(value)) { \
|
|
24
|
-
printf("%s:%d - %s(addr", __FILE__, __LINE__, __FUNCTION__); \
|
|
25
|
-
FP_ASSERT_FWD(value) \
|
|
26
|
-
} \
|
|
27
|
-
|
|
28
|
-
#define FP_VERIFY_FWD_1(value) \
|
|
29
|
-
if (!isfinite(value)) { \
|
|
30
|
-
printf("%s:%d - %s(arr, %d) ", __FILE__, __LINE__, __FUNCTION__, i); \
|
|
31
|
-
FP_ASSERT_FWD(value) \
|
|
32
|
-
} \
|
|
33
|
-
|
|
34
|
-
#define FP_VERIFY_FWD_2(value) \
|
|
35
|
-
if (!isfinite(value)) { \
|
|
36
|
-
printf("%s:%d - %s(arr, %d, %d) ", __FILE__, __LINE__, __FUNCTION__, i, j); \
|
|
37
|
-
FP_ASSERT_FWD(value) \
|
|
38
|
-
} \
|
|
39
|
-
|
|
40
|
-
#define FP_VERIFY_FWD_3(value) \
|
|
41
|
-
if (!isfinite(value)) { \
|
|
42
|
-
printf("%s:%d - %s(arr, %d, %d, %d) ", __FILE__, __LINE__, __FUNCTION__, i, j, k); \
|
|
43
|
-
FP_ASSERT_FWD(value) \
|
|
44
|
-
} \
|
|
45
|
-
|
|
46
|
-
#define FP_VERIFY_FWD_4(value) \
|
|
47
|
-
if (!isfinite(value)) { \
|
|
48
|
-
printf("%s:%d - %s(arr, %d, %d, %d, %d) ", __FILE__, __LINE__, __FUNCTION__, i, j, k, l); \
|
|
49
|
-
FP_ASSERT_FWD(value) \
|
|
50
|
-
} \
|
|
51
|
-
|
|
52
|
-
#define FP_VERIFY_ADJ(value, adj_value) \
|
|
53
|
-
if (!isfinite(value) || !isfinite(adj_value)) \
|
|
54
|
-
{ \
|
|
55
|
-
printf("%s:%d - %s(addr", __FILE__, __LINE__, __FUNCTION__); \
|
|
56
|
-
FP_ASSERT_ADJ(value, adj_value); \
|
|
57
|
-
} \
|
|
58
|
-
|
|
59
|
-
#define FP_VERIFY_ADJ_1(value, adj_value) \
|
|
60
|
-
if (!isfinite(value) || !isfinite(adj_value)) \
|
|
61
|
-
{ \
|
|
62
|
-
printf("%s:%d - %s(arr, %d) ", __FILE__, __LINE__, __FUNCTION__, i); \
|
|
63
|
-
FP_ASSERT_ADJ(value, adj_value); \
|
|
64
|
-
} \
|
|
65
|
-
|
|
66
|
-
#define FP_VERIFY_ADJ_2(value, adj_value) \
|
|
67
|
-
if (!isfinite(value) || !isfinite(adj_value)) \
|
|
68
|
-
{ \
|
|
69
|
-
printf("%s:%d - %s(arr, %d, %d) ", __FILE__, __LINE__, __FUNCTION__, i, j); \
|
|
70
|
-
FP_ASSERT_ADJ(value, adj_value); \
|
|
71
|
-
} \
|
|
72
|
-
|
|
73
|
-
#define FP_VERIFY_ADJ_3(value, adj_value) \
|
|
74
|
-
if (!isfinite(value) || !isfinite(adj_value)) \
|
|
75
|
-
{ \
|
|
76
|
-
printf("%s:%d - %s(arr, %d, %d, %d) ", __FILE__, __LINE__, __FUNCTION__, i, j, k); \
|
|
77
|
-
FP_ASSERT_ADJ(value, adj_value); \
|
|
78
|
-
} \
|
|
79
|
-
|
|
80
|
-
#define FP_VERIFY_ADJ_4(value, adj_value) \
|
|
81
|
-
if (!isfinite(value) || !isfinite(adj_value)) \
|
|
82
|
-
{ \
|
|
83
|
-
printf("%s:%d - %s(arr, %d, %d, %d, %d) ", __FILE__, __LINE__, __FUNCTION__, i, j, k, l); \
|
|
84
|
-
FP_ASSERT_ADJ(value, adj_value); \
|
|
85
|
-
} \
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
#else
|
|
89
|
-
|
|
90
|
-
#define FP_VERIFY_FWD(value) {}
|
|
91
|
-
#define FP_VERIFY_FWD_1(value) {}
|
|
92
|
-
#define FP_VERIFY_FWD_2(value) {}
|
|
93
|
-
#define FP_VERIFY_FWD_3(value) {}
|
|
94
|
-
#define FP_VERIFY_FWD_4(value) {}
|
|
95
|
-
|
|
96
|
-
#define FP_VERIFY_ADJ(value, adj_value) {}
|
|
97
|
-
#define FP_VERIFY_ADJ_1(value, adj_value) {}
|
|
98
|
-
#define FP_VERIFY_ADJ_2(value, adj_value) {}
|
|
99
|
-
#define FP_VERIFY_ADJ_3(value, adj_value) {}
|
|
100
|
-
#define FP_VERIFY_ADJ_4(value, adj_value) {}
|
|
101
|
-
|
|
102
|
-
#endif // WP_FP_CHECK
|
|
103
|
-
|
|
104
|
-
const int ARRAY_MAX_DIMS = 4; // must match constant in types.py
|
|
105
|
-
|
|
106
|
-
// must match constants in types.py
|
|
107
|
-
const int ARRAY_TYPE_REGULAR = 0;
|
|
108
|
-
const int ARRAY_TYPE_INDEXED = 1;
|
|
109
|
-
const int ARRAY_TYPE_FABRIC = 2;
|
|
110
|
-
const int ARRAY_TYPE_FABRIC_INDEXED = 3;
|
|
111
|
-
|
|
112
|
-
struct shape_t
|
|
113
|
-
{
|
|
114
|
-
int dims[ARRAY_MAX_DIMS];
|
|
115
|
-
|
|
116
|
-
CUDA_CALLABLE inline shape_t()
|
|
117
|
-
: dims()
|
|
118
|
-
{}
|
|
119
|
-
|
|
120
|
-
CUDA_CALLABLE inline int operator[](int i) const
|
|
121
|
-
{
|
|
122
|
-
assert(i < ARRAY_MAX_DIMS);
|
|
123
|
-
return dims[i];
|
|
124
|
-
}
|
|
125
|
-
|
|
126
|
-
CUDA_CALLABLE inline int& operator[](int i)
|
|
127
|
-
{
|
|
128
|
-
assert(i < ARRAY_MAX_DIMS);
|
|
129
|
-
return dims[i];
|
|
130
|
-
}
|
|
131
|
-
};
|
|
132
|
-
|
|
133
|
-
CUDA_CALLABLE inline int extract(const shape_t& s, int i)
|
|
134
|
-
{
|
|
135
|
-
return s.dims[i];
|
|
136
|
-
}
|
|
137
|
-
|
|
138
|
-
CUDA_CALLABLE inline void adj_extract(const shape_t& s, int i, const shape_t& adj_s, int adj_i, int adj_ret) {}
|
|
139
|
-
|
|
140
|
-
inline CUDA_CALLABLE void print(shape_t s)
|
|
141
|
-
{
|
|
142
|
-
// todo: only print valid dims, currently shape has a fixed size
|
|
143
|
-
// but we don't know how many dims are valid (e.g.: 1d, 2d, etc)
|
|
144
|
-
// should probably store ndim with shape
|
|
145
|
-
printf("(%d, %d, %d, %d)\n", s.dims[0], s.dims[1], s.dims[2], s.dims[3]);
|
|
146
|
-
}
|
|
147
|
-
inline CUDA_CALLABLE void adj_print(shape_t s, shape_t& shape_t) {}
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
template <typename T>
|
|
151
|
-
struct array_t
|
|
152
|
-
{
|
|
153
|
-
CUDA_CALLABLE inline array_t()
|
|
154
|
-
: data(nullptr),
|
|
155
|
-
grad(nullptr),
|
|
156
|
-
shape(),
|
|
157
|
-
strides(),
|
|
158
|
-
ndim(0)
|
|
159
|
-
{}
|
|
160
|
-
|
|
161
|
-
CUDA_CALLABLE array_t(T* data, int size, T* grad=nullptr) : data(data), grad(grad) {
|
|
162
|
-
// constructor for 1d array
|
|
163
|
-
shape.dims[0] = size;
|
|
164
|
-
shape.dims[1] = 0;
|
|
165
|
-
shape.dims[2] = 0;
|
|
166
|
-
shape.dims[3] = 0;
|
|
167
|
-
ndim = 1;
|
|
168
|
-
strides[0] = sizeof(T);
|
|
169
|
-
strides[1] = 0;
|
|
170
|
-
strides[2] = 0;
|
|
171
|
-
strides[3] = 0;
|
|
172
|
-
}
|
|
173
|
-
CUDA_CALLABLE array_t(T* data, int dim0, int dim1, T* grad=nullptr) : data(data), grad(grad) {
|
|
174
|
-
// constructor for 2d array
|
|
175
|
-
shape.dims[0] = dim0;
|
|
176
|
-
shape.dims[1] = dim1;
|
|
177
|
-
shape.dims[2] = 0;
|
|
178
|
-
shape.dims[3] = 0;
|
|
179
|
-
ndim = 2;
|
|
180
|
-
strides[0] = dim1 * sizeof(T);
|
|
181
|
-
strides[1] = sizeof(T);
|
|
182
|
-
strides[2] = 0;
|
|
183
|
-
strides[3] = 0;
|
|
184
|
-
}
|
|
185
|
-
CUDA_CALLABLE array_t(T* data, int dim0, int dim1, int dim2, T* grad=nullptr) : data(data), grad(grad) {
|
|
186
|
-
// constructor for 3d array
|
|
187
|
-
shape.dims[0] = dim0;
|
|
188
|
-
shape.dims[1] = dim1;
|
|
189
|
-
shape.dims[2] = dim2;
|
|
190
|
-
shape.dims[3] = 0;
|
|
191
|
-
ndim = 3;
|
|
192
|
-
strides[0] = dim1 * dim2 * sizeof(T);
|
|
193
|
-
strides[1] = dim2 * sizeof(T);
|
|
194
|
-
strides[2] = sizeof(T);
|
|
195
|
-
strides[3] = 0;
|
|
196
|
-
}
|
|
197
|
-
CUDA_CALLABLE array_t(T* data, int dim0, int dim1, int dim2, int dim3, T* grad=nullptr) : data(data), grad(grad) {
|
|
198
|
-
// constructor for 4d array
|
|
199
|
-
shape.dims[0] = dim0;
|
|
200
|
-
shape.dims[1] = dim1;
|
|
201
|
-
shape.dims[2] = dim2;
|
|
202
|
-
shape.dims[3] = dim3;
|
|
203
|
-
ndim = 4;
|
|
204
|
-
strides[0] = dim1 * dim2 * dim3 * sizeof(T);
|
|
205
|
-
strides[1] = dim2 * dim3 * sizeof(T);
|
|
206
|
-
strides[2] = dim3 * sizeof(T);
|
|
207
|
-
strides[3] = sizeof(T);
|
|
208
|
-
}
|
|
209
|
-
|
|
210
|
-
CUDA_CALLABLE inline bool empty() const { return !data; }
|
|
211
|
-
|
|
212
|
-
T* data;
|
|
213
|
-
T* grad;
|
|
214
|
-
shape_t shape;
|
|
215
|
-
int strides[ARRAY_MAX_DIMS];
|
|
216
|
-
int ndim;
|
|
217
|
-
|
|
218
|
-
CUDA_CALLABLE inline operator T*() const { return data; }
|
|
219
|
-
};
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
// TODO:
|
|
223
|
-
// - templated index type?
|
|
224
|
-
// - templated dimensionality? (also for array_t to save space when passing arrays to kernels)
|
|
225
|
-
template <typename T>
|
|
226
|
-
struct indexedarray_t
|
|
227
|
-
{
|
|
228
|
-
CUDA_CALLABLE inline indexedarray_t()
|
|
229
|
-
: arr(),
|
|
230
|
-
indices(),
|
|
231
|
-
shape()
|
|
232
|
-
{}
|
|
233
|
-
|
|
234
|
-
CUDA_CALLABLE inline bool empty() const { return !arr.data; }
|
|
235
|
-
|
|
236
|
-
array_t<T> arr;
|
|
237
|
-
int* indices[ARRAY_MAX_DIMS]; // index array per dimension (can be NULL)
|
|
238
|
-
shape_t shape; // element count per dimension (num. indices if indexed, array dim if not)
|
|
239
|
-
};
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
// return stride (in bytes) of the given index
|
|
243
|
-
template <typename T>
|
|
244
|
-
CUDA_CALLABLE inline size_t stride(const array_t<T>& a, int dim)
|
|
245
|
-
{
|
|
246
|
-
return size_t(a.strides[dim]);
|
|
247
|
-
}
|
|
248
|
-
|
|
249
|
-
template <typename T>
|
|
250
|
-
CUDA_CALLABLE inline T* data_at_byte_offset(const array_t<T>& a, size_t byte_offset)
|
|
251
|
-
{
|
|
252
|
-
return reinterpret_cast<T*>(reinterpret_cast<char*>(a.data) + byte_offset);
|
|
253
|
-
}
|
|
254
|
-
|
|
255
|
-
template <typename T>
|
|
256
|
-
CUDA_CALLABLE inline T* grad_at_byte_offset(const array_t<T>& a, size_t byte_offset)
|
|
257
|
-
{
|
|
258
|
-
return reinterpret_cast<T*>(reinterpret_cast<char*>(a.grad) + byte_offset);
|
|
259
|
-
}
|
|
260
|
-
|
|
261
|
-
template <typename T>
|
|
262
|
-
CUDA_CALLABLE inline size_t byte_offset(const array_t<T>& arr, int i)
|
|
263
|
-
{
|
|
264
|
-
assert(i >= 0 && i < arr.shape[0]);
|
|
265
|
-
|
|
266
|
-
return i*stride(arr, 0);
|
|
267
|
-
}
|
|
268
|
-
|
|
269
|
-
template <typename T>
|
|
270
|
-
CUDA_CALLABLE inline size_t byte_offset(const array_t<T>& arr, int i, int j)
|
|
271
|
-
{
|
|
272
|
-
assert(i >= 0 && i < arr.shape[0]);
|
|
273
|
-
assert(j >= 0 && j < arr.shape[1]);
|
|
274
|
-
|
|
275
|
-
return i*stride(arr, 0) + j*stride(arr, 1);
|
|
276
|
-
}
|
|
277
|
-
|
|
278
|
-
template <typename T>
|
|
279
|
-
CUDA_CALLABLE inline size_t byte_offset(const array_t<T>& arr, int i, int j, int k)
|
|
280
|
-
{
|
|
281
|
-
assert(i >= 0 && i < arr.shape[0]);
|
|
282
|
-
assert(j >= 0 && j < arr.shape[1]);
|
|
283
|
-
assert(k >= 0 && k < arr.shape[2]);
|
|
284
|
-
|
|
285
|
-
return i*stride(arr, 0) + j*stride(arr, 1) + k*stride(arr, 2);
|
|
286
|
-
}
|
|
287
|
-
|
|
288
|
-
template <typename T>
|
|
289
|
-
CUDA_CALLABLE inline size_t byte_offset(const array_t<T>& arr, int i, int j, int k, int l)
|
|
290
|
-
{
|
|
291
|
-
assert(i >= 0 && i < arr.shape[0]);
|
|
292
|
-
assert(j >= 0 && j < arr.shape[1]);
|
|
293
|
-
assert(k >= 0 && k < arr.shape[2]);
|
|
294
|
-
assert(l >= 0 && l < arr.shape[3]);
|
|
295
|
-
|
|
296
|
-
return i*stride(arr, 0) + j*stride(arr, 1) + k*stride(arr, 2) + l*stride(arr, 3);
|
|
297
|
-
}
|
|
298
|
-
|
|
299
|
-
template <typename T>
|
|
300
|
-
CUDA_CALLABLE inline T& index(const array_t<T>& arr, int i)
|
|
301
|
-
{
|
|
302
|
-
assert(arr.ndim == 1);
|
|
303
|
-
T& result = *data_at_byte_offset(arr, byte_offset(arr, i));
|
|
304
|
-
FP_VERIFY_FWD_1(result)
|
|
305
|
-
|
|
306
|
-
return result;
|
|
307
|
-
}
|
|
308
|
-
|
|
309
|
-
template <typename T>
|
|
310
|
-
CUDA_CALLABLE inline T& index(const array_t<T>& arr, int i, int j)
|
|
311
|
-
{
|
|
312
|
-
assert(arr.ndim == 2);
|
|
313
|
-
T& result = *data_at_byte_offset(arr, byte_offset(arr, i, j));
|
|
314
|
-
FP_VERIFY_FWD_2(result)
|
|
315
|
-
|
|
316
|
-
return result;
|
|
317
|
-
}
|
|
318
|
-
|
|
319
|
-
template <typename T>
|
|
320
|
-
CUDA_CALLABLE inline T& index(const array_t<T>& arr, int i, int j, int k)
|
|
321
|
-
{
|
|
322
|
-
assert(arr.ndim == 3);
|
|
323
|
-
T& result = *data_at_byte_offset(arr, byte_offset(arr, i, j, k));
|
|
324
|
-
FP_VERIFY_FWD_3(result)
|
|
325
|
-
|
|
326
|
-
return result;
|
|
327
|
-
}
|
|
328
|
-
|
|
329
|
-
template <typename T>
|
|
330
|
-
CUDA_CALLABLE inline T& index(const array_t<T>& arr, int i, int j, int k, int l)
|
|
331
|
-
{
|
|
332
|
-
assert(arr.ndim == 4);
|
|
333
|
-
T& result = *data_at_byte_offset(arr, byte_offset(arr, i, j, k, l));
|
|
334
|
-
FP_VERIFY_FWD_4(result)
|
|
335
|
-
|
|
336
|
-
return result;
|
|
337
|
-
}
|
|
338
|
-
|
|
339
|
-
template <typename T>
|
|
340
|
-
CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i)
|
|
341
|
-
{
|
|
342
|
-
T& result = *grad_at_byte_offset(arr, byte_offset(arr, i));
|
|
343
|
-
FP_VERIFY_FWD_1(result)
|
|
344
|
-
|
|
345
|
-
return result;
|
|
346
|
-
}
|
|
347
|
-
|
|
348
|
-
template <typename T>
|
|
349
|
-
CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i, int j)
|
|
350
|
-
{
|
|
351
|
-
T& result = *grad_at_byte_offset(arr, byte_offset(arr, i, j));
|
|
352
|
-
FP_VERIFY_FWD_2(result)
|
|
353
|
-
|
|
354
|
-
return result;
|
|
355
|
-
}
|
|
356
|
-
|
|
357
|
-
template <typename T>
|
|
358
|
-
CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i, int j, int k)
|
|
359
|
-
{
|
|
360
|
-
T& result = *grad_at_byte_offset(arr, byte_offset(arr, i, j, k));
|
|
361
|
-
FP_VERIFY_FWD_3(result)
|
|
362
|
-
|
|
363
|
-
return result;
|
|
364
|
-
}
|
|
365
|
-
|
|
366
|
-
template <typename T>
|
|
367
|
-
CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i, int j, int k, int l)
|
|
368
|
-
{
|
|
369
|
-
T& result = *grad_at_byte_offset(arr, byte_offset(arr, i, j, k, l));
|
|
370
|
-
FP_VERIFY_FWD_4(result)
|
|
371
|
-
|
|
372
|
-
return result;
|
|
373
|
-
}
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
template <typename T>
|
|
377
|
-
CUDA_CALLABLE inline T& index(const indexedarray_t<T>& iarr, int i)
|
|
378
|
-
{
|
|
379
|
-
assert(iarr.arr.ndim == 1);
|
|
380
|
-
assert(i >= 0 && i < iarr.shape[0]);
|
|
381
|
-
|
|
382
|
-
if (iarr.indices[0])
|
|
383
|
-
{
|
|
384
|
-
i = iarr.indices[0][i];
|
|
385
|
-
assert(i >= 0 && i < iarr.arr.shape[0]);
|
|
386
|
-
}
|
|
387
|
-
|
|
388
|
-
T& result = *data_at_byte_offset(iarr.arr, byte_offset(iarr.arr, i));
|
|
389
|
-
FP_VERIFY_FWD_1(result)
|
|
390
|
-
|
|
391
|
-
return result;
|
|
392
|
-
}
|
|
393
|
-
|
|
394
|
-
template <typename T>
|
|
395
|
-
CUDA_CALLABLE inline T& index(const indexedarray_t<T>& iarr, int i, int j)
|
|
396
|
-
{
|
|
397
|
-
assert(iarr.arr.ndim == 2);
|
|
398
|
-
assert(i >= 0 && i < iarr.shape[0]);
|
|
399
|
-
assert(j >= 0 && j < iarr.shape[1]);
|
|
400
|
-
|
|
401
|
-
if (iarr.indices[0])
|
|
402
|
-
{
|
|
403
|
-
i = iarr.indices[0][i];
|
|
404
|
-
assert(i >= 0 && i < iarr.arr.shape[0]);
|
|
405
|
-
}
|
|
406
|
-
if (iarr.indices[1])
|
|
407
|
-
{
|
|
408
|
-
j = iarr.indices[1][j];
|
|
409
|
-
assert(j >= 0 && j < iarr.arr.shape[1]);
|
|
410
|
-
}
|
|
411
|
-
|
|
412
|
-
T& result = *data_at_byte_offset(iarr.arr, byte_offset(iarr.arr, i, j));
|
|
413
|
-
FP_VERIFY_FWD_1(result)
|
|
414
|
-
|
|
415
|
-
return result;
|
|
416
|
-
}
|
|
417
|
-
|
|
418
|
-
template <typename T>
|
|
419
|
-
CUDA_CALLABLE inline T& index(const indexedarray_t<T>& iarr, int i, int j, int k)
|
|
420
|
-
{
|
|
421
|
-
assert(iarr.arr.ndim == 3);
|
|
422
|
-
assert(i >= 0 && i < iarr.shape[0]);
|
|
423
|
-
assert(j >= 0 && j < iarr.shape[1]);
|
|
424
|
-
assert(k >= 0 && k < iarr.shape[2]);
|
|
425
|
-
|
|
426
|
-
if (iarr.indices[0])
|
|
427
|
-
{
|
|
428
|
-
i = iarr.indices[0][i];
|
|
429
|
-
assert(i >= 0 && i < iarr.arr.shape[0]);
|
|
430
|
-
}
|
|
431
|
-
if (iarr.indices[1])
|
|
432
|
-
{
|
|
433
|
-
j = iarr.indices[1][j];
|
|
434
|
-
assert(j >= 0 && j < iarr.arr.shape[1]);
|
|
435
|
-
}
|
|
436
|
-
if (iarr.indices[2])
|
|
437
|
-
{
|
|
438
|
-
k = iarr.indices[2][k];
|
|
439
|
-
assert(k >= 0 && k < iarr.arr.shape[2]);
|
|
440
|
-
}
|
|
441
|
-
|
|
442
|
-
T& result = *data_at_byte_offset(iarr.arr, byte_offset(iarr.arr, i, j, k));
|
|
443
|
-
FP_VERIFY_FWD_1(result)
|
|
444
|
-
|
|
445
|
-
return result;
|
|
446
|
-
}
|
|
447
|
-
|
|
448
|
-
template <typename T>
|
|
449
|
-
CUDA_CALLABLE inline T& index(const indexedarray_t<T>& iarr, int i, int j, int k, int l)
|
|
450
|
-
{
|
|
451
|
-
assert(iarr.arr.ndim == 4);
|
|
452
|
-
assert(i >= 0 && i < iarr.shape[0]);
|
|
453
|
-
assert(j >= 0 && j < iarr.shape[1]);
|
|
454
|
-
assert(k >= 0 && k < iarr.shape[2]);
|
|
455
|
-
assert(l >= 0 && l < iarr.shape[3]);
|
|
456
|
-
|
|
457
|
-
if (iarr.indices[0])
|
|
458
|
-
{
|
|
459
|
-
i = iarr.indices[0][i];
|
|
460
|
-
assert(i >= 0 && i < iarr.arr.shape[0]);
|
|
461
|
-
}
|
|
462
|
-
if (iarr.indices[1])
|
|
463
|
-
{
|
|
464
|
-
j = iarr.indices[1][j];
|
|
465
|
-
assert(j >= 0 && j < iarr.arr.shape[1]);
|
|
466
|
-
}
|
|
467
|
-
if (iarr.indices[2])
|
|
468
|
-
{
|
|
469
|
-
k = iarr.indices[2][k];
|
|
470
|
-
assert(k >= 0 && k < iarr.arr.shape[2]);
|
|
471
|
-
}
|
|
472
|
-
if (iarr.indices[3])
|
|
473
|
-
{
|
|
474
|
-
l = iarr.indices[3][l];
|
|
475
|
-
assert(l >= 0 && l < iarr.arr.shape[3]);
|
|
476
|
-
}
|
|
477
|
-
|
|
478
|
-
T& result = *data_at_byte_offset(iarr.arr, byte_offset(iarr.arr, i, j, k, l));
|
|
479
|
-
FP_VERIFY_FWD_1(result)
|
|
480
|
-
|
|
481
|
-
return result;
|
|
482
|
-
}
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
template <typename T>
|
|
486
|
-
CUDA_CALLABLE inline array_t<T> view(array_t<T>& src, int i)
|
|
487
|
-
{
|
|
488
|
-
assert(src.ndim > 1);
|
|
489
|
-
assert(i >= 0 && i < src.shape[0]);
|
|
490
|
-
|
|
491
|
-
array_t<T> a;
|
|
492
|
-
a.data = data_at_byte_offset(src, byte_offset(src, i));
|
|
493
|
-
a.shape[0] = src.shape[1];
|
|
494
|
-
a.shape[1] = src.shape[2];
|
|
495
|
-
a.shape[2] = src.shape[3];
|
|
496
|
-
a.strides[0] = src.strides[1];
|
|
497
|
-
a.strides[1] = src.strides[2];
|
|
498
|
-
a.strides[2] = src.strides[3];
|
|
499
|
-
a.ndim = src.ndim-1;
|
|
500
|
-
|
|
501
|
-
return a;
|
|
502
|
-
}
|
|
503
|
-
|
|
504
|
-
template <typename T>
|
|
505
|
-
CUDA_CALLABLE inline array_t<T> view(array_t<T>& src, int i, int j)
|
|
506
|
-
{
|
|
507
|
-
assert(src.ndim > 2);
|
|
508
|
-
assert(i >= 0 && i < src.shape[0]);
|
|
509
|
-
assert(j >= 0 && j < src.shape[1]);
|
|
510
|
-
|
|
511
|
-
array_t<T> a;
|
|
512
|
-
a.data = data_at_byte_offset(src, byte_offset(src, i, j));
|
|
513
|
-
a.shape[0] = src.shape[2];
|
|
514
|
-
a.shape[1] = src.shape[3];
|
|
515
|
-
a.strides[0] = src.strides[2];
|
|
516
|
-
a.strides[1] = src.strides[3];
|
|
517
|
-
a.ndim = src.ndim-2;
|
|
518
|
-
|
|
519
|
-
return a;
|
|
520
|
-
}
|
|
521
|
-
|
|
522
|
-
template <typename T>
|
|
523
|
-
CUDA_CALLABLE inline array_t<T> view(array_t<T>& src, int i, int j, int k)
|
|
524
|
-
{
|
|
525
|
-
assert(src.ndim > 3);
|
|
526
|
-
assert(i >= 0 && i < src.shape[0]);
|
|
527
|
-
assert(j >= 0 && j < src.shape[1]);
|
|
528
|
-
assert(k >= 0 && k < src.shape[2]);
|
|
529
|
-
|
|
530
|
-
array_t<T> a;
|
|
531
|
-
a.data = data_at_byte_offset(src, byte_offset(src, i, j, k));
|
|
532
|
-
a.shape[0] = src.shape[3];
|
|
533
|
-
a.strides[0] = src.strides[3];
|
|
534
|
-
a.ndim = src.ndim-3;
|
|
535
|
-
|
|
536
|
-
return a;
|
|
537
|
-
}
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
template <typename T>
|
|
541
|
-
CUDA_CALLABLE inline indexedarray_t<T> view(indexedarray_t<T>& src, int i)
|
|
542
|
-
{
|
|
543
|
-
assert(src.arr.ndim > 1);
|
|
544
|
-
|
|
545
|
-
if (src.indices[0])
|
|
546
|
-
{
|
|
547
|
-
assert(i >= 0 && i < src.shape[0]);
|
|
548
|
-
i = src.indices[0][i];
|
|
549
|
-
}
|
|
550
|
-
|
|
551
|
-
indexedarray_t<T> a;
|
|
552
|
-
a.arr = view(src.arr, i);
|
|
553
|
-
a.indices[0] = src.indices[1];
|
|
554
|
-
a.indices[1] = src.indices[2];
|
|
555
|
-
a.indices[2] = src.indices[3];
|
|
556
|
-
a.shape[0] = src.shape[1];
|
|
557
|
-
a.shape[1] = src.shape[2];
|
|
558
|
-
a.shape[2] = src.shape[3];
|
|
559
|
-
|
|
560
|
-
return a;
|
|
561
|
-
}
|
|
562
|
-
|
|
563
|
-
template <typename T>
|
|
564
|
-
CUDA_CALLABLE inline indexedarray_t<T> view(indexedarray_t<T>& src, int i, int j)
|
|
565
|
-
{
|
|
566
|
-
assert(src.arr.ndim > 2);
|
|
567
|
-
|
|
568
|
-
if (src.indices[0])
|
|
569
|
-
{
|
|
570
|
-
assert(i >= 0 && i < src.shape[0]);
|
|
571
|
-
i = src.indices[0][i];
|
|
572
|
-
}
|
|
573
|
-
if (src.indices[1])
|
|
574
|
-
{
|
|
575
|
-
assert(j >= 0 && j < src.shape[1]);
|
|
576
|
-
j = src.indices[1][j];
|
|
577
|
-
}
|
|
578
|
-
|
|
579
|
-
indexedarray_t<T> a;
|
|
580
|
-
a.arr = view(src.arr, i, j);
|
|
581
|
-
a.indices[0] = src.indices[2];
|
|
582
|
-
a.indices[1] = src.indices[3];
|
|
583
|
-
a.shape[0] = src.shape[2];
|
|
584
|
-
a.shape[1] = src.shape[3];
|
|
585
|
-
|
|
586
|
-
return a;
|
|
587
|
-
}
|
|
588
|
-
|
|
589
|
-
template <typename T>
|
|
590
|
-
CUDA_CALLABLE inline indexedarray_t<T> view(indexedarray_t<T>& src, int i, int j, int k)
|
|
591
|
-
{
|
|
592
|
-
assert(src.arr.ndim > 3);
|
|
593
|
-
|
|
594
|
-
if (src.indices[0])
|
|
595
|
-
{
|
|
596
|
-
assert(i >= 0 && i < src.shape[0]);
|
|
597
|
-
i = src.indices[0][i];
|
|
598
|
-
}
|
|
599
|
-
if (src.indices[1])
|
|
600
|
-
{
|
|
601
|
-
assert(j >= 0 && j < src.shape[1]);
|
|
602
|
-
j = src.indices[1][j];
|
|
603
|
-
}
|
|
604
|
-
if (src.indices[2])
|
|
605
|
-
{
|
|
606
|
-
assert(k >= 0 && k < src.shape[2]);
|
|
607
|
-
k = src.indices[2][k];
|
|
608
|
-
}
|
|
609
|
-
|
|
610
|
-
indexedarray_t<T> a;
|
|
611
|
-
a.arr = view(src.arr, i, j, k);
|
|
612
|
-
a.indices[0] = src.indices[3];
|
|
613
|
-
a.shape[0] = src.shape[3];
|
|
614
|
-
|
|
615
|
-
return a;
|
|
616
|
-
}
|
|
617
|
-
|
|
618
|
-
template<template<typename> class A1, template<typename> class A2, template<typename> class A3, typename T>
|
|
619
|
-
inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, A2<T>& adj_src, int adj_i, A3<T> adj_ret) {}
|
|
620
|
-
template<template<typename> class A1, template<typename> class A2, template<typename> class A3, typename T>
|
|
621
|
-
inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, int j, A2<T>& adj_src, int adj_i, int adj_j, A3<T> adj_ret) {}
|
|
622
|
-
template<template<typename> class A1, template<typename> class A2, template<typename> class A3, typename T>
|
|
623
|
-
inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, int j, int k, A2<T>& adj_src, int adj_i, int adj_j, int adj_k, A3<T> adj_ret) {}
|
|
624
|
-
|
|
625
|
-
// TODO: lower_bound() for indexed arrays?
|
|
626
|
-
|
|
627
|
-
template <typename T>
|
|
628
|
-
CUDA_CALLABLE inline int lower_bound(const array_t<T>& arr, int arr_begin, int arr_end, T value)
|
|
629
|
-
{
|
|
630
|
-
assert(arr.ndim == 1);
|
|
631
|
-
|
|
632
|
-
int lower = arr_begin;
|
|
633
|
-
int upper = arr_end - 1;
|
|
634
|
-
|
|
635
|
-
while(lower < upper)
|
|
636
|
-
{
|
|
637
|
-
int mid = lower + (upper - lower) / 2;
|
|
638
|
-
|
|
639
|
-
if (arr[mid] < value)
|
|
640
|
-
{
|
|
641
|
-
lower = mid + 1;
|
|
642
|
-
}
|
|
643
|
-
else
|
|
644
|
-
{
|
|
645
|
-
upper = mid;
|
|
646
|
-
}
|
|
647
|
-
}
|
|
648
|
-
|
|
649
|
-
return lower;
|
|
650
|
-
}
|
|
651
|
-
|
|
652
|
-
template <typename T>
|
|
653
|
-
CUDA_CALLABLE inline int lower_bound(const array_t<T>& arr, T value)
|
|
654
|
-
{
|
|
655
|
-
return lower_bound(arr, 0, arr.shape[0], value);
|
|
656
|
-
}
|
|
657
|
-
|
|
658
|
-
template <typename T> inline CUDA_CALLABLE void adj_lower_bound(const array_t<T>& arr, T value, array_t<T> adj_arr, T adj_value, int adj_ret) {}
|
|
659
|
-
template <typename T> inline CUDA_CALLABLE void adj_lower_bound(const array_t<T>& arr, int arr_begin, int arr_end, T value, array_t<T> adj_arr, int adj_arr_begin, int adj_arr_end, T adj_value, int adj_ret) {}
|
|
660
|
-
|
|
661
|
-
template<template<typename> class A, typename T>
|
|
662
|
-
inline CUDA_CALLABLE T atomic_add(const A<T>& buf, int i, T value) { return atomic_add(&index(buf, i), value); }
|
|
663
|
-
template<template<typename> class A, typename T>
|
|
664
|
-
inline CUDA_CALLABLE T atomic_add(const A<T>& buf, int i, int j, T value) { return atomic_add(&index(buf, i, j), value); }
|
|
665
|
-
template<template<typename> class A, typename T>
|
|
666
|
-
inline CUDA_CALLABLE T atomic_add(const A<T>& buf, int i, int j, int k, T value) { return atomic_add(&index(buf, i, j, k), value); }
|
|
667
|
-
template<template<typename> class A, typename T>
|
|
668
|
-
inline CUDA_CALLABLE T atomic_add(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_add(&index(buf, i, j, k, l), value); }
|
|
669
|
-
|
|
670
|
-
template<template<typename> class A, typename T>
|
|
671
|
-
inline CUDA_CALLABLE T atomic_sub(const A<T>& buf, int i, T value) { return atomic_add(&index(buf, i), -value); }
|
|
672
|
-
template<template<typename> class A, typename T>
|
|
673
|
-
inline CUDA_CALLABLE T atomic_sub(const A<T>& buf, int i, int j, T value) { return atomic_add(&index(buf, i, j), -value); }
|
|
674
|
-
template<template<typename> class A, typename T>
|
|
675
|
-
inline CUDA_CALLABLE T atomic_sub(const A<T>& buf, int i, int j, int k, T value) { return atomic_add(&index(buf, i, j, k), -value); }
|
|
676
|
-
template<template<typename> class A, typename T>
|
|
677
|
-
inline CUDA_CALLABLE T atomic_sub(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_add(&index(buf, i, j, k, l), -value); }
|
|
678
|
-
|
|
679
|
-
template<template<typename> class A, typename T>
|
|
680
|
-
inline CUDA_CALLABLE T atomic_min(const A<T>& buf, int i, T value) { return atomic_min(&index(buf, i), value); }
|
|
681
|
-
template<template<typename> class A, typename T>
|
|
682
|
-
inline CUDA_CALLABLE T atomic_min(const A<T>& buf, int i, int j, T value) { return atomic_min(&index(buf, i, j), value); }
|
|
683
|
-
template<template<typename> class A, typename T>
|
|
684
|
-
inline CUDA_CALLABLE T atomic_min(const A<T>& buf, int i, int j, int k, T value) { return atomic_min(&index(buf, i, j, k), value); }
|
|
685
|
-
template<template<typename> class A, typename T>
|
|
686
|
-
inline CUDA_CALLABLE T atomic_min(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_min(&index(buf, i, j, k, l), value); }
|
|
687
|
-
|
|
688
|
-
template<template<typename> class A, typename T>
|
|
689
|
-
inline CUDA_CALLABLE T atomic_max(const A<T>& buf, int i, T value) { return atomic_max(&index(buf, i), value); }
|
|
690
|
-
template<template<typename> class A, typename T>
|
|
691
|
-
inline CUDA_CALLABLE T atomic_max(const A<T>& buf, int i, int j, T value) { return atomic_max(&index(buf, i, j), value); }
|
|
692
|
-
template<template<typename> class A, typename T>
|
|
693
|
-
inline CUDA_CALLABLE T atomic_max(const A<T>& buf, int i, int j, int k, T value) { return atomic_max(&index(buf, i, j, k), value); }
|
|
694
|
-
template<template<typename> class A, typename T>
|
|
695
|
-
inline CUDA_CALLABLE T atomic_max(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_max(&index(buf, i, j, k, l), value); }
|
|
696
|
-
|
|
697
|
-
template<template<typename> class A, typename T>
|
|
698
|
-
inline CUDA_CALLABLE T* address(const A<T>& buf, int i) { return &index(buf, i); }
|
|
699
|
-
template<template<typename> class A, typename T>
|
|
700
|
-
inline CUDA_CALLABLE T* address(const A<T>& buf, int i, int j) { return &index(buf, i, j); }
|
|
701
|
-
template<template<typename> class A, typename T>
|
|
702
|
-
inline CUDA_CALLABLE T* address(const A<T>& buf, int i, int j, int k) { return &index(buf, i, j, k); }
|
|
703
|
-
template<template<typename> class A, typename T>
|
|
704
|
-
inline CUDA_CALLABLE T* address(const A<T>& buf, int i, int j, int k, int l) { return &index(buf, i, j, k, l); }
|
|
705
|
-
|
|
706
|
-
template<template<typename> class A, typename T>
|
|
707
|
-
inline CUDA_CALLABLE void array_store(const A<T>& buf, int i, T value)
|
|
708
|
-
{
|
|
709
|
-
FP_VERIFY_FWD_1(value)
|
|
710
|
-
|
|
711
|
-
index(buf, i) = value;
|
|
712
|
-
}
|
|
713
|
-
template<template<typename> class A, typename T>
|
|
714
|
-
inline CUDA_CALLABLE void array_store(const A<T>& buf, int i, int j, T value)
|
|
715
|
-
{
|
|
716
|
-
FP_VERIFY_FWD_2(value)
|
|
717
|
-
|
|
718
|
-
index(buf, i, j) = value;
|
|
719
|
-
}
|
|
720
|
-
template<template<typename> class A, typename T>
|
|
721
|
-
inline CUDA_CALLABLE void array_store(const A<T>& buf, int i, int j, int k, T value)
|
|
722
|
-
{
|
|
723
|
-
FP_VERIFY_FWD_3(value)
|
|
724
|
-
|
|
725
|
-
index(buf, i, j, k) = value;
|
|
726
|
-
}
|
|
727
|
-
template<template<typename> class A, typename T>
|
|
728
|
-
inline CUDA_CALLABLE void array_store(const A<T>& buf, int i, int j, int k, int l, T value)
|
|
729
|
-
{
|
|
730
|
-
FP_VERIFY_FWD_4(value)
|
|
731
|
-
|
|
732
|
-
index(buf, i, j, k, l) = value;
|
|
733
|
-
}
|
|
734
|
-
|
|
735
|
-
template<typename T>
|
|
736
|
-
inline CUDA_CALLABLE void store(T* address, T value)
|
|
737
|
-
{
|
|
738
|
-
FP_VERIFY_FWD(value)
|
|
739
|
-
|
|
740
|
-
*address = value;
|
|
741
|
-
}
|
|
742
|
-
|
|
743
|
-
template<typename T>
|
|
744
|
-
inline CUDA_CALLABLE T load(T* address)
|
|
745
|
-
{
|
|
746
|
-
T value = *address;
|
|
747
|
-
FP_VERIFY_FWD(value)
|
|
748
|
-
|
|
749
|
-
return value;
|
|
750
|
-
}
|
|
751
|
-
|
|
752
|
-
// select operator to check for array being null
|
|
753
|
-
template <typename T1, typename T2>
|
|
754
|
-
CUDA_CALLABLE inline T2 select(const array_t<T1>& arr, const T2& a, const T2& b) { return arr.data?b:a; }
|
|
755
|
-
|
|
756
|
-
template <typename T1, typename T2>
|
|
757
|
-
CUDA_CALLABLE inline void adj_select(const array_t<T1>& arr, const T2& a, const T2& b, const array_t<T1>& adj_cond, T2& adj_a, T2& adj_b, const T2& adj_ret)
|
|
758
|
-
{
|
|
759
|
-
if (arr.data)
|
|
760
|
-
adj_b += adj_ret;
|
|
761
|
-
else
|
|
762
|
-
adj_a += adj_ret;
|
|
763
|
-
}
|
|
764
|
-
|
|
765
|
-
// stub for the case where we have an nested array inside a struct and
|
|
766
|
-
// atomic add the whole struct onto an array (e.g.: during backwards pass)
|
|
767
|
-
template <typename T>
|
|
768
|
-
CUDA_CALLABLE inline void atomic_add(array_t<T>*, array_t<T>) {}
|
|
769
|
-
|
|
770
|
-
// for float and vector types this is just an alias for an atomic add
|
|
771
|
-
template <typename T>
|
|
772
|
-
CUDA_CALLABLE inline void adj_atomic_add(T* buf, T value) { atomic_add(buf, value); }
|
|
773
|
-
|
|
774
|
-
|
|
775
|
-
// for integral types we do not accumulate gradients
|
|
776
|
-
CUDA_CALLABLE inline void adj_atomic_add(int8* buf, int8 value) { }
|
|
777
|
-
CUDA_CALLABLE inline void adj_atomic_add(uint8* buf, uint8 value) { }
|
|
778
|
-
CUDA_CALLABLE inline void adj_atomic_add(int16* buf, int16 value) { }
|
|
779
|
-
CUDA_CALLABLE inline void adj_atomic_add(uint16* buf, uint16 value) { }
|
|
780
|
-
CUDA_CALLABLE inline void adj_atomic_add(int32* buf, int32 value) { }
|
|
781
|
-
CUDA_CALLABLE inline void adj_atomic_add(uint32* buf, uint32 value) { }
|
|
782
|
-
CUDA_CALLABLE inline void adj_atomic_add(int64* buf, int64 value) { }
|
|
783
|
-
CUDA_CALLABLE inline void adj_atomic_add(uint64* buf, uint64 value) { }
|
|
784
|
-
|
|
785
|
-
CUDA_CALLABLE inline void adj_atomic_add(bool* buf, bool value) { }
|
|
786
|
-
|
|
787
|
-
// only generate gradients for T types
|
|
788
|
-
template<typename T>
|
|
789
|
-
inline CUDA_CALLABLE void adj_address(const array_t<T>& buf, int i, const array_t<T>& adj_buf, int& adj_i, const T& adj_output)
|
|
790
|
-
{
|
|
791
|
-
if (
|
|
792
|
-
adj_atomic_add(&
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
|
|
800
|
-
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
804
|
-
|
|
805
|
-
|
|
806
|
-
|
|
807
|
-
|
|
808
|
-
|
|
809
|
-
if (buf.grad)
|
|
810
|
-
adj_atomic_add(&index_grad(buf, i, j, k
|
|
811
|
-
}
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
template<typename T>
|
|
822
|
-
inline CUDA_CALLABLE void adj_array_store(const array_t<T>& buf, int i,
|
|
823
|
-
{
|
|
824
|
-
if (
|
|
825
|
-
adj_value +=
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
|
|
832
|
-
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
|
|
842
|
-
|
|
843
|
-
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
|
|
858
|
-
|
|
859
|
-
|
|
860
|
-
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
}
|
|
868
|
-
|
|
869
|
-
|
|
870
|
-
|
|
871
|
-
|
|
872
|
-
|
|
873
|
-
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
|
|
887
|
-
|
|
888
|
-
|
|
889
|
-
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
898
|
-
|
|
899
|
-
|
|
900
|
-
|
|
901
|
-
|
|
902
|
-
|
|
903
|
-
|
|
904
|
-
|
|
905
|
-
|
|
906
|
-
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
|
|
913
|
-
|
|
914
|
-
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
919
|
-
|
|
920
|
-
|
|
921
|
-
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
926
|
-
|
|
927
|
-
|
|
928
|
-
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
936
|
-
|
|
937
|
-
template<
|
|
938
|
-
inline CUDA_CALLABLE void
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
960
|
-
inline CUDA_CALLABLE void
|
|
961
|
-
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
962
|
-
inline CUDA_CALLABLE void
|
|
963
|
-
|
|
964
|
-
|
|
965
|
-
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
966
|
-
inline CUDA_CALLABLE void
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
}
|
|
972
|
-
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
973
|
-
inline CUDA_CALLABLE void
|
|
974
|
-
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
|
|
978
|
-
}
|
|
979
|
-
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
980
|
-
inline CUDA_CALLABLE void
|
|
981
|
-
|
|
982
|
-
|
|
983
|
-
|
|
984
|
-
|
|
985
|
-
|
|
986
|
-
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
987
|
-
inline CUDA_CALLABLE void
|
|
988
|
-
|
|
989
|
-
|
|
990
|
-
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
|
|
994
|
-
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
1005
|
-
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
|
|
1009
|
-
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
|
|
1013
|
-
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
|
|
1017
|
-
|
|
1018
|
-
|
|
1019
|
-
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1
|
+
#pragma once
|
|
2
|
+
|
|
3
|
+
#include "builtin.h"
|
|
4
|
+
|
|
5
|
+
namespace wp
|
|
6
|
+
{
|
|
7
|
+
|
|
8
|
+
#if FP_CHECK
|
|
9
|
+
|
|
10
|
+
#define FP_ASSERT_FWD(value) \
|
|
11
|
+
print(value); \
|
|
12
|
+
printf(")\n"); \
|
|
13
|
+
assert(0); \
|
|
14
|
+
|
|
15
|
+
#define FP_ASSERT_ADJ(value, adj_value) \
|
|
16
|
+
print(value); \
|
|
17
|
+
printf(", "); \
|
|
18
|
+
print(adj_value); \
|
|
19
|
+
printf(")\n"); \
|
|
20
|
+
assert(0); \
|
|
21
|
+
|
|
22
|
+
#define FP_VERIFY_FWD(value) \
|
|
23
|
+
if (!isfinite(value)) { \
|
|
24
|
+
printf("%s:%d - %s(addr", __FILE__, __LINE__, __FUNCTION__); \
|
|
25
|
+
FP_ASSERT_FWD(value) \
|
|
26
|
+
} \
|
|
27
|
+
|
|
28
|
+
#define FP_VERIFY_FWD_1(value) \
|
|
29
|
+
if (!isfinite(value)) { \
|
|
30
|
+
printf("%s:%d - %s(arr, %d) ", __FILE__, __LINE__, __FUNCTION__, i); \
|
|
31
|
+
FP_ASSERT_FWD(value) \
|
|
32
|
+
} \
|
|
33
|
+
|
|
34
|
+
#define FP_VERIFY_FWD_2(value) \
|
|
35
|
+
if (!isfinite(value)) { \
|
|
36
|
+
printf("%s:%d - %s(arr, %d, %d) ", __FILE__, __LINE__, __FUNCTION__, i, j); \
|
|
37
|
+
FP_ASSERT_FWD(value) \
|
|
38
|
+
} \
|
|
39
|
+
|
|
40
|
+
#define FP_VERIFY_FWD_3(value) \
|
|
41
|
+
if (!isfinite(value)) { \
|
|
42
|
+
printf("%s:%d - %s(arr, %d, %d, %d) ", __FILE__, __LINE__, __FUNCTION__, i, j, k); \
|
|
43
|
+
FP_ASSERT_FWD(value) \
|
|
44
|
+
} \
|
|
45
|
+
|
|
46
|
+
#define FP_VERIFY_FWD_4(value) \
|
|
47
|
+
if (!isfinite(value)) { \
|
|
48
|
+
printf("%s:%d - %s(arr, %d, %d, %d, %d) ", __FILE__, __LINE__, __FUNCTION__, i, j, k, l); \
|
|
49
|
+
FP_ASSERT_FWD(value) \
|
|
50
|
+
} \
|
|
51
|
+
|
|
52
|
+
#define FP_VERIFY_ADJ(value, adj_value) \
|
|
53
|
+
if (!isfinite(value) || !isfinite(adj_value)) \
|
|
54
|
+
{ \
|
|
55
|
+
printf("%s:%d - %s(addr", __FILE__, __LINE__, __FUNCTION__); \
|
|
56
|
+
FP_ASSERT_ADJ(value, adj_value); \
|
|
57
|
+
} \
|
|
58
|
+
|
|
59
|
+
#define FP_VERIFY_ADJ_1(value, adj_value) \
|
|
60
|
+
if (!isfinite(value) || !isfinite(adj_value)) \
|
|
61
|
+
{ \
|
|
62
|
+
printf("%s:%d - %s(arr, %d) ", __FILE__, __LINE__, __FUNCTION__, i); \
|
|
63
|
+
FP_ASSERT_ADJ(value, adj_value); \
|
|
64
|
+
} \
|
|
65
|
+
|
|
66
|
+
#define FP_VERIFY_ADJ_2(value, adj_value) \
|
|
67
|
+
if (!isfinite(value) || !isfinite(adj_value)) \
|
|
68
|
+
{ \
|
|
69
|
+
printf("%s:%d - %s(arr, %d, %d) ", __FILE__, __LINE__, __FUNCTION__, i, j); \
|
|
70
|
+
FP_ASSERT_ADJ(value, adj_value); \
|
|
71
|
+
} \
|
|
72
|
+
|
|
73
|
+
#define FP_VERIFY_ADJ_3(value, adj_value) \
|
|
74
|
+
if (!isfinite(value) || !isfinite(adj_value)) \
|
|
75
|
+
{ \
|
|
76
|
+
printf("%s:%d - %s(arr, %d, %d, %d) ", __FILE__, __LINE__, __FUNCTION__, i, j, k); \
|
|
77
|
+
FP_ASSERT_ADJ(value, adj_value); \
|
|
78
|
+
} \
|
|
79
|
+
|
|
80
|
+
#define FP_VERIFY_ADJ_4(value, adj_value) \
|
|
81
|
+
if (!isfinite(value) || !isfinite(adj_value)) \
|
|
82
|
+
{ \
|
|
83
|
+
printf("%s:%d - %s(arr, %d, %d, %d, %d) ", __FILE__, __LINE__, __FUNCTION__, i, j, k, l); \
|
|
84
|
+
FP_ASSERT_ADJ(value, adj_value); \
|
|
85
|
+
} \
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
#else
|
|
89
|
+
|
|
90
|
+
#define FP_VERIFY_FWD(value) {}
|
|
91
|
+
#define FP_VERIFY_FWD_1(value) {}
|
|
92
|
+
#define FP_VERIFY_FWD_2(value) {}
|
|
93
|
+
#define FP_VERIFY_FWD_3(value) {}
|
|
94
|
+
#define FP_VERIFY_FWD_4(value) {}
|
|
95
|
+
|
|
96
|
+
#define FP_VERIFY_ADJ(value, adj_value) {}
|
|
97
|
+
#define FP_VERIFY_ADJ_1(value, adj_value) {}
|
|
98
|
+
#define FP_VERIFY_ADJ_2(value, adj_value) {}
|
|
99
|
+
#define FP_VERIFY_ADJ_3(value, adj_value) {}
|
|
100
|
+
#define FP_VERIFY_ADJ_4(value, adj_value) {}
|
|
101
|
+
|
|
102
|
+
#endif // WP_FP_CHECK
|
|
103
|
+
|
|
104
|
+
const int ARRAY_MAX_DIMS = 4; // must match constant in types.py
|
|
105
|
+
|
|
106
|
+
// must match constants in types.py
|
|
107
|
+
const int ARRAY_TYPE_REGULAR = 0;
|
|
108
|
+
const int ARRAY_TYPE_INDEXED = 1;
|
|
109
|
+
const int ARRAY_TYPE_FABRIC = 2;
|
|
110
|
+
const int ARRAY_TYPE_FABRIC_INDEXED = 3;
|
|
111
|
+
|
|
112
|
+
struct shape_t
|
|
113
|
+
{
|
|
114
|
+
int dims[ARRAY_MAX_DIMS];
|
|
115
|
+
|
|
116
|
+
CUDA_CALLABLE inline shape_t()
|
|
117
|
+
: dims()
|
|
118
|
+
{}
|
|
119
|
+
|
|
120
|
+
CUDA_CALLABLE inline int operator[](int i) const
|
|
121
|
+
{
|
|
122
|
+
assert(i < ARRAY_MAX_DIMS);
|
|
123
|
+
return dims[i];
|
|
124
|
+
}
|
|
125
|
+
|
|
126
|
+
CUDA_CALLABLE inline int& operator[](int i)
|
|
127
|
+
{
|
|
128
|
+
assert(i < ARRAY_MAX_DIMS);
|
|
129
|
+
return dims[i];
|
|
130
|
+
}
|
|
131
|
+
};
|
|
132
|
+
|
|
133
|
+
CUDA_CALLABLE inline int extract(const shape_t& s, int i)
|
|
134
|
+
{
|
|
135
|
+
return s.dims[i];
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
CUDA_CALLABLE inline void adj_extract(const shape_t& s, int i, const shape_t& adj_s, int adj_i, int adj_ret) {}
|
|
139
|
+
|
|
140
|
+
inline CUDA_CALLABLE void print(shape_t s)
|
|
141
|
+
{
|
|
142
|
+
// todo: only print valid dims, currently shape has a fixed size
|
|
143
|
+
// but we don't know how many dims are valid (e.g.: 1d, 2d, etc)
|
|
144
|
+
// should probably store ndim with shape
|
|
145
|
+
printf("(%d, %d, %d, %d)\n", s.dims[0], s.dims[1], s.dims[2], s.dims[3]);
|
|
146
|
+
}
|
|
147
|
+
inline CUDA_CALLABLE void adj_print(shape_t s, shape_t& shape_t) {}
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
template <typename T>
|
|
151
|
+
struct array_t
|
|
152
|
+
{
|
|
153
|
+
CUDA_CALLABLE inline array_t()
|
|
154
|
+
: data(nullptr),
|
|
155
|
+
grad(nullptr),
|
|
156
|
+
shape(),
|
|
157
|
+
strides(),
|
|
158
|
+
ndim(0)
|
|
159
|
+
{}
|
|
160
|
+
|
|
161
|
+
CUDA_CALLABLE array_t(T* data, int size, T* grad=nullptr) : data(data), grad(grad) {
|
|
162
|
+
// constructor for 1d array
|
|
163
|
+
shape.dims[0] = size;
|
|
164
|
+
shape.dims[1] = 0;
|
|
165
|
+
shape.dims[2] = 0;
|
|
166
|
+
shape.dims[3] = 0;
|
|
167
|
+
ndim = 1;
|
|
168
|
+
strides[0] = sizeof(T);
|
|
169
|
+
strides[1] = 0;
|
|
170
|
+
strides[2] = 0;
|
|
171
|
+
strides[3] = 0;
|
|
172
|
+
}
|
|
173
|
+
CUDA_CALLABLE array_t(T* data, int dim0, int dim1, T* grad=nullptr) : data(data), grad(grad) {
|
|
174
|
+
// constructor for 2d array
|
|
175
|
+
shape.dims[0] = dim0;
|
|
176
|
+
shape.dims[1] = dim1;
|
|
177
|
+
shape.dims[2] = 0;
|
|
178
|
+
shape.dims[3] = 0;
|
|
179
|
+
ndim = 2;
|
|
180
|
+
strides[0] = dim1 * sizeof(T);
|
|
181
|
+
strides[1] = sizeof(T);
|
|
182
|
+
strides[2] = 0;
|
|
183
|
+
strides[3] = 0;
|
|
184
|
+
}
|
|
185
|
+
CUDA_CALLABLE array_t(T* data, int dim0, int dim1, int dim2, T* grad=nullptr) : data(data), grad(grad) {
|
|
186
|
+
// constructor for 3d array
|
|
187
|
+
shape.dims[0] = dim0;
|
|
188
|
+
shape.dims[1] = dim1;
|
|
189
|
+
shape.dims[2] = dim2;
|
|
190
|
+
shape.dims[3] = 0;
|
|
191
|
+
ndim = 3;
|
|
192
|
+
strides[0] = dim1 * dim2 * sizeof(T);
|
|
193
|
+
strides[1] = dim2 * sizeof(T);
|
|
194
|
+
strides[2] = sizeof(T);
|
|
195
|
+
strides[3] = 0;
|
|
196
|
+
}
|
|
197
|
+
CUDA_CALLABLE array_t(T* data, int dim0, int dim1, int dim2, int dim3, T* grad=nullptr) : data(data), grad(grad) {
|
|
198
|
+
// constructor for 4d array
|
|
199
|
+
shape.dims[0] = dim0;
|
|
200
|
+
shape.dims[1] = dim1;
|
|
201
|
+
shape.dims[2] = dim2;
|
|
202
|
+
shape.dims[3] = dim3;
|
|
203
|
+
ndim = 4;
|
|
204
|
+
strides[0] = dim1 * dim2 * dim3 * sizeof(T);
|
|
205
|
+
strides[1] = dim2 * dim3 * sizeof(T);
|
|
206
|
+
strides[2] = dim3 * sizeof(T);
|
|
207
|
+
strides[3] = sizeof(T);
|
|
208
|
+
}
|
|
209
|
+
|
|
210
|
+
CUDA_CALLABLE inline bool empty() const { return !data; }
|
|
211
|
+
|
|
212
|
+
T* data;
|
|
213
|
+
T* grad;
|
|
214
|
+
shape_t shape;
|
|
215
|
+
int strides[ARRAY_MAX_DIMS];
|
|
216
|
+
int ndim;
|
|
217
|
+
|
|
218
|
+
CUDA_CALLABLE inline operator T*() const { return data; }
|
|
219
|
+
};
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
// TODO:
|
|
223
|
+
// - templated index type?
|
|
224
|
+
// - templated dimensionality? (also for array_t to save space when passing arrays to kernels)
|
|
225
|
+
template <typename T>
|
|
226
|
+
struct indexedarray_t
|
|
227
|
+
{
|
|
228
|
+
CUDA_CALLABLE inline indexedarray_t()
|
|
229
|
+
: arr(),
|
|
230
|
+
indices(),
|
|
231
|
+
shape()
|
|
232
|
+
{}
|
|
233
|
+
|
|
234
|
+
CUDA_CALLABLE inline bool empty() const { return !arr.data; }
|
|
235
|
+
|
|
236
|
+
array_t<T> arr;
|
|
237
|
+
int* indices[ARRAY_MAX_DIMS]; // index array per dimension (can be NULL)
|
|
238
|
+
shape_t shape; // element count per dimension (num. indices if indexed, array dim if not)
|
|
239
|
+
};
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
// return stride (in bytes) of the given index
|
|
243
|
+
template <typename T>
|
|
244
|
+
CUDA_CALLABLE inline size_t stride(const array_t<T>& a, int dim)
|
|
245
|
+
{
|
|
246
|
+
return size_t(a.strides[dim]);
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
template <typename T>
|
|
250
|
+
CUDA_CALLABLE inline T* data_at_byte_offset(const array_t<T>& a, size_t byte_offset)
|
|
251
|
+
{
|
|
252
|
+
return reinterpret_cast<T*>(reinterpret_cast<char*>(a.data) + byte_offset);
|
|
253
|
+
}
|
|
254
|
+
|
|
255
|
+
template <typename T>
|
|
256
|
+
CUDA_CALLABLE inline T* grad_at_byte_offset(const array_t<T>& a, size_t byte_offset)
|
|
257
|
+
{
|
|
258
|
+
return reinterpret_cast<T*>(reinterpret_cast<char*>(a.grad) + byte_offset);
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
template <typename T>
|
|
262
|
+
CUDA_CALLABLE inline size_t byte_offset(const array_t<T>& arr, int i)
|
|
263
|
+
{
|
|
264
|
+
assert(i >= 0 && i < arr.shape[0]);
|
|
265
|
+
|
|
266
|
+
return i*stride(arr, 0);
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
template <typename T>
|
|
270
|
+
CUDA_CALLABLE inline size_t byte_offset(const array_t<T>& arr, int i, int j)
|
|
271
|
+
{
|
|
272
|
+
assert(i >= 0 && i < arr.shape[0]);
|
|
273
|
+
assert(j >= 0 && j < arr.shape[1]);
|
|
274
|
+
|
|
275
|
+
return i*stride(arr, 0) + j*stride(arr, 1);
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
template <typename T>
|
|
279
|
+
CUDA_CALLABLE inline size_t byte_offset(const array_t<T>& arr, int i, int j, int k)
|
|
280
|
+
{
|
|
281
|
+
assert(i >= 0 && i < arr.shape[0]);
|
|
282
|
+
assert(j >= 0 && j < arr.shape[1]);
|
|
283
|
+
assert(k >= 0 && k < arr.shape[2]);
|
|
284
|
+
|
|
285
|
+
return i*stride(arr, 0) + j*stride(arr, 1) + k*stride(arr, 2);
|
|
286
|
+
}
|
|
287
|
+
|
|
288
|
+
template <typename T>
|
|
289
|
+
CUDA_CALLABLE inline size_t byte_offset(const array_t<T>& arr, int i, int j, int k, int l)
|
|
290
|
+
{
|
|
291
|
+
assert(i >= 0 && i < arr.shape[0]);
|
|
292
|
+
assert(j >= 0 && j < arr.shape[1]);
|
|
293
|
+
assert(k >= 0 && k < arr.shape[2]);
|
|
294
|
+
assert(l >= 0 && l < arr.shape[3]);
|
|
295
|
+
|
|
296
|
+
return i*stride(arr, 0) + j*stride(arr, 1) + k*stride(arr, 2) + l*stride(arr, 3);
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
template <typename T>
|
|
300
|
+
CUDA_CALLABLE inline T& index(const array_t<T>& arr, int i)
|
|
301
|
+
{
|
|
302
|
+
assert(arr.ndim == 1);
|
|
303
|
+
T& result = *data_at_byte_offset(arr, byte_offset(arr, i));
|
|
304
|
+
FP_VERIFY_FWD_1(result)
|
|
305
|
+
|
|
306
|
+
return result;
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
template <typename T>
|
|
310
|
+
CUDA_CALLABLE inline T& index(const array_t<T>& arr, int i, int j)
|
|
311
|
+
{
|
|
312
|
+
assert(arr.ndim == 2);
|
|
313
|
+
T& result = *data_at_byte_offset(arr, byte_offset(arr, i, j));
|
|
314
|
+
FP_VERIFY_FWD_2(result)
|
|
315
|
+
|
|
316
|
+
return result;
|
|
317
|
+
}
|
|
318
|
+
|
|
319
|
+
template <typename T>
|
|
320
|
+
CUDA_CALLABLE inline T& index(const array_t<T>& arr, int i, int j, int k)
|
|
321
|
+
{
|
|
322
|
+
assert(arr.ndim == 3);
|
|
323
|
+
T& result = *data_at_byte_offset(arr, byte_offset(arr, i, j, k));
|
|
324
|
+
FP_VERIFY_FWD_3(result)
|
|
325
|
+
|
|
326
|
+
return result;
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
template <typename T>
|
|
330
|
+
CUDA_CALLABLE inline T& index(const array_t<T>& arr, int i, int j, int k, int l)
|
|
331
|
+
{
|
|
332
|
+
assert(arr.ndim == 4);
|
|
333
|
+
T& result = *data_at_byte_offset(arr, byte_offset(arr, i, j, k, l));
|
|
334
|
+
FP_VERIFY_FWD_4(result)
|
|
335
|
+
|
|
336
|
+
return result;
|
|
337
|
+
}
|
|
338
|
+
|
|
339
|
+
template <typename T>
|
|
340
|
+
CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i)
|
|
341
|
+
{
|
|
342
|
+
T& result = *grad_at_byte_offset(arr, byte_offset(arr, i));
|
|
343
|
+
FP_VERIFY_FWD_1(result)
|
|
344
|
+
|
|
345
|
+
return result;
|
|
346
|
+
}
|
|
347
|
+
|
|
348
|
+
template <typename T>
|
|
349
|
+
CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i, int j)
|
|
350
|
+
{
|
|
351
|
+
T& result = *grad_at_byte_offset(arr, byte_offset(arr, i, j));
|
|
352
|
+
FP_VERIFY_FWD_2(result)
|
|
353
|
+
|
|
354
|
+
return result;
|
|
355
|
+
}
|
|
356
|
+
|
|
357
|
+
template <typename T>
|
|
358
|
+
CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i, int j, int k)
|
|
359
|
+
{
|
|
360
|
+
T& result = *grad_at_byte_offset(arr, byte_offset(arr, i, j, k));
|
|
361
|
+
FP_VERIFY_FWD_3(result)
|
|
362
|
+
|
|
363
|
+
return result;
|
|
364
|
+
}
|
|
365
|
+
|
|
366
|
+
template <typename T>
|
|
367
|
+
CUDA_CALLABLE inline T& index_grad(const array_t<T>& arr, int i, int j, int k, int l)
|
|
368
|
+
{
|
|
369
|
+
T& result = *grad_at_byte_offset(arr, byte_offset(arr, i, j, k, l));
|
|
370
|
+
FP_VERIFY_FWD_4(result)
|
|
371
|
+
|
|
372
|
+
return result;
|
|
373
|
+
}
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
template <typename T>
|
|
377
|
+
CUDA_CALLABLE inline T& index(const indexedarray_t<T>& iarr, int i)
|
|
378
|
+
{
|
|
379
|
+
assert(iarr.arr.ndim == 1);
|
|
380
|
+
assert(i >= 0 && i < iarr.shape[0]);
|
|
381
|
+
|
|
382
|
+
if (iarr.indices[0])
|
|
383
|
+
{
|
|
384
|
+
i = iarr.indices[0][i];
|
|
385
|
+
assert(i >= 0 && i < iarr.arr.shape[0]);
|
|
386
|
+
}
|
|
387
|
+
|
|
388
|
+
T& result = *data_at_byte_offset(iarr.arr, byte_offset(iarr.arr, i));
|
|
389
|
+
FP_VERIFY_FWD_1(result)
|
|
390
|
+
|
|
391
|
+
return result;
|
|
392
|
+
}
|
|
393
|
+
|
|
394
|
+
template <typename T>
|
|
395
|
+
CUDA_CALLABLE inline T& index(const indexedarray_t<T>& iarr, int i, int j)
|
|
396
|
+
{
|
|
397
|
+
assert(iarr.arr.ndim == 2);
|
|
398
|
+
assert(i >= 0 && i < iarr.shape[0]);
|
|
399
|
+
assert(j >= 0 && j < iarr.shape[1]);
|
|
400
|
+
|
|
401
|
+
if (iarr.indices[0])
|
|
402
|
+
{
|
|
403
|
+
i = iarr.indices[0][i];
|
|
404
|
+
assert(i >= 0 && i < iarr.arr.shape[0]);
|
|
405
|
+
}
|
|
406
|
+
if (iarr.indices[1])
|
|
407
|
+
{
|
|
408
|
+
j = iarr.indices[1][j];
|
|
409
|
+
assert(j >= 0 && j < iarr.arr.shape[1]);
|
|
410
|
+
}
|
|
411
|
+
|
|
412
|
+
T& result = *data_at_byte_offset(iarr.arr, byte_offset(iarr.arr, i, j));
|
|
413
|
+
FP_VERIFY_FWD_1(result)
|
|
414
|
+
|
|
415
|
+
return result;
|
|
416
|
+
}
|
|
417
|
+
|
|
418
|
+
template <typename T>
|
|
419
|
+
CUDA_CALLABLE inline T& index(const indexedarray_t<T>& iarr, int i, int j, int k)
|
|
420
|
+
{
|
|
421
|
+
assert(iarr.arr.ndim == 3);
|
|
422
|
+
assert(i >= 0 && i < iarr.shape[0]);
|
|
423
|
+
assert(j >= 0 && j < iarr.shape[1]);
|
|
424
|
+
assert(k >= 0 && k < iarr.shape[2]);
|
|
425
|
+
|
|
426
|
+
if (iarr.indices[0])
|
|
427
|
+
{
|
|
428
|
+
i = iarr.indices[0][i];
|
|
429
|
+
assert(i >= 0 && i < iarr.arr.shape[0]);
|
|
430
|
+
}
|
|
431
|
+
if (iarr.indices[1])
|
|
432
|
+
{
|
|
433
|
+
j = iarr.indices[1][j];
|
|
434
|
+
assert(j >= 0 && j < iarr.arr.shape[1]);
|
|
435
|
+
}
|
|
436
|
+
if (iarr.indices[2])
|
|
437
|
+
{
|
|
438
|
+
k = iarr.indices[2][k];
|
|
439
|
+
assert(k >= 0 && k < iarr.arr.shape[2]);
|
|
440
|
+
}
|
|
441
|
+
|
|
442
|
+
T& result = *data_at_byte_offset(iarr.arr, byte_offset(iarr.arr, i, j, k));
|
|
443
|
+
FP_VERIFY_FWD_1(result)
|
|
444
|
+
|
|
445
|
+
return result;
|
|
446
|
+
}
|
|
447
|
+
|
|
448
|
+
template <typename T>
|
|
449
|
+
CUDA_CALLABLE inline T& index(const indexedarray_t<T>& iarr, int i, int j, int k, int l)
|
|
450
|
+
{
|
|
451
|
+
assert(iarr.arr.ndim == 4);
|
|
452
|
+
assert(i >= 0 && i < iarr.shape[0]);
|
|
453
|
+
assert(j >= 0 && j < iarr.shape[1]);
|
|
454
|
+
assert(k >= 0 && k < iarr.shape[2]);
|
|
455
|
+
assert(l >= 0 && l < iarr.shape[3]);
|
|
456
|
+
|
|
457
|
+
if (iarr.indices[0])
|
|
458
|
+
{
|
|
459
|
+
i = iarr.indices[0][i];
|
|
460
|
+
assert(i >= 0 && i < iarr.arr.shape[0]);
|
|
461
|
+
}
|
|
462
|
+
if (iarr.indices[1])
|
|
463
|
+
{
|
|
464
|
+
j = iarr.indices[1][j];
|
|
465
|
+
assert(j >= 0 && j < iarr.arr.shape[1]);
|
|
466
|
+
}
|
|
467
|
+
if (iarr.indices[2])
|
|
468
|
+
{
|
|
469
|
+
k = iarr.indices[2][k];
|
|
470
|
+
assert(k >= 0 && k < iarr.arr.shape[2]);
|
|
471
|
+
}
|
|
472
|
+
if (iarr.indices[3])
|
|
473
|
+
{
|
|
474
|
+
l = iarr.indices[3][l];
|
|
475
|
+
assert(l >= 0 && l < iarr.arr.shape[3]);
|
|
476
|
+
}
|
|
477
|
+
|
|
478
|
+
T& result = *data_at_byte_offset(iarr.arr, byte_offset(iarr.arr, i, j, k, l));
|
|
479
|
+
FP_VERIFY_FWD_1(result)
|
|
480
|
+
|
|
481
|
+
return result;
|
|
482
|
+
}
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
template <typename T>
|
|
486
|
+
CUDA_CALLABLE inline array_t<T> view(array_t<T>& src, int i)
|
|
487
|
+
{
|
|
488
|
+
assert(src.ndim > 1);
|
|
489
|
+
assert(i >= 0 && i < src.shape[0]);
|
|
490
|
+
|
|
491
|
+
array_t<T> a;
|
|
492
|
+
a.data = data_at_byte_offset(src, byte_offset(src, i));
|
|
493
|
+
a.shape[0] = src.shape[1];
|
|
494
|
+
a.shape[1] = src.shape[2];
|
|
495
|
+
a.shape[2] = src.shape[3];
|
|
496
|
+
a.strides[0] = src.strides[1];
|
|
497
|
+
a.strides[1] = src.strides[2];
|
|
498
|
+
a.strides[2] = src.strides[3];
|
|
499
|
+
a.ndim = src.ndim-1;
|
|
500
|
+
|
|
501
|
+
return a;
|
|
502
|
+
}
|
|
503
|
+
|
|
504
|
+
template <typename T>
|
|
505
|
+
CUDA_CALLABLE inline array_t<T> view(array_t<T>& src, int i, int j)
|
|
506
|
+
{
|
|
507
|
+
assert(src.ndim > 2);
|
|
508
|
+
assert(i >= 0 && i < src.shape[0]);
|
|
509
|
+
assert(j >= 0 && j < src.shape[1]);
|
|
510
|
+
|
|
511
|
+
array_t<T> a;
|
|
512
|
+
a.data = data_at_byte_offset(src, byte_offset(src, i, j));
|
|
513
|
+
a.shape[0] = src.shape[2];
|
|
514
|
+
a.shape[1] = src.shape[3];
|
|
515
|
+
a.strides[0] = src.strides[2];
|
|
516
|
+
a.strides[1] = src.strides[3];
|
|
517
|
+
a.ndim = src.ndim-2;
|
|
518
|
+
|
|
519
|
+
return a;
|
|
520
|
+
}
|
|
521
|
+
|
|
522
|
+
template <typename T>
|
|
523
|
+
CUDA_CALLABLE inline array_t<T> view(array_t<T>& src, int i, int j, int k)
|
|
524
|
+
{
|
|
525
|
+
assert(src.ndim > 3);
|
|
526
|
+
assert(i >= 0 && i < src.shape[0]);
|
|
527
|
+
assert(j >= 0 && j < src.shape[1]);
|
|
528
|
+
assert(k >= 0 && k < src.shape[2]);
|
|
529
|
+
|
|
530
|
+
array_t<T> a;
|
|
531
|
+
a.data = data_at_byte_offset(src, byte_offset(src, i, j, k));
|
|
532
|
+
a.shape[0] = src.shape[3];
|
|
533
|
+
a.strides[0] = src.strides[3];
|
|
534
|
+
a.ndim = src.ndim-3;
|
|
535
|
+
|
|
536
|
+
return a;
|
|
537
|
+
}
|
|
538
|
+
|
|
539
|
+
|
|
540
|
+
template <typename T>
|
|
541
|
+
CUDA_CALLABLE inline indexedarray_t<T> view(indexedarray_t<T>& src, int i)
|
|
542
|
+
{
|
|
543
|
+
assert(src.arr.ndim > 1);
|
|
544
|
+
|
|
545
|
+
if (src.indices[0])
|
|
546
|
+
{
|
|
547
|
+
assert(i >= 0 && i < src.shape[0]);
|
|
548
|
+
i = src.indices[0][i];
|
|
549
|
+
}
|
|
550
|
+
|
|
551
|
+
indexedarray_t<T> a;
|
|
552
|
+
a.arr = view(src.arr, i);
|
|
553
|
+
a.indices[0] = src.indices[1];
|
|
554
|
+
a.indices[1] = src.indices[2];
|
|
555
|
+
a.indices[2] = src.indices[3];
|
|
556
|
+
a.shape[0] = src.shape[1];
|
|
557
|
+
a.shape[1] = src.shape[2];
|
|
558
|
+
a.shape[2] = src.shape[3];
|
|
559
|
+
|
|
560
|
+
return a;
|
|
561
|
+
}
|
|
562
|
+
|
|
563
|
+
template <typename T>
|
|
564
|
+
CUDA_CALLABLE inline indexedarray_t<T> view(indexedarray_t<T>& src, int i, int j)
|
|
565
|
+
{
|
|
566
|
+
assert(src.arr.ndim > 2);
|
|
567
|
+
|
|
568
|
+
if (src.indices[0])
|
|
569
|
+
{
|
|
570
|
+
assert(i >= 0 && i < src.shape[0]);
|
|
571
|
+
i = src.indices[0][i];
|
|
572
|
+
}
|
|
573
|
+
if (src.indices[1])
|
|
574
|
+
{
|
|
575
|
+
assert(j >= 0 && j < src.shape[1]);
|
|
576
|
+
j = src.indices[1][j];
|
|
577
|
+
}
|
|
578
|
+
|
|
579
|
+
indexedarray_t<T> a;
|
|
580
|
+
a.arr = view(src.arr, i, j);
|
|
581
|
+
a.indices[0] = src.indices[2];
|
|
582
|
+
a.indices[1] = src.indices[3];
|
|
583
|
+
a.shape[0] = src.shape[2];
|
|
584
|
+
a.shape[1] = src.shape[3];
|
|
585
|
+
|
|
586
|
+
return a;
|
|
587
|
+
}
|
|
588
|
+
|
|
589
|
+
template <typename T>
|
|
590
|
+
CUDA_CALLABLE inline indexedarray_t<T> view(indexedarray_t<T>& src, int i, int j, int k)
|
|
591
|
+
{
|
|
592
|
+
assert(src.arr.ndim > 3);
|
|
593
|
+
|
|
594
|
+
if (src.indices[0])
|
|
595
|
+
{
|
|
596
|
+
assert(i >= 0 && i < src.shape[0]);
|
|
597
|
+
i = src.indices[0][i];
|
|
598
|
+
}
|
|
599
|
+
if (src.indices[1])
|
|
600
|
+
{
|
|
601
|
+
assert(j >= 0 && j < src.shape[1]);
|
|
602
|
+
j = src.indices[1][j];
|
|
603
|
+
}
|
|
604
|
+
if (src.indices[2])
|
|
605
|
+
{
|
|
606
|
+
assert(k >= 0 && k < src.shape[2]);
|
|
607
|
+
k = src.indices[2][k];
|
|
608
|
+
}
|
|
609
|
+
|
|
610
|
+
indexedarray_t<T> a;
|
|
611
|
+
a.arr = view(src.arr, i, j, k);
|
|
612
|
+
a.indices[0] = src.indices[3];
|
|
613
|
+
a.shape[0] = src.shape[3];
|
|
614
|
+
|
|
615
|
+
return a;
|
|
616
|
+
}
|
|
617
|
+
|
|
618
|
+
template<template<typename> class A1, template<typename> class A2, template<typename> class A3, typename T>
|
|
619
|
+
inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, A2<T>& adj_src, int adj_i, A3<T> adj_ret) {}
|
|
620
|
+
template<template<typename> class A1, template<typename> class A2, template<typename> class A3, typename T>
|
|
621
|
+
inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, int j, A2<T>& adj_src, int adj_i, int adj_j, A3<T> adj_ret) {}
|
|
622
|
+
template<template<typename> class A1, template<typename> class A2, template<typename> class A3, typename T>
|
|
623
|
+
inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, int j, int k, A2<T>& adj_src, int adj_i, int adj_j, int adj_k, A3<T> adj_ret) {}
|
|
624
|
+
|
|
625
|
+
// TODO: lower_bound() for indexed arrays?
|
|
626
|
+
|
|
627
|
+
template <typename T>
|
|
628
|
+
CUDA_CALLABLE inline int lower_bound(const array_t<T>& arr, int arr_begin, int arr_end, T value)
|
|
629
|
+
{
|
|
630
|
+
assert(arr.ndim == 1);
|
|
631
|
+
|
|
632
|
+
int lower = arr_begin;
|
|
633
|
+
int upper = arr_end - 1;
|
|
634
|
+
|
|
635
|
+
while(lower < upper)
|
|
636
|
+
{
|
|
637
|
+
int mid = lower + (upper - lower) / 2;
|
|
638
|
+
|
|
639
|
+
if (arr[mid] < value)
|
|
640
|
+
{
|
|
641
|
+
lower = mid + 1;
|
|
642
|
+
}
|
|
643
|
+
else
|
|
644
|
+
{
|
|
645
|
+
upper = mid;
|
|
646
|
+
}
|
|
647
|
+
}
|
|
648
|
+
|
|
649
|
+
return lower;
|
|
650
|
+
}
|
|
651
|
+
|
|
652
|
+
template <typename T>
|
|
653
|
+
CUDA_CALLABLE inline int lower_bound(const array_t<T>& arr, T value)
|
|
654
|
+
{
|
|
655
|
+
return lower_bound(arr, 0, arr.shape[0], value);
|
|
656
|
+
}
|
|
657
|
+
|
|
658
|
+
template <typename T> inline CUDA_CALLABLE void adj_lower_bound(const array_t<T>& arr, T value, array_t<T> adj_arr, T adj_value, int adj_ret) {}
|
|
659
|
+
template <typename T> inline CUDA_CALLABLE void adj_lower_bound(const array_t<T>& arr, int arr_begin, int arr_end, T value, array_t<T> adj_arr, int adj_arr_begin, int adj_arr_end, T adj_value, int adj_ret) {}
|
|
660
|
+
|
|
661
|
+
template<template<typename> class A, typename T>
|
|
662
|
+
inline CUDA_CALLABLE T atomic_add(const A<T>& buf, int i, T value) { return atomic_add(&index(buf, i), value); }
|
|
663
|
+
template<template<typename> class A, typename T>
|
|
664
|
+
inline CUDA_CALLABLE T atomic_add(const A<T>& buf, int i, int j, T value) { return atomic_add(&index(buf, i, j), value); }
|
|
665
|
+
template<template<typename> class A, typename T>
|
|
666
|
+
inline CUDA_CALLABLE T atomic_add(const A<T>& buf, int i, int j, int k, T value) { return atomic_add(&index(buf, i, j, k), value); }
|
|
667
|
+
template<template<typename> class A, typename T>
|
|
668
|
+
inline CUDA_CALLABLE T atomic_add(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_add(&index(buf, i, j, k, l), value); }
|
|
669
|
+
|
|
670
|
+
template<template<typename> class A, typename T>
|
|
671
|
+
inline CUDA_CALLABLE T atomic_sub(const A<T>& buf, int i, T value) { return atomic_add(&index(buf, i), -value); }
|
|
672
|
+
template<template<typename> class A, typename T>
|
|
673
|
+
inline CUDA_CALLABLE T atomic_sub(const A<T>& buf, int i, int j, T value) { return atomic_add(&index(buf, i, j), -value); }
|
|
674
|
+
template<template<typename> class A, typename T>
|
|
675
|
+
inline CUDA_CALLABLE T atomic_sub(const A<T>& buf, int i, int j, int k, T value) { return atomic_add(&index(buf, i, j, k), -value); }
|
|
676
|
+
template<template<typename> class A, typename T>
|
|
677
|
+
inline CUDA_CALLABLE T atomic_sub(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_add(&index(buf, i, j, k, l), -value); }
|
|
678
|
+
|
|
679
|
+
template<template<typename> class A, typename T>
|
|
680
|
+
inline CUDA_CALLABLE T atomic_min(const A<T>& buf, int i, T value) { return atomic_min(&index(buf, i), value); }
|
|
681
|
+
template<template<typename> class A, typename T>
|
|
682
|
+
inline CUDA_CALLABLE T atomic_min(const A<T>& buf, int i, int j, T value) { return atomic_min(&index(buf, i, j), value); }
|
|
683
|
+
template<template<typename> class A, typename T>
|
|
684
|
+
inline CUDA_CALLABLE T atomic_min(const A<T>& buf, int i, int j, int k, T value) { return atomic_min(&index(buf, i, j, k), value); }
|
|
685
|
+
template<template<typename> class A, typename T>
|
|
686
|
+
inline CUDA_CALLABLE T atomic_min(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_min(&index(buf, i, j, k, l), value); }
|
|
687
|
+
|
|
688
|
+
template<template<typename> class A, typename T>
|
|
689
|
+
inline CUDA_CALLABLE T atomic_max(const A<T>& buf, int i, T value) { return atomic_max(&index(buf, i), value); }
|
|
690
|
+
template<template<typename> class A, typename T>
|
|
691
|
+
inline CUDA_CALLABLE T atomic_max(const A<T>& buf, int i, int j, T value) { return atomic_max(&index(buf, i, j), value); }
|
|
692
|
+
template<template<typename> class A, typename T>
|
|
693
|
+
inline CUDA_CALLABLE T atomic_max(const A<T>& buf, int i, int j, int k, T value) { return atomic_max(&index(buf, i, j, k), value); }
|
|
694
|
+
template<template<typename> class A, typename T>
|
|
695
|
+
inline CUDA_CALLABLE T atomic_max(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_max(&index(buf, i, j, k, l), value); }
|
|
696
|
+
|
|
697
|
+
template<template<typename> class A, typename T>
|
|
698
|
+
inline CUDA_CALLABLE T* address(const A<T>& buf, int i) { return &index(buf, i); }
|
|
699
|
+
template<template<typename> class A, typename T>
|
|
700
|
+
inline CUDA_CALLABLE T* address(const A<T>& buf, int i, int j) { return &index(buf, i, j); }
|
|
701
|
+
template<template<typename> class A, typename T>
|
|
702
|
+
inline CUDA_CALLABLE T* address(const A<T>& buf, int i, int j, int k) { return &index(buf, i, j, k); }
|
|
703
|
+
template<template<typename> class A, typename T>
|
|
704
|
+
inline CUDA_CALLABLE T* address(const A<T>& buf, int i, int j, int k, int l) { return &index(buf, i, j, k, l); }
|
|
705
|
+
|
|
706
|
+
template<template<typename> class A, typename T>
|
|
707
|
+
inline CUDA_CALLABLE void array_store(const A<T>& buf, int i, T value)
|
|
708
|
+
{
|
|
709
|
+
FP_VERIFY_FWD_1(value)
|
|
710
|
+
|
|
711
|
+
index(buf, i) = value;
|
|
712
|
+
}
|
|
713
|
+
template<template<typename> class A, typename T>
|
|
714
|
+
inline CUDA_CALLABLE void array_store(const A<T>& buf, int i, int j, T value)
|
|
715
|
+
{
|
|
716
|
+
FP_VERIFY_FWD_2(value)
|
|
717
|
+
|
|
718
|
+
index(buf, i, j) = value;
|
|
719
|
+
}
|
|
720
|
+
template<template<typename> class A, typename T>
|
|
721
|
+
inline CUDA_CALLABLE void array_store(const A<T>& buf, int i, int j, int k, T value)
|
|
722
|
+
{
|
|
723
|
+
FP_VERIFY_FWD_3(value)
|
|
724
|
+
|
|
725
|
+
index(buf, i, j, k) = value;
|
|
726
|
+
}
|
|
727
|
+
template<template<typename> class A, typename T>
|
|
728
|
+
inline CUDA_CALLABLE void array_store(const A<T>& buf, int i, int j, int k, int l, T value)
|
|
729
|
+
{
|
|
730
|
+
FP_VERIFY_FWD_4(value)
|
|
731
|
+
|
|
732
|
+
index(buf, i, j, k, l) = value;
|
|
733
|
+
}
|
|
734
|
+
|
|
735
|
+
template<typename T>
|
|
736
|
+
inline CUDA_CALLABLE void store(T* address, T value)
|
|
737
|
+
{
|
|
738
|
+
FP_VERIFY_FWD(value)
|
|
739
|
+
|
|
740
|
+
*address = value;
|
|
741
|
+
}
|
|
742
|
+
|
|
743
|
+
template<typename T>
|
|
744
|
+
inline CUDA_CALLABLE T load(T* address)
|
|
745
|
+
{
|
|
746
|
+
T value = *address;
|
|
747
|
+
FP_VERIFY_FWD(value)
|
|
748
|
+
|
|
749
|
+
return value;
|
|
750
|
+
}
|
|
751
|
+
|
|
752
|
+
// select operator to check for array being null
|
|
753
|
+
template <typename T1, typename T2>
|
|
754
|
+
CUDA_CALLABLE inline T2 select(const array_t<T1>& arr, const T2& a, const T2& b) { return arr.data?b:a; }
|
|
755
|
+
|
|
756
|
+
template <typename T1, typename T2>
|
|
757
|
+
CUDA_CALLABLE inline void adj_select(const array_t<T1>& arr, const T2& a, const T2& b, const array_t<T1>& adj_cond, T2& adj_a, T2& adj_b, const T2& adj_ret)
|
|
758
|
+
{
|
|
759
|
+
if (arr.data)
|
|
760
|
+
adj_b += adj_ret;
|
|
761
|
+
else
|
|
762
|
+
adj_a += adj_ret;
|
|
763
|
+
}
|
|
764
|
+
|
|
765
|
+
// stub for the case where we have an nested array inside a struct and
|
|
766
|
+
// atomic add the whole struct onto an array (e.g.: during backwards pass)
|
|
767
|
+
template <typename T>
|
|
768
|
+
CUDA_CALLABLE inline void atomic_add(array_t<T>*, array_t<T>) {}
|
|
769
|
+
|
|
770
|
+
// for float and vector types this is just an alias for an atomic add
|
|
771
|
+
template <typename T>
|
|
772
|
+
CUDA_CALLABLE inline void adj_atomic_add(T* buf, T value) { atomic_add(buf, value); }
|
|
773
|
+
|
|
774
|
+
|
|
775
|
+
// for integral types we do not accumulate gradients
|
|
776
|
+
CUDA_CALLABLE inline void adj_atomic_add(int8* buf, int8 value) { }
|
|
777
|
+
CUDA_CALLABLE inline void adj_atomic_add(uint8* buf, uint8 value) { }
|
|
778
|
+
CUDA_CALLABLE inline void adj_atomic_add(int16* buf, int16 value) { }
|
|
779
|
+
CUDA_CALLABLE inline void adj_atomic_add(uint16* buf, uint16 value) { }
|
|
780
|
+
CUDA_CALLABLE inline void adj_atomic_add(int32* buf, int32 value) { }
|
|
781
|
+
CUDA_CALLABLE inline void adj_atomic_add(uint32* buf, uint32 value) { }
|
|
782
|
+
CUDA_CALLABLE inline void adj_atomic_add(int64* buf, int64 value) { }
|
|
783
|
+
CUDA_CALLABLE inline void adj_atomic_add(uint64* buf, uint64 value) { }
|
|
784
|
+
|
|
785
|
+
CUDA_CALLABLE inline void adj_atomic_add(bool* buf, bool value) { }
|
|
786
|
+
|
|
787
|
+
// only generate gradients for T types
|
|
788
|
+
template<typename T>
|
|
789
|
+
inline CUDA_CALLABLE void adj_address(const array_t<T>& buf, int i, const array_t<T>& adj_buf, int& adj_i, const T& adj_output)
|
|
790
|
+
{
|
|
791
|
+
if (adj_buf.data)
|
|
792
|
+
adj_atomic_add(&index(adj_buf, i), adj_output);
|
|
793
|
+
else if (buf.grad)
|
|
794
|
+
adj_atomic_add(&index_grad(buf, i), adj_output);
|
|
795
|
+
}
|
|
796
|
+
template<typename T>
|
|
797
|
+
inline CUDA_CALLABLE void adj_address(const array_t<T>& buf, int i, int j, const array_t<T>& adj_buf, int& adj_i, int& adj_j, const T& adj_output)
|
|
798
|
+
{
|
|
799
|
+
if (adj_buf.data)
|
|
800
|
+
adj_atomic_add(&index(adj_buf, i, j), adj_output);
|
|
801
|
+
else if (buf.grad)
|
|
802
|
+
adj_atomic_add(&index_grad(buf, i, j), adj_output);
|
|
803
|
+
}
|
|
804
|
+
template<typename T>
|
|
805
|
+
inline CUDA_CALLABLE void adj_address(const array_t<T>& buf, int i, int j, int k, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, const T& adj_output)
|
|
806
|
+
{
|
|
807
|
+
if (adj_buf.data)
|
|
808
|
+
adj_atomic_add(&index(adj_buf, i, j, k), adj_output);
|
|
809
|
+
else if (buf.grad)
|
|
810
|
+
adj_atomic_add(&index_grad(buf, i, j, k), adj_output);
|
|
811
|
+
}
|
|
812
|
+
template<typename T>
|
|
813
|
+
inline CUDA_CALLABLE void adj_address(const array_t<T>& buf, int i, int j, int k, int l, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, const T& adj_output)
|
|
814
|
+
{
|
|
815
|
+
if (adj_buf.data)
|
|
816
|
+
adj_atomic_add(&index(adj_buf, i, j, k, l), adj_output);
|
|
817
|
+
else if (buf.grad)
|
|
818
|
+
adj_atomic_add(&index_grad(buf, i, j, k, l), adj_output);
|
|
819
|
+
}
|
|
820
|
+
|
|
821
|
+
template<typename T>
|
|
822
|
+
inline CUDA_CALLABLE void adj_array_store(const array_t<T>& buf, int i, T value, const array_t<T>& adj_buf, int& adj_i, T& adj_value)
|
|
823
|
+
{
|
|
824
|
+
if (adj_buf.data)
|
|
825
|
+
adj_value += index(adj_buf, i);
|
|
826
|
+
else if (buf.grad)
|
|
827
|
+
adj_value += index_grad(buf, i);
|
|
828
|
+
|
|
829
|
+
FP_VERIFY_ADJ_1(value, adj_value)
|
|
830
|
+
}
|
|
831
|
+
template<typename T>
|
|
832
|
+
inline CUDA_CALLABLE void adj_array_store(const array_t<T>& buf, int i, int j, T value, const array_t<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value)
|
|
833
|
+
{
|
|
834
|
+
if (adj_buf.data)
|
|
835
|
+
adj_value += index(adj_buf, i, j);
|
|
836
|
+
else if (buf.grad)
|
|
837
|
+
adj_value += index_grad(buf, i, j);
|
|
838
|
+
|
|
839
|
+
FP_VERIFY_ADJ_2(value, adj_value)
|
|
840
|
+
}
|
|
841
|
+
template<typename T>
|
|
842
|
+
inline CUDA_CALLABLE void adj_array_store(const array_t<T>& buf, int i, int j, int k, T value, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value)
|
|
843
|
+
{
|
|
844
|
+
if (adj_buf.data)
|
|
845
|
+
adj_value += index(adj_buf, i, j, k);
|
|
846
|
+
else if (buf.grad)
|
|
847
|
+
adj_value += index_grad(buf, i, j, k);
|
|
848
|
+
|
|
849
|
+
FP_VERIFY_ADJ_3(value, adj_value)
|
|
850
|
+
}
|
|
851
|
+
template<typename T>
|
|
852
|
+
inline CUDA_CALLABLE void adj_array_store(const array_t<T>& buf, int i, int j, int k, int l, T value, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value)
|
|
853
|
+
{
|
|
854
|
+
if (adj_buf.data)
|
|
855
|
+
adj_value += index(adj_buf, i, j, k, l);
|
|
856
|
+
else if (buf.grad)
|
|
857
|
+
adj_value += index_grad(buf, i, j, k, l);
|
|
858
|
+
|
|
859
|
+
FP_VERIFY_ADJ_4(value, adj_value)
|
|
860
|
+
}
|
|
861
|
+
|
|
862
|
+
template<typename T>
|
|
863
|
+
inline CUDA_CALLABLE void adj_store(const T* address, T value, const T& adj_address, T& adj_value)
|
|
864
|
+
{
|
|
865
|
+
// nop; generic store() operations are not differentiable, only array_store() is
|
|
866
|
+
FP_VERIFY_ADJ(value, adj_value)
|
|
867
|
+
}
|
|
868
|
+
|
|
869
|
+
template<typename T>
|
|
870
|
+
inline CUDA_CALLABLE void adj_load(const T* address, const T& adj_address, T& adj_value)
|
|
871
|
+
{
|
|
872
|
+
// nop; generic load() operations are not differentiable
|
|
873
|
+
}
|
|
874
|
+
|
|
875
|
+
template<typename T>
|
|
876
|
+
inline CUDA_CALLABLE void adj_atomic_add(const array_t<T>& buf, int i, T value, const array_t<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret)
|
|
877
|
+
{
|
|
878
|
+
if (adj_buf.data)
|
|
879
|
+
adj_value += index(adj_buf, i);
|
|
880
|
+
else if (buf.grad)
|
|
881
|
+
adj_value += index_grad(buf, i);
|
|
882
|
+
|
|
883
|
+
FP_VERIFY_ADJ_1(value, adj_value)
|
|
884
|
+
}
|
|
885
|
+
template<typename T>
|
|
886
|
+
inline CUDA_CALLABLE void adj_atomic_add(const array_t<T>& buf, int i, int j, T value, const array_t<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value, const T& adj_ret)
|
|
887
|
+
{
|
|
888
|
+
if (adj_buf.data)
|
|
889
|
+
adj_value += index(adj_buf, i, j);
|
|
890
|
+
else if (buf.grad)
|
|
891
|
+
adj_value += index_grad(buf, i, j);
|
|
892
|
+
|
|
893
|
+
FP_VERIFY_ADJ_2(value, adj_value)
|
|
894
|
+
}
|
|
895
|
+
template<typename T>
|
|
896
|
+
inline CUDA_CALLABLE void adj_atomic_add(const array_t<T>& buf, int i, int j, int k, T value, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value, const T& adj_ret)
|
|
897
|
+
{
|
|
898
|
+
if (adj_buf.data)
|
|
899
|
+
adj_value += index(adj_buf, i, j, k);
|
|
900
|
+
else if (buf.grad)
|
|
901
|
+
adj_value += index_grad(buf, i, j, k);
|
|
902
|
+
|
|
903
|
+
FP_VERIFY_ADJ_3(value, adj_value)
|
|
904
|
+
}
|
|
905
|
+
template<typename T>
|
|
906
|
+
inline CUDA_CALLABLE void adj_atomic_add(const array_t<T>& buf, int i, int j, int k, int l, T value, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret)
|
|
907
|
+
{
|
|
908
|
+
if (adj_buf.data)
|
|
909
|
+
adj_value += index(adj_buf, i, j, k, l);
|
|
910
|
+
else if (buf.grad)
|
|
911
|
+
adj_value += index_grad(buf, i, j, k, l);
|
|
912
|
+
|
|
913
|
+
FP_VERIFY_ADJ_4(value, adj_value)
|
|
914
|
+
}
|
|
915
|
+
|
|
916
|
+
|
|
917
|
+
template<typename T>
|
|
918
|
+
inline CUDA_CALLABLE void adj_atomic_sub(const array_t<T>& buf, int i, T value, const array_t<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret)
|
|
919
|
+
{
|
|
920
|
+
if (adj_buf.data)
|
|
921
|
+
adj_value -= index(adj_buf, i);
|
|
922
|
+
else if (buf.grad)
|
|
923
|
+
adj_value -= index_grad(buf, i);
|
|
924
|
+
|
|
925
|
+
FP_VERIFY_ADJ_1(value, adj_value)
|
|
926
|
+
}
|
|
927
|
+
template<typename T>
|
|
928
|
+
inline CUDA_CALLABLE void adj_atomic_sub(const array_t<T>& buf, int i, int j, T value, const array_t<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value, const T& adj_ret)
|
|
929
|
+
{
|
|
930
|
+
if (adj_buf.data)
|
|
931
|
+
adj_value -= index(adj_buf, i, j);
|
|
932
|
+
else if (buf.grad)
|
|
933
|
+
adj_value -= index_grad(buf, i, j);
|
|
934
|
+
|
|
935
|
+
FP_VERIFY_ADJ_2(value, adj_value)
|
|
936
|
+
}
|
|
937
|
+
template<typename T>
|
|
938
|
+
inline CUDA_CALLABLE void adj_atomic_sub(const array_t<T>& buf, int i, int j, int k, T value, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value, const T& adj_ret)
|
|
939
|
+
{
|
|
940
|
+
if (adj_buf.data)
|
|
941
|
+
adj_value -= index(adj_buf, i, j, k);
|
|
942
|
+
else if (buf.grad)
|
|
943
|
+
adj_value -= index_grad(buf, i, j, k);
|
|
944
|
+
|
|
945
|
+
FP_VERIFY_ADJ_3(value, adj_value)
|
|
946
|
+
}
|
|
947
|
+
template<typename T>
|
|
948
|
+
inline CUDA_CALLABLE void adj_atomic_sub(const array_t<T>& buf, int i, int j, int k, int l, T value, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret)
|
|
949
|
+
{
|
|
950
|
+
if (adj_buf.data)
|
|
951
|
+
adj_value -= index(adj_buf, i, j, k, l);
|
|
952
|
+
else if (buf.grad)
|
|
953
|
+
adj_value -= index_grad(buf, i, j, k, l);
|
|
954
|
+
|
|
955
|
+
FP_VERIFY_ADJ_4(value, adj_value)
|
|
956
|
+
}
|
|
957
|
+
|
|
958
|
+
// generic array types that do not support gradient computation (indexedarray, etc.)
|
|
959
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
960
|
+
inline CUDA_CALLABLE void adj_address(const A1<T>& buf, int i, const A2<T>& adj_buf, int& adj_i, const T& adj_output) {}
|
|
961
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
962
|
+
inline CUDA_CALLABLE void adj_address(const A1<T>& buf, int i, int j, const A2<T>& adj_buf, int& adj_i, int& adj_j, const T& adj_output) {}
|
|
963
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
964
|
+
inline CUDA_CALLABLE void adj_address(const A1<T>& buf, int i, int j, int k, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, const T& adj_output) {}
|
|
965
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
966
|
+
inline CUDA_CALLABLE void adj_address(const A1<T>& buf, int i, int j, int k, int l, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, const T& adj_output) {}
|
|
967
|
+
|
|
968
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
969
|
+
inline CUDA_CALLABLE void adj_array_store(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value) {}
|
|
970
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
971
|
+
inline CUDA_CALLABLE void adj_array_store(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value) {}
|
|
972
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
973
|
+
inline CUDA_CALLABLE void adj_array_store(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value) {}
|
|
974
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
975
|
+
inline CUDA_CALLABLE void adj_array_store(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value) {}
|
|
976
|
+
|
|
977
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
978
|
+
inline CUDA_CALLABLE void adj_atomic_add(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) {}
|
|
979
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
980
|
+
inline CUDA_CALLABLE void adj_atomic_add(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value, const T& adj_ret) {}
|
|
981
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
982
|
+
inline CUDA_CALLABLE void adj_atomic_add(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value, const T& adj_ret) {}
|
|
983
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
984
|
+
inline CUDA_CALLABLE void adj_atomic_add(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) {}
|
|
985
|
+
|
|
986
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
987
|
+
inline CUDA_CALLABLE void adj_atomic_sub(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) {}
|
|
988
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
989
|
+
inline CUDA_CALLABLE void adj_atomic_sub(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value, const T& adj_ret) {}
|
|
990
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
991
|
+
inline CUDA_CALLABLE void adj_atomic_sub(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value, const T& adj_ret) {}
|
|
992
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
993
|
+
inline CUDA_CALLABLE void adj_atomic_sub(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) {}
|
|
994
|
+
|
|
995
|
+
// generic handler for scalar values
|
|
996
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
997
|
+
inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) {
|
|
998
|
+
if (adj_buf.data)
|
|
999
|
+
adj_atomic_minmax(&index(buf, i), &index(adj_buf, i), value, adj_value);
|
|
1000
|
+
else if (buf.grad)
|
|
1001
|
+
adj_atomic_minmax(&index(buf, i), &index_grad(buf, i), value, adj_value);
|
|
1002
|
+
|
|
1003
|
+
FP_VERIFY_ADJ_1(value, adj_value)
|
|
1004
|
+
}
|
|
1005
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
1006
|
+
inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value, const T& adj_ret) {
|
|
1007
|
+
if (adj_buf.data)
|
|
1008
|
+
adj_atomic_minmax(&index(buf, i, j), &index(adj_buf, i, j), value, adj_value);
|
|
1009
|
+
else if (buf.grad)
|
|
1010
|
+
adj_atomic_minmax(&index(buf, i, j), &index_grad(buf, i, j), value, adj_value);
|
|
1011
|
+
|
|
1012
|
+
FP_VERIFY_ADJ_2(value, adj_value)
|
|
1013
|
+
}
|
|
1014
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
1015
|
+
inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value, const T& adj_ret) {
|
|
1016
|
+
if (adj_buf.data)
|
|
1017
|
+
adj_atomic_minmax(&index(buf, i, j, k), &index(adj_buf, i, j, k), value, adj_value);
|
|
1018
|
+
else if (buf.grad)
|
|
1019
|
+
adj_atomic_minmax(&index(buf, i, j, k), &index_grad(buf, i, j, k), value, adj_value);
|
|
1020
|
+
|
|
1021
|
+
FP_VERIFY_ADJ_3(value, adj_value)
|
|
1022
|
+
}
|
|
1023
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
1024
|
+
inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) {
|
|
1025
|
+
if (adj_buf.data)
|
|
1026
|
+
adj_atomic_minmax(&index(buf, i, j, k, l), &index(adj_buf, i, j, k, l), value, adj_value);
|
|
1027
|
+
else if (buf.grad)
|
|
1028
|
+
adj_atomic_minmax(&index(buf, i, j, k, l), &index_grad(buf, i, j, k, l), value, adj_value);
|
|
1029
|
+
|
|
1030
|
+
FP_VERIFY_ADJ_4(value, adj_value)
|
|
1031
|
+
}
|
|
1032
|
+
|
|
1033
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
1034
|
+
inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) {
|
|
1035
|
+
if (adj_buf.data)
|
|
1036
|
+
adj_atomic_minmax(&index(buf, i), &index(adj_buf, i), value, adj_value);
|
|
1037
|
+
else if (buf.grad)
|
|
1038
|
+
adj_atomic_minmax(&index(buf, i), &index_grad(buf, i), value, adj_value);
|
|
1039
|
+
|
|
1040
|
+
FP_VERIFY_ADJ_1(value, adj_value)
|
|
1041
|
+
}
|
|
1042
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
1043
|
+
inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value, const T& adj_ret) {
|
|
1044
|
+
if (adj_buf.data)
|
|
1045
|
+
adj_atomic_minmax(&index(buf, i, j), &index(adj_buf, i, j), value, adj_value);
|
|
1046
|
+
else if (buf.grad)
|
|
1047
|
+
adj_atomic_minmax(&index(buf, i, j), &index_grad(buf, i, j), value, adj_value);
|
|
1048
|
+
|
|
1049
|
+
FP_VERIFY_ADJ_2(value, adj_value)
|
|
1050
|
+
}
|
|
1051
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
1052
|
+
inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value, const T& adj_ret) {
|
|
1053
|
+
if (adj_buf.data)
|
|
1054
|
+
adj_atomic_minmax(&index(buf, i, j, k), &index(adj_buf, i, j, k), value, adj_value);
|
|
1055
|
+
else if (buf.grad)
|
|
1056
|
+
adj_atomic_minmax(&index(buf, i, j, k), &index_grad(buf, i, j, k), value, adj_value);
|
|
1057
|
+
|
|
1058
|
+
FP_VERIFY_ADJ_3(value, adj_value)
|
|
1059
|
+
}
|
|
1060
|
+
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
1061
|
+
inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) {
|
|
1062
|
+
if (adj_buf.data)
|
|
1063
|
+
adj_atomic_minmax(&index(buf, i, j, k, l), &index(adj_buf, i, j, k, l), value, adj_value);
|
|
1064
|
+
else if (buf.grad)
|
|
1065
|
+
adj_atomic_minmax(&index(buf, i, j, k, l), &index_grad(buf, i, j, k, l), value, adj_value);
|
|
1066
|
+
|
|
1067
|
+
FP_VERIFY_ADJ_4(value, adj_value)
|
|
1068
|
+
}
|
|
1069
|
+
|
|
1070
|
+
} // namespace wp
|
|
1071
|
+
|
|
1072
|
+
#include "fabric.h"
|