warp-lang 1.7.0__py3-none-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (429) hide show
  1. warp/__init__.py +139 -0
  2. warp/__init__.pyi +1 -0
  3. warp/autograd.py +1142 -0
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +557 -0
  7. warp/build_dll.py +405 -0
  8. warp/builtins.py +6855 -0
  9. warp/codegen.py +3969 -0
  10. warp/config.py +158 -0
  11. warp/constants.py +57 -0
  12. warp/context.py +6812 -0
  13. warp/dlpack.py +462 -0
  14. warp/examples/__init__.py +24 -0
  15. warp/examples/assets/bear.usd +0 -0
  16. warp/examples/assets/bunny.usd +0 -0
  17. warp/examples/assets/cartpole.urdf +110 -0
  18. warp/examples/assets/crazyflie.usd +0 -0
  19. warp/examples/assets/cube.usd +0 -0
  20. warp/examples/assets/nonuniform.usd +0 -0
  21. warp/examples/assets/nv_ant.xml +92 -0
  22. warp/examples/assets/nv_humanoid.xml +183 -0
  23. warp/examples/assets/nvidia_logo.png +0 -0
  24. warp/examples/assets/pixel.jpg +0 -0
  25. warp/examples/assets/quadruped.urdf +268 -0
  26. warp/examples/assets/rocks.nvdb +0 -0
  27. warp/examples/assets/rocks.usd +0 -0
  28. warp/examples/assets/sphere.usd +0 -0
  29. warp/examples/assets/square_cloth.usd +0 -0
  30. warp/examples/benchmarks/benchmark_api.py +389 -0
  31. warp/examples/benchmarks/benchmark_cloth.py +296 -0
  32. warp/examples/benchmarks/benchmark_cloth_cupy.py +96 -0
  33. warp/examples/benchmarks/benchmark_cloth_jax.py +105 -0
  34. warp/examples/benchmarks/benchmark_cloth_numba.py +161 -0
  35. warp/examples/benchmarks/benchmark_cloth_numpy.py +85 -0
  36. warp/examples/benchmarks/benchmark_cloth_paddle.py +94 -0
  37. warp/examples/benchmarks/benchmark_cloth_pytorch.py +94 -0
  38. warp/examples/benchmarks/benchmark_cloth_taichi.py +120 -0
  39. warp/examples/benchmarks/benchmark_cloth_warp.py +153 -0
  40. warp/examples/benchmarks/benchmark_gemm.py +164 -0
  41. warp/examples/benchmarks/benchmark_interop_paddle.py +166 -0
  42. warp/examples/benchmarks/benchmark_interop_torch.py +166 -0
  43. warp/examples/benchmarks/benchmark_launches.py +301 -0
  44. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  45. warp/examples/browse.py +37 -0
  46. warp/examples/core/example_cupy.py +86 -0
  47. warp/examples/core/example_dem.py +241 -0
  48. warp/examples/core/example_fluid.py +299 -0
  49. warp/examples/core/example_graph_capture.py +150 -0
  50. warp/examples/core/example_marching_cubes.py +194 -0
  51. warp/examples/core/example_mesh.py +180 -0
  52. warp/examples/core/example_mesh_intersect.py +211 -0
  53. warp/examples/core/example_nvdb.py +182 -0
  54. warp/examples/core/example_raycast.py +111 -0
  55. warp/examples/core/example_raymarch.py +205 -0
  56. warp/examples/core/example_render_opengl.py +193 -0
  57. warp/examples/core/example_sample_mesh.py +300 -0
  58. warp/examples/core/example_sph.py +411 -0
  59. warp/examples/core/example_torch.py +211 -0
  60. warp/examples/core/example_wave.py +269 -0
  61. warp/examples/fem/example_adaptive_grid.py +286 -0
  62. warp/examples/fem/example_apic_fluid.py +423 -0
  63. warp/examples/fem/example_burgers.py +261 -0
  64. warp/examples/fem/example_convection_diffusion.py +178 -0
  65. warp/examples/fem/example_convection_diffusion_dg.py +204 -0
  66. warp/examples/fem/example_deformed_geometry.py +172 -0
  67. warp/examples/fem/example_diffusion.py +196 -0
  68. warp/examples/fem/example_diffusion_3d.py +225 -0
  69. warp/examples/fem/example_diffusion_mgpu.py +220 -0
  70. warp/examples/fem/example_distortion_energy.py +228 -0
  71. warp/examples/fem/example_magnetostatics.py +240 -0
  72. warp/examples/fem/example_mixed_elasticity.py +291 -0
  73. warp/examples/fem/example_navier_stokes.py +261 -0
  74. warp/examples/fem/example_nonconforming_contact.py +298 -0
  75. warp/examples/fem/example_stokes.py +213 -0
  76. warp/examples/fem/example_stokes_transfer.py +262 -0
  77. warp/examples/fem/example_streamlines.py +352 -0
  78. warp/examples/fem/utils.py +1000 -0
  79. warp/examples/interop/example_jax_callable.py +116 -0
  80. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  81. warp/examples/interop/example_jax_kernel.py +205 -0
  82. warp/examples/optim/example_bounce.py +266 -0
  83. warp/examples/optim/example_cloth_throw.py +228 -0
  84. warp/examples/optim/example_diffray.py +561 -0
  85. warp/examples/optim/example_drone.py +870 -0
  86. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  87. warp/examples/optim/example_inverse_kinematics.py +182 -0
  88. warp/examples/optim/example_inverse_kinematics_torch.py +191 -0
  89. warp/examples/optim/example_softbody_properties.py +400 -0
  90. warp/examples/optim/example_spring_cage.py +245 -0
  91. warp/examples/optim/example_trajectory.py +227 -0
  92. warp/examples/sim/example_cartpole.py +143 -0
  93. warp/examples/sim/example_cloth.py +225 -0
  94. warp/examples/sim/example_cloth_self_contact.py +322 -0
  95. warp/examples/sim/example_granular.py +130 -0
  96. warp/examples/sim/example_granular_collision_sdf.py +202 -0
  97. warp/examples/sim/example_jacobian_ik.py +244 -0
  98. warp/examples/sim/example_particle_chain.py +124 -0
  99. warp/examples/sim/example_quadruped.py +203 -0
  100. warp/examples/sim/example_rigid_chain.py +203 -0
  101. warp/examples/sim/example_rigid_contact.py +195 -0
  102. warp/examples/sim/example_rigid_force.py +133 -0
  103. warp/examples/sim/example_rigid_gyroscopic.py +115 -0
  104. warp/examples/sim/example_rigid_soft_contact.py +140 -0
  105. warp/examples/sim/example_soft_body.py +196 -0
  106. warp/examples/tile/example_tile_cholesky.py +87 -0
  107. warp/examples/tile/example_tile_convolution.py +66 -0
  108. warp/examples/tile/example_tile_fft.py +55 -0
  109. warp/examples/tile/example_tile_filtering.py +113 -0
  110. warp/examples/tile/example_tile_matmul.py +85 -0
  111. warp/examples/tile/example_tile_mlp.py +383 -0
  112. warp/examples/tile/example_tile_nbody.py +199 -0
  113. warp/examples/tile/example_tile_walker.py +327 -0
  114. warp/fabric.py +355 -0
  115. warp/fem/__init__.py +106 -0
  116. warp/fem/adaptivity.py +508 -0
  117. warp/fem/cache.py +572 -0
  118. warp/fem/dirichlet.py +202 -0
  119. warp/fem/domain.py +411 -0
  120. warp/fem/field/__init__.py +125 -0
  121. warp/fem/field/field.py +619 -0
  122. warp/fem/field/nodal_field.py +326 -0
  123. warp/fem/field/restriction.py +37 -0
  124. warp/fem/field/virtual.py +848 -0
  125. warp/fem/geometry/__init__.py +32 -0
  126. warp/fem/geometry/adaptive_nanogrid.py +857 -0
  127. warp/fem/geometry/closest_point.py +84 -0
  128. warp/fem/geometry/deformed_geometry.py +221 -0
  129. warp/fem/geometry/element.py +776 -0
  130. warp/fem/geometry/geometry.py +362 -0
  131. warp/fem/geometry/grid_2d.py +392 -0
  132. warp/fem/geometry/grid_3d.py +452 -0
  133. warp/fem/geometry/hexmesh.py +911 -0
  134. warp/fem/geometry/nanogrid.py +571 -0
  135. warp/fem/geometry/partition.py +389 -0
  136. warp/fem/geometry/quadmesh.py +663 -0
  137. warp/fem/geometry/tetmesh.py +855 -0
  138. warp/fem/geometry/trimesh.py +806 -0
  139. warp/fem/integrate.py +2335 -0
  140. warp/fem/linalg.py +419 -0
  141. warp/fem/operator.py +293 -0
  142. warp/fem/polynomial.py +229 -0
  143. warp/fem/quadrature/__init__.py +17 -0
  144. warp/fem/quadrature/pic_quadrature.py +299 -0
  145. warp/fem/quadrature/quadrature.py +591 -0
  146. warp/fem/space/__init__.py +228 -0
  147. warp/fem/space/basis_function_space.py +468 -0
  148. warp/fem/space/basis_space.py +667 -0
  149. warp/fem/space/dof_mapper.py +251 -0
  150. warp/fem/space/function_space.py +309 -0
  151. warp/fem/space/grid_2d_function_space.py +177 -0
  152. warp/fem/space/grid_3d_function_space.py +227 -0
  153. warp/fem/space/hexmesh_function_space.py +257 -0
  154. warp/fem/space/nanogrid_function_space.py +201 -0
  155. warp/fem/space/partition.py +367 -0
  156. warp/fem/space/quadmesh_function_space.py +223 -0
  157. warp/fem/space/restriction.py +179 -0
  158. warp/fem/space/shape/__init__.py +143 -0
  159. warp/fem/space/shape/cube_shape_function.py +1105 -0
  160. warp/fem/space/shape/shape_function.py +133 -0
  161. warp/fem/space/shape/square_shape_function.py +926 -0
  162. warp/fem/space/shape/tet_shape_function.py +834 -0
  163. warp/fem/space/shape/triangle_shape_function.py +672 -0
  164. warp/fem/space/tetmesh_function_space.py +271 -0
  165. warp/fem/space/topology.py +424 -0
  166. warp/fem/space/trimesh_function_space.py +194 -0
  167. warp/fem/types.py +99 -0
  168. warp/fem/utils.py +420 -0
  169. warp/jax.py +187 -0
  170. warp/jax_experimental/__init__.py +16 -0
  171. warp/jax_experimental/custom_call.py +351 -0
  172. warp/jax_experimental/ffi.py +698 -0
  173. warp/jax_experimental/xla_ffi.py +602 -0
  174. warp/math.py +244 -0
  175. warp/native/array.h +1145 -0
  176. warp/native/builtin.h +1800 -0
  177. warp/native/bvh.cpp +492 -0
  178. warp/native/bvh.cu +791 -0
  179. warp/native/bvh.h +554 -0
  180. warp/native/clang/clang.cpp +536 -0
  181. warp/native/coloring.cpp +613 -0
  182. warp/native/crt.cpp +51 -0
  183. warp/native/crt.h +362 -0
  184. warp/native/cuda_crt.h +1058 -0
  185. warp/native/cuda_util.cpp +646 -0
  186. warp/native/cuda_util.h +307 -0
  187. warp/native/error.cpp +77 -0
  188. warp/native/error.h +36 -0
  189. warp/native/exports.h +1878 -0
  190. warp/native/fabric.h +245 -0
  191. warp/native/hashgrid.cpp +311 -0
  192. warp/native/hashgrid.cu +87 -0
  193. warp/native/hashgrid.h +240 -0
  194. warp/native/initializer_array.h +41 -0
  195. warp/native/intersect.h +1230 -0
  196. warp/native/intersect_adj.h +375 -0
  197. warp/native/intersect_tri.h +339 -0
  198. warp/native/marching.cpp +19 -0
  199. warp/native/marching.cu +514 -0
  200. warp/native/marching.h +19 -0
  201. warp/native/mat.h +2220 -0
  202. warp/native/mathdx.cpp +87 -0
  203. warp/native/matnn.h +343 -0
  204. warp/native/mesh.cpp +266 -0
  205. warp/native/mesh.cu +404 -0
  206. warp/native/mesh.h +1980 -0
  207. warp/native/nanovdb/GridHandle.h +366 -0
  208. warp/native/nanovdb/HostBuffer.h +590 -0
  209. warp/native/nanovdb/NanoVDB.h +6624 -0
  210. warp/native/nanovdb/PNanoVDB.h +3390 -0
  211. warp/native/noise.h +859 -0
  212. warp/native/quat.h +1371 -0
  213. warp/native/rand.h +342 -0
  214. warp/native/range.h +139 -0
  215. warp/native/reduce.cpp +174 -0
  216. warp/native/reduce.cu +364 -0
  217. warp/native/runlength_encode.cpp +79 -0
  218. warp/native/runlength_encode.cu +61 -0
  219. warp/native/scan.cpp +47 -0
  220. warp/native/scan.cu +53 -0
  221. warp/native/scan.h +23 -0
  222. warp/native/solid_angle.h +466 -0
  223. warp/native/sort.cpp +251 -0
  224. warp/native/sort.cu +277 -0
  225. warp/native/sort.h +33 -0
  226. warp/native/sparse.cpp +378 -0
  227. warp/native/sparse.cu +524 -0
  228. warp/native/spatial.h +657 -0
  229. warp/native/svd.h +702 -0
  230. warp/native/temp_buffer.h +46 -0
  231. warp/native/tile.h +2584 -0
  232. warp/native/tile_reduce.h +264 -0
  233. warp/native/vec.h +1426 -0
  234. warp/native/volume.cpp +501 -0
  235. warp/native/volume.cu +67 -0
  236. warp/native/volume.h +969 -0
  237. warp/native/volume_builder.cu +477 -0
  238. warp/native/volume_builder.h +52 -0
  239. warp/native/volume_impl.h +70 -0
  240. warp/native/warp.cpp +1082 -0
  241. warp/native/warp.cu +3636 -0
  242. warp/native/warp.h +381 -0
  243. warp/optim/__init__.py +17 -0
  244. warp/optim/adam.py +163 -0
  245. warp/optim/linear.py +1137 -0
  246. warp/optim/sgd.py +112 -0
  247. warp/paddle.py +407 -0
  248. warp/render/__init__.py +18 -0
  249. warp/render/render_opengl.py +3518 -0
  250. warp/render/render_usd.py +784 -0
  251. warp/render/utils.py +160 -0
  252. warp/sim/__init__.py +65 -0
  253. warp/sim/articulation.py +793 -0
  254. warp/sim/collide.py +2395 -0
  255. warp/sim/graph_coloring.py +300 -0
  256. warp/sim/import_mjcf.py +790 -0
  257. warp/sim/import_snu.py +227 -0
  258. warp/sim/import_urdf.py +579 -0
  259. warp/sim/import_usd.py +894 -0
  260. warp/sim/inertia.py +324 -0
  261. warp/sim/integrator.py +242 -0
  262. warp/sim/integrator_euler.py +1997 -0
  263. warp/sim/integrator_featherstone.py +2101 -0
  264. warp/sim/integrator_vbd.py +2048 -0
  265. warp/sim/integrator_xpbd.py +3292 -0
  266. warp/sim/model.py +4791 -0
  267. warp/sim/particles.py +121 -0
  268. warp/sim/render.py +427 -0
  269. warp/sim/utils.py +428 -0
  270. warp/sparse.py +2057 -0
  271. warp/stubs.py +3333 -0
  272. warp/tape.py +1203 -0
  273. warp/tests/__init__.py +1 -0
  274. warp/tests/__main__.py +4 -0
  275. warp/tests/assets/curlnoise_golden.npy +0 -0
  276. warp/tests/assets/mlp_golden.npy +0 -0
  277. warp/tests/assets/pixel.npy +0 -0
  278. warp/tests/assets/pnoise_golden.npy +0 -0
  279. warp/tests/assets/spiky.usd +0 -0
  280. warp/tests/assets/test_grid.nvdb +0 -0
  281. warp/tests/assets/test_index_grid.nvdb +0 -0
  282. warp/tests/assets/test_int32_grid.nvdb +0 -0
  283. warp/tests/assets/test_vec_grid.nvdb +0 -0
  284. warp/tests/assets/torus.nvdb +0 -0
  285. warp/tests/assets/torus.usda +105 -0
  286. warp/tests/aux_test_class_kernel.py +34 -0
  287. warp/tests/aux_test_compile_consts_dummy.py +18 -0
  288. warp/tests/aux_test_conditional_unequal_types_kernels.py +29 -0
  289. warp/tests/aux_test_dependent.py +29 -0
  290. warp/tests/aux_test_grad_customs.py +29 -0
  291. warp/tests/aux_test_instancing_gc.py +26 -0
  292. warp/tests/aux_test_module_unload.py +23 -0
  293. warp/tests/aux_test_name_clash1.py +40 -0
  294. warp/tests/aux_test_name_clash2.py +40 -0
  295. warp/tests/aux_test_reference.py +9 -0
  296. warp/tests/aux_test_reference_reference.py +8 -0
  297. warp/tests/aux_test_square.py +16 -0
  298. warp/tests/aux_test_unresolved_func.py +22 -0
  299. warp/tests/aux_test_unresolved_symbol.py +22 -0
  300. warp/tests/cuda/__init__.py +0 -0
  301. warp/tests/cuda/test_async.py +676 -0
  302. warp/tests/cuda/test_ipc.py +124 -0
  303. warp/tests/cuda/test_mempool.py +233 -0
  304. warp/tests/cuda/test_multigpu.py +169 -0
  305. warp/tests/cuda/test_peer.py +139 -0
  306. warp/tests/cuda/test_pinned.py +84 -0
  307. warp/tests/cuda/test_streams.py +634 -0
  308. warp/tests/geometry/__init__.py +0 -0
  309. warp/tests/geometry/test_bvh.py +200 -0
  310. warp/tests/geometry/test_hash_grid.py +221 -0
  311. warp/tests/geometry/test_marching_cubes.py +74 -0
  312. warp/tests/geometry/test_mesh.py +316 -0
  313. warp/tests/geometry/test_mesh_query_aabb.py +399 -0
  314. warp/tests/geometry/test_mesh_query_point.py +932 -0
  315. warp/tests/geometry/test_mesh_query_ray.py +311 -0
  316. warp/tests/geometry/test_volume.py +1103 -0
  317. warp/tests/geometry/test_volume_write.py +346 -0
  318. warp/tests/interop/__init__.py +0 -0
  319. warp/tests/interop/test_dlpack.py +729 -0
  320. warp/tests/interop/test_jax.py +371 -0
  321. warp/tests/interop/test_paddle.py +800 -0
  322. warp/tests/interop/test_torch.py +1001 -0
  323. warp/tests/run_coverage_serial.py +39 -0
  324. warp/tests/sim/__init__.py +0 -0
  325. warp/tests/sim/disabled_kinematics.py +244 -0
  326. warp/tests/sim/flaky_test_sim_grad.py +290 -0
  327. warp/tests/sim/test_collision.py +604 -0
  328. warp/tests/sim/test_coloring.py +258 -0
  329. warp/tests/sim/test_model.py +224 -0
  330. warp/tests/sim/test_sim_grad_bounce_linear.py +212 -0
  331. warp/tests/sim/test_sim_kinematics.py +98 -0
  332. warp/tests/sim/test_vbd.py +597 -0
  333. warp/tests/test_adam.py +163 -0
  334. warp/tests/test_arithmetic.py +1096 -0
  335. warp/tests/test_array.py +2972 -0
  336. warp/tests/test_array_reduce.py +156 -0
  337. warp/tests/test_assert.py +250 -0
  338. warp/tests/test_atomic.py +153 -0
  339. warp/tests/test_bool.py +220 -0
  340. warp/tests/test_builtins_resolution.py +1298 -0
  341. warp/tests/test_closest_point_edge_edge.py +327 -0
  342. warp/tests/test_codegen.py +810 -0
  343. warp/tests/test_codegen_instancing.py +1495 -0
  344. warp/tests/test_compile_consts.py +215 -0
  345. warp/tests/test_conditional.py +252 -0
  346. warp/tests/test_context.py +42 -0
  347. warp/tests/test_copy.py +238 -0
  348. warp/tests/test_ctypes.py +638 -0
  349. warp/tests/test_dense.py +73 -0
  350. warp/tests/test_devices.py +97 -0
  351. warp/tests/test_examples.py +482 -0
  352. warp/tests/test_fabricarray.py +996 -0
  353. warp/tests/test_fast_math.py +74 -0
  354. warp/tests/test_fem.py +2003 -0
  355. warp/tests/test_fp16.py +136 -0
  356. warp/tests/test_func.py +454 -0
  357. warp/tests/test_future_annotations.py +98 -0
  358. warp/tests/test_generics.py +656 -0
  359. warp/tests/test_grad.py +893 -0
  360. warp/tests/test_grad_customs.py +339 -0
  361. warp/tests/test_grad_debug.py +341 -0
  362. warp/tests/test_implicit_init.py +411 -0
  363. warp/tests/test_import.py +45 -0
  364. warp/tests/test_indexedarray.py +1140 -0
  365. warp/tests/test_intersect.py +73 -0
  366. warp/tests/test_iter.py +76 -0
  367. warp/tests/test_large.py +177 -0
  368. warp/tests/test_launch.py +411 -0
  369. warp/tests/test_lerp.py +151 -0
  370. warp/tests/test_linear_solvers.py +193 -0
  371. warp/tests/test_lvalue.py +427 -0
  372. warp/tests/test_mat.py +2089 -0
  373. warp/tests/test_mat_lite.py +122 -0
  374. warp/tests/test_mat_scalar_ops.py +2913 -0
  375. warp/tests/test_math.py +178 -0
  376. warp/tests/test_mlp.py +282 -0
  377. warp/tests/test_module_hashing.py +258 -0
  378. warp/tests/test_modules_lite.py +44 -0
  379. warp/tests/test_noise.py +252 -0
  380. warp/tests/test_operators.py +299 -0
  381. warp/tests/test_options.py +129 -0
  382. warp/tests/test_overwrite.py +551 -0
  383. warp/tests/test_print.py +339 -0
  384. warp/tests/test_quat.py +2315 -0
  385. warp/tests/test_rand.py +339 -0
  386. warp/tests/test_reload.py +302 -0
  387. warp/tests/test_rounding.py +185 -0
  388. warp/tests/test_runlength_encode.py +196 -0
  389. warp/tests/test_scalar_ops.py +105 -0
  390. warp/tests/test_smoothstep.py +108 -0
  391. warp/tests/test_snippet.py +318 -0
  392. warp/tests/test_sparse.py +582 -0
  393. warp/tests/test_spatial.py +2229 -0
  394. warp/tests/test_special_values.py +361 -0
  395. warp/tests/test_static.py +592 -0
  396. warp/tests/test_struct.py +734 -0
  397. warp/tests/test_tape.py +204 -0
  398. warp/tests/test_transient_module.py +93 -0
  399. warp/tests/test_triangle_closest_point.py +145 -0
  400. warp/tests/test_types.py +562 -0
  401. warp/tests/test_utils.py +588 -0
  402. warp/tests/test_vec.py +1487 -0
  403. warp/tests/test_vec_lite.py +80 -0
  404. warp/tests/test_vec_scalar_ops.py +2327 -0
  405. warp/tests/test_verify_fp.py +100 -0
  406. warp/tests/tile/__init__.py +0 -0
  407. warp/tests/tile/test_tile.py +780 -0
  408. warp/tests/tile/test_tile_load.py +407 -0
  409. warp/tests/tile/test_tile_mathdx.py +208 -0
  410. warp/tests/tile/test_tile_mlp.py +402 -0
  411. warp/tests/tile/test_tile_reduce.py +447 -0
  412. warp/tests/tile/test_tile_shared_memory.py +247 -0
  413. warp/tests/tile/test_tile_view.py +173 -0
  414. warp/tests/unittest_serial.py +47 -0
  415. warp/tests/unittest_suites.py +427 -0
  416. warp/tests/unittest_utils.py +468 -0
  417. warp/tests/walkthrough_debug.py +93 -0
  418. warp/thirdparty/__init__.py +0 -0
  419. warp/thirdparty/appdirs.py +598 -0
  420. warp/thirdparty/dlpack.py +145 -0
  421. warp/thirdparty/unittest_parallel.py +570 -0
  422. warp/torch.py +391 -0
  423. warp/types.py +5230 -0
  424. warp/utils.py +1137 -0
  425. warp_lang-1.7.0.dist-info/METADATA +516 -0
  426. warp_lang-1.7.0.dist-info/RECORD +429 -0
  427. warp_lang-1.7.0.dist-info/WHEEL +5 -0
  428. warp_lang-1.7.0.dist-info/licenses/LICENSE.md +202 -0
  429. warp_lang-1.7.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,66 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ ###########################################################################
17
+ # Example Tile Convolution
18
+ #
19
+ # Shows how to write a simple convolution kernel using Warp FFT tile
20
+ # primitives.
21
+ #
22
+ ###########################################################################
23
+
24
+ import numpy as np
25
+
26
+ import warp as wp
27
+
28
+ wp.set_module_options({"enable_backward": False})
29
+
30
+ BLOCK_DIM = 64
31
+ TILE_M = 1
32
+ TILE_N = 128
33
+
34
+ scale = wp.vec2d(wp.float64(1 / TILE_N), wp.float64(1 / TILE_N))
35
+
36
+
37
+ @wp.func
38
+ def filter(x: wp.vec2d):
39
+ return wp.cw_mul(x, scale)
40
+
41
+
42
+ @wp.kernel
43
+ def conv_tiled(x: wp.array2d(dtype=wp.vec2d), y: wp.array2d(dtype=wp.vec2d)):
44
+ i, j, _ = wp.tid()
45
+ a = wp.tile_load(x, shape=(TILE_M, TILE_N))
46
+ wp.tile_fft(a)
47
+ b = wp.tile_map(filter, a)
48
+ wp.tile_ifft(b)
49
+ wp.tile_store(y, b)
50
+
51
+
52
+ if __name__ == "__main__":
53
+ wp.set_device("cuda:0")
54
+
55
+ rng = np.random.default_rng(42)
56
+
57
+ x_h = rng.standard_normal((TILE_M, TILE_N, 2), dtype=np.float64)
58
+ y_h = np.zeros_like(x_h)
59
+
60
+ x_wp = wp.array2d(x_h, dtype=wp.vec2d)
61
+ y_wp = wp.array2d(y_h, dtype=wp.vec2d)
62
+
63
+ wp.launch_tiled(conv_tiled, dim=[1, 1], inputs=[x_wp], outputs=[y_wp], block_dim=BLOCK_DIM)
64
+
65
+ # Since filter is 1/N, conv_tiled is a ~no-op
66
+ assert np.allclose(x_h, y_wp.numpy())
@@ -0,0 +1,55 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ ###########################################################################
17
+ # Example Tile FFT
18
+ #
19
+ # Shows how to write a simple FFT kernel using Warp tile primitives.
20
+ #
21
+ ###########################################################################
22
+
23
+ import numpy as np
24
+
25
+ import warp as wp
26
+
27
+ wp.set_module_options({"enable_backward": False})
28
+
29
+ BLOCK_DIM = 8
30
+ TILE_M = 1
31
+ TILE_N = 32
32
+
33
+
34
+ @wp.kernel
35
+ def fft_tiled(x: wp.array2d(dtype=wp.vec2d), y: wp.array2d(dtype=wp.vec2d)):
36
+ i, j, _ = wp.tid()
37
+ a = wp.tile_load(x, shape=(TILE_M, TILE_N))
38
+ wp.tile_fft(a)
39
+ wp.tile_ifft(a)
40
+ wp.tile_store(y, a)
41
+
42
+
43
+ if __name__ == "__main__":
44
+ wp.set_device("cuda:0")
45
+
46
+ x_h = np.ones((TILE_M, TILE_N, 2), dtype=np.float64)
47
+ x_h[:, :, 1] = 0
48
+ y_h = 3 * np.ones((TILE_M, TILE_N, 2), dtype=np.float64)
49
+ x_wp = wp.array2d(x_h, dtype=wp.vec2d)
50
+ y_wp = wp.array2d(y_h, dtype=wp.vec2d)
51
+
52
+ wp.launch_tiled(fft_tiled, dim=[1, 1], inputs=[x_wp], outputs=[y_wp], block_dim=BLOCK_DIM)
53
+
54
+ print("Inputs:\n", x_wp) # [1+0i, 1+0i, 1+0i, ...]
55
+ print("Output:\n", y_wp) # [32+0i, 0, 0, ...]
@@ -0,0 +1,113 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ ###########################################################################
17
+ # Example Tile Filtering
18
+ #
19
+ # Shows how to write a simple filtering kernel using Warp FFT tile
20
+ # primitives.
21
+ #
22
+ ###########################################################################
23
+
24
+ import numpy as np
25
+
26
+ import warp as wp
27
+
28
+ wp.set_module_options({"enable_backward": False})
29
+
30
+ BLOCK_DIM = 128
31
+ TILE_M = 1
32
+ TILE_N = 512
33
+
34
+ scale = wp.vec2d(wp.float64(1 / TILE_N), wp.float64(1 / TILE_N))
35
+
36
+
37
+ def cplx(array):
38
+ return array[..., 0] + 1j * array[..., 1]
39
+
40
+
41
+ @wp.func
42
+ def cplx_prod(x: wp.vec2d, y: wp.vec2d):
43
+ return wp.cw_mul(wp.vec2d(x[0] * y[0] - x[1] * y[1], x[0] * y[1] + x[1] * y[0]), scale)
44
+
45
+
46
+ @wp.kernel
47
+ def conv_tiled(x: wp.array2d(dtype=wp.vec2d), y: wp.array2d(dtype=wp.vec2d), z: wp.array2d(dtype=wp.vec2d)):
48
+ i, j, _ = wp.tid()
49
+ a = wp.tile_load(x, shape=(TILE_M, TILE_N))
50
+ b = wp.tile_load(y, shape=(TILE_M, TILE_N))
51
+ wp.tile_fft(a)
52
+ c = wp.tile_map(cplx_prod, a, b)
53
+ wp.tile_ifft(c)
54
+ wp.tile_store(z, c)
55
+
56
+
57
+ if __name__ == "__main__":
58
+ rng = np.random.default_rng(42)
59
+
60
+ # Create noisy input signal
61
+ t = np.linspace(0, 2 * np.pi, TILE_N, dtype=np.float64)
62
+ x = np.sin(t) + 0.5 * rng.random(TILE_N, dtype=np.float64)
63
+
64
+ # Create filter. This filter keeps only ~10% of the frequencies at the center
65
+ # of the spectrum.
66
+ f = np.ones_like(x)
67
+ freq = np.fft.fftfreq(TILE_N)
68
+ f[np.abs(freq) > 0.05] = 0.0
69
+ f[np.abs(freq) <= 0.05] = 1.0
70
+
71
+ # Create Warp input data
72
+ # We use vec2d to hold complex numbers
73
+ x_h = np.zeros((TILE_M, TILE_N, 2), dtype=np.float64)
74
+ f_h = np.zeros_like(x_h)
75
+ y_h = np.zeros_like(f_h)
76
+
77
+ x_h[:, :, 0] = x
78
+ f_h[:, :, 0] = f
79
+
80
+ x_wp = wp.array2d(x_h, dtype=wp.vec2d)
81
+ f_wp = wp.array2d(f_h, dtype=wp.vec2d)
82
+ y_wp = wp.array2d(y_h, dtype=wp.vec2d)
83
+
84
+ wp.launch_tiled(conv_tiled, dim=[1, 1], inputs=[x_wp, f_wp], outputs=[y_wp], block_dim=BLOCK_DIM)
85
+
86
+ # Extract output and compare with numpy
87
+ x_np = cplx(x_h)
88
+ f_np = cplx(f_h)
89
+ y_test = cplx(y_wp.numpy())
90
+ y_ref = np.fft.ifft(f_np * np.fft.fft(x_np))
91
+ assert np.allclose(y_ref, y_test)
92
+
93
+ try:
94
+ import matplotlib.pyplot as plt
95
+
96
+ fig, ax = plt.subplots(figsize=(10, 5))
97
+
98
+ ax.plot(
99
+ x,
100
+ color="#DDDDDD",
101
+ linewidth=2,
102
+ label="Original",
103
+ )
104
+ ax.plot(y_test[0, :].real, color="#76B900", linewidth=3, label="Smoothed")
105
+
106
+ ax.legend()
107
+ ax.grid(True)
108
+
109
+ plt.tight_layout()
110
+ plt.show()
111
+
112
+ except ModuleNotFoundError:
113
+ print("Matplotlib not available; skipping figure")
@@ -0,0 +1,85 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ ###########################################################################
17
+ # Example Tile MatMul
18
+ #
19
+ # Shows how to write a simple GEMM kernel using Warp tile primitives.
20
+ #
21
+ ###########################################################################
22
+
23
+ import numpy as np
24
+
25
+ import warp as wp
26
+
27
+ # tile size
28
+ TILE_M = wp.constant(8)
29
+ TILE_N = wp.constant(4)
30
+ TILE_K = wp.constant(8)
31
+
32
+ # num threads per-tile
33
+ TILE_THREADS = 64
34
+
35
+
36
+ @wp.kernel
37
+ def tile_gemm(A: wp.array2d(dtype=wp.float32), B: wp.array2d(dtype=wp.float16), C: wp.array2d(dtype=wp.float64)):
38
+ # output tile index
39
+ i, j = wp.tid()
40
+
41
+ sum = wp.tile_zeros(shape=(TILE_M, TILE_N), dtype=wp.float64)
42
+
43
+ _M = A.shape[0]
44
+ _N = B.shape[1]
45
+ K = A.shape[1]
46
+
47
+ count = int(K / TILE_K)
48
+
49
+ for k in range(0, count):
50
+ a = wp.tile_load(A, shape=(TILE_M, TILE_K), offset=(i * TILE_M, k * TILE_K))
51
+ b = wp.tile_load(B, shape=(TILE_K, TILE_N), offset=(k * TILE_K, j * TILE_N))
52
+
53
+ # sum += a*b
54
+ wp.tile_matmul(a, b, sum)
55
+
56
+ wp.tile_store(C, sum, offset=(i * TILE_M, j * TILE_N))
57
+
58
+
59
+ if __name__ == "__main__":
60
+ # generate some tile aligned matrix dimensions
61
+ M = TILE_M * 7
62
+ K = TILE_K * 6
63
+ N = TILE_N * 5
64
+
65
+ rng = np.random.default_rng(42)
66
+ A = rng.random((M, K), dtype=np.float32)
67
+ B = rng.random((K, N), dtype=np.float32).astype(np.float16)
68
+ C = np.zeros((M, N), dtype=np.float64)
69
+
70
+ A_wp = wp.array(A, requires_grad=True)
71
+ B_wp = wp.array(B, requires_grad=True)
72
+ C_wp = wp.array(C, requires_grad=True)
73
+
74
+ with wp.Tape() as tape:
75
+ wp.launch_tiled(
76
+ tile_gemm,
77
+ dim=(M // TILE_M, N // TILE_N),
78
+ inputs=[A_wp, B_wp],
79
+ outputs=[C_wp],
80
+ block_dim=TILE_THREADS,
81
+ )
82
+
83
+ assert np.allclose(C_wp.numpy(), A @ B, atol=1.0e-4)
84
+
85
+ print("Example matrix multiplication passed")
@@ -0,0 +1,383 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ ###########################################################################
17
+ # Example Image Multilayer Perceptron (MLP)
18
+ #
19
+ # Shows how to train a coordinate-based MLP on an image to predict the RGB
20
+ # color at a given input position. By default, a positional encoding is
21
+ # applied to the input coordinates to improve the ability of the MLP to
22
+ # represent higher-frequency content. This can be disabled by passing the
23
+ # '--no_encoding' option.
24
+ #
25
+ # References:
26
+ # Ben Mildenhall et al. 2021. NeRF: representing scenes
27
+ # as neural radiance fields for view synthesis. Commun. ACM 65, 1
28
+ # (January 2022), 99–106. https://doi.org/10.1145/3503250
29
+ #
30
+ ###########################################################################
31
+
32
+ import math
33
+ import os
34
+
35
+ import numpy as np
36
+ from PIL import Image
37
+
38
+ import warp as wp
39
+ import warp.examples
40
+ import warp.optim
41
+
42
+ rng = np.random.default_rng(45)
43
+
44
+
45
+ def create_layer(dim_in, dim_hid, dtype=float):
46
+ w = rng.uniform(-1.0 / np.sqrt(dim_in), 1.0 / np.sqrt(dim_in), (dim_hid, dim_in))
47
+ b = rng.uniform(-1.0 / np.sqrt(dim_in), 1.0 / np.sqrt(dim_in), (dim_hid, 1))
48
+
49
+ weights = wp.array(w, dtype=dtype, requires_grad=True)
50
+ bias = wp.array(b, dtype=dtype, requires_grad=True)
51
+
52
+ return (weights, bias)
53
+
54
+
55
+ def create_array(dim_in, dim_hid, dtype=float):
56
+ s = rng.uniform(-1.0 / np.sqrt(dim_in), 1.0 / np.sqrt(dim_in), (dim_hid, dim_in))
57
+ a = wp.array(s, dtype=dtype, requires_grad=True)
58
+
59
+ return a
60
+
61
+
62
+ # number of frequencies for the positional encoding
63
+ NUM_FREQ = wp.constant(8)
64
+
65
+ DIM_IN = wp.constant(4 * NUM_FREQ) # sin,cos for both x,y at each frequenecy
66
+ DIM_HID = 32
67
+ DIM_OUT = 3
68
+
69
+ # threads per-block
70
+ NUM_THREADS = 32
71
+
72
+ IMG_WIDTH = 512
73
+ IMG_HEIGHT = 512
74
+
75
+ BATCH_SIZE = min(1024, int((IMG_WIDTH * IMG_HEIGHT) / 8))
76
+
77
+ # dtype for our weights and bias matrices
78
+ dtype = wp.float16
79
+
80
+
81
+ @wp.func
82
+ def relu(x: dtype):
83
+ return wp.max(x, dtype(0.0))
84
+
85
+
86
+ @wp.kernel
87
+ def compute(
88
+ indices: wp.array(dtype=int),
89
+ weights_0: wp.array2d(dtype=dtype),
90
+ bias_0: wp.array2d(dtype=dtype),
91
+ weights_1: wp.array2d(dtype=dtype),
92
+ bias_1: wp.array2d(dtype=dtype),
93
+ weights_2: wp.array2d(dtype=dtype),
94
+ bias_2: wp.array2d(dtype=dtype),
95
+ weights_3: wp.array2d(dtype=dtype),
96
+ bias_3: wp.array2d(dtype=dtype),
97
+ reference: wp.array2d(dtype=float),
98
+ loss: wp.array1d(dtype=float),
99
+ out: wp.array2d(dtype=float),
100
+ ):
101
+ # batch indices
102
+ linear = indices[wp.tid()]
103
+
104
+ row = linear / IMG_WIDTH
105
+ col = linear % IMG_WIDTH
106
+
107
+ # normalize input coordinates to [-1, 1]
108
+ x = (float(row) / float(IMG_WIDTH) - 0.5) * 2.0
109
+ y = (float(col) / float(IMG_HEIGHT) - 0.5) * 2.0
110
+
111
+ local = wp.vector(dtype=dtype, length=DIM_IN)
112
+
113
+ # construct positional encoding
114
+ for s in range(NUM_FREQ):
115
+ scale = wp.pow(2.0, float(s)) * wp.pi
116
+
117
+ # x-coord
118
+ local[s * 4 + 0] = dtype(wp.sin(x * scale))
119
+ local[s * 4 + 1] = dtype(wp.cos(x * scale))
120
+ # y-coord
121
+ local[s * 4 + 2] = dtype(wp.sin(y * scale))
122
+ local[s * 4 + 3] = dtype(wp.cos(y * scale))
123
+
124
+ # tile feature vectors across the block, returns [dim(f), NUM_THREADS]
125
+ f = wp.tile(local)
126
+
127
+ # input layer
128
+ w0 = wp.tile_load(weights_0, shape=(DIM_HID, DIM_IN))
129
+ b0 = wp.tile_load(bias_0, shape=(DIM_HID, 1))
130
+ z = wp.tile_map(relu, wp.tile_matmul(w0, f) + wp.tile_broadcast(b0, shape=(DIM_HID, NUM_THREADS)))
131
+
132
+ # hidden layer
133
+ w1 = wp.tile_load(weights_1, shape=(DIM_HID, DIM_HID))
134
+ b1 = wp.tile_load(bias_1, shape=(DIM_HID, 1))
135
+ z = wp.tile_map(relu, wp.tile_matmul(w1, z) + wp.tile_broadcast(b1, shape=(DIM_HID, NUM_THREADS)))
136
+
137
+ w2 = wp.tile_load(weights_2, shape=(DIM_HID, DIM_HID))
138
+ b2 = wp.tile_load(bias_2, shape=(DIM_HID, 1))
139
+ z = wp.tile_map(relu, wp.tile_matmul(w2, z) + wp.tile_broadcast(b2, shape=(DIM_HID, NUM_THREADS)))
140
+
141
+ # output layer
142
+ w3 = wp.tile_load(weights_3, shape=(DIM_OUT, DIM_HID))
143
+ b3 = wp.tile_load(bias_3, shape=(DIM_OUT, 1))
144
+ o = wp.tile_map(relu, wp.tile_matmul(w3, z) + wp.tile_broadcast(b3, shape=(DIM_OUT, NUM_THREADS)))
145
+
146
+ # untile back to SIMT
147
+ output = wp.untile(o)
148
+
149
+ # compute error
150
+ error = wp.vec3(
151
+ float(output[0]) - reference[0, linear],
152
+ float(output[1]) - reference[1, linear],
153
+ float(output[2]) - reference[2, linear],
154
+ )
155
+
156
+ # write MSE loss
157
+ if loss:
158
+ wp.atomic_add(loss, 0, wp.length_sq(error) / float(3 * BATCH_SIZE))
159
+
160
+ # write image output
161
+ if out:
162
+ for i in range(DIM_OUT):
163
+ out[i, linear] = float(output[i])
164
+
165
+
166
+ class Example:
167
+ def __init__(self, train_iters):
168
+ self.weights_0, self.bias_0 = create_layer(DIM_IN, DIM_HID, dtype=dtype)
169
+ self.weights_1, self.bias_1 = create_layer(DIM_HID, DIM_HID, dtype=dtype)
170
+ self.weights_2, self.bias_2 = create_layer(DIM_HID, DIM_HID, dtype=dtype)
171
+ self.weights_3, self.bias_3 = create_layer(DIM_HID, DIM_OUT, dtype=dtype)
172
+
173
+ # reference
174
+ reference_path = os.path.join(wp.examples.get_asset_directory(), "pixel.jpg")
175
+ with Image.open(reference_path) as im:
176
+ reference_image = np.asarray(im.resize((IMG_WIDTH, IMG_HEIGHT)).convert("RGB")) / 255.0
177
+ self.reference = wp.array(reference_image.reshape(IMG_WIDTH * IMG_HEIGHT, 3).T, dtype=float)
178
+
179
+ # create randomized batch indices
180
+ indices = np.arange(0, IMG_WIDTH * IMG_HEIGHT, dtype=np.int32)
181
+ rng.shuffle(indices)
182
+ self.indices = wp.array(indices)
183
+
184
+ self.num_batches = int((IMG_WIDTH * IMG_HEIGHT) / BATCH_SIZE)
185
+ self.max_iters = train_iters
186
+ self.max_epochs = max(1, int(self.max_iters / self.num_batches))
187
+
188
+ def train_warp(self):
189
+ params = [
190
+ self.weights_0,
191
+ self.bias_0,
192
+ self.weights_1,
193
+ self.bias_1,
194
+ self.weights_2,
195
+ self.bias_2,
196
+ self.weights_3,
197
+ self.bias_3,
198
+ ]
199
+
200
+ optimizer_grads = [p.grad.flatten() for p in params]
201
+ optimizer_inputs = [p.flatten() for p in params]
202
+ optimizer = warp.optim.Adam(optimizer_inputs, lr=0.01)
203
+
204
+ loss = wp.zeros(1, dtype=float, requires_grad=True)
205
+ output = create_array(IMG_WIDTH * IMG_HEIGHT, DIM_OUT)
206
+
207
+ # capture graph for whole epoch
208
+ wp.capture_begin()
209
+
210
+ for b in range(0, IMG_WIDTH * IMG_HEIGHT, BATCH_SIZE):
211
+ loss.zero_()
212
+
213
+ with wp.Tape() as tape:
214
+ wp.launch(
215
+ compute,
216
+ dim=[BATCH_SIZE],
217
+ inputs=[
218
+ self.indices[b : b + BATCH_SIZE],
219
+ self.weights_0,
220
+ self.bias_0,
221
+ self.weights_1,
222
+ self.bias_1,
223
+ self.weights_2,
224
+ self.bias_2,
225
+ self.weights_3,
226
+ self.bias_3,
227
+ self.reference,
228
+ loss,
229
+ None,
230
+ ],
231
+ block_dim=NUM_THREADS,
232
+ )
233
+
234
+ tape.backward(loss)
235
+ optimizer.step(optimizer_grads)
236
+ tape.zero()
237
+
238
+ graph = wp.capture_end()
239
+
240
+ with wp.ScopedTimer("Training"):
241
+ for i in range(self.max_epochs):
242
+ with wp.ScopedTimer("Epoch"):
243
+ wp.capture_launch(graph)
244
+ print(f"Epoch: {i} Loss: {loss.numpy()}")
245
+
246
+ # evaluate full image
247
+ wp.launch(
248
+ compute,
249
+ dim=[IMG_WIDTH * IMG_HEIGHT],
250
+ inputs=[
251
+ self.indices,
252
+ self.weights_0,
253
+ self.bias_0,
254
+ self.weights_1,
255
+ self.bias_1,
256
+ self.weights_2,
257
+ self.bias_2,
258
+ self.weights_3,
259
+ self.bias_3,
260
+ self.reference,
261
+ loss,
262
+ output,
263
+ ],
264
+ block_dim=NUM_THREADS,
265
+ )
266
+
267
+ self.save_image("example_tile_mlp.jpg", output.numpy())
268
+
269
+ def train_torch(self):
270
+ import torch as tc
271
+
272
+ weights_0 = tc.nn.Parameter(wp.to_torch(self.weights_0))
273
+ weights_1 = tc.nn.Parameter(wp.to_torch(self.weights_1))
274
+ weights_2 = tc.nn.Parameter(wp.to_torch(self.weights_2))
275
+ weights_3 = tc.nn.Parameter(wp.to_torch(self.weights_3))
276
+
277
+ bias_0 = tc.nn.Parameter(wp.to_torch(self.bias_0))
278
+ bias_1 = tc.nn.Parameter(wp.to_torch(self.bias_1))
279
+ bias_2 = tc.nn.Parameter(wp.to_torch(self.bias_2))
280
+ bias_3 = tc.nn.Parameter(wp.to_torch(self.bias_3))
281
+
282
+ indices = wp.to_torch(self.indices)
283
+ reference = wp.to_torch(self.reference)
284
+
285
+ optimizer = tc.optim.Adam(
286
+ [weights_0, bias_0, weights_1, bias_1, weights_2, bias_2, weights_3, bias_3],
287
+ capturable=True,
288
+ lr=0.0001,
289
+ betas=(0.9, 0.95),
290
+ eps=1.0e-6,
291
+ )
292
+
293
+ # generate frequency space encoding of pixels
294
+ # based on their linear index in the image
295
+ def encode(linear):
296
+ row = (linear // IMG_WIDTH).float()
297
+ col = (linear % IMG_WIDTH).float()
298
+
299
+ x = (row / float(IMG_WIDTH) - 0.5) * 2.0
300
+ y = (col / float(IMG_HEIGHT) - 0.5) * 2.0
301
+
302
+ encoding = tc.zeros((NUM_FREQ * 4, len(linear)), dtype=tc.float16, device="cuda")
303
+
304
+ for s in range(NUM_FREQ):
305
+ scale = math.pow(2.0, float(s)) * math.pi
306
+
307
+ # Directly write the computed values into the encoding tensor
308
+ encoding[s * 4 + 0, :] = tc.sin(scale * x)
309
+ encoding[s * 4 + 1, :] = tc.cos(scale * x)
310
+ encoding[s * 4 + 2, :] = tc.sin(scale * y)
311
+ encoding[s * 4 + 3, :] = tc.cos(scale * y)
312
+
313
+ return encoding
314
+
315
+ stream = tc.cuda.Stream()
316
+ graph = tc.cuda.CUDAGraph()
317
+
318
+ # warm-up
319
+ with tc.cuda.stream(stream):
320
+ f = tc.rand((NUM_FREQ * 4, BATCH_SIZE), dtype=tc.float16, device="cuda")
321
+ z = tc.relu(weights_0 @ f + bias_0)
322
+ z = tc.relu(weights_1 @ z + bias_1)
323
+ z = tc.relu(weights_2 @ z + bias_2)
324
+ z = tc.relu(weights_3 @ z + bias_3)
325
+ ref = tc.rand((3, BATCH_SIZE), dtype=tc.float16, device="cuda")
326
+ loss = tc.mean((z - ref) ** 2)
327
+ optimizer.zero_grad()
328
+ loss.backward()
329
+ optimizer.step()
330
+
331
+ with tc.cuda.graph(graph):
332
+ for b in range(0, IMG_WIDTH * IMG_HEIGHT, BATCH_SIZE):
333
+ linear = indices[b : b + BATCH_SIZE]
334
+
335
+ f = encode(linear)
336
+
337
+ z = tc.relu(weights_0 @ f + bias_0)
338
+ z = tc.relu(weights_1 @ z + bias_1)
339
+ z = tc.relu(weights_2 @ z + bias_2)
340
+ z = tc.relu(weights_3 @ z + bias_3)
341
+
342
+ ref = reference[:, linear]
343
+ loss = tc.mean((z - ref) ** 2)
344
+
345
+ optimizer.zero_grad()
346
+ loss.backward()
347
+ optimizer.step()
348
+
349
+ with wp.ScopedTimer("Training (Torch)"):
350
+ for _i in range(self.max_epochs):
351
+ with wp.ScopedTimer("Epoch"):
352
+ graph.replay()
353
+
354
+ print(loss)
355
+
356
+ f = encode(tc.arange(0, IMG_WIDTH * IMG_HEIGHT))
357
+ z = tc.relu(weights_0 @ f + bias_0)
358
+ z = tc.relu(weights_1 @ z + bias_1)
359
+ z = tc.relu(weights_2 @ z + bias_2)
360
+ z = tc.relu(weights_3 @ z + bias_3)
361
+
362
+ self.save_image("example_tile_mlp_torch.jpg", z.detach().cpu().numpy())
363
+
364
+ def save_image(self, name, output):
365
+ predicted_image = output.T.reshape(IMG_WIDTH, IMG_HEIGHT, 3)
366
+ predicted_image = (predicted_image * 255).astype(np.uint8)
367
+
368
+ predicted_image_pil = Image.fromarray(predicted_image)
369
+ predicted_image_pil.save(name)
370
+
371
+
372
+ if __name__ == "__main__":
373
+ import argparse
374
+
375
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
376
+ parser.add_argument("--train_iters", type=int, default=20000, help="Total number of training iterations.")
377
+
378
+ args = parser.parse_known_args()[0]
379
+
380
+ with wp.ScopedDevice("cuda:0"):
381
+ example = Example(args.train_iters)
382
+ example.train_warp()
383
+ # example.train_torch()