warp-lang 1.0.2__py3-none-win_amd64.whl → 1.2.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +108 -97
- warp/__init__.pyi +1 -1
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +88 -113
- warp/build_dll.py +383 -375
- warp/builtins.py +3693 -3354
- warp/codegen.py +2925 -2792
- warp/config.py +40 -36
- warp/constants.py +49 -45
- warp/context.py +5409 -5102
- warp/dlpack.py +442 -442
- warp/examples/__init__.py +16 -16
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nv_ant.xml +92 -92
- warp/examples/assets/nv_humanoid.xml +183 -183
- warp/examples/assets/quadruped.urdf +267 -267
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +381 -383
- warp/examples/benchmarks/benchmark_cloth.py +278 -277
- warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -88
- warp/examples/benchmarks/benchmark_cloth_jax.py +97 -100
- warp/examples/benchmarks/benchmark_cloth_numba.py +146 -142
- warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -77
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -86
- warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -112
- warp/examples/benchmarks/benchmark_cloth_warp.py +145 -146
- warp/examples/benchmarks/benchmark_launches.py +293 -295
- warp/examples/browse.py +29 -29
- warp/examples/core/example_dem.py +232 -219
- warp/examples/core/example_fluid.py +291 -267
- warp/examples/core/example_graph_capture.py +142 -126
- warp/examples/core/example_marching_cubes.py +186 -174
- warp/examples/core/example_mesh.py +172 -155
- warp/examples/core/example_mesh_intersect.py +203 -193
- warp/examples/core/example_nvdb.py +174 -170
- warp/examples/core/example_raycast.py +103 -90
- warp/examples/core/example_raymarch.py +197 -178
- warp/examples/core/example_render_opengl.py +183 -141
- warp/examples/core/example_sph.py +403 -387
- warp/examples/core/example_torch.py +219 -181
- warp/examples/core/example_wave.py +261 -248
- warp/examples/fem/bsr_utils.py +378 -380
- warp/examples/fem/example_apic_fluid.py +432 -389
- warp/examples/fem/example_burgers.py +262 -0
- warp/examples/fem/example_convection_diffusion.py +180 -168
- warp/examples/fem/example_convection_diffusion_dg.py +217 -209
- warp/examples/fem/example_deformed_geometry.py +175 -159
- warp/examples/fem/example_diffusion.py +199 -173
- warp/examples/fem/example_diffusion_3d.py +178 -152
- warp/examples/fem/example_diffusion_mgpu.py +219 -214
- warp/examples/fem/example_mixed_elasticity.py +242 -222
- warp/examples/fem/example_navier_stokes.py +257 -243
- warp/examples/fem/example_stokes.py +218 -192
- warp/examples/fem/example_stokes_transfer.py +263 -249
- warp/examples/fem/mesh_utils.py +133 -109
- warp/examples/fem/plot_utils.py +292 -287
- warp/examples/optim/example_bounce.py +258 -246
- warp/examples/optim/example_cloth_throw.py +220 -209
- warp/examples/optim/example_diffray.py +564 -536
- warp/examples/optim/example_drone.py +862 -835
- warp/examples/optim/example_inverse_kinematics.py +174 -168
- warp/examples/optim/example_inverse_kinematics_torch.py +183 -169
- warp/examples/optim/example_spring_cage.py +237 -231
- warp/examples/optim/example_trajectory.py +221 -199
- warp/examples/optim/example_walker.py +304 -293
- warp/examples/sim/example_cartpole.py +137 -129
- warp/examples/sim/example_cloth.py +194 -186
- warp/examples/sim/example_granular.py +122 -111
- warp/examples/sim/example_granular_collision_sdf.py +195 -186
- warp/examples/sim/example_jacobian_ik.py +234 -214
- warp/examples/sim/example_particle_chain.py +116 -105
- warp/examples/sim/example_quadruped.py +191 -180
- warp/examples/sim/example_rigid_chain.py +195 -187
- warp/examples/sim/example_rigid_contact.py +187 -177
- warp/examples/sim/example_rigid_force.py +125 -125
- warp/examples/sim/example_rigid_gyroscopic.py +107 -95
- warp/examples/sim/example_rigid_soft_contact.py +132 -122
- warp/examples/sim/example_soft_body.py +188 -177
- warp/fabric.py +337 -335
- warp/fem/__init__.py +61 -27
- warp/fem/cache.py +403 -388
- warp/fem/dirichlet.py +178 -179
- warp/fem/domain.py +262 -263
- warp/fem/field/__init__.py +100 -101
- warp/fem/field/field.py +148 -149
- warp/fem/field/nodal_field.py +298 -299
- warp/fem/field/restriction.py +22 -21
- warp/fem/field/test.py +180 -181
- warp/fem/field/trial.py +183 -183
- warp/fem/geometry/__init__.py +16 -19
- warp/fem/geometry/closest_point.py +69 -70
- warp/fem/geometry/deformed_geometry.py +270 -271
- warp/fem/geometry/element.py +748 -744
- warp/fem/geometry/geometry.py +184 -186
- warp/fem/geometry/grid_2d.py +380 -373
- warp/fem/geometry/grid_3d.py +437 -435
- warp/fem/geometry/hexmesh.py +953 -953
- warp/fem/geometry/nanogrid.py +455 -0
- warp/fem/geometry/partition.py +374 -376
- warp/fem/geometry/quadmesh_2d.py +532 -532
- warp/fem/geometry/tetmesh.py +840 -840
- warp/fem/geometry/trimesh_2d.py +577 -577
- warp/fem/integrate.py +1684 -1615
- warp/fem/operator.py +190 -191
- warp/fem/polynomial.py +214 -213
- warp/fem/quadrature/__init__.py +2 -2
- warp/fem/quadrature/pic_quadrature.py +243 -245
- warp/fem/quadrature/quadrature.py +295 -294
- warp/fem/space/__init__.py +179 -292
- warp/fem/space/basis_space.py +522 -489
- warp/fem/space/collocated_function_space.py +100 -105
- warp/fem/space/dof_mapper.py +236 -236
- warp/fem/space/function_space.py +148 -145
- warp/fem/space/grid_2d_function_space.py +148 -267
- warp/fem/space/grid_3d_function_space.py +167 -306
- warp/fem/space/hexmesh_function_space.py +253 -352
- warp/fem/space/nanogrid_function_space.py +202 -0
- warp/fem/space/partition.py +350 -350
- warp/fem/space/quadmesh_2d_function_space.py +261 -369
- warp/fem/space/restriction.py +161 -160
- warp/fem/space/shape/__init__.py +90 -15
- warp/fem/space/shape/cube_shape_function.py +728 -738
- warp/fem/space/shape/shape_function.py +102 -103
- warp/fem/space/shape/square_shape_function.py +611 -611
- warp/fem/space/shape/tet_shape_function.py +565 -567
- warp/fem/space/shape/triangle_shape_function.py +429 -429
- warp/fem/space/tetmesh_function_space.py +224 -292
- warp/fem/space/topology.py +297 -295
- warp/fem/space/trimesh_2d_function_space.py +153 -221
- warp/fem/types.py +77 -77
- warp/fem/utils.py +495 -495
- warp/jax.py +166 -141
- warp/jax_experimental.py +341 -339
- warp/native/array.h +1081 -1025
- warp/native/builtin.h +1603 -1560
- warp/native/bvh.cpp +402 -398
- warp/native/bvh.cu +533 -525
- warp/native/bvh.h +430 -429
- warp/native/clang/clang.cpp +496 -464
- warp/native/crt.cpp +42 -32
- warp/native/crt.h +352 -335
- warp/native/cuda_crt.h +1049 -1049
- warp/native/cuda_util.cpp +549 -540
- warp/native/cuda_util.h +288 -203
- warp/native/cutlass_gemm.cpp +34 -34
- warp/native/cutlass_gemm.cu +372 -372
- warp/native/error.cpp +66 -66
- warp/native/error.h +27 -27
- warp/native/exports.h +187 -0
- warp/native/fabric.h +228 -228
- warp/native/hashgrid.cpp +301 -278
- warp/native/hashgrid.cu +78 -77
- warp/native/hashgrid.h +227 -227
- warp/native/initializer_array.h +32 -32
- warp/native/intersect.h +1204 -1204
- warp/native/intersect_adj.h +365 -365
- warp/native/intersect_tri.h +322 -322
- warp/native/marching.cpp +2 -2
- warp/native/marching.cu +497 -497
- warp/native/marching.h +2 -2
- warp/native/mat.h +1545 -1498
- warp/native/matnn.h +333 -333
- warp/native/mesh.cpp +203 -203
- warp/native/mesh.cu +292 -293
- warp/native/mesh.h +1887 -1887
- warp/native/nanovdb/GridHandle.h +366 -0
- warp/native/nanovdb/HostBuffer.h +590 -0
- warp/native/nanovdb/NanoVDB.h +6624 -4782
- warp/native/nanovdb/PNanoVDB.h +3390 -2553
- warp/native/noise.h +850 -850
- warp/native/quat.h +1112 -1085
- warp/native/rand.h +303 -299
- warp/native/range.h +108 -108
- warp/native/reduce.cpp +156 -156
- warp/native/reduce.cu +348 -348
- warp/native/runlength_encode.cpp +61 -61
- warp/native/runlength_encode.cu +46 -46
- warp/native/scan.cpp +30 -30
- warp/native/scan.cu +36 -36
- warp/native/scan.h +7 -7
- warp/native/solid_angle.h +442 -442
- warp/native/sort.cpp +94 -94
- warp/native/sort.cu +97 -97
- warp/native/sort.h +14 -14
- warp/native/sparse.cpp +337 -337
- warp/native/sparse.cu +544 -544
- warp/native/spatial.h +630 -630
- warp/native/svd.h +562 -562
- warp/native/temp_buffer.h +30 -30
- warp/native/vec.h +1177 -1133
- warp/native/volume.cpp +529 -297
- warp/native/volume.cu +58 -32
- warp/native/volume.h +960 -538
- warp/native/volume_builder.cu +446 -425
- warp/native/volume_builder.h +34 -19
- warp/native/volume_impl.h +61 -0
- warp/native/warp.cpp +1057 -1052
- warp/native/warp.cu +2949 -2828
- warp/native/warp.h +321 -305
- warp/optim/__init__.py +9 -9
- warp/optim/adam.py +120 -120
- warp/optim/linear.py +1104 -939
- warp/optim/sgd.py +104 -92
- warp/render/__init__.py +10 -10
- warp/render/render_opengl.py +3356 -3204
- warp/render/render_usd.py +768 -749
- warp/render/utils.py +152 -150
- warp/sim/__init__.py +52 -59
- warp/sim/articulation.py +685 -685
- warp/sim/collide.py +1594 -1590
- warp/sim/import_mjcf.py +489 -481
- warp/sim/import_snu.py +220 -221
- warp/sim/import_urdf.py +536 -516
- warp/sim/import_usd.py +887 -881
- warp/sim/inertia.py +316 -317
- warp/sim/integrator.py +234 -233
- warp/sim/integrator_euler.py +1956 -1956
- warp/sim/integrator_featherstone.py +1917 -1991
- warp/sim/integrator_xpbd.py +3288 -3312
- warp/sim/model.py +4473 -4314
- warp/sim/particles.py +113 -112
- warp/sim/render.py +417 -403
- warp/sim/utils.py +413 -410
- warp/sparse.py +1289 -1227
- warp/stubs.py +2192 -2469
- warp/tape.py +1162 -225
- warp/tests/__init__.py +1 -1
- warp/tests/__main__.py +4 -4
- warp/tests/assets/test_index_grid.nvdb +0 -0
- warp/tests/assets/torus.usda +105 -105
- warp/tests/aux_test_class_kernel.py +26 -26
- warp/tests/aux_test_compile_consts_dummy.py +10 -10
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -21
- warp/tests/aux_test_dependent.py +20 -22
- warp/tests/aux_test_grad_customs.py +21 -23
- warp/tests/aux_test_reference.py +9 -11
- warp/tests/aux_test_reference_reference.py +8 -10
- warp/tests/aux_test_square.py +15 -17
- warp/tests/aux_test_unresolved_func.py +14 -14
- warp/tests/aux_test_unresolved_symbol.py +14 -14
- warp/tests/disabled_kinematics.py +237 -239
- warp/tests/run_coverage_serial.py +31 -31
- warp/tests/test_adam.py +155 -157
- warp/tests/test_arithmetic.py +1088 -1124
- warp/tests/test_array.py +2415 -2326
- warp/tests/test_array_reduce.py +148 -150
- warp/tests/test_async.py +666 -656
- warp/tests/test_atomic.py +139 -141
- warp/tests/test_bool.py +212 -149
- warp/tests/test_builtins_resolution.py +1290 -1292
- warp/tests/test_bvh.py +162 -171
- warp/tests/test_closest_point_edge_edge.py +227 -228
- warp/tests/test_codegen.py +562 -553
- warp/tests/test_compile_consts.py +217 -101
- warp/tests/test_conditional.py +244 -246
- warp/tests/test_copy.py +230 -215
- warp/tests/test_ctypes.py +630 -632
- warp/tests/test_dense.py +65 -67
- warp/tests/test_devices.py +89 -98
- warp/tests/test_dlpack.py +528 -529
- warp/tests/test_examples.py +403 -378
- warp/tests/test_fabricarray.py +952 -955
- warp/tests/test_fast_math.py +60 -54
- warp/tests/test_fem.py +1298 -1278
- warp/tests/test_fp16.py +128 -130
- warp/tests/test_func.py +336 -337
- warp/tests/test_generics.py +596 -571
- warp/tests/test_grad.py +885 -640
- warp/tests/test_grad_customs.py +331 -336
- warp/tests/test_hash_grid.py +208 -164
- warp/tests/test_import.py +37 -39
- warp/tests/test_indexedarray.py +1132 -1134
- warp/tests/test_intersect.py +65 -67
- warp/tests/test_jax.py +305 -307
- warp/tests/test_large.py +169 -164
- warp/tests/test_launch.py +352 -354
- warp/tests/test_lerp.py +217 -261
- warp/tests/test_linear_solvers.py +189 -171
- warp/tests/test_lvalue.py +419 -493
- warp/tests/test_marching_cubes.py +63 -65
- warp/tests/test_mat.py +1799 -1827
- warp/tests/test_mat_lite.py +113 -115
- warp/tests/test_mat_scalar_ops.py +2905 -2889
- warp/tests/test_math.py +124 -193
- warp/tests/test_matmul.py +498 -499
- warp/tests/test_matmul_lite.py +408 -410
- warp/tests/test_mempool.py +186 -190
- warp/tests/test_mesh.py +281 -324
- warp/tests/test_mesh_query_aabb.py +226 -241
- warp/tests/test_mesh_query_point.py +690 -702
- warp/tests/test_mesh_query_ray.py +290 -303
- warp/tests/test_mlp.py +274 -276
- warp/tests/test_model.py +108 -110
- warp/tests/test_module_hashing.py +111 -0
- warp/tests/test_modules_lite.py +36 -39
- warp/tests/test_multigpu.py +161 -163
- warp/tests/test_noise.py +244 -248
- warp/tests/test_operators.py +248 -250
- warp/tests/test_options.py +121 -125
- warp/tests/test_peer.py +131 -137
- warp/tests/test_pinned.py +76 -78
- warp/tests/test_print.py +52 -54
- warp/tests/test_quat.py +2084 -2086
- warp/tests/test_rand.py +324 -288
- warp/tests/test_reload.py +207 -217
- warp/tests/test_rounding.py +177 -179
- warp/tests/test_runlength_encode.py +188 -190
- warp/tests/test_sim_grad.py +241 -0
- warp/tests/test_sim_kinematics.py +89 -97
- warp/tests/test_smoothstep.py +166 -168
- warp/tests/test_snippet.py +303 -266
- warp/tests/test_sparse.py +466 -460
- warp/tests/test_spatial.py +2146 -2148
- warp/tests/test_special_values.py +362 -0
- warp/tests/test_streams.py +484 -473
- warp/tests/test_struct.py +708 -675
- warp/tests/test_tape.py +171 -148
- warp/tests/test_torch.py +741 -743
- warp/tests/test_transient_module.py +85 -87
- warp/tests/test_types.py +554 -659
- warp/tests/test_utils.py +488 -499
- warp/tests/test_vec.py +1262 -1268
- warp/tests/test_vec_lite.py +71 -73
- warp/tests/test_vec_scalar_ops.py +2097 -2099
- warp/tests/test_verify_fp.py +92 -94
- warp/tests/test_volume.py +961 -736
- warp/tests/test_volume_write.py +338 -265
- warp/tests/unittest_serial.py +38 -37
- warp/tests/unittest_suites.py +367 -359
- warp/tests/unittest_utils.py +434 -578
- warp/tests/unused_test_misc.py +69 -71
- warp/tests/walkthrough_debug.py +85 -85
- warp/thirdparty/appdirs.py +598 -598
- warp/thirdparty/dlpack.py +143 -143
- warp/thirdparty/unittest_parallel.py +563 -561
- warp/torch.py +321 -295
- warp/types.py +4941 -4450
- warp/utils.py +1008 -821
- {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/LICENSE.md +126 -126
- {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/METADATA +365 -400
- warp_lang-1.2.0.dist-info/RECORD +359 -0
- warp/examples/assets/cube.usda +0 -42
- warp/examples/assets/sphere.usda +0 -56
- warp/examples/assets/torus.usda +0 -105
- warp/examples/fem/example_convection_diffusion_dg0.py +0 -194
- warp/native/nanovdb/PNanoVDBWrite.h +0 -295
- warp_lang-1.0.2.dist-info/RECORD +0 -352
- {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.0.2.dist-info → warp_lang-1.2.0.dist-info}/top_level.txt +0 -0
warp/native/matnn.h
CHANGED
|
@@ -1,334 +1,334 @@
|
|
|
1
|
-
/** Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
-
* NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
-
* and proprietary rights in and to this software, related documentation
|
|
4
|
-
* and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
-
* distribution of this software and related documentation without an express
|
|
6
|
-
* license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
-
*/
|
|
8
|
-
|
|
9
|
-
#pragma once
|
|
10
|
-
|
|
11
|
-
namespace wp
|
|
12
|
-
{
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
CUDA_CALLABLE inline int dense_index(int stride, int i, int j)
|
|
16
|
-
{
|
|
17
|
-
return i*stride + j;
|
|
18
|
-
}
|
|
19
|
-
|
|
20
|
-
template <bool transpose>
|
|
21
|
-
CUDA_CALLABLE inline int dense_index(int rows, int cols, int i, int j)
|
|
22
|
-
{
|
|
23
|
-
if (transpose)
|
|
24
|
-
return j*rows + i;
|
|
25
|
-
else
|
|
26
|
-
return i*cols + j;
|
|
27
|
-
}
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
template <bool t1, bool t2, bool add>
|
|
32
|
-
CUDA_CALLABLE inline void dense_gemm_impl(int m, int n, int p, const float* __restrict__ A, const float* __restrict__ B, float* __restrict__ C)
|
|
33
|
-
{
|
|
34
|
-
for (int i=0; i < m; i++)
|
|
35
|
-
{
|
|
36
|
-
for (int j=0; j < n; ++j)
|
|
37
|
-
{
|
|
38
|
-
float sum = 0.0f;
|
|
39
|
-
|
|
40
|
-
for (int k=0; k < p; ++k)
|
|
41
|
-
{
|
|
42
|
-
sum += A[dense_index<t1>(m, p, i, k)]*B[dense_index<t2>(p, n, k, j)];
|
|
43
|
-
}
|
|
44
|
-
|
|
45
|
-
if (add)
|
|
46
|
-
C[i*n + j] += sum;
|
|
47
|
-
else
|
|
48
|
-
C[i*n + j] = sum;
|
|
49
|
-
}
|
|
50
|
-
}
|
|
51
|
-
}
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
template <bool add=false>
|
|
55
|
-
CUDA_CALLABLE inline void dense_gemm(int m, int n, int p, int t1, int t2, const array_t<float>& A, const array_t<float>& B, array_t<float>& C)
|
|
56
|
-
{
|
|
57
|
-
if (t1 == 0 && t2 == 0)
|
|
58
|
-
dense_gemm_impl<false, false, add>(m, n, p, A.data, B.data, C.data);
|
|
59
|
-
else if (t1 == 1 && t2 == 0)
|
|
60
|
-
dense_gemm_impl<true, false, add>(m, n, p, A.data, B.data, C.data);
|
|
61
|
-
else if (t1 == 0 && t2 == 1)
|
|
62
|
-
dense_gemm_impl<false, true, add>(m, n, p, A.data, B.data, C.data);
|
|
63
|
-
else if (t1 == 1 && t2 == 1)
|
|
64
|
-
dense_gemm_impl<true, true, add>(m, n, p, A.data, B.data, C.data);
|
|
65
|
-
}
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
void CUDA_CALLABLE inline dense_chol(int n, const array_t<float>& A, float regularization, array_t<float>& L)
|
|
71
|
-
{
|
|
72
|
-
for (int j=0; j < n; ++j)
|
|
73
|
-
{
|
|
74
|
-
float s = A.data[dense_index(n, j, j)] + regularization;
|
|
75
|
-
|
|
76
|
-
for (int k=0; k < j; ++k)
|
|
77
|
-
{
|
|
78
|
-
float r = L.data[dense_index(n, j, k)];
|
|
79
|
-
s -= r*r;
|
|
80
|
-
}
|
|
81
|
-
|
|
82
|
-
s = sqrt(s);
|
|
83
|
-
const float invS = 1.0f/s;
|
|
84
|
-
|
|
85
|
-
L.data[dense_index(n, j, j)] = s;
|
|
86
|
-
|
|
87
|
-
for (int i=j+1; i < n; ++i)
|
|
88
|
-
{
|
|
89
|
-
s = A.data[dense_index(n, i, j)];
|
|
90
|
-
|
|
91
|
-
for (int k=0; k < j; ++k)
|
|
92
|
-
{
|
|
93
|
-
s -= L.data[dense_index(n, i, k)]*L.data[dense_index(n, j, k)];
|
|
94
|
-
}
|
|
95
|
-
|
|
96
|
-
L.data[dense_index(n, i, j)] = s*invS;
|
|
97
|
-
}
|
|
98
|
-
}
|
|
99
|
-
}
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
// Solves (L*L^T)x = b given the Cholesky factor L
|
|
105
|
-
CUDA_CALLABLE inline void dense_subs(int n, const array_t<float>& L, const array_t<float>& b, array_t<float>& x)
|
|
106
|
-
{
|
|
107
|
-
// forward substitution
|
|
108
|
-
for (int i=0; i < n; ++i)
|
|
109
|
-
{
|
|
110
|
-
float s = b.data[i];
|
|
111
|
-
|
|
112
|
-
for (int j=0; j < i; ++j)
|
|
113
|
-
{
|
|
114
|
-
s -= L.data[dense_index(n, i, j)]*x.data[j];
|
|
115
|
-
}
|
|
116
|
-
|
|
117
|
-
x.data[i] = s/L.data[dense_index(n, i, i)];
|
|
118
|
-
}
|
|
119
|
-
|
|
120
|
-
// backward substitution
|
|
121
|
-
for (int i=n-1; i >= 0; --i)
|
|
122
|
-
{
|
|
123
|
-
float s = x.data[i];
|
|
124
|
-
|
|
125
|
-
for (int j=i+1; j < n; ++j)
|
|
126
|
-
{
|
|
127
|
-
s -= L.data[dense_index(n, j, i)]*x.data[j];
|
|
128
|
-
}
|
|
129
|
-
|
|
130
|
-
x.data[i] = s/L.data[dense_index(n, i, i)];
|
|
131
|
-
}
|
|
132
|
-
}
|
|
133
|
-
|
|
134
|
-
CUDA_CALLABLE inline void dense_solve(int n, const array_t<float>& A, const array_t<float>& L, const array_t<float>& b, array_t<float>& x)
|
|
135
|
-
{
|
|
136
|
-
dense_subs(n, L, b, x);
|
|
137
|
-
}
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
// CUDA_CALLABLE inline void print_matrix(const char* name, int m, int n, const float* data)
|
|
141
|
-
// {
|
|
142
|
-
// printf("%s = [", name);
|
|
143
|
-
|
|
144
|
-
// for (int i=0; i < m; ++i)
|
|
145
|
-
// {
|
|
146
|
-
// for (int j=0; j < n; ++j)
|
|
147
|
-
// {
|
|
148
|
-
// printf("%f ", data[dense_index(n, i, j)]);
|
|
149
|
-
// }
|
|
150
|
-
|
|
151
|
-
// printf(";\n");
|
|
152
|
-
// }
|
|
153
|
-
|
|
154
|
-
// printf("]\n");
|
|
155
|
-
// }
|
|
156
|
-
|
|
157
|
-
// adjoint methods
|
|
158
|
-
CUDA_CALLABLE inline void adj_dense_gemm(
|
|
159
|
-
int m, int n, int p, int t1, int t2, const array_t<float>& A, const array_t<float>& B, array_t<float>& C,
|
|
160
|
-
int adj_m, int adj_n, int adj_p, int adj_t1, int adj_t2, array_t<float>& adj_A, array_t<float>& adj_B, const array_t<float>& adj_C)
|
|
161
|
-
{
|
|
162
|
-
|
|
163
|
-
// print_matrix("A", m, p, A);
|
|
164
|
-
// print_matrix("B", p, n, B);
|
|
165
|
-
// printf("t1: %d t2: %d\n", t1, t2);
|
|
166
|
-
|
|
167
|
-
if (t1)
|
|
168
|
-
{
|
|
169
|
-
dense_gemm<true>(p, m, n, 0, 1, B, adj_C, adj_A);
|
|
170
|
-
dense_gemm<true>(p, n, m, int(!t1), 0, A, adj_C, adj_B);
|
|
171
|
-
}
|
|
172
|
-
else
|
|
173
|
-
{
|
|
174
|
-
dense_gemm<true>(m, p, n, 0, int(!t2), adj_C, B, adj_A);
|
|
175
|
-
dense_gemm<true>(p, n, m, int(!t1), 0, A, adj_C, adj_B);
|
|
176
|
-
}
|
|
177
|
-
}
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
CUDA_CALLABLE inline void adj_dense_chol(
|
|
181
|
-
int n, const array_t<float>& A, float regularization, array_t<float>& L,
|
|
182
|
-
int adj_n, const array_t<float>& adj_A, float adj_regularization, array_t<float>& adj_L)
|
|
183
|
-
{
|
|
184
|
-
// nop, use dense_solve to differentiate through (A^-1)b = x
|
|
185
|
-
}
|
|
186
|
-
|
|
187
|
-
CUDA_CALLABLE inline void adj_dense_subs(
|
|
188
|
-
int n, const array_t<float>& L, const array_t<float>& b, array_t<float>& x,
|
|
189
|
-
int adj_n, const array_t<float>& adj_L, const array_t<float>& adj_b, array_t<float>& adj_x)
|
|
190
|
-
{
|
|
191
|
-
// nop, use dense_solve to differentiate through (A^-1)b = x
|
|
192
|
-
}
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
CUDA_CALLABLE inline void adj_dense_solve(int n,
|
|
196
|
-
const array_t<float>& A, const array_t<float>& L, const array_t<float>& b, const array_t<float>& x,
|
|
197
|
-
int adj_n, array_t<float>& adj_A, array_t<float>& adj_L, array_t<float>& adj_b, const array_t<float>& adj_x)
|
|
198
|
-
{
|
|
199
|
-
// see https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pwp, section 2.3.1
|
|
200
|
-
dense_subs(n, L, adj_x, adj_b);
|
|
201
|
-
|
|
202
|
-
// A* = -adj_b*x^T
|
|
203
|
-
for (int i=0; i < n; ++i)
|
|
204
|
-
{
|
|
205
|
-
for (int j=0; j < n; ++j)
|
|
206
|
-
{
|
|
207
|
-
adj_A.data[dense_index(n, i, j)] += -adj_b.data[i]*x.data[j];
|
|
208
|
-
}
|
|
209
|
-
}
|
|
210
|
-
}
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
template <typename F>
|
|
214
|
-
CUDA_CALLABLE inline void mlp(const array_t<float>& weights, const array_t<float>& bias, F activation, int index, const array_t<float>& x, array_t<float>& out)
|
|
215
|
-
{
|
|
216
|
-
const int m = weights.shape[0];
|
|
217
|
-
const int n = weights.shape[1];
|
|
218
|
-
const int b = x.shape[1];
|
|
219
|
-
|
|
220
|
-
for (int i=0; i < m; ++i)
|
|
221
|
-
{
|
|
222
|
-
float tmp = bias.data[i];
|
|
223
|
-
|
|
224
|
-
for(int j=0; j < n; ++j)
|
|
225
|
-
{
|
|
226
|
-
tmp += weights.data[i*n + j]*x.data[index + b*j];
|
|
227
|
-
}
|
|
228
|
-
|
|
229
|
-
out.data[index + b*i] = activation(tmp);
|
|
230
|
-
}
|
|
231
|
-
}
|
|
232
|
-
|
|
233
|
-
template <typename F, typename AdjF>
|
|
234
|
-
CUDA_CALLABLE inline void adj_mlp(const array_t<float>& weights, const array_t<float>& bias, F activation, int index, const array_t<float>& x, array_t<float>& out,
|
|
235
|
-
array_t<float>& adj_weights, array_t<float>& adj_bias, AdjF adj_activation, int adj_index, array_t<float>& adj_x, array_t<float>& adj_out)
|
|
236
|
-
{
|
|
237
|
-
const int m = weights.shape[0];
|
|
238
|
-
const int n = weights.shape[1];
|
|
239
|
-
const int b = x.shape[1];
|
|
240
|
-
|
|
241
|
-
for (int i=0; i < m; ++i)
|
|
242
|
-
{
|
|
243
|
-
// recompute forward pass so we don't have to store pre-activation outputs
|
|
244
|
-
float tmp = bias.data[i];
|
|
245
|
-
|
|
246
|
-
for(int j=0; j < n; ++j)
|
|
247
|
-
{
|
|
248
|
-
tmp += weights.data[i*n + j]*x.data[index + b*j];
|
|
249
|
-
}
|
|
250
|
-
|
|
251
|
-
// adjoint w.r.t to activation
|
|
252
|
-
float adj_f = 0.0f;
|
|
253
|
-
|
|
254
|
-
if (adj_out.data)
|
|
255
|
-
adj_activation(tmp, adj_f, adj_out.data[index + b*i]);
|
|
256
|
-
|
|
257
|
-
for (int j=0; j < n; ++j)
|
|
258
|
-
{
|
|
259
|
-
// adjoint w.r.t M_i
|
|
260
|
-
if (adj_weights.data)
|
|
261
|
-
atomic_add(&adj_weights.data[i*n + j], x.data[index + b*j]*adj_f); // todo: reduce these atomic stores using warp/block level reductions
|
|
262
|
-
|
|
263
|
-
// adjoint w.r.t x
|
|
264
|
-
if (adj_x.data)
|
|
265
|
-
atomic_add(&adj_x.data[index + b*j], weights.data[i*n + j]*adj_f);
|
|
266
|
-
}
|
|
267
|
-
|
|
268
|
-
// adjoint w.r.t b
|
|
269
|
-
if (adj_bias.data)
|
|
270
|
-
atomic_add(&adj_bias.data[i], adj_f);
|
|
271
|
-
|
|
272
|
-
}
|
|
273
|
-
}
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
// template <typename F>
|
|
277
|
-
// CUDA_CALLABLE inline void mlp(const array_t<float>& weights, const array_t<float>& bias, F activation, int m, int n, int b, int index, const array_t<float>& x, array_t<float>& out)
|
|
278
|
-
// {
|
|
279
|
-
// x += index*n;
|
|
280
|
-
// out += index*m;
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
// for (int i=0; i < m; ++i)
|
|
284
|
-
// {
|
|
285
|
-
// float tmp = bias[i];
|
|
286
|
-
|
|
287
|
-
// for(int j=0; j < n; ++j)
|
|
288
|
-
// {
|
|
289
|
-
// tmp += weights[i*n + j]*x[j];
|
|
290
|
-
// }
|
|
291
|
-
|
|
292
|
-
// out[i] = activation(tmp);
|
|
293
|
-
// }
|
|
294
|
-
// }
|
|
295
|
-
|
|
296
|
-
// template <typename F, typename AdjF>
|
|
297
|
-
// CUDA_CALLABLE inline void adj_mlp(const array_t<float>& weights, const array_t<float>& bias, F activation, int m, int n, int b, int index, const array_t<float>& x, const array_t<float>& out,
|
|
298
|
-
// array_t<float>& adj_weights, array_t<float>& adj_bias, AdjF adj_activation, int adj_m, int adj_n, int adj_b, int adj_index, array_t<float>& adj_x, array_t<float>& adj_out)
|
|
299
|
-
// {
|
|
300
|
-
// x += index*n;
|
|
301
|
-
// out += index*m;
|
|
302
|
-
|
|
303
|
-
// adj_x += index*n;
|
|
304
|
-
// adj_out += index*m;
|
|
305
|
-
|
|
306
|
-
// for (int i=0; i < m; ++i)
|
|
307
|
-
// {
|
|
308
|
-
// // recompute forward pass so we don't have to store pre-activation outputs
|
|
309
|
-
// float tmp = bias[i];
|
|
310
|
-
|
|
311
|
-
// for(int j=0; j < n; ++j)
|
|
312
|
-
// {
|
|
313
|
-
// tmp += weights[i*n + j]*x[index + b*j];
|
|
314
|
-
// }
|
|
315
|
-
|
|
316
|
-
// // adjoint w.r.t to activation
|
|
317
|
-
// float adj_f = 0.0f;
|
|
318
|
-
// adj_activation(tmp, adj_f, adj_out[index + b*i]);
|
|
319
|
-
|
|
320
|
-
// for (int j=0; j < n; ++j)
|
|
321
|
-
// {
|
|
322
|
-
// // adjoint w.r.t M_i
|
|
323
|
-
// adj_weights[i*n + j] += x[j]*adj_f;
|
|
324
|
-
|
|
325
|
-
// // adjoint w.r.t x
|
|
326
|
-
// adj_x[index + b*j] += weights[i*n + j]*adj_f;
|
|
327
|
-
// }
|
|
328
|
-
|
|
329
|
-
// // adjoint w.r.t b
|
|
330
|
-
// adj_bias[i] += adj_f;
|
|
331
|
-
// }
|
|
332
|
-
// }
|
|
333
|
-
|
|
1
|
+
/** Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
* NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
+
* and proprietary rights in and to this software, related documentation
|
|
4
|
+
* and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
+
* distribution of this software and related documentation without an express
|
|
6
|
+
* license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
+
*/
|
|
8
|
+
|
|
9
|
+
#pragma once
|
|
10
|
+
|
|
11
|
+
namespace wp
|
|
12
|
+
{
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
CUDA_CALLABLE inline int dense_index(int stride, int i, int j)
|
|
16
|
+
{
|
|
17
|
+
return i*stride + j;
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
template <bool transpose>
|
|
21
|
+
CUDA_CALLABLE inline int dense_index(int rows, int cols, int i, int j)
|
|
22
|
+
{
|
|
23
|
+
if (transpose)
|
|
24
|
+
return j*rows + i;
|
|
25
|
+
else
|
|
26
|
+
return i*cols + j;
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
template <bool t1, bool t2, bool add>
|
|
32
|
+
CUDA_CALLABLE inline void dense_gemm_impl(int m, int n, int p, const float* __restrict__ A, const float* __restrict__ B, float* __restrict__ C)
|
|
33
|
+
{
|
|
34
|
+
for (int i=0; i < m; i++)
|
|
35
|
+
{
|
|
36
|
+
for (int j=0; j < n; ++j)
|
|
37
|
+
{
|
|
38
|
+
float sum = 0.0f;
|
|
39
|
+
|
|
40
|
+
for (int k=0; k < p; ++k)
|
|
41
|
+
{
|
|
42
|
+
sum += A[dense_index<t1>(m, p, i, k)]*B[dense_index<t2>(p, n, k, j)];
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
if (add)
|
|
46
|
+
C[i*n + j] += sum;
|
|
47
|
+
else
|
|
48
|
+
C[i*n + j] = sum;
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
template <bool add=false>
|
|
55
|
+
CUDA_CALLABLE inline void dense_gemm(int m, int n, int p, int t1, int t2, const array_t<float>& A, const array_t<float>& B, array_t<float>& C)
|
|
56
|
+
{
|
|
57
|
+
if (t1 == 0 && t2 == 0)
|
|
58
|
+
dense_gemm_impl<false, false, add>(m, n, p, A.data, B.data, C.data);
|
|
59
|
+
else if (t1 == 1 && t2 == 0)
|
|
60
|
+
dense_gemm_impl<true, false, add>(m, n, p, A.data, B.data, C.data);
|
|
61
|
+
else if (t1 == 0 && t2 == 1)
|
|
62
|
+
dense_gemm_impl<false, true, add>(m, n, p, A.data, B.data, C.data);
|
|
63
|
+
else if (t1 == 1 && t2 == 1)
|
|
64
|
+
dense_gemm_impl<true, true, add>(m, n, p, A.data, B.data, C.data);
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
void CUDA_CALLABLE inline dense_chol(int n, const array_t<float>& A, float regularization, array_t<float>& L)
|
|
71
|
+
{
|
|
72
|
+
for (int j=0; j < n; ++j)
|
|
73
|
+
{
|
|
74
|
+
float s = A.data[dense_index(n, j, j)] + regularization;
|
|
75
|
+
|
|
76
|
+
for (int k=0; k < j; ++k)
|
|
77
|
+
{
|
|
78
|
+
float r = L.data[dense_index(n, j, k)];
|
|
79
|
+
s -= r*r;
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
s = sqrt(s);
|
|
83
|
+
const float invS = 1.0f/s;
|
|
84
|
+
|
|
85
|
+
L.data[dense_index(n, j, j)] = s;
|
|
86
|
+
|
|
87
|
+
for (int i=j+1; i < n; ++i)
|
|
88
|
+
{
|
|
89
|
+
s = A.data[dense_index(n, i, j)];
|
|
90
|
+
|
|
91
|
+
for (int k=0; k < j; ++k)
|
|
92
|
+
{
|
|
93
|
+
s -= L.data[dense_index(n, i, k)]*L.data[dense_index(n, j, k)];
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
L.data[dense_index(n, i, j)] = s*invS;
|
|
97
|
+
}
|
|
98
|
+
}
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
// Solves (L*L^T)x = b given the Cholesky factor L
|
|
105
|
+
CUDA_CALLABLE inline void dense_subs(int n, const array_t<float>& L, const array_t<float>& b, array_t<float>& x)
|
|
106
|
+
{
|
|
107
|
+
// forward substitution
|
|
108
|
+
for (int i=0; i < n; ++i)
|
|
109
|
+
{
|
|
110
|
+
float s = b.data[i];
|
|
111
|
+
|
|
112
|
+
for (int j=0; j < i; ++j)
|
|
113
|
+
{
|
|
114
|
+
s -= L.data[dense_index(n, i, j)]*x.data[j];
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
x.data[i] = s/L.data[dense_index(n, i, i)];
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
// backward substitution
|
|
121
|
+
for (int i=n-1; i >= 0; --i)
|
|
122
|
+
{
|
|
123
|
+
float s = x.data[i];
|
|
124
|
+
|
|
125
|
+
for (int j=i+1; j < n; ++j)
|
|
126
|
+
{
|
|
127
|
+
s -= L.data[dense_index(n, j, i)]*x.data[j];
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
x.data[i] = s/L.data[dense_index(n, i, i)];
|
|
131
|
+
}
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
CUDA_CALLABLE inline void dense_solve(int n, const array_t<float>& A, const array_t<float>& L, const array_t<float>& b, array_t<float>& x)
|
|
135
|
+
{
|
|
136
|
+
dense_subs(n, L, b, x);
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
// CUDA_CALLABLE inline void print_matrix(const char* name, int m, int n, const float* data)
|
|
141
|
+
// {
|
|
142
|
+
// printf("%s = [", name);
|
|
143
|
+
|
|
144
|
+
// for (int i=0; i < m; ++i)
|
|
145
|
+
// {
|
|
146
|
+
// for (int j=0; j < n; ++j)
|
|
147
|
+
// {
|
|
148
|
+
// printf("%f ", data[dense_index(n, i, j)]);
|
|
149
|
+
// }
|
|
150
|
+
|
|
151
|
+
// printf(";\n");
|
|
152
|
+
// }
|
|
153
|
+
|
|
154
|
+
// printf("]\n");
|
|
155
|
+
// }
|
|
156
|
+
|
|
157
|
+
// adjoint methods
|
|
158
|
+
CUDA_CALLABLE inline void adj_dense_gemm(
|
|
159
|
+
int m, int n, int p, int t1, int t2, const array_t<float>& A, const array_t<float>& B, array_t<float>& C,
|
|
160
|
+
int adj_m, int adj_n, int adj_p, int adj_t1, int adj_t2, array_t<float>& adj_A, array_t<float>& adj_B, const array_t<float>& adj_C)
|
|
161
|
+
{
|
|
162
|
+
|
|
163
|
+
// print_matrix("A", m, p, A);
|
|
164
|
+
// print_matrix("B", p, n, B);
|
|
165
|
+
// printf("t1: %d t2: %d\n", t1, t2);
|
|
166
|
+
|
|
167
|
+
if (t1)
|
|
168
|
+
{
|
|
169
|
+
dense_gemm<true>(p, m, n, 0, 1, B, adj_C, adj_A);
|
|
170
|
+
dense_gemm<true>(p, n, m, int(!t1), 0, A, adj_C, adj_B);
|
|
171
|
+
}
|
|
172
|
+
else
|
|
173
|
+
{
|
|
174
|
+
dense_gemm<true>(m, p, n, 0, int(!t2), adj_C, B, adj_A);
|
|
175
|
+
dense_gemm<true>(p, n, m, int(!t1), 0, A, adj_C, adj_B);
|
|
176
|
+
}
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
CUDA_CALLABLE inline void adj_dense_chol(
|
|
181
|
+
int n, const array_t<float>& A, float regularization, array_t<float>& L,
|
|
182
|
+
int adj_n, const array_t<float>& adj_A, float adj_regularization, array_t<float>& adj_L)
|
|
183
|
+
{
|
|
184
|
+
// nop, use dense_solve to differentiate through (A^-1)b = x
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
CUDA_CALLABLE inline void adj_dense_subs(
|
|
188
|
+
int n, const array_t<float>& L, const array_t<float>& b, array_t<float>& x,
|
|
189
|
+
int adj_n, const array_t<float>& adj_L, const array_t<float>& adj_b, array_t<float>& adj_x)
|
|
190
|
+
{
|
|
191
|
+
// nop, use dense_solve to differentiate through (A^-1)b = x
|
|
192
|
+
}
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
CUDA_CALLABLE inline void adj_dense_solve(int n,
|
|
196
|
+
const array_t<float>& A, const array_t<float>& L, const array_t<float>& b, const array_t<float>& x,
|
|
197
|
+
int adj_n, array_t<float>& adj_A, array_t<float>& adj_L, array_t<float>& adj_b, const array_t<float>& adj_x)
|
|
198
|
+
{
|
|
199
|
+
// see https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pwp, section 2.3.1
|
|
200
|
+
dense_subs(n, L, adj_x, adj_b);
|
|
201
|
+
|
|
202
|
+
// A* = -adj_b*x^T
|
|
203
|
+
for (int i=0; i < n; ++i)
|
|
204
|
+
{
|
|
205
|
+
for (int j=0; j < n; ++j)
|
|
206
|
+
{
|
|
207
|
+
adj_A.data[dense_index(n, i, j)] += -adj_b.data[i]*x.data[j];
|
|
208
|
+
}
|
|
209
|
+
}
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
template <typename F>
|
|
214
|
+
CUDA_CALLABLE inline void mlp(const array_t<float>& weights, const array_t<float>& bias, F activation, int index, const array_t<float>& x, array_t<float>& out)
|
|
215
|
+
{
|
|
216
|
+
const int m = weights.shape[0];
|
|
217
|
+
const int n = weights.shape[1];
|
|
218
|
+
const int b = x.shape[1];
|
|
219
|
+
|
|
220
|
+
for (int i=0; i < m; ++i)
|
|
221
|
+
{
|
|
222
|
+
float tmp = bias.data[i];
|
|
223
|
+
|
|
224
|
+
for(int j=0; j < n; ++j)
|
|
225
|
+
{
|
|
226
|
+
tmp += weights.data[i*n + j]*x.data[index + b*j];
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
out.data[index + b*i] = activation(tmp);
|
|
230
|
+
}
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
template <typename F, typename AdjF>
|
|
234
|
+
CUDA_CALLABLE inline void adj_mlp(const array_t<float>& weights, const array_t<float>& bias, F activation, int index, const array_t<float>& x, array_t<float>& out,
|
|
235
|
+
array_t<float>& adj_weights, array_t<float>& adj_bias, AdjF adj_activation, int adj_index, array_t<float>& adj_x, array_t<float>& adj_out)
|
|
236
|
+
{
|
|
237
|
+
const int m = weights.shape[0];
|
|
238
|
+
const int n = weights.shape[1];
|
|
239
|
+
const int b = x.shape[1];
|
|
240
|
+
|
|
241
|
+
for (int i=0; i < m; ++i)
|
|
242
|
+
{
|
|
243
|
+
// recompute forward pass so we don't have to store pre-activation outputs
|
|
244
|
+
float tmp = bias.data[i];
|
|
245
|
+
|
|
246
|
+
for(int j=0; j < n; ++j)
|
|
247
|
+
{
|
|
248
|
+
tmp += weights.data[i*n + j]*x.data[index + b*j];
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
// adjoint w.r.t to activation
|
|
252
|
+
float adj_f = 0.0f;
|
|
253
|
+
|
|
254
|
+
if (adj_out.data)
|
|
255
|
+
adj_activation(tmp, adj_f, adj_out.data[index + b*i]);
|
|
256
|
+
|
|
257
|
+
for (int j=0; j < n; ++j)
|
|
258
|
+
{
|
|
259
|
+
// adjoint w.r.t M_i
|
|
260
|
+
if (adj_weights.data)
|
|
261
|
+
atomic_add(&adj_weights.data[i*n + j], x.data[index + b*j]*adj_f); // todo: reduce these atomic stores using warp/block level reductions
|
|
262
|
+
|
|
263
|
+
// adjoint w.r.t x
|
|
264
|
+
if (adj_x.data)
|
|
265
|
+
atomic_add(&adj_x.data[index + b*j], weights.data[i*n + j]*adj_f);
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
// adjoint w.r.t b
|
|
269
|
+
if (adj_bias.data)
|
|
270
|
+
atomic_add(&adj_bias.data[i], adj_f);
|
|
271
|
+
|
|
272
|
+
}
|
|
273
|
+
}
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
// template <typename F>
|
|
277
|
+
// CUDA_CALLABLE inline void mlp(const array_t<float>& weights, const array_t<float>& bias, F activation, int m, int n, int b, int index, const array_t<float>& x, array_t<float>& out)
|
|
278
|
+
// {
|
|
279
|
+
// x += index*n;
|
|
280
|
+
// out += index*m;
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
// for (int i=0; i < m; ++i)
|
|
284
|
+
// {
|
|
285
|
+
// float tmp = bias[i];
|
|
286
|
+
|
|
287
|
+
// for(int j=0; j < n; ++j)
|
|
288
|
+
// {
|
|
289
|
+
// tmp += weights[i*n + j]*x[j];
|
|
290
|
+
// }
|
|
291
|
+
|
|
292
|
+
// out[i] = activation(tmp);
|
|
293
|
+
// }
|
|
294
|
+
// }
|
|
295
|
+
|
|
296
|
+
// template <typename F, typename AdjF>
|
|
297
|
+
// CUDA_CALLABLE inline void adj_mlp(const array_t<float>& weights, const array_t<float>& bias, F activation, int m, int n, int b, int index, const array_t<float>& x, const array_t<float>& out,
|
|
298
|
+
// array_t<float>& adj_weights, array_t<float>& adj_bias, AdjF adj_activation, int adj_m, int adj_n, int adj_b, int adj_index, array_t<float>& adj_x, array_t<float>& adj_out)
|
|
299
|
+
// {
|
|
300
|
+
// x += index*n;
|
|
301
|
+
// out += index*m;
|
|
302
|
+
|
|
303
|
+
// adj_x += index*n;
|
|
304
|
+
// adj_out += index*m;
|
|
305
|
+
|
|
306
|
+
// for (int i=0; i < m; ++i)
|
|
307
|
+
// {
|
|
308
|
+
// // recompute forward pass so we don't have to store pre-activation outputs
|
|
309
|
+
// float tmp = bias[i];
|
|
310
|
+
|
|
311
|
+
// for(int j=0; j < n; ++j)
|
|
312
|
+
// {
|
|
313
|
+
// tmp += weights[i*n + j]*x[index + b*j];
|
|
314
|
+
// }
|
|
315
|
+
|
|
316
|
+
// // adjoint w.r.t to activation
|
|
317
|
+
// float adj_f = 0.0f;
|
|
318
|
+
// adj_activation(tmp, adj_f, adj_out[index + b*i]);
|
|
319
|
+
|
|
320
|
+
// for (int j=0; j < n; ++j)
|
|
321
|
+
// {
|
|
322
|
+
// // adjoint w.r.t M_i
|
|
323
|
+
// adj_weights[i*n + j] += x[j]*adj_f;
|
|
324
|
+
|
|
325
|
+
// // adjoint w.r.t x
|
|
326
|
+
// adj_x[index + b*j] += weights[i*n + j]*adj_f;
|
|
327
|
+
// }
|
|
328
|
+
|
|
329
|
+
// // adjoint w.r.t b
|
|
330
|
+
// adj_bias[i] += adj_f;
|
|
331
|
+
// }
|
|
332
|
+
// }
|
|
333
|
+
|
|
334
334
|
} // namespace wp
|