warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +15 -7
- warp/__init__.pyi +1 -0
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +22 -443
- warp/build_dll.py +384 -0
- warp/builtins.py +998 -488
- warp/codegen.py +1307 -739
- warp/config.py +5 -3
- warp/constants.py +6 -0
- warp/context.py +1291 -548
- warp/dlpack.py +31 -31
- warp/fabric.py +326 -0
- warp/fem/__init__.py +27 -0
- warp/fem/cache.py +389 -0
- warp/fem/dirichlet.py +181 -0
- warp/fem/domain.py +263 -0
- warp/fem/field/__init__.py +101 -0
- warp/fem/field/field.py +149 -0
- warp/fem/field/nodal_field.py +299 -0
- warp/fem/field/restriction.py +21 -0
- warp/fem/field/test.py +181 -0
- warp/fem/field/trial.py +183 -0
- warp/fem/geometry/__init__.py +19 -0
- warp/fem/geometry/closest_point.py +70 -0
- warp/fem/geometry/deformed_geometry.py +271 -0
- warp/fem/geometry/element.py +744 -0
- warp/fem/geometry/geometry.py +186 -0
- warp/fem/geometry/grid_2d.py +373 -0
- warp/fem/geometry/grid_3d.py +435 -0
- warp/fem/geometry/hexmesh.py +953 -0
- warp/fem/geometry/partition.py +376 -0
- warp/fem/geometry/quadmesh_2d.py +532 -0
- warp/fem/geometry/tetmesh.py +840 -0
- warp/fem/geometry/trimesh_2d.py +577 -0
- warp/fem/integrate.py +1616 -0
- warp/fem/operator.py +191 -0
- warp/fem/polynomial.py +213 -0
- warp/fem/quadrature/__init__.py +2 -0
- warp/fem/quadrature/pic_quadrature.py +245 -0
- warp/fem/quadrature/quadrature.py +294 -0
- warp/fem/space/__init__.py +292 -0
- warp/fem/space/basis_space.py +489 -0
- warp/fem/space/collocated_function_space.py +105 -0
- warp/fem/space/dof_mapper.py +236 -0
- warp/fem/space/function_space.py +145 -0
- warp/fem/space/grid_2d_function_space.py +267 -0
- warp/fem/space/grid_3d_function_space.py +306 -0
- warp/fem/space/hexmesh_function_space.py +352 -0
- warp/fem/space/partition.py +350 -0
- warp/fem/space/quadmesh_2d_function_space.py +369 -0
- warp/fem/space/restriction.py +160 -0
- warp/fem/space/shape/__init__.py +15 -0
- warp/fem/space/shape/cube_shape_function.py +738 -0
- warp/fem/space/shape/shape_function.py +103 -0
- warp/fem/space/shape/square_shape_function.py +611 -0
- warp/fem/space/shape/tet_shape_function.py +567 -0
- warp/fem/space/shape/triangle_shape_function.py +429 -0
- warp/fem/space/tetmesh_function_space.py +292 -0
- warp/fem/space/topology.py +295 -0
- warp/fem/space/trimesh_2d_function_space.py +221 -0
- warp/fem/types.py +77 -0
- warp/fem/utils.py +495 -0
- warp/native/array.h +164 -55
- warp/native/builtin.h +150 -174
- warp/native/bvh.cpp +75 -328
- warp/native/bvh.cu +406 -23
- warp/native/bvh.h +37 -45
- warp/native/clang/clang.cpp +136 -24
- warp/native/crt.cpp +1 -76
- warp/native/crt.h +111 -104
- warp/native/cuda_crt.h +1049 -0
- warp/native/cuda_util.cpp +15 -3
- warp/native/cuda_util.h +3 -1
- warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
- warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
- warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
- warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
- warp/native/cutlass/tools/library/scripts/library.py +799 -0
- warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
- warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
- warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
- warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
- warp/native/cutlass/tools/library/scripts/rt.py +796 -0
- warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
- warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
- warp/native/cutlass_gemm.cu +5 -3
- warp/native/exports.h +1240 -949
- warp/native/fabric.h +228 -0
- warp/native/hashgrid.cpp +4 -4
- warp/native/hashgrid.h +22 -2
- warp/native/initializer_array.h +2 -2
- warp/native/intersect.h +22 -7
- warp/native/intersect_adj.h +8 -8
- warp/native/intersect_tri.h +13 -16
- warp/native/marching.cu +157 -161
- warp/native/mat.h +119 -19
- warp/native/matnn.h +2 -2
- warp/native/mesh.cpp +108 -83
- warp/native/mesh.cu +243 -6
- warp/native/mesh.h +1547 -458
- warp/native/nanovdb/NanoVDB.h +1 -1
- warp/native/noise.h +272 -329
- warp/native/quat.h +51 -8
- warp/native/rand.h +45 -35
- warp/native/range.h +6 -2
- warp/native/reduce.cpp +157 -0
- warp/native/reduce.cu +348 -0
- warp/native/runlength_encode.cpp +62 -0
- warp/native/runlength_encode.cu +46 -0
- warp/native/scan.cu +11 -13
- warp/native/scan.h +1 -0
- warp/native/solid_angle.h +442 -0
- warp/native/sort.cpp +13 -0
- warp/native/sort.cu +9 -1
- warp/native/sparse.cpp +338 -0
- warp/native/sparse.cu +545 -0
- warp/native/spatial.h +2 -2
- warp/native/temp_buffer.h +30 -0
- warp/native/vec.h +126 -24
- warp/native/volume.h +120 -0
- warp/native/warp.cpp +658 -53
- warp/native/warp.cu +660 -68
- warp/native/warp.h +112 -12
- warp/optim/__init__.py +1 -0
- warp/optim/linear.py +922 -0
- warp/optim/sgd.py +92 -0
- warp/render/render_opengl.py +392 -152
- warp/render/render_usd.py +11 -11
- warp/sim/__init__.py +2 -2
- warp/sim/articulation.py +385 -185
- warp/sim/collide.py +21 -8
- warp/sim/import_mjcf.py +297 -106
- warp/sim/import_urdf.py +389 -210
- warp/sim/import_usd.py +198 -97
- warp/sim/inertia.py +17 -18
- warp/sim/integrator_euler.py +14 -8
- warp/sim/integrator_xpbd.py +161 -19
- warp/sim/model.py +795 -291
- warp/sim/optimizer.py +2 -6
- warp/sim/render.py +65 -3
- warp/sim/utils.py +3 -0
- warp/sparse.py +1227 -0
- warp/stubs.py +665 -223
- warp/tape.py +66 -15
- warp/tests/__main__.py +3 -6
- warp/tests/assets/curlnoise_golden.npy +0 -0
- warp/tests/assets/pnoise_golden.npy +0 -0
- warp/tests/assets/torus.usda +105 -105
- warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
- warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
- warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
- warp/tests/aux_test_unresolved_func.py +14 -0
- warp/tests/aux_test_unresolved_symbol.py +14 -0
- warp/tests/disabled_kinematics.py +239 -0
- warp/tests/run_coverage_serial.py +31 -0
- warp/tests/test_adam.py +103 -106
- warp/tests/test_arithmetic.py +128 -74
- warp/tests/test_array.py +1497 -211
- warp/tests/test_array_reduce.py +150 -0
- warp/tests/test_atomic.py +64 -28
- warp/tests/test_bool.py +99 -0
- warp/tests/test_builtins_resolution.py +1292 -0
- warp/tests/test_bvh.py +75 -43
- warp/tests/test_closest_point_edge_edge.py +54 -57
- warp/tests/test_codegen.py +233 -128
- warp/tests/test_compile_consts.py +28 -20
- warp/tests/test_conditional.py +108 -24
- warp/tests/test_copy.py +10 -12
- warp/tests/test_ctypes.py +112 -88
- warp/tests/test_dense.py +21 -14
- warp/tests/test_devices.py +98 -0
- warp/tests/test_dlpack.py +136 -108
- warp/tests/test_examples.py +277 -0
- warp/tests/test_fabricarray.py +955 -0
- warp/tests/test_fast_math.py +15 -11
- warp/tests/test_fem.py +1271 -0
- warp/tests/test_fp16.py +53 -19
- warp/tests/test_func.py +187 -74
- warp/tests/test_generics.py +194 -49
- warp/tests/test_grad.py +180 -116
- warp/tests/test_grad_customs.py +176 -0
- warp/tests/test_hash_grid.py +52 -37
- warp/tests/test_import.py +10 -23
- warp/tests/test_indexedarray.py +577 -24
- warp/tests/test_intersect.py +18 -9
- warp/tests/test_large.py +141 -0
- warp/tests/test_launch.py +251 -15
- warp/tests/test_lerp.py +64 -65
- warp/tests/test_linear_solvers.py +154 -0
- warp/tests/test_lvalue.py +493 -0
- warp/tests/test_marching_cubes.py +12 -13
- warp/tests/test_mat.py +508 -2778
- warp/tests/test_mat_lite.py +115 -0
- warp/tests/test_mat_scalar_ops.py +2889 -0
- warp/tests/test_math.py +103 -9
- warp/tests/test_matmul.py +305 -69
- warp/tests/test_matmul_lite.py +410 -0
- warp/tests/test_mesh.py +71 -14
- warp/tests/test_mesh_query_aabb.py +41 -25
- warp/tests/test_mesh_query_point.py +325 -34
- warp/tests/test_mesh_query_ray.py +39 -22
- warp/tests/test_mlp.py +30 -22
- warp/tests/test_model.py +92 -89
- warp/tests/test_modules_lite.py +39 -0
- warp/tests/test_multigpu.py +88 -114
- warp/tests/test_noise.py +12 -11
- warp/tests/test_operators.py +16 -20
- warp/tests/test_options.py +11 -11
- warp/tests/test_pinned.py +17 -18
- warp/tests/test_print.py +32 -11
- warp/tests/test_quat.py +275 -129
- warp/tests/test_rand.py +18 -16
- warp/tests/test_reload.py +38 -34
- warp/tests/test_rounding.py +50 -43
- warp/tests/test_runlength_encode.py +190 -0
- warp/tests/test_smoothstep.py +9 -11
- warp/tests/test_snippet.py +143 -0
- warp/tests/test_sparse.py +460 -0
- warp/tests/test_spatial.py +276 -243
- warp/tests/test_streams.py +110 -85
- warp/tests/test_struct.py +331 -85
- warp/tests/test_tape.py +39 -21
- warp/tests/test_torch.py +118 -89
- warp/tests/test_transient_module.py +12 -13
- warp/tests/test_types.py +614 -0
- warp/tests/test_utils.py +494 -0
- warp/tests/test_vec.py +354 -1987
- warp/tests/test_vec_lite.py +73 -0
- warp/tests/test_vec_scalar_ops.py +2099 -0
- warp/tests/test_volume.py +457 -293
- warp/tests/test_volume_write.py +124 -134
- warp/tests/unittest_serial.py +35 -0
- warp/tests/unittest_suites.py +341 -0
- warp/tests/unittest_utils.py +568 -0
- warp/tests/unused_test_misc.py +71 -0
- warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
- warp/thirdparty/appdirs.py +36 -45
- warp/thirdparty/unittest_parallel.py +549 -0
- warp/torch.py +72 -30
- warp/types.py +1744 -713
- warp/utils.py +360 -350
- warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
- warp_lang-0.11.0.dist-info/METADATA +238 -0
- warp_lang-0.11.0.dist-info/RECORD +332 -0
- {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
- warp/bin/warp-clang.exp +0 -0
- warp/bin/warp-clang.lib +0 -0
- warp/bin/warp.exp +0 -0
- warp/bin/warp.lib +0 -0
- warp/tests/test_all.py +0 -215
- warp/tests/test_array_scan.py +0 -60
- warp/tests/test_base.py +0 -208
- warp/tests/test_unresolved_func.py +0 -7
- warp/tests/test_unresolved_symbol.py +0 -7
- warp_lang-0.9.0.dist-info/METADATA +0 -20
- warp_lang-0.9.0.dist-info/RECORD +0 -177
- /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
- /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
- /warp/tests/{test_square.py → aux_test_square.py} +0 -0
- {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/native/array.h
CHANGED
|
@@ -19,6 +19,12 @@ namespace wp
|
|
|
19
19
|
printf(")\n"); \
|
|
20
20
|
assert(0); \
|
|
21
21
|
|
|
22
|
+
#define FP_VERIFY_FWD(value) \
|
|
23
|
+
if (!isfinite(value)) { \
|
|
24
|
+
printf("%s:%d - %s(addr", __FILE__, __LINE__, __FUNCTION__); \
|
|
25
|
+
FP_ASSERT_FWD(value) \
|
|
26
|
+
} \
|
|
27
|
+
|
|
22
28
|
#define FP_VERIFY_FWD_1(value) \
|
|
23
29
|
if (!isfinite(value)) { \
|
|
24
30
|
printf("%s:%d - %s(arr, %d) ", __FILE__, __LINE__, __FUNCTION__, i); \
|
|
@@ -43,6 +49,13 @@ namespace wp
|
|
|
43
49
|
FP_ASSERT_FWD(value) \
|
|
44
50
|
} \
|
|
45
51
|
|
|
52
|
+
#define FP_VERIFY_ADJ(value, adj_value) \
|
|
53
|
+
if (!isfinite(value) || !isfinite(adj_value)) \
|
|
54
|
+
{ \
|
|
55
|
+
printf("%s:%d - %s(addr", __FILE__, __LINE__, __FUNCTION__); \
|
|
56
|
+
FP_ASSERT_ADJ(value, adj_value); \
|
|
57
|
+
} \
|
|
58
|
+
|
|
46
59
|
#define FP_VERIFY_ADJ_1(value, adj_value) \
|
|
47
60
|
if (!isfinite(value) || !isfinite(adj_value)) \
|
|
48
61
|
{ \
|
|
@@ -74,11 +87,13 @@ namespace wp
|
|
|
74
87
|
|
|
75
88
|
#else
|
|
76
89
|
|
|
90
|
+
#define FP_VERIFY_FWD(value) {}
|
|
77
91
|
#define FP_VERIFY_FWD_1(value) {}
|
|
78
92
|
#define FP_VERIFY_FWD_2(value) {}
|
|
79
93
|
#define FP_VERIFY_FWD_3(value) {}
|
|
80
94
|
#define FP_VERIFY_FWD_4(value) {}
|
|
81
95
|
|
|
96
|
+
#define FP_VERIFY_ADJ(value, adj_value) {}
|
|
82
97
|
#define FP_VERIFY_ADJ_1(value, adj_value) {}
|
|
83
98
|
#define FP_VERIFY_ADJ_2(value, adj_value) {}
|
|
84
99
|
#define FP_VERIFY_ADJ_3(value, adj_value) {}
|
|
@@ -88,14 +103,19 @@ namespace wp
|
|
|
88
103
|
|
|
89
104
|
const int ARRAY_MAX_DIMS = 4; // must match constant in types.py
|
|
90
105
|
|
|
91
|
-
|
|
92
|
-
const int
|
|
106
|
+
// must match constants in types.py
|
|
107
|
+
const int ARRAY_TYPE_REGULAR = 0;
|
|
108
|
+
const int ARRAY_TYPE_INDEXED = 1;
|
|
109
|
+
const int ARRAY_TYPE_FABRIC = 2;
|
|
110
|
+
const int ARRAY_TYPE_FABRIC_INDEXED = 3;
|
|
93
111
|
|
|
94
112
|
struct shape_t
|
|
95
113
|
{
|
|
96
114
|
int dims[ARRAY_MAX_DIMS];
|
|
97
115
|
|
|
98
|
-
CUDA_CALLABLE inline shape_t()
|
|
116
|
+
CUDA_CALLABLE inline shape_t()
|
|
117
|
+
: dims()
|
|
118
|
+
{}
|
|
99
119
|
|
|
100
120
|
CUDA_CALLABLE inline int operator[](int i) const
|
|
101
121
|
{
|
|
@@ -110,12 +130,12 @@ struct shape_t
|
|
|
110
130
|
}
|
|
111
131
|
};
|
|
112
132
|
|
|
113
|
-
CUDA_CALLABLE inline int
|
|
133
|
+
CUDA_CALLABLE inline int extract(const shape_t& s, int i)
|
|
114
134
|
{
|
|
115
135
|
return s.dims[i];
|
|
116
136
|
}
|
|
117
137
|
|
|
118
|
-
CUDA_CALLABLE inline void
|
|
138
|
+
CUDA_CALLABLE inline void adj_extract(const shape_t& s, int i, const shape_t& adj_s, int adj_i, int adj_ret) {}
|
|
119
139
|
|
|
120
140
|
inline CUDA_CALLABLE void print(shape_t s)
|
|
121
141
|
{
|
|
@@ -130,10 +150,15 @@ inline CUDA_CALLABLE void adj_print(shape_t s, shape_t& shape_t) {}
|
|
|
130
150
|
template <typename T>
|
|
131
151
|
struct array_t
|
|
132
152
|
{
|
|
133
|
-
CUDA_CALLABLE inline array_t()
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
153
|
+
CUDA_CALLABLE inline array_t()
|
|
154
|
+
: data(nullptr),
|
|
155
|
+
grad(nullptr),
|
|
156
|
+
shape(),
|
|
157
|
+
strides(),
|
|
158
|
+
ndim(0)
|
|
159
|
+
{}
|
|
160
|
+
|
|
161
|
+
CUDA_CALLABLE array_t(T* data, int size, T* grad=nullptr) : data(data), grad(grad) {
|
|
137
162
|
// constructor for 1d array
|
|
138
163
|
shape.dims[0] = size;
|
|
139
164
|
shape.dims[1] = 0;
|
|
@@ -145,7 +170,7 @@ struct array_t
|
|
|
145
170
|
strides[2] = 0;
|
|
146
171
|
strides[3] = 0;
|
|
147
172
|
}
|
|
148
|
-
array_t(T* data, int dim0, int dim1, T* grad=nullptr) : data(data), grad(grad) {
|
|
173
|
+
CUDA_CALLABLE array_t(T* data, int dim0, int dim1, T* grad=nullptr) : data(data), grad(grad) {
|
|
149
174
|
// constructor for 2d array
|
|
150
175
|
shape.dims[0] = dim0;
|
|
151
176
|
shape.dims[1] = dim1;
|
|
@@ -157,7 +182,7 @@ struct array_t
|
|
|
157
182
|
strides[2] = 0;
|
|
158
183
|
strides[3] = 0;
|
|
159
184
|
}
|
|
160
|
-
array_t(T* data, int dim0, int dim1, int dim2, T* grad=nullptr) : data(data), grad(grad) {
|
|
185
|
+
CUDA_CALLABLE array_t(T* data, int dim0, int dim1, int dim2, T* grad=nullptr) : data(data), grad(grad) {
|
|
161
186
|
// constructor for 3d array
|
|
162
187
|
shape.dims[0] = dim0;
|
|
163
188
|
shape.dims[1] = dim1;
|
|
@@ -169,7 +194,7 @@ struct array_t
|
|
|
169
194
|
strides[2] = sizeof(T);
|
|
170
195
|
strides[3] = 0;
|
|
171
196
|
}
|
|
172
|
-
array_t(T* data, int dim0, int dim1, int dim2, int dim3, T* grad=nullptr) : data(data), grad(grad) {
|
|
197
|
+
CUDA_CALLABLE array_t(T* data, int dim0, int dim1, int dim2, int dim3, T* grad=nullptr) : data(data), grad(grad) {
|
|
173
198
|
// constructor for 4d array
|
|
174
199
|
shape.dims[0] = dim0;
|
|
175
200
|
shape.dims[1] = dim1;
|
|
@@ -182,10 +207,10 @@ struct array_t
|
|
|
182
207
|
strides[3] = sizeof(T);
|
|
183
208
|
}
|
|
184
209
|
|
|
185
|
-
inline bool empty() const { return !data; }
|
|
210
|
+
CUDA_CALLABLE inline bool empty() const { return !data; }
|
|
186
211
|
|
|
187
|
-
T* data
|
|
188
|
-
T* grad
|
|
212
|
+
T* data;
|
|
213
|
+
T* grad;
|
|
189
214
|
shape_t shape;
|
|
190
215
|
int strides[ARRAY_MAX_DIMS];
|
|
191
216
|
int ndim;
|
|
@@ -200,10 +225,13 @@ struct array_t
|
|
|
200
225
|
template <typename T>
|
|
201
226
|
struct indexedarray_t
|
|
202
227
|
{
|
|
203
|
-
CUDA_CALLABLE inline indexedarray_t()
|
|
204
|
-
|
|
228
|
+
CUDA_CALLABLE inline indexedarray_t()
|
|
229
|
+
: arr(),
|
|
230
|
+
indices(),
|
|
231
|
+
shape()
|
|
232
|
+
{}
|
|
205
233
|
|
|
206
|
-
inline bool empty() const { return !arr.data; }
|
|
234
|
+
CUDA_CALLABLE inline bool empty() const { return !arr.data; }
|
|
207
235
|
|
|
208
236
|
array_t<T> arr;
|
|
209
237
|
int* indices[ARRAY_MAX_DIMS]; // index array per dimension (can be NULL)
|
|
@@ -597,13 +625,12 @@ inline CUDA_CALLABLE void adj_view(A1<T>& src, int i, int j, int k, A2<T>& adj_s
|
|
|
597
625
|
// TODO: lower_bound() for indexed arrays?
|
|
598
626
|
|
|
599
627
|
template <typename T>
|
|
600
|
-
CUDA_CALLABLE inline int lower_bound(const array_t<T>& arr, T value)
|
|
628
|
+
CUDA_CALLABLE inline int lower_bound(const array_t<T>& arr, int arr_begin, int arr_end, T value)
|
|
601
629
|
{
|
|
602
630
|
assert(arr.ndim == 1);
|
|
603
|
-
int n = arr.shape[0];
|
|
604
631
|
|
|
605
|
-
int lower =
|
|
606
|
-
int upper =
|
|
632
|
+
int lower = arr_begin;
|
|
633
|
+
int upper = arr_end - 1;
|
|
607
634
|
|
|
608
635
|
while(lower < upper)
|
|
609
636
|
{
|
|
@@ -622,7 +649,14 @@ CUDA_CALLABLE inline int lower_bound(const array_t<T>& arr, T value)
|
|
|
622
649
|
return lower;
|
|
623
650
|
}
|
|
624
651
|
|
|
652
|
+
template <typename T>
|
|
653
|
+
CUDA_CALLABLE inline int lower_bound(const array_t<T>& arr, T value)
|
|
654
|
+
{
|
|
655
|
+
return lower_bound(arr, 0, arr.shape[0], value);
|
|
656
|
+
}
|
|
657
|
+
|
|
625
658
|
template <typename T> inline CUDA_CALLABLE void adj_lower_bound(const array_t<T>& arr, T value, array_t<T> adj_arr, T adj_value, int adj_ret) {}
|
|
659
|
+
template <typename T> inline CUDA_CALLABLE void adj_lower_bound(const array_t<T>& arr, int arr_begin, int arr_end, T value, array_t<T> adj_arr, int adj_arr_begin, int adj_arr_end, T adj_value, int adj_ret) {}
|
|
626
660
|
|
|
627
661
|
template<template<typename> class A, typename T>
|
|
628
662
|
inline CUDA_CALLABLE T atomic_add(const A<T>& buf, int i, T value) { return atomic_add(&index(buf, i), value); }
|
|
@@ -661,43 +695,60 @@ template<template<typename> class A, typename T>
|
|
|
661
695
|
inline CUDA_CALLABLE T atomic_max(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_max(&index(buf, i, j, k, l), value); }
|
|
662
696
|
|
|
663
697
|
template<template<typename> class A, typename T>
|
|
664
|
-
inline CUDA_CALLABLE T
|
|
698
|
+
inline CUDA_CALLABLE T* address(const A<T>& buf, int i) { return &index(buf, i); }
|
|
665
699
|
template<template<typename> class A, typename T>
|
|
666
|
-
inline CUDA_CALLABLE T
|
|
700
|
+
inline CUDA_CALLABLE T* address(const A<T>& buf, int i, int j) { return &index(buf, i, j); }
|
|
667
701
|
template<template<typename> class A, typename T>
|
|
668
|
-
inline CUDA_CALLABLE T
|
|
702
|
+
inline CUDA_CALLABLE T* address(const A<T>& buf, int i, int j, int k) { return &index(buf, i, j, k); }
|
|
669
703
|
template<template<typename> class A, typename T>
|
|
670
|
-
inline CUDA_CALLABLE T
|
|
704
|
+
inline CUDA_CALLABLE T* address(const A<T>& buf, int i, int j, int k, int l) { return &index(buf, i, j, k, l); }
|
|
671
705
|
|
|
672
706
|
template<template<typename> class A, typename T>
|
|
673
|
-
inline CUDA_CALLABLE void
|
|
707
|
+
inline CUDA_CALLABLE void array_store(const A<T>& buf, int i, T value)
|
|
674
708
|
{
|
|
675
709
|
FP_VERIFY_FWD_1(value)
|
|
676
710
|
|
|
677
711
|
index(buf, i) = value;
|
|
678
712
|
}
|
|
679
713
|
template<template<typename> class A, typename T>
|
|
680
|
-
inline CUDA_CALLABLE void
|
|
714
|
+
inline CUDA_CALLABLE void array_store(const A<T>& buf, int i, int j, T value)
|
|
681
715
|
{
|
|
682
716
|
FP_VERIFY_FWD_2(value)
|
|
683
717
|
|
|
684
718
|
index(buf, i, j) = value;
|
|
685
719
|
}
|
|
686
720
|
template<template<typename> class A, typename T>
|
|
687
|
-
inline CUDA_CALLABLE void
|
|
721
|
+
inline CUDA_CALLABLE void array_store(const A<T>& buf, int i, int j, int k, T value)
|
|
688
722
|
{
|
|
689
723
|
FP_VERIFY_FWD_3(value)
|
|
690
724
|
|
|
691
725
|
index(buf, i, j, k) = value;
|
|
692
726
|
}
|
|
693
727
|
template<template<typename> class A, typename T>
|
|
694
|
-
inline CUDA_CALLABLE void
|
|
728
|
+
inline CUDA_CALLABLE void array_store(const A<T>& buf, int i, int j, int k, int l, T value)
|
|
695
729
|
{
|
|
696
730
|
FP_VERIFY_FWD_4(value)
|
|
697
731
|
|
|
698
732
|
index(buf, i, j, k, l) = value;
|
|
699
733
|
}
|
|
700
734
|
|
|
735
|
+
template<typename T>
|
|
736
|
+
inline CUDA_CALLABLE void store(T* address, T value)
|
|
737
|
+
{
|
|
738
|
+
FP_VERIFY_FWD(value)
|
|
739
|
+
|
|
740
|
+
*address = value;
|
|
741
|
+
}
|
|
742
|
+
|
|
743
|
+
template<typename T>
|
|
744
|
+
inline CUDA_CALLABLE T load(T* address)
|
|
745
|
+
{
|
|
746
|
+
T value = *address;
|
|
747
|
+
FP_VERIFY_FWD(value)
|
|
748
|
+
|
|
749
|
+
return value;
|
|
750
|
+
}
|
|
751
|
+
|
|
701
752
|
// select operator to check for array being null
|
|
702
753
|
template <typename T1, typename T2>
|
|
703
754
|
CUDA_CALLABLE inline T2 select(const array_t<T1>& arr, const T2& a, const T2& b) { return arr.data?b:a; }
|
|
@@ -731,34 +782,36 @@ CUDA_CALLABLE inline void adj_atomic_add(uint32* buf, uint32 value) { }
|
|
|
731
782
|
CUDA_CALLABLE inline void adj_atomic_add(int64* buf, int64 value) { }
|
|
732
783
|
CUDA_CALLABLE inline void adj_atomic_add(uint64* buf, uint64 value) { }
|
|
733
784
|
|
|
785
|
+
CUDA_CALLABLE inline void adj_atomic_add(bool* buf, bool value) { }
|
|
786
|
+
|
|
734
787
|
// only generate gradients for T types
|
|
735
788
|
template<typename T>
|
|
736
|
-
inline CUDA_CALLABLE void
|
|
789
|
+
inline CUDA_CALLABLE void adj_address(const array_t<T>& buf, int i, const array_t<T>& adj_buf, int& adj_i, const T& adj_output)
|
|
737
790
|
{
|
|
738
791
|
if (buf.grad)
|
|
739
792
|
adj_atomic_add(&index_grad(buf, i), adj_output);
|
|
740
793
|
}
|
|
741
794
|
template<typename T>
|
|
742
|
-
inline CUDA_CALLABLE void
|
|
795
|
+
inline CUDA_CALLABLE void adj_address(const array_t<T>& buf, int i, int j, const array_t<T>& adj_buf, int& adj_i, int& adj_j, const T& adj_output)
|
|
743
796
|
{
|
|
744
797
|
if (buf.grad)
|
|
745
798
|
adj_atomic_add(&index_grad(buf, i, j), adj_output);
|
|
746
799
|
}
|
|
747
800
|
template<typename T>
|
|
748
|
-
inline CUDA_CALLABLE void
|
|
801
|
+
inline CUDA_CALLABLE void adj_address(const array_t<T>& buf, int i, int j, int k, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, const T& adj_output)
|
|
749
802
|
{
|
|
750
803
|
if (buf.grad)
|
|
751
804
|
adj_atomic_add(&index_grad(buf, i, j, k), adj_output);
|
|
752
805
|
}
|
|
753
806
|
template<typename T>
|
|
754
|
-
inline CUDA_CALLABLE void
|
|
807
|
+
inline CUDA_CALLABLE void adj_address(const array_t<T>& buf, int i, int j, int k, int l, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, const T& adj_output)
|
|
755
808
|
{
|
|
756
809
|
if (buf.grad)
|
|
757
810
|
adj_atomic_add(&index_grad(buf, i, j, k, l), adj_output);
|
|
758
811
|
}
|
|
759
812
|
|
|
760
813
|
template<typename T>
|
|
761
|
-
inline CUDA_CALLABLE void
|
|
814
|
+
inline CUDA_CALLABLE void adj_array_store(const array_t<T>& buf, int i, T value, const array_t<T>& adj_buf, int& adj_i, T& adj_value)
|
|
762
815
|
{
|
|
763
816
|
if (buf.grad)
|
|
764
817
|
adj_value += index_grad(buf, i);
|
|
@@ -766,7 +819,7 @@ inline CUDA_CALLABLE void adj_store(const array_t<T>& buf, int i, T value, const
|
|
|
766
819
|
FP_VERIFY_ADJ_1(value, adj_value)
|
|
767
820
|
}
|
|
768
821
|
template<typename T>
|
|
769
|
-
inline CUDA_CALLABLE void
|
|
822
|
+
inline CUDA_CALLABLE void adj_array_store(const array_t<T>& buf, int i, int j, T value, const array_t<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value)
|
|
770
823
|
{
|
|
771
824
|
if (buf.grad)
|
|
772
825
|
adj_value += index_grad(buf, i, j);
|
|
@@ -775,7 +828,7 @@ inline CUDA_CALLABLE void adj_store(const array_t<T>& buf, int i, int j, T value
|
|
|
775
828
|
|
|
776
829
|
}
|
|
777
830
|
template<typename T>
|
|
778
|
-
inline CUDA_CALLABLE void
|
|
831
|
+
inline CUDA_CALLABLE void adj_array_store(const array_t<T>& buf, int i, int j, int k, T value, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value)
|
|
779
832
|
{
|
|
780
833
|
if (buf.grad)
|
|
781
834
|
adj_value += index_grad(buf, i, j, k);
|
|
@@ -783,7 +836,7 @@ inline CUDA_CALLABLE void adj_store(const array_t<T>& buf, int i, int j, int k,
|
|
|
783
836
|
FP_VERIFY_ADJ_3(value, adj_value)
|
|
784
837
|
}
|
|
785
838
|
template<typename T>
|
|
786
|
-
inline CUDA_CALLABLE void
|
|
839
|
+
inline CUDA_CALLABLE void adj_array_store(const array_t<T>& buf, int i, int j, int k, int l, T value, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value)
|
|
787
840
|
{
|
|
788
841
|
if (buf.grad)
|
|
789
842
|
adj_value += index_grad(buf, i, j, k, l);
|
|
@@ -791,6 +844,19 @@ inline CUDA_CALLABLE void adj_store(const array_t<T>& buf, int i, int j, int k,
|
|
|
791
844
|
FP_VERIFY_ADJ_4(value, adj_value)
|
|
792
845
|
}
|
|
793
846
|
|
|
847
|
+
template<typename T>
|
|
848
|
+
inline CUDA_CALLABLE void adj_store(const T* address, T value, const T& adj_address, T& adj_value)
|
|
849
|
+
{
|
|
850
|
+
// nop; generic store() operations are not differentiable, only array_store() is
|
|
851
|
+
FP_VERIFY_ADJ(value, adj_value)
|
|
852
|
+
}
|
|
853
|
+
|
|
854
|
+
template<typename T>
|
|
855
|
+
inline CUDA_CALLABLE void adj_load(const T* address, const T& adj_address, T& adj_value)
|
|
856
|
+
{
|
|
857
|
+
// nop; generic load() operations are not differentiable
|
|
858
|
+
}
|
|
859
|
+
|
|
794
860
|
template<typename T>
|
|
795
861
|
inline CUDA_CALLABLE void adj_atomic_add(const array_t<T>& buf, int i, T value, const array_t<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret)
|
|
796
862
|
{
|
|
@@ -860,22 +926,22 @@ inline CUDA_CALLABLE void adj_atomic_sub(const array_t<T>& buf, int i, int j, in
|
|
|
860
926
|
|
|
861
927
|
// generic array types that do not support gradient computation (indexedarray, etc.)
|
|
862
928
|
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
863
|
-
inline CUDA_CALLABLE void
|
|
929
|
+
inline CUDA_CALLABLE void adj_address(const A1<T>& buf, int i, const A2<T>& adj_buf, int& adj_i, const T& adj_output) {}
|
|
864
930
|
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
865
|
-
inline CUDA_CALLABLE void
|
|
931
|
+
inline CUDA_CALLABLE void adj_address(const A1<T>& buf, int i, int j, const A2<T>& adj_buf, int& adj_i, int& adj_j, const T& adj_output) {}
|
|
866
932
|
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
867
|
-
inline CUDA_CALLABLE void
|
|
933
|
+
inline CUDA_CALLABLE void adj_address(const A1<T>& buf, int i, int j, int k, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, const T& adj_output) {}
|
|
868
934
|
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
869
|
-
inline CUDA_CALLABLE void
|
|
935
|
+
inline CUDA_CALLABLE void adj_address(const A1<T>& buf, int i, int j, int k, int l, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, const T& adj_output) {}
|
|
870
936
|
|
|
871
937
|
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
872
|
-
inline CUDA_CALLABLE void
|
|
938
|
+
inline CUDA_CALLABLE void adj_array_store(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value) {}
|
|
873
939
|
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
874
|
-
inline CUDA_CALLABLE void
|
|
940
|
+
inline CUDA_CALLABLE void adj_array_store(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value) {}
|
|
875
941
|
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
876
|
-
inline CUDA_CALLABLE void
|
|
942
|
+
inline CUDA_CALLABLE void adj_array_store(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value) {}
|
|
877
943
|
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
878
|
-
inline CUDA_CALLABLE void
|
|
944
|
+
inline CUDA_CALLABLE void adj_array_store(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value) {}
|
|
879
945
|
|
|
880
946
|
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
881
947
|
inline CUDA_CALLABLE void adj_atomic_add(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) {}
|
|
@@ -895,22 +961,65 @@ inline CUDA_CALLABLE void adj_atomic_sub(const A1<T>& buf, int i, int j, int k,
|
|
|
895
961
|
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
896
962
|
inline CUDA_CALLABLE void adj_atomic_sub(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) {}
|
|
897
963
|
|
|
964
|
+
// generic handler for scalar values
|
|
898
965
|
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
899
|
-
inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) {
|
|
966
|
+
inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) {
|
|
967
|
+
if (buf.grad)
|
|
968
|
+
adj_atomic_minmax(&index(buf, i), &index_grad(buf, i), value, adj_value);
|
|
969
|
+
|
|
970
|
+
FP_VERIFY_ADJ_1(value, adj_value)
|
|
971
|
+
}
|
|
900
972
|
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
901
|
-
inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value, const T& adj_ret) {
|
|
973
|
+
inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value, const T& adj_ret) {
|
|
974
|
+
if (buf.grad)
|
|
975
|
+
adj_atomic_minmax(&index(buf, i, j), &index_grad(buf, i, j), value, adj_value);
|
|
976
|
+
|
|
977
|
+
FP_VERIFY_ADJ_2(value, adj_value)
|
|
978
|
+
}
|
|
902
979
|
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
903
|
-
inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value, const T& adj_ret) {
|
|
980
|
+
inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value, const T& adj_ret) {
|
|
981
|
+
if (buf.grad)
|
|
982
|
+
adj_atomic_minmax(&index(buf, i, j, k), &index_grad(buf, i, j, k), value, adj_value);
|
|
983
|
+
|
|
984
|
+
FP_VERIFY_ADJ_3(value, adj_value)
|
|
985
|
+
}
|
|
904
986
|
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
905
|
-
inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) {
|
|
987
|
+
inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) {
|
|
988
|
+
if (buf.grad)
|
|
989
|
+
adj_atomic_minmax(&index(buf, i, j, k, l), &index_grad(buf, i, j, k, l), value, adj_value);
|
|
990
|
+
|
|
991
|
+
FP_VERIFY_ADJ_4(value, adj_value)
|
|
992
|
+
}
|
|
906
993
|
|
|
907
994
|
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
908
|
-
inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) {
|
|
995
|
+
inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) {
|
|
996
|
+
if (buf.grad)
|
|
997
|
+
adj_atomic_minmax(&index(buf, i), &index_grad(buf, i), value, adj_value);
|
|
998
|
+
|
|
999
|
+
FP_VERIFY_ADJ_1(value, adj_value)
|
|
1000
|
+
}
|
|
909
1001
|
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
910
|
-
inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value, const T& adj_ret) {
|
|
1002
|
+
inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value, const T& adj_ret) {
|
|
1003
|
+
if (buf.grad)
|
|
1004
|
+
adj_atomic_minmax(&index(buf, i, j), &index_grad(buf, i, j), value, adj_value);
|
|
1005
|
+
|
|
1006
|
+
FP_VERIFY_ADJ_2(value, adj_value)
|
|
1007
|
+
}
|
|
911
1008
|
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
912
|
-
inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value, const T& adj_ret) {
|
|
1009
|
+
inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value, const T& adj_ret) {
|
|
1010
|
+
if (buf.grad)
|
|
1011
|
+
adj_atomic_minmax(&index(buf, i, j, k), &index_grad(buf, i, j, k), value, adj_value);
|
|
1012
|
+
|
|
1013
|
+
FP_VERIFY_ADJ_3(value, adj_value)
|
|
1014
|
+
}
|
|
913
1015
|
template<template<typename> class A1, template<typename> class A2, typename T>
|
|
914
|
-
inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) {
|
|
1016
|
+
inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) {
|
|
1017
|
+
if (buf.grad)
|
|
1018
|
+
adj_atomic_minmax(&index(buf, i, j, k, l), &index_grad(buf, i, j, k, l), value, adj_value);
|
|
1019
|
+
|
|
1020
|
+
FP_VERIFY_ADJ_4(value, adj_value)
|
|
1021
|
+
}
|
|
1022
|
+
|
|
1023
|
+
} // namespace wp
|
|
915
1024
|
|
|
916
|
-
|
|
1025
|
+
#include "fabric.h"
|