warp-lang 1.0.2__py3-none-win_amd64.whl → 1.1.0__py3-none-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +108 -97
- warp/__init__.pyi +1 -1
- warp/bin/warp-clang.dll +0 -0
- warp/bin/warp.dll +0 -0
- warp/build.py +115 -113
- warp/build_dll.py +383 -375
- warp/builtins.py +3425 -3354
- warp/codegen.py +2878 -2792
- warp/config.py +40 -36
- warp/constants.py +45 -45
- warp/context.py +5194 -5102
- warp/dlpack.py +442 -442
- warp/examples/__init__.py +16 -16
- warp/examples/assets/bear.usd +0 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/assets/cartpole.urdf +110 -110
- warp/examples/assets/crazyflie.usd +0 -0
- warp/examples/assets/cube.usd +0 -0
- warp/examples/assets/nv_ant.xml +92 -92
- warp/examples/assets/nv_humanoid.xml +183 -183
- warp/examples/assets/quadruped.urdf +267 -267
- warp/examples/assets/rocks.nvdb +0 -0
- warp/examples/assets/rocks.usd +0 -0
- warp/examples/assets/sphere.usd +0 -0
- warp/examples/benchmarks/benchmark_api.py +383 -383
- warp/examples/benchmarks/benchmark_cloth.py +278 -277
- warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -88
- warp/examples/benchmarks/benchmark_cloth_jax.py +97 -100
- warp/examples/benchmarks/benchmark_cloth_numba.py +146 -142
- warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -77
- warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -86
- warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -112
- warp/examples/benchmarks/benchmark_cloth_warp.py +146 -146
- warp/examples/benchmarks/benchmark_launches.py +295 -295
- warp/examples/browse.py +29 -29
- warp/examples/core/example_dem.py +234 -219
- warp/examples/core/example_fluid.py +293 -267
- warp/examples/core/example_graph_capture.py +144 -126
- warp/examples/core/example_marching_cubes.py +188 -174
- warp/examples/core/example_mesh.py +174 -155
- warp/examples/core/example_mesh_intersect.py +205 -193
- warp/examples/core/example_nvdb.py +176 -170
- warp/examples/core/example_raycast.py +105 -90
- warp/examples/core/example_raymarch.py +199 -178
- warp/examples/core/example_render_opengl.py +185 -141
- warp/examples/core/example_sph.py +405 -387
- warp/examples/core/example_torch.py +222 -181
- warp/examples/core/example_wave.py +263 -248
- warp/examples/fem/bsr_utils.py +378 -380
- warp/examples/fem/example_apic_fluid.py +407 -389
- warp/examples/fem/example_convection_diffusion.py +182 -168
- warp/examples/fem/example_convection_diffusion_dg.py +219 -209
- warp/examples/fem/example_convection_diffusion_dg0.py +204 -194
- warp/examples/fem/example_deformed_geometry.py +177 -159
- warp/examples/fem/example_diffusion.py +201 -173
- warp/examples/fem/example_diffusion_3d.py +177 -152
- warp/examples/fem/example_diffusion_mgpu.py +221 -214
- warp/examples/fem/example_mixed_elasticity.py +244 -222
- warp/examples/fem/example_navier_stokes.py +259 -243
- warp/examples/fem/example_stokes.py +220 -192
- warp/examples/fem/example_stokes_transfer.py +265 -249
- warp/examples/fem/mesh_utils.py +133 -109
- warp/examples/fem/plot_utils.py +292 -287
- warp/examples/optim/example_bounce.py +260 -246
- warp/examples/optim/example_cloth_throw.py +222 -209
- warp/examples/optim/example_diffray.py +566 -536
- warp/examples/optim/example_drone.py +864 -835
- warp/examples/optim/example_inverse_kinematics.py +176 -168
- warp/examples/optim/example_inverse_kinematics_torch.py +185 -169
- warp/examples/optim/example_spring_cage.py +239 -231
- warp/examples/optim/example_trajectory.py +223 -199
- warp/examples/optim/example_walker.py +306 -293
- warp/examples/sim/example_cartpole.py +139 -129
- warp/examples/sim/example_cloth.py +196 -186
- warp/examples/sim/example_granular.py +124 -111
- warp/examples/sim/example_granular_collision_sdf.py +197 -186
- warp/examples/sim/example_jacobian_ik.py +236 -214
- warp/examples/sim/example_particle_chain.py +118 -105
- warp/examples/sim/example_quadruped.py +193 -180
- warp/examples/sim/example_rigid_chain.py +197 -187
- warp/examples/sim/example_rigid_contact.py +189 -177
- warp/examples/sim/example_rigid_force.py +127 -125
- warp/examples/sim/example_rigid_gyroscopic.py +109 -95
- warp/examples/sim/example_rigid_soft_contact.py +134 -122
- warp/examples/sim/example_soft_body.py +190 -177
- warp/fabric.py +337 -335
- warp/fem/__init__.py +60 -27
- warp/fem/cache.py +401 -388
- warp/fem/dirichlet.py +178 -179
- warp/fem/domain.py +262 -263
- warp/fem/field/__init__.py +100 -101
- warp/fem/field/field.py +148 -149
- warp/fem/field/nodal_field.py +298 -299
- warp/fem/field/restriction.py +22 -21
- warp/fem/field/test.py +180 -181
- warp/fem/field/trial.py +183 -183
- warp/fem/geometry/__init__.py +15 -19
- warp/fem/geometry/closest_point.py +69 -70
- warp/fem/geometry/deformed_geometry.py +270 -271
- warp/fem/geometry/element.py +744 -744
- warp/fem/geometry/geometry.py +184 -186
- warp/fem/geometry/grid_2d.py +380 -373
- warp/fem/geometry/grid_3d.py +441 -435
- warp/fem/geometry/hexmesh.py +953 -953
- warp/fem/geometry/partition.py +374 -376
- warp/fem/geometry/quadmesh_2d.py +532 -532
- warp/fem/geometry/tetmesh.py +840 -840
- warp/fem/geometry/trimesh_2d.py +577 -577
- warp/fem/integrate.py +1630 -1615
- warp/fem/operator.py +190 -191
- warp/fem/polynomial.py +214 -213
- warp/fem/quadrature/__init__.py +2 -2
- warp/fem/quadrature/pic_quadrature.py +243 -245
- warp/fem/quadrature/quadrature.py +295 -294
- warp/fem/space/__init__.py +294 -292
- warp/fem/space/basis_space.py +488 -489
- warp/fem/space/collocated_function_space.py +100 -105
- warp/fem/space/dof_mapper.py +236 -236
- warp/fem/space/function_space.py +148 -145
- warp/fem/space/grid_2d_function_space.py +267 -267
- warp/fem/space/grid_3d_function_space.py +305 -306
- warp/fem/space/hexmesh_function_space.py +350 -352
- warp/fem/space/partition.py +350 -350
- warp/fem/space/quadmesh_2d_function_space.py +368 -369
- warp/fem/space/restriction.py +158 -160
- warp/fem/space/shape/__init__.py +13 -15
- warp/fem/space/shape/cube_shape_function.py +738 -738
- warp/fem/space/shape/shape_function.py +102 -103
- warp/fem/space/shape/square_shape_function.py +611 -611
- warp/fem/space/shape/tet_shape_function.py +565 -567
- warp/fem/space/shape/triangle_shape_function.py +429 -429
- warp/fem/space/tetmesh_function_space.py +294 -292
- warp/fem/space/topology.py +297 -295
- warp/fem/space/trimesh_2d_function_space.py +223 -221
- warp/fem/types.py +77 -77
- warp/fem/utils.py +495 -495
- warp/jax.py +166 -141
- warp/jax_experimental.py +341 -339
- warp/native/array.h +1072 -1025
- warp/native/builtin.h +1560 -1560
- warp/native/bvh.cpp +398 -398
- warp/native/bvh.cu +525 -525
- warp/native/bvh.h +429 -429
- warp/native/clang/clang.cpp +495 -464
- warp/native/crt.cpp +31 -31
- warp/native/crt.h +334 -334
- warp/native/cuda_crt.h +1049 -1049
- warp/native/cuda_util.cpp +549 -540
- warp/native/cuda_util.h +288 -203
- warp/native/cutlass_gemm.cpp +34 -34
- warp/native/cutlass_gemm.cu +372 -372
- warp/native/error.cpp +66 -66
- warp/native/error.h +27 -27
- warp/native/fabric.h +228 -228
- warp/native/hashgrid.cpp +301 -278
- warp/native/hashgrid.cu +78 -77
- warp/native/hashgrid.h +227 -227
- warp/native/initializer_array.h +32 -32
- warp/native/intersect.h +1204 -1204
- warp/native/intersect_adj.h +365 -365
- warp/native/intersect_tri.h +322 -322
- warp/native/marching.cpp +2 -2
- warp/native/marching.cu +497 -497
- warp/native/marching.h +2 -2
- warp/native/mat.h +1498 -1498
- warp/native/matnn.h +333 -333
- warp/native/mesh.cpp +203 -203
- warp/native/mesh.cu +293 -293
- warp/native/mesh.h +1887 -1887
- warp/native/nanovdb/NanoVDB.h +4782 -4782
- warp/native/nanovdb/PNanoVDB.h +2553 -2553
- warp/native/nanovdb/PNanoVDBWrite.h +294 -294
- warp/native/noise.h +850 -850
- warp/native/quat.h +1084 -1084
- warp/native/rand.h +299 -299
- warp/native/range.h +108 -108
- warp/native/reduce.cpp +156 -156
- warp/native/reduce.cu +348 -348
- warp/native/runlength_encode.cpp +61 -61
- warp/native/runlength_encode.cu +46 -46
- warp/native/scan.cpp +30 -30
- warp/native/scan.cu +36 -36
- warp/native/scan.h +7 -7
- warp/native/solid_angle.h +442 -442
- warp/native/sort.cpp +94 -94
- warp/native/sort.cu +97 -97
- warp/native/sort.h +14 -14
- warp/native/sparse.cpp +337 -337
- warp/native/sparse.cu +544 -544
- warp/native/spatial.h +630 -630
- warp/native/svd.h +562 -562
- warp/native/temp_buffer.h +30 -30
- warp/native/vec.h +1132 -1132
- warp/native/volume.cpp +297 -297
- warp/native/volume.cu +32 -32
- warp/native/volume.h +538 -538
- warp/native/volume_builder.cu +425 -425
- warp/native/volume_builder.h +19 -19
- warp/native/warp.cpp +1057 -1052
- warp/native/warp.cu +2943 -2828
- warp/native/warp.h +313 -305
- warp/optim/__init__.py +9 -9
- warp/optim/adam.py +120 -120
- warp/optim/linear.py +1104 -939
- warp/optim/sgd.py +104 -92
- warp/render/__init__.py +10 -10
- warp/render/render_opengl.py +3217 -3204
- warp/render/render_usd.py +768 -749
- warp/render/utils.py +152 -150
- warp/sim/__init__.py +52 -59
- warp/sim/articulation.py +685 -685
- warp/sim/collide.py +1594 -1590
- warp/sim/import_mjcf.py +489 -481
- warp/sim/import_snu.py +220 -221
- warp/sim/import_urdf.py +536 -516
- warp/sim/import_usd.py +887 -881
- warp/sim/inertia.py +316 -317
- warp/sim/integrator.py +234 -233
- warp/sim/integrator_euler.py +1956 -1956
- warp/sim/integrator_featherstone.py +1910 -1991
- warp/sim/integrator_xpbd.py +3294 -3312
- warp/sim/model.py +4473 -4314
- warp/sim/particles.py +113 -112
- warp/sim/render.py +417 -403
- warp/sim/utils.py +413 -410
- warp/sparse.py +1227 -1227
- warp/stubs.py +2109 -2469
- warp/tape.py +1162 -225
- warp/tests/__init__.py +1 -1
- warp/tests/__main__.py +4 -4
- warp/tests/assets/torus.usda +105 -105
- warp/tests/aux_test_class_kernel.py +26 -26
- warp/tests/aux_test_compile_consts_dummy.py +10 -10
- warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -21
- warp/tests/aux_test_dependent.py +22 -22
- warp/tests/aux_test_grad_customs.py +23 -23
- warp/tests/aux_test_reference.py +11 -11
- warp/tests/aux_test_reference_reference.py +10 -10
- warp/tests/aux_test_square.py +17 -17
- warp/tests/aux_test_unresolved_func.py +14 -14
- warp/tests/aux_test_unresolved_symbol.py +14 -14
- warp/tests/disabled_kinematics.py +239 -239
- warp/tests/run_coverage_serial.py +31 -31
- warp/tests/test_adam.py +157 -157
- warp/tests/test_arithmetic.py +1124 -1124
- warp/tests/test_array.py +2417 -2326
- warp/tests/test_array_reduce.py +150 -150
- warp/tests/test_async.py +668 -656
- warp/tests/test_atomic.py +141 -141
- warp/tests/test_bool.py +204 -149
- warp/tests/test_builtins_resolution.py +1292 -1292
- warp/tests/test_bvh.py +164 -171
- warp/tests/test_closest_point_edge_edge.py +228 -228
- warp/tests/test_codegen.py +566 -553
- warp/tests/test_compile_consts.py +97 -101
- warp/tests/test_conditional.py +246 -246
- warp/tests/test_copy.py +232 -215
- warp/tests/test_ctypes.py +632 -632
- warp/tests/test_dense.py +67 -67
- warp/tests/test_devices.py +91 -98
- warp/tests/test_dlpack.py +530 -529
- warp/tests/test_examples.py +400 -378
- warp/tests/test_fabricarray.py +955 -955
- warp/tests/test_fast_math.py +62 -54
- warp/tests/test_fem.py +1277 -1278
- warp/tests/test_fp16.py +130 -130
- warp/tests/test_func.py +338 -337
- warp/tests/test_generics.py +571 -571
- warp/tests/test_grad.py +746 -640
- warp/tests/test_grad_customs.py +333 -336
- warp/tests/test_hash_grid.py +210 -164
- warp/tests/test_import.py +39 -39
- warp/tests/test_indexedarray.py +1134 -1134
- warp/tests/test_intersect.py +67 -67
- warp/tests/test_jax.py +307 -307
- warp/tests/test_large.py +167 -164
- warp/tests/test_launch.py +354 -354
- warp/tests/test_lerp.py +261 -261
- warp/tests/test_linear_solvers.py +191 -171
- warp/tests/test_lvalue.py +421 -493
- warp/tests/test_marching_cubes.py +65 -65
- warp/tests/test_mat.py +1801 -1827
- warp/tests/test_mat_lite.py +115 -115
- warp/tests/test_mat_scalar_ops.py +2907 -2889
- warp/tests/test_math.py +126 -193
- warp/tests/test_matmul.py +500 -499
- warp/tests/test_matmul_lite.py +410 -410
- warp/tests/test_mempool.py +188 -190
- warp/tests/test_mesh.py +284 -324
- warp/tests/test_mesh_query_aabb.py +228 -241
- warp/tests/test_mesh_query_point.py +692 -702
- warp/tests/test_mesh_query_ray.py +292 -303
- warp/tests/test_mlp.py +276 -276
- warp/tests/test_model.py +110 -110
- warp/tests/test_modules_lite.py +39 -39
- warp/tests/test_multigpu.py +163 -163
- warp/tests/test_noise.py +248 -248
- warp/tests/test_operators.py +250 -250
- warp/tests/test_options.py +123 -125
- warp/tests/test_peer.py +133 -137
- warp/tests/test_pinned.py +78 -78
- warp/tests/test_print.py +54 -54
- warp/tests/test_quat.py +2086 -2086
- warp/tests/test_rand.py +288 -288
- warp/tests/test_reload.py +217 -217
- warp/tests/test_rounding.py +179 -179
- warp/tests/test_runlength_encode.py +190 -190
- warp/tests/test_sim_grad.py +243 -0
- warp/tests/test_sim_kinematics.py +91 -97
- warp/tests/test_smoothstep.py +168 -168
- warp/tests/test_snippet.py +305 -266
- warp/tests/test_sparse.py +468 -460
- warp/tests/test_spatial.py +2148 -2148
- warp/tests/test_streams.py +486 -473
- warp/tests/test_struct.py +710 -675
- warp/tests/test_tape.py +173 -148
- warp/tests/test_torch.py +743 -743
- warp/tests/test_transient_module.py +87 -87
- warp/tests/test_types.py +556 -659
- warp/tests/test_utils.py +490 -499
- warp/tests/test_vec.py +1264 -1268
- warp/tests/test_vec_lite.py +73 -73
- warp/tests/test_vec_scalar_ops.py +2099 -2099
- warp/tests/test_verify_fp.py +94 -94
- warp/tests/test_volume.py +737 -736
- warp/tests/test_volume_write.py +255 -265
- warp/tests/unittest_serial.py +37 -37
- warp/tests/unittest_suites.py +363 -359
- warp/tests/unittest_utils.py +603 -578
- warp/tests/unused_test_misc.py +71 -71
- warp/tests/walkthrough_debug.py +85 -85
- warp/thirdparty/appdirs.py +598 -598
- warp/thirdparty/dlpack.py +143 -143
- warp/thirdparty/unittest_parallel.py +566 -561
- warp/torch.py +321 -295
- warp/types.py +4504 -4450
- warp/utils.py +1008 -821
- {warp_lang-1.0.2.dist-info → warp_lang-1.1.0.dist-info}/LICENSE.md +126 -126
- {warp_lang-1.0.2.dist-info → warp_lang-1.1.0.dist-info}/METADATA +338 -400
- warp_lang-1.1.0.dist-info/RECORD +352 -0
- warp/examples/assets/cube.usda +0 -42
- warp/examples/assets/sphere.usda +0 -56
- warp/examples/assets/torus.usda +0 -105
- warp_lang-1.0.2.dist-info/RECORD +0 -352
- {warp_lang-1.0.2.dist-info → warp_lang-1.1.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.0.2.dist-info → warp_lang-1.1.0.dist-info}/top_level.txt +0 -0
warp/optim/__init__.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
# Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
-
# and proprietary rights in and to this software, related documentation
|
|
4
|
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
-
# distribution of this software and related documentation without an express
|
|
6
|
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
-
|
|
8
|
-
from .adam import Adam
|
|
9
|
-
from .sgd import SGD
|
|
1
|
+
# Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
+
# and proprietary rights in and to this software, related documentation
|
|
4
|
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
+
# distribution of this software and related documentation without an express
|
|
6
|
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
+
|
|
8
|
+
from .adam import Adam
|
|
9
|
+
from .sgd import SGD
|
warp/optim/adam.py
CHANGED
|
@@ -1,120 +1,120 @@
|
|
|
1
|
-
# Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
-
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
-
# and proprietary rights in and to this software, related documentation
|
|
4
|
-
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
-
# distribution of this software and related documentation without an express
|
|
6
|
-
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
-
|
|
8
|
-
import warp as wp
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
@wp.kernel
|
|
12
|
-
def adam_step_kernel_vec3(
|
|
13
|
-
g: wp.array(dtype=wp.vec3),
|
|
14
|
-
m: wp.array(dtype=wp.vec3),
|
|
15
|
-
v: wp.array(dtype=wp.vec3),
|
|
16
|
-
lr: float,
|
|
17
|
-
beta1: float,
|
|
18
|
-
beta2: float,
|
|
19
|
-
t: float,
|
|
20
|
-
eps: float,
|
|
21
|
-
params: wp.array(dtype=wp.vec3),
|
|
22
|
-
):
|
|
23
|
-
i = wp.tid()
|
|
24
|
-
m[i] = beta1 * m[i] + (1.0 - beta1) * g[i]
|
|
25
|
-
v[i] = beta2 * v[i] + (1.0 - beta2) * wp.cw_mul(g[i], g[i])
|
|
26
|
-
mhat = m[i] / (1.0 - wp.pow(beta1, (t + 1.0)))
|
|
27
|
-
vhat = v[i] / (1.0 - wp.pow(beta2, (t + 1.0)))
|
|
28
|
-
sqrt_vhat = wp.vec3(wp.sqrt(vhat[0]), wp.sqrt(vhat[1]), wp.sqrt(vhat[2]))
|
|
29
|
-
eps_vec3 = wp.vec3(eps, eps, eps)
|
|
30
|
-
params[i] = params[i] - lr * wp.cw_div(mhat, (sqrt_vhat + eps_vec3))
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
@wp.kernel
|
|
34
|
-
def adam_step_kernel_float(
|
|
35
|
-
g: wp.array(dtype=float),
|
|
36
|
-
m: wp.array(dtype=float),
|
|
37
|
-
v: wp.array(dtype=float),
|
|
38
|
-
lr: float,
|
|
39
|
-
beta1: float,
|
|
40
|
-
beta2: float,
|
|
41
|
-
t: float,
|
|
42
|
-
eps: float,
|
|
43
|
-
params: wp.array(dtype=float),
|
|
44
|
-
):
|
|
45
|
-
i = wp.tid()
|
|
46
|
-
m[i] = beta1 * m[i] + (1.0 - beta1) * g[i]
|
|
47
|
-
v[i] = beta2 * v[i] + (1.0 - beta2) * g[i] * g[i]
|
|
48
|
-
mhat = m[i] / (1.0 - wp.pow(beta1, (t + 1.0)))
|
|
49
|
-
vhat = v[i] / (1.0 - wp.pow(beta2, (t + 1.0)))
|
|
50
|
-
params[i] = params[i] - lr * mhat / (wp.sqrt(vhat) + eps)
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
class Adam:
|
|
54
|
-
"""An implementation of the Adam Optimizer
|
|
55
|
-
It is designed to mimic Pytorch's version.
|
|
56
|
-
https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam
|
|
57
|
-
"""
|
|
58
|
-
|
|
59
|
-
def __init__(self, params=None, lr=0.001, betas=(0.9, 0.999), eps=1e-08):
|
|
60
|
-
self.m = [] # first moment
|
|
61
|
-
self.v = [] # second moment
|
|
62
|
-
self.set_params(params)
|
|
63
|
-
self.lr = lr
|
|
64
|
-
self.beta1 = betas[0]
|
|
65
|
-
self.beta2 = betas[1]
|
|
66
|
-
self.eps = eps
|
|
67
|
-
self.t = 0
|
|
68
|
-
|
|
69
|
-
def set_params(self, params):
|
|
70
|
-
self.params = params
|
|
71
|
-
if params
|
|
72
|
-
if len(self.m) != len(params):
|
|
73
|
-
self.m = [None] * len(params) # reset first moment
|
|
74
|
-
if len(self.v) != len(params):
|
|
75
|
-
self.v = [None] * len(params) # reset second moment
|
|
76
|
-
for i in range(len(params)):
|
|
77
|
-
param = params[i]
|
|
78
|
-
if self.m[i]
|
|
79
|
-
self.m[i] = wp.zeros_like(param)
|
|
80
|
-
if self.v[i]
|
|
81
|
-
self.v[i] = wp.zeros_like(param)
|
|
82
|
-
|
|
83
|
-
def reset_internal_state(self):
|
|
84
|
-
for m_i in self.m:
|
|
85
|
-
m_i.zero_()
|
|
86
|
-
for v_i in self.v:
|
|
87
|
-
v_i.zero_()
|
|
88
|
-
self.t = 0
|
|
89
|
-
|
|
90
|
-
def step(self, grad):
|
|
91
|
-
assert self.params
|
|
92
|
-
for i in range(len(self.params)):
|
|
93
|
-
Adam.step_detail(
|
|
94
|
-
grad[i], self.m[i], self.v[i], self.lr, self.beta1, self.beta2, self.t, self.eps, self.params[i]
|
|
95
|
-
)
|
|
96
|
-
self.t = self.t + 1
|
|
97
|
-
|
|
98
|
-
@staticmethod
|
|
99
|
-
def step_detail(g, m, v, lr, beta1, beta2, t, eps, params):
|
|
100
|
-
assert params.dtype == g.dtype
|
|
101
|
-
assert params.dtype == m.dtype
|
|
102
|
-
assert params.dtype == v.dtype
|
|
103
|
-
assert params.shape == g.shape
|
|
104
|
-
kernel_inputs = [g, m, v, lr, beta1, beta2, t, eps, params]
|
|
105
|
-
if params.dtype == wp.types.float32:
|
|
106
|
-
wp.launch(
|
|
107
|
-
kernel=adam_step_kernel_float,
|
|
108
|
-
dim=len(params),
|
|
109
|
-
inputs=kernel_inputs,
|
|
110
|
-
device=params.device,
|
|
111
|
-
)
|
|
112
|
-
elif params.dtype == wp.types.vec3:
|
|
113
|
-
wp.launch(
|
|
114
|
-
kernel=adam_step_kernel_vec3,
|
|
115
|
-
dim=len(params),
|
|
116
|
-
inputs=kernel_inputs,
|
|
117
|
-
device=params.device,
|
|
118
|
-
)
|
|
119
|
-
else:
|
|
120
|
-
raise RuntimeError("Params data type not supported in Adam step kernels.")
|
|
1
|
+
# Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
|
|
2
|
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
|
3
|
+
# and proprietary rights in and to this software, related documentation
|
|
4
|
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
|
5
|
+
# distribution of this software and related documentation without an express
|
|
6
|
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
|
7
|
+
|
|
8
|
+
import warp as wp
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@wp.kernel
|
|
12
|
+
def adam_step_kernel_vec3(
|
|
13
|
+
g: wp.array(dtype=wp.vec3),
|
|
14
|
+
m: wp.array(dtype=wp.vec3),
|
|
15
|
+
v: wp.array(dtype=wp.vec3),
|
|
16
|
+
lr: float,
|
|
17
|
+
beta1: float,
|
|
18
|
+
beta2: float,
|
|
19
|
+
t: float,
|
|
20
|
+
eps: float,
|
|
21
|
+
params: wp.array(dtype=wp.vec3),
|
|
22
|
+
):
|
|
23
|
+
i = wp.tid()
|
|
24
|
+
m[i] = beta1 * m[i] + (1.0 - beta1) * g[i]
|
|
25
|
+
v[i] = beta2 * v[i] + (1.0 - beta2) * wp.cw_mul(g[i], g[i])
|
|
26
|
+
mhat = m[i] / (1.0 - wp.pow(beta1, (t + 1.0)))
|
|
27
|
+
vhat = v[i] / (1.0 - wp.pow(beta2, (t + 1.0)))
|
|
28
|
+
sqrt_vhat = wp.vec3(wp.sqrt(vhat[0]), wp.sqrt(vhat[1]), wp.sqrt(vhat[2]))
|
|
29
|
+
eps_vec3 = wp.vec3(eps, eps, eps)
|
|
30
|
+
params[i] = params[i] - lr * wp.cw_div(mhat, (sqrt_vhat + eps_vec3))
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@wp.kernel
|
|
34
|
+
def adam_step_kernel_float(
|
|
35
|
+
g: wp.array(dtype=float),
|
|
36
|
+
m: wp.array(dtype=float),
|
|
37
|
+
v: wp.array(dtype=float),
|
|
38
|
+
lr: float,
|
|
39
|
+
beta1: float,
|
|
40
|
+
beta2: float,
|
|
41
|
+
t: float,
|
|
42
|
+
eps: float,
|
|
43
|
+
params: wp.array(dtype=float),
|
|
44
|
+
):
|
|
45
|
+
i = wp.tid()
|
|
46
|
+
m[i] = beta1 * m[i] + (1.0 - beta1) * g[i]
|
|
47
|
+
v[i] = beta2 * v[i] + (1.0 - beta2) * g[i] * g[i]
|
|
48
|
+
mhat = m[i] / (1.0 - wp.pow(beta1, (t + 1.0)))
|
|
49
|
+
vhat = v[i] / (1.0 - wp.pow(beta2, (t + 1.0)))
|
|
50
|
+
params[i] = params[i] - lr * mhat / (wp.sqrt(vhat) + eps)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class Adam:
|
|
54
|
+
"""An implementation of the Adam Optimizer
|
|
55
|
+
It is designed to mimic Pytorch's version.
|
|
56
|
+
https://pytorch.org/docs/stable/generated/torch.optim.Adam.html#torch.optim.Adam
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def __init__(self, params=None, lr=0.001, betas=(0.9, 0.999), eps=1e-08):
|
|
60
|
+
self.m = [] # first moment
|
|
61
|
+
self.v = [] # second moment
|
|
62
|
+
self.set_params(params)
|
|
63
|
+
self.lr = lr
|
|
64
|
+
self.beta1 = betas[0]
|
|
65
|
+
self.beta2 = betas[1]
|
|
66
|
+
self.eps = eps
|
|
67
|
+
self.t = 0
|
|
68
|
+
|
|
69
|
+
def set_params(self, params):
|
|
70
|
+
self.params = params
|
|
71
|
+
if params is not None and isinstance(params, list) and len(params) > 0:
|
|
72
|
+
if len(self.m) != len(params):
|
|
73
|
+
self.m = [None] * len(params) # reset first moment
|
|
74
|
+
if len(self.v) != len(params):
|
|
75
|
+
self.v = [None] * len(params) # reset second moment
|
|
76
|
+
for i in range(len(params)):
|
|
77
|
+
param = params[i]
|
|
78
|
+
if self.m[i] is None or self.m[i].shape != param.shape or self.m[i].dtype != param.dtype:
|
|
79
|
+
self.m[i] = wp.zeros_like(param)
|
|
80
|
+
if self.v[i] is None or self.v[i].shape != param.shape or self.v[i].dtype != param.dtype:
|
|
81
|
+
self.v[i] = wp.zeros_like(param)
|
|
82
|
+
|
|
83
|
+
def reset_internal_state(self):
|
|
84
|
+
for m_i in self.m:
|
|
85
|
+
m_i.zero_()
|
|
86
|
+
for v_i in self.v:
|
|
87
|
+
v_i.zero_()
|
|
88
|
+
self.t = 0
|
|
89
|
+
|
|
90
|
+
def step(self, grad):
|
|
91
|
+
assert self.params is not None
|
|
92
|
+
for i in range(len(self.params)):
|
|
93
|
+
Adam.step_detail(
|
|
94
|
+
grad[i], self.m[i], self.v[i], self.lr, self.beta1, self.beta2, self.t, self.eps, self.params[i]
|
|
95
|
+
)
|
|
96
|
+
self.t = self.t + 1
|
|
97
|
+
|
|
98
|
+
@staticmethod
|
|
99
|
+
def step_detail(g, m, v, lr, beta1, beta2, t, eps, params):
|
|
100
|
+
assert params.dtype == g.dtype
|
|
101
|
+
assert params.dtype == m.dtype
|
|
102
|
+
assert params.dtype == v.dtype
|
|
103
|
+
assert params.shape == g.shape
|
|
104
|
+
kernel_inputs = [g, m, v, lr, beta1, beta2, t, eps, params]
|
|
105
|
+
if params.dtype == wp.types.float32:
|
|
106
|
+
wp.launch(
|
|
107
|
+
kernel=adam_step_kernel_float,
|
|
108
|
+
dim=len(params),
|
|
109
|
+
inputs=kernel_inputs,
|
|
110
|
+
device=params.device,
|
|
111
|
+
)
|
|
112
|
+
elif params.dtype == wp.types.vec3:
|
|
113
|
+
wp.launch(
|
|
114
|
+
kernel=adam_step_kernel_vec3,
|
|
115
|
+
dim=len(params),
|
|
116
|
+
inputs=kernel_inputs,
|
|
117
|
+
device=params.device,
|
|
118
|
+
)
|
|
119
|
+
else:
|
|
120
|
+
raise RuntimeError("Params data type not supported in Adam step kernels.")
|