warp-lang 1.6.0__py3-none-manylinux2014_x86_64.whl → 1.6.2__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (385) hide show
  1. warp/__init__.py +14 -6
  2. warp/autograd.py +14 -6
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +14 -6
  6. warp/build_dll.py +14 -6
  7. warp/builtins.py +16 -7
  8. warp/codegen.py +24 -9
  9. warp/config.py +79 -27
  10. warp/constants.py +14 -6
  11. warp/context.py +236 -71
  12. warp/dlpack.py +14 -6
  13. warp/examples/__init__.py +14 -6
  14. warp/examples/benchmarks/benchmark_api.py +14 -6
  15. warp/examples/benchmarks/benchmark_cloth.py +14 -6
  16. warp/examples/benchmarks/benchmark_cloth_cupy.py +14 -6
  17. warp/examples/benchmarks/benchmark_cloth_jax.py +14 -6
  18. warp/examples/benchmarks/benchmark_cloth_numba.py +15 -0
  19. warp/examples/benchmarks/benchmark_cloth_numpy.py +14 -6
  20. warp/examples/benchmarks/benchmark_cloth_paddle.py +14 -6
  21. warp/examples/benchmarks/benchmark_cloth_pytorch.py +14 -6
  22. warp/examples/benchmarks/benchmark_cloth_taichi.py +14 -6
  23. warp/examples/benchmarks/benchmark_cloth_warp.py +14 -6
  24. warp/examples/benchmarks/benchmark_gemm.py +82 -48
  25. warp/examples/benchmarks/benchmark_interop_paddle.py +14 -6
  26. warp/examples/benchmarks/benchmark_interop_torch.py +14 -6
  27. warp/examples/benchmarks/benchmark_launches.py +14 -6
  28. warp/examples/browse.py +14 -6
  29. warp/examples/core/example_cupy.py +14 -6
  30. warp/examples/core/example_dem.py +14 -6
  31. warp/examples/core/example_fluid.py +14 -6
  32. warp/examples/core/example_graph_capture.py +14 -6
  33. warp/examples/core/example_marching_cubes.py +15 -7
  34. warp/examples/core/example_mesh.py +15 -7
  35. warp/examples/core/example_mesh_intersect.py +14 -6
  36. warp/examples/core/example_nvdb.py +14 -6
  37. warp/examples/core/example_raycast.py +14 -6
  38. warp/examples/core/example_raymarch.py +14 -6
  39. warp/examples/core/example_render_opengl.py +14 -6
  40. warp/examples/core/example_sph.py +14 -6
  41. warp/examples/core/example_torch.py +14 -6
  42. warp/examples/core/example_wave.py +15 -7
  43. warp/examples/fem/example_adaptive_grid.py +14 -6
  44. warp/examples/fem/example_apic_fluid.py +14 -6
  45. warp/examples/fem/example_burgers.py +14 -6
  46. warp/examples/fem/example_convection_diffusion.py +14 -6
  47. warp/examples/fem/example_convection_diffusion_dg.py +14 -6
  48. warp/examples/fem/example_deformed_geometry.py +14 -6
  49. warp/examples/fem/example_diffusion.py +14 -6
  50. warp/examples/fem/example_diffusion_3d.py +14 -6
  51. warp/examples/fem/example_diffusion_mgpu.py +14 -6
  52. warp/examples/fem/example_distortion_energy.py +14 -6
  53. warp/examples/fem/example_magnetostatics.py +14 -6
  54. warp/examples/fem/example_mixed_elasticity.py +14 -6
  55. warp/examples/fem/example_navier_stokes.py +14 -6
  56. warp/examples/fem/example_nonconforming_contact.py +14 -6
  57. warp/examples/fem/example_stokes.py +14 -6
  58. warp/examples/fem/example_stokes_transfer.py +14 -6
  59. warp/examples/fem/example_streamlines.py +14 -6
  60. warp/examples/fem/utils.py +15 -0
  61. warp/examples/optim/example_bounce.py +14 -6
  62. warp/examples/optim/example_cloth_throw.py +14 -6
  63. warp/examples/optim/example_diffray.py +14 -6
  64. warp/examples/optim/example_drone.py +14 -6
  65. warp/examples/optim/example_inverse_kinematics.py +14 -6
  66. warp/examples/optim/example_inverse_kinematics_torch.py +14 -6
  67. warp/examples/optim/example_softbody_properties.py +14 -6
  68. warp/examples/optim/example_spring_cage.py +14 -6
  69. warp/examples/optim/example_trajectory.py +14 -6
  70. warp/examples/optim/example_walker.py +14 -6
  71. warp/examples/sim/example_cartpole.py +14 -6
  72. warp/examples/sim/example_cloth.py +14 -6
  73. warp/examples/sim/example_cloth_self_contact.py +95 -33
  74. warp/examples/sim/example_granular.py +14 -6
  75. warp/examples/sim/example_granular_collision_sdf.py +14 -6
  76. warp/examples/sim/example_jacobian_ik.py +14 -6
  77. warp/examples/sim/example_particle_chain.py +14 -6
  78. warp/examples/sim/example_quadruped.py +14 -6
  79. warp/examples/sim/example_rigid_chain.py +14 -6
  80. warp/examples/sim/example_rigid_contact.py +14 -6
  81. warp/examples/sim/example_rigid_force.py +14 -6
  82. warp/examples/sim/example_rigid_gyroscopic.py +14 -6
  83. warp/examples/sim/example_rigid_soft_contact.py +14 -6
  84. warp/examples/sim/example_soft_body.py +14 -6
  85. warp/examples/tile/example_tile_cholesky.py +14 -6
  86. warp/examples/tile/example_tile_convolution.py +14 -6
  87. warp/examples/tile/example_tile_fft.py +14 -6
  88. warp/examples/tile/example_tile_filtering.py +14 -6
  89. warp/examples/tile/example_tile_matmul.py +14 -6
  90. warp/examples/tile/example_tile_mlp.py +14 -6
  91. warp/examples/tile/example_tile_nbody.py +40 -21
  92. warp/examples/tile/example_tile_walker.py +14 -6
  93. warp/fabric.py +15 -0
  94. warp/fem/__init__.py +15 -0
  95. warp/fem/adaptivity.py +15 -0
  96. warp/fem/cache.py +15 -0
  97. warp/fem/dirichlet.py +15 -0
  98. warp/fem/domain.py +15 -0
  99. warp/fem/field/__init__.py +15 -0
  100. warp/fem/field/field.py +15 -0
  101. warp/fem/field/nodal_field.py +15 -0
  102. warp/fem/field/restriction.py +15 -0
  103. warp/fem/field/virtual.py +15 -0
  104. warp/fem/geometry/__init__.py +15 -0
  105. warp/fem/geometry/adaptive_nanogrid.py +15 -0
  106. warp/fem/geometry/closest_point.py +15 -0
  107. warp/fem/geometry/deformed_geometry.py +15 -0
  108. warp/fem/geometry/element.py +15 -0
  109. warp/fem/geometry/geometry.py +15 -0
  110. warp/fem/geometry/grid_2d.py +15 -0
  111. warp/fem/geometry/grid_3d.py +15 -0
  112. warp/fem/geometry/hexmesh.py +15 -0
  113. warp/fem/geometry/nanogrid.py +15 -0
  114. warp/fem/geometry/partition.py +15 -0
  115. warp/fem/geometry/quadmesh.py +15 -0
  116. warp/fem/geometry/tetmesh.py +15 -0
  117. warp/fem/geometry/trimesh.py +15 -0
  118. warp/fem/integrate.py +15 -0
  119. warp/fem/linalg.py +15 -0
  120. warp/fem/operator.py +15 -0
  121. warp/fem/polynomial.py +15 -0
  122. warp/fem/quadrature/__init__.py +15 -0
  123. warp/fem/quadrature/pic_quadrature.py +15 -0
  124. warp/fem/quadrature/quadrature.py +15 -0
  125. warp/fem/space/__init__.py +15 -0
  126. warp/fem/space/basis_function_space.py +15 -0
  127. warp/fem/space/basis_space.py +15 -0
  128. warp/fem/space/dof_mapper.py +15 -0
  129. warp/fem/space/function_space.py +15 -0
  130. warp/fem/space/grid_2d_function_space.py +15 -0
  131. warp/fem/space/grid_3d_function_space.py +15 -0
  132. warp/fem/space/hexmesh_function_space.py +15 -0
  133. warp/fem/space/nanogrid_function_space.py +15 -0
  134. warp/fem/space/partition.py +15 -0
  135. warp/fem/space/quadmesh_function_space.py +15 -0
  136. warp/fem/space/restriction.py +15 -0
  137. warp/fem/space/shape/__init__.py +15 -0
  138. warp/fem/space/shape/cube_shape_function.py +15 -0
  139. warp/fem/space/shape/shape_function.py +15 -0
  140. warp/fem/space/shape/square_shape_function.py +15 -0
  141. warp/fem/space/shape/tet_shape_function.py +15 -0
  142. warp/fem/space/shape/triangle_shape_function.py +15 -0
  143. warp/fem/space/tetmesh_function_space.py +15 -0
  144. warp/fem/space/topology.py +15 -0
  145. warp/fem/space/trimesh_function_space.py +15 -0
  146. warp/fem/types.py +15 -0
  147. warp/fem/utils.py +15 -0
  148. warp/jax.py +14 -6
  149. warp/jax_experimental.py +14 -6
  150. warp/math.py +14 -6
  151. warp/native/array.h +15 -6
  152. warp/native/builtin.h +15 -6
  153. warp/native/bvh.cpp +15 -6
  154. warp/native/bvh.cu +15 -6
  155. warp/native/bvh.h +15 -6
  156. warp/native/clang/clang.cpp +16 -7
  157. warp/native/coloring.cpp +15 -6
  158. warp/native/crt.cpp +15 -6
  159. warp/native/crt.h +16 -6
  160. warp/native/cuda_crt.h +15 -6
  161. warp/native/cuda_util.cpp +15 -6
  162. warp/native/cuda_util.h +15 -6
  163. warp/native/cutlass_gemm.cpp +15 -6
  164. warp/native/cutlass_gemm.cu +16 -7
  165. warp/native/error.cpp +15 -6
  166. warp/native/error.h +15 -6
  167. warp/native/exports.h +17 -0
  168. warp/native/fabric.h +15 -6
  169. warp/native/hashgrid.cpp +15 -6
  170. warp/native/hashgrid.cu +15 -6
  171. warp/native/hashgrid.h +15 -6
  172. warp/native/initializer_array.h +15 -6
  173. warp/native/intersect.h +15 -6
  174. warp/native/intersect_adj.h +15 -6
  175. warp/native/intersect_tri.h +17 -0
  176. warp/native/marching.cpp +16 -0
  177. warp/native/marching.cu +15 -6
  178. warp/native/marching.h +17 -0
  179. warp/native/mat.h +31 -9
  180. warp/native/mathdx.cpp +15 -6
  181. warp/native/matnn.h +15 -6
  182. warp/native/mesh.cpp +15 -6
  183. warp/native/mesh.cu +15 -6
  184. warp/native/mesh.h +15 -6
  185. warp/native/noise.h +15 -6
  186. warp/native/quat.h +15 -6
  187. warp/native/rand.h +15 -6
  188. warp/native/range.h +15 -6
  189. warp/native/reduce.cpp +15 -6
  190. warp/native/reduce.cu +15 -6
  191. warp/native/runlength_encode.cpp +15 -6
  192. warp/native/runlength_encode.cu +15 -6
  193. warp/native/scan.cpp +15 -6
  194. warp/native/scan.cu +15 -6
  195. warp/native/scan.h +15 -6
  196. warp/native/solid_angle.h +17 -0
  197. warp/native/sort.cpp +15 -6
  198. warp/native/sort.cu +15 -6
  199. warp/native/sort.h +15 -6
  200. warp/native/sparse.cpp +15 -6
  201. warp/native/sparse.cu +15 -6
  202. warp/native/spatial.h +15 -6
  203. warp/native/svd.h +15 -6
  204. warp/native/temp_buffer.h +15 -6
  205. warp/native/tile.h +27 -14
  206. warp/native/tile_reduce.h +15 -6
  207. warp/native/vec.h +15 -6
  208. warp/native/volume.cpp +15 -6
  209. warp/native/volume.cu +15 -6
  210. warp/native/volume.h +15 -6
  211. warp/native/volume_builder.cu +15 -6
  212. warp/native/volume_builder.h +15 -6
  213. warp/native/volume_impl.h +15 -6
  214. warp/native/warp.cpp +15 -6
  215. warp/native/warp.cu +15 -6
  216. warp/native/warp.h +15 -6
  217. warp/optim/__init__.py +14 -6
  218. warp/optim/adam.py +14 -6
  219. warp/optim/linear.py +15 -0
  220. warp/optim/sgd.py +14 -6
  221. warp/paddle.py +14 -6
  222. warp/render/__init__.py +14 -6
  223. warp/render/render_opengl.py +37 -21
  224. warp/render/render_usd.py +24 -8
  225. warp/render/utils.py +14 -6
  226. warp/sim/__init__.py +14 -7
  227. warp/sim/articulation.py +14 -6
  228. warp/sim/collide.py +43 -22
  229. warp/sim/graph_coloring.py +14 -6
  230. warp/sim/import_mjcf.py +14 -7
  231. warp/sim/import_snu.py +14 -7
  232. warp/sim/import_urdf.py +34 -11
  233. warp/sim/import_usd.py +14 -7
  234. warp/sim/inertia.py +14 -6
  235. warp/sim/integrator.py +14 -6
  236. warp/sim/integrator_euler.py +14 -6
  237. warp/sim/integrator_featherstone.py +18 -17
  238. warp/sim/integrator_vbd.py +15 -6
  239. warp/sim/integrator_xpbd.py +14 -6
  240. warp/sim/model.py +76 -65
  241. warp/sim/particles.py +14 -6
  242. warp/sim/render.py +16 -8
  243. warp/sim/utils.py +15 -0
  244. warp/sparse.py +15 -0
  245. warp/stubs.py +16 -1
  246. warp/tape.py +14 -6
  247. warp/tests/__main__.py +15 -0
  248. warp/tests/aux_test_class_kernel.py +14 -6
  249. warp/tests/aux_test_compile_consts_dummy.py +14 -6
  250. warp/tests/aux_test_conditional_unequal_types_kernels.py +14 -6
  251. warp/tests/aux_test_dependent.py +14 -6
  252. warp/tests/aux_test_grad_customs.py +14 -6
  253. warp/tests/aux_test_instancing_gc.py +14 -6
  254. warp/tests/aux_test_module_unload.py +14 -6
  255. warp/tests/aux_test_name_clash1.py +14 -6
  256. warp/tests/aux_test_name_clash2.py +14 -6
  257. warp/tests/aux_test_unresolved_func.py +14 -6
  258. warp/tests/aux_test_unresolved_symbol.py +14 -6
  259. warp/tests/disabled_kinematics.py +14 -6
  260. warp/tests/flaky_test_sim_grad.py +14 -6
  261. warp/tests/run_coverage_serial.py +14 -6
  262. warp/tests/test_adam.py +14 -6
  263. warp/tests/test_arithmetic.py +14 -6
  264. warp/tests/test_array.py +40 -6
  265. warp/tests/test_array_reduce.py +14 -6
  266. warp/tests/test_assert.py +14 -6
  267. warp/tests/test_async.py +14 -6
  268. warp/tests/test_atomic.py +14 -6
  269. warp/tests/test_bool.py +14 -6
  270. warp/tests/test_builtins_resolution.py +14 -6
  271. warp/tests/test_bvh.py +14 -6
  272. warp/tests/test_closest_point_edge_edge.py +14 -6
  273. warp/tests/test_codegen.py +14 -6
  274. warp/tests/test_codegen_instancing.py +14 -6
  275. warp/tests/test_collision.py +20 -12
  276. warp/tests/test_coloring.py +14 -7
  277. warp/tests/test_compile_consts.py +14 -6
  278. warp/tests/test_conditional.py +14 -6
  279. warp/tests/test_context.py +14 -6
  280. warp/tests/test_copy.py +14 -6
  281. warp/tests/test_ctypes.py +14 -6
  282. warp/tests/test_dense.py +14 -6
  283. warp/tests/test_devices.py +14 -6
  284. warp/tests/test_dlpack.py +14 -6
  285. warp/tests/test_examples.py +21 -7
  286. warp/tests/test_fabricarray.py +14 -6
  287. warp/tests/test_fast_math.py +14 -6
  288. warp/tests/test_fem.py +14 -6
  289. warp/tests/test_fp16.py +14 -6
  290. warp/tests/test_func.py +14 -6
  291. warp/tests/test_future_annotations.py +14 -6
  292. warp/tests/test_generics.py +14 -6
  293. warp/tests/test_grad.py +14 -6
  294. warp/tests/test_grad_customs.py +14 -6
  295. warp/tests/test_grad_debug.py +14 -6
  296. warp/tests/test_hash_grid.py +14 -6
  297. warp/tests/test_implicit_init.py +14 -6
  298. warp/tests/test_import.py +14 -6
  299. warp/tests/test_indexedarray.py +14 -6
  300. warp/tests/test_intersect.py +14 -6
  301. warp/tests/test_ipc.py +14 -6
  302. warp/tests/test_iter.py +14 -6
  303. warp/tests/test_jax.py +14 -6
  304. warp/tests/test_large.py +14 -6
  305. warp/tests/test_launch.py +91 -32
  306. warp/tests/test_lerp.py +14 -6
  307. warp/tests/test_linear_solvers.py +15 -0
  308. warp/tests/test_lvalue.py +14 -6
  309. warp/tests/test_marching_cubes.py +14 -6
  310. warp/tests/test_mat.py +89 -7
  311. warp/tests/test_mat_lite.py +14 -6
  312. warp/tests/test_mat_scalar_ops.py +14 -6
  313. warp/tests/test_math.py +14 -6
  314. warp/tests/test_matmul.py +14 -6
  315. warp/tests/test_matmul_lite.py +14 -6
  316. warp/tests/test_mempool.py +14 -6
  317. warp/tests/test_mesh.py +14 -6
  318. warp/tests/test_mesh_query_aabb.py +14 -6
  319. warp/tests/test_mesh_query_point.py +14 -6
  320. warp/tests/test_mesh_query_ray.py +14 -6
  321. warp/tests/test_mlp.py +14 -6
  322. warp/tests/test_model.py +14 -6
  323. warp/tests/test_module_hashing.py +14 -6
  324. warp/tests/test_modules_lite.py +14 -6
  325. warp/tests/test_multigpu.py +14 -6
  326. warp/tests/test_noise.py +14 -6
  327. warp/tests/test_operators.py +14 -6
  328. warp/tests/test_options.py +14 -6
  329. warp/tests/test_overwrite.py +19 -3
  330. warp/tests/test_paddle.py +14 -6
  331. warp/tests/test_peer.py +14 -6
  332. warp/tests/test_pinned.py +14 -6
  333. warp/tests/test_print.py +14 -6
  334. warp/tests/test_quat.py +14 -6
  335. warp/tests/test_rand.py +14 -6
  336. warp/tests/test_reload.py +14 -6
  337. warp/tests/test_rounding.py +14 -6
  338. warp/tests/test_runlength_encode.py +14 -6
  339. warp/tests/test_scalar_ops.py +14 -6
  340. warp/tests/test_sim_grad_bounce_linear.py +14 -6
  341. warp/tests/test_sim_kinematics.py +14 -6
  342. warp/tests/test_smoothstep.py +14 -6
  343. warp/tests/test_snippet.py +15 -0
  344. warp/tests/test_sparse.py +14 -6
  345. warp/tests/test_spatial.py +14 -6
  346. warp/tests/test_special_values.py +14 -6
  347. warp/tests/test_static.py +14 -6
  348. warp/tests/test_streams.py +14 -6
  349. warp/tests/test_struct.py +14 -6
  350. warp/tests/test_tape.py +14 -6
  351. warp/tests/test_tile.py +14 -6
  352. warp/tests/test_tile_load.py +58 -7
  353. warp/tests/test_tile_mathdx.py +14 -6
  354. warp/tests/test_tile_mlp.py +14 -6
  355. warp/tests/test_tile_reduce.py +14 -6
  356. warp/tests/test_tile_shared_memory.py +14 -6
  357. warp/tests/test_tile_view.py +14 -6
  358. warp/tests/test_torch.py +14 -6
  359. warp/tests/test_transient_module.py +14 -6
  360. warp/tests/test_triangle_closest_point.py +14 -6
  361. warp/tests/test_types.py +14 -6
  362. warp/tests/test_utils.py +14 -6
  363. warp/tests/test_vbd.py +14 -6
  364. warp/tests/test_vec.py +14 -6
  365. warp/tests/test_vec_lite.py +14 -6
  366. warp/tests/test_vec_scalar_ops.py +14 -6
  367. warp/tests/test_verify_fp.py +14 -6
  368. warp/tests/test_volume.py +14 -6
  369. warp/tests/test_volume_write.py +14 -6
  370. warp/tests/unittest_serial.py +14 -6
  371. warp/tests/unittest_suites.py +14 -6
  372. warp/tests/unittest_utils.py +14 -6
  373. warp/tests/unused_test_misc.py +14 -6
  374. warp/tests/walkthrough_debug.py +14 -6
  375. warp/thirdparty/unittest_parallel.py +15 -7
  376. warp/torch.py +14 -6
  377. warp/types.py +80 -74
  378. warp/utils.py +14 -6
  379. warp_lang-1.6.2.dist-info/LICENSE.md +202 -0
  380. {warp_lang-1.6.0.dist-info → warp_lang-1.6.2.dist-info}/METADATA +44 -22
  381. warp_lang-1.6.2.dist-info/RECORD +419 -0
  382. {warp_lang-1.6.0.dist-info → warp_lang-1.6.2.dist-info}/WHEEL +1 -1
  383. warp_lang-1.6.0.dist-info/LICENSE.md +0 -126
  384. warp_lang-1.6.0.dist-info/RECORD +0 -419
  385. {warp_lang-1.6.0.dist-info → warp_lang-1.6.2.dist-info}/top_level.txt +0 -0
warp/context.py CHANGED
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  from __future__ import annotations
9
17
 
@@ -34,6 +42,7 @@ import warp
34
42
  import warp.build
35
43
  import warp.codegen
36
44
  import warp.config
45
+ from warp.types import launch_bounds_t
37
46
 
38
47
  # represents either a built-in or user-defined function
39
48
 
@@ -5187,8 +5196,23 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
5187
5196
  # represents all data required for a kernel launch
5188
5197
  # so that launches can be replayed quickly, use `wp.launch(..., record_cmd=True)`
5189
5198
  class Launch:
5199
+ """Represents all data required for a kernel launch so that launches can be replayed quickly.
5200
+
5201
+ Users should not directly instantiate this class, instead use
5202
+ ``wp.launch(..., record_cmd=True)`` to record a launch.
5203
+ """
5204
+
5190
5205
  def __init__(
5191
- self, kernel, device, hooks=None, params=None, params_addr=None, bounds=None, max_blocks=0, block_dim=256
5206
+ self,
5207
+ kernel,
5208
+ device: Device,
5209
+ hooks: Optional[KernelHooks] = None,
5210
+ params: Optional[Sequence[Any]] = None,
5211
+ params_addr: Optional[Sequence[ctypes.c_void_p]] = None,
5212
+ bounds: Optional[launch_bounds_t] = None,
5213
+ max_blocks: int = 0,
5214
+ block_dim: int = 256,
5215
+ adjoint: bool = False,
5192
5216
  ):
5193
5217
  # retain the module executable so it doesn't get unloaded
5194
5218
  self.module_exec = kernel.module.load(device)
@@ -5201,13 +5225,14 @@ class Launch:
5201
5225
 
5202
5226
  # if not specified set a zero bound
5203
5227
  if not bounds:
5204
- bounds = warp.types.launch_bounds_t(0)
5228
+ bounds = launch_bounds_t(0)
5205
5229
 
5206
5230
  # if not specified then build a list of default value params for args
5207
5231
  if not params:
5208
5232
  params = []
5209
5233
  params.append(bounds)
5210
5234
 
5235
+ # Pack forward parameters
5211
5236
  for a in kernel.adj.args:
5212
5237
  if isinstance(a.type, warp.types.array):
5213
5238
  params.append(a.type.__ctype__())
@@ -5216,6 +5241,18 @@ class Launch:
5216
5241
  else:
5217
5242
  params.append(pack_arg(kernel, a.type, a.label, 0, device, False))
5218
5243
 
5244
+ # Pack adjoint parameters if adjoint=True
5245
+ if adjoint:
5246
+ for a in kernel.adj.args:
5247
+ if isinstance(a.type, warp.types.array):
5248
+ params.append(a.type.__ctype__())
5249
+ elif isinstance(a.type, warp.codegen.Struct):
5250
+ params.append(a.type().__ctype__())
5251
+ else:
5252
+ # For primitive types in adjoint mode, initialize with 0
5253
+ params.append(pack_arg(kernel, a.type, a.label, 0, device, True))
5254
+
5255
+ # Create array of parameter addresses
5219
5256
  kernel_args = [ctypes.c_void_p(ctypes.addressof(x)) for x in params]
5220
5257
  kernel_params = (ctypes.c_void_p * len(kernel_args))(*kernel_args)
5221
5258
 
@@ -5225,13 +5262,30 @@ class Launch:
5225
5262
  self.hooks = hooks
5226
5263
  self.params = params
5227
5264
  self.params_addr = params_addr
5228
- self.device = device
5229
- self.bounds = bounds
5230
- self.max_blocks = max_blocks
5231
- self.block_dim = block_dim
5265
+ self.device: Device = device
5266
+ """The device to launch on.
5267
+ This should not be changed after the launch object is created.
5268
+ """
5269
+
5270
+ self.bounds: launch_bounds_t = bounds
5271
+ """The launch bounds. Update with :meth:`set_dim`."""
5272
+
5273
+ self.max_blocks: int = max_blocks
5274
+ """The maximum number of CUDA thread blocks to use."""
5275
+
5276
+ self.block_dim: int = block_dim
5277
+ """The number of threads per block."""
5232
5278
 
5233
- def set_dim(self, dim):
5234
- self.bounds = warp.types.launch_bounds_t(dim)
5279
+ self.adjoint: bool = adjoint
5280
+ """Whether to run the adjoint kernel instead of the forward kernel."""
5281
+
5282
+ def set_dim(self, dim: Union[int, List[int], Tuple[int, ...]]):
5283
+ """Set the launch dimensions.
5284
+
5285
+ Args:
5286
+ dim: The dimensions of the launch.
5287
+ """
5288
+ self.bounds = launch_bounds_t(dim)
5235
5289
 
5236
5290
  # launch bounds always at index 0
5237
5291
  self.params[0] = self.bounds
@@ -5240,22 +5294,36 @@ class Launch:
5240
5294
  if self.params_addr:
5241
5295
  self.params_addr[0] = ctypes.c_void_p(ctypes.addressof(self.bounds))
5242
5296
 
5243
- # set kernel param at an index, will convert to ctype as necessary
5244
- def set_param_at_index(self, index, value):
5297
+ def set_param_at_index(self, index: int, value: Any, adjoint: bool = False):
5298
+ """Set a kernel parameter at an index.
5299
+
5300
+ Args:
5301
+ index: The index of the param to set.
5302
+ value: The value to set the param to.
5303
+ """
5245
5304
  arg_type = self.kernel.adj.args[index].type
5246
5305
  arg_name = self.kernel.adj.args[index].label
5247
5306
 
5248
- carg = pack_arg(self.kernel, arg_type, arg_name, value, self.device, False)
5307
+ carg = pack_arg(self.kernel, arg_type, arg_name, value, self.device, adjoint)
5308
+
5309
+ if adjoint:
5310
+ params_index = index + len(self.kernel.adj.args) + 1
5311
+ else:
5312
+ params_index = index + 1
5249
5313
 
5250
- self.params[index + 1] = carg
5314
+ self.params[params_index] = carg
5251
5315
 
5252
5316
  # for CUDA kernels we need to update the address to each arg
5253
5317
  if self.params_addr:
5254
- self.params_addr[index + 1] = ctypes.c_void_p(ctypes.addressof(carg))
5318
+ self.params_addr[params_index] = ctypes.c_void_p(ctypes.addressof(carg))
5255
5319
 
5256
- # set kernel param at an index without any type conversion
5257
- # args must be passed as ctypes or basic int / float types
5258
- def set_param_at_index_from_ctype(self, index, value):
5320
+ def set_param_at_index_from_ctype(self, index: int, value: Union[ctypes.Structure, int, float]):
5321
+ """Set a kernel parameter at an index without any type conversion.
5322
+
5323
+ Args:
5324
+ index: The index of the param to set.
5325
+ value: The value to set the param to.
5326
+ """
5259
5327
  if isinstance(value, ctypes.Structure):
5260
5328
  # not sure how to directly assign struct->struct without reallocating using ctypes
5261
5329
  self.params[index + 1] = value
@@ -5267,32 +5335,62 @@ class Launch:
5267
5335
  else:
5268
5336
  self.params[index + 1].__init__(value)
5269
5337
 
5270
- # set kernel param by argument name
5271
- def set_param_by_name(self, name, value):
5338
+ def set_param_by_name(self, name: str, value: Any, adjoint: bool = False):
5339
+ """Set a kernel parameter by argument name.
5340
+
5341
+ Args:
5342
+ name: The name of the argument to set.
5343
+ value: The value to set the argument to.
5344
+ adjoint: If ``True``, set the adjoint of this parameter instead of the forward parameter.
5345
+ """
5272
5346
  for i, arg in enumerate(self.kernel.adj.args):
5273
5347
  if arg.label == name:
5274
- self.set_param_at_index(i, value)
5348
+ self.set_param_at_index(i, value, adjoint)
5349
+ return
5350
+
5351
+ raise ValueError(f"Argument '{name}' not found in kernel '{self.kernel.key}'")
5352
+
5353
+ def set_param_by_name_from_ctype(self, name: str, value: ctypes.Structure):
5354
+ """Set a kernel parameter by argument name with no type conversions.
5275
5355
 
5276
- # set kernel param by argument name with no type conversions
5277
- def set_param_by_name_from_ctype(self, name, value):
5356
+ Args:
5357
+ name: The name of the argument to set.
5358
+ value: The value to set the argument to.
5359
+ """
5278
5360
  # lookup argument index
5279
5361
  for i, arg in enumerate(self.kernel.adj.args):
5280
5362
  if arg.label == name:
5281
5363
  self.set_param_at_index_from_ctype(i, value)
5282
5364
 
5283
- # set all params
5284
- def set_params(self, values):
5365
+ def set_params(self, values: Sequence[Any]):
5366
+ """Set all parameters.
5367
+
5368
+ Args:
5369
+ values: A list of values to set the params to.
5370
+ """
5285
5371
  for i, v in enumerate(values):
5286
5372
  self.set_param_at_index(i, v)
5287
5373
 
5288
- # set all params without performing type-conversions
5289
- def set_params_from_ctypes(self, values):
5374
+ def set_params_from_ctypes(self, values: Sequence[ctypes.Structure]):
5375
+ """Set all parameters without performing type-conversions.
5376
+
5377
+ Args:
5378
+ values: A list of ctypes or basic int / float types.
5379
+ """
5290
5380
  for i, v in enumerate(values):
5291
5381
  self.set_param_at_index_from_ctype(i, v)
5292
5382
 
5293
- def launch(self, stream=None) -> Any:
5383
+ def launch(self, stream: Optional[Stream] = None) -> None:
5384
+ """Launch the kernel.
5385
+
5386
+ Args:
5387
+ stream: The stream to launch on.
5388
+ """
5294
5389
  if self.device.is_cpu:
5295
- self.hooks.forward(*self.params)
5390
+ if self.adjoint:
5391
+ self.hooks.backward(*self.params)
5392
+ else:
5393
+ self.hooks.forward(*self.params)
5296
5394
  else:
5297
5395
  if stream is None:
5298
5396
  stream = self.device.stream
@@ -5305,32 +5403,44 @@ class Launch:
5305
5403
  if graph is not None:
5306
5404
  graph.retain_module_exec(self.module_exec)
5307
5405
 
5308
- runtime.core.cuda_launch_kernel(
5309
- self.device.context,
5310
- self.hooks.forward,
5311
- self.bounds.size,
5312
- self.max_blocks,
5313
- self.block_dim,
5314
- self.hooks.forward_smem_bytes,
5315
- self.params_addr,
5316
- stream.cuda_stream,
5317
- )
5406
+ if self.adjoint:
5407
+ runtime.core.cuda_launch_kernel(
5408
+ self.device.context,
5409
+ self.hooks.backward,
5410
+ self.bounds.size,
5411
+ self.max_blocks,
5412
+ self.block_dim,
5413
+ self.hooks.backward_smem_bytes,
5414
+ self.params_addr,
5415
+ stream.cuda_stream,
5416
+ )
5417
+ else:
5418
+ runtime.core.cuda_launch_kernel(
5419
+ self.device.context,
5420
+ self.hooks.forward,
5421
+ self.bounds.size,
5422
+ self.max_blocks,
5423
+ self.block_dim,
5424
+ self.hooks.forward_smem_bytes,
5425
+ self.params_addr,
5426
+ stream.cuda_stream,
5427
+ )
5318
5428
 
5319
5429
 
5320
5430
  def launch(
5321
5431
  kernel,
5322
- dim: Tuple[int],
5432
+ dim: Union[int, Sequence[int]],
5323
5433
  inputs: Sequence = [],
5324
5434
  outputs: Sequence = [],
5325
5435
  adj_inputs: Sequence = [],
5326
5436
  adj_outputs: Sequence = [],
5327
5437
  device: Devicelike = None,
5328
- stream: Stream = None,
5329
- adjoint=False,
5330
- record_tape=True,
5331
- record_cmd=False,
5332
- max_blocks=0,
5333
- block_dim=256,
5438
+ stream: Optional[Stream] = None,
5439
+ adjoint: bool = False,
5440
+ record_tape: bool = True,
5441
+ record_cmd: bool = False,
5442
+ max_blocks: int = 0,
5443
+ block_dim: int = 256,
5334
5444
  ):
5335
5445
  """Launch a Warp kernel on the target device
5336
5446
 
@@ -5338,18 +5448,23 @@ def launch(
5338
5448
 
5339
5449
  Args:
5340
5450
  kernel: The name of a Warp kernel function, decorated with the ``@wp.kernel`` decorator
5341
- dim: The number of threads to launch the kernel, can be an integer, or a Tuple of ints with max of 4 dimensions
5451
+ dim: The number of threads to launch the kernel, can be an integer or a
5452
+ sequence of integers with a maximum of 4 dimensions.
5342
5453
  inputs: The input parameters to the kernel (optional)
5343
5454
  outputs: The output parameters (optional)
5344
5455
  adj_inputs: The adjoint inputs (optional)
5345
5456
  adj_outputs: The adjoint outputs (optional)
5346
- device: The device to launch on (optional)
5347
- stream: The stream to launch on (optional)
5348
- adjoint: Whether to run forward or backward pass (typically use False)
5349
- record_tape: When true the launch will be recorded the global wp.Tape() object when present
5350
- record_cmd: When True the launch will be returned as a ``Launch`` command object, the launch will not occur until the user calls ``cmd.launch()``
5351
- max_blocks: The maximum number of CUDA thread blocks to use. Only has an effect for CUDA kernel launches.
5352
- If negative or zero, the maximum hardware value will be used.
5457
+ device: The device to launch on.
5458
+ stream: The stream to launch on.
5459
+ adjoint: Whether to run forward or backward pass (typically use ``False``).
5460
+ record_tape: When ``True``, the launch will be recorded the global
5461
+ :class:`wp.Tape() <warp.Tape>` object when present.
5462
+ record_cmd: When ``True``, the launch will return a :class:`Launch`
5463
+ object. The launch will not occur until the user calls
5464
+ :meth:`Launch.launch()`.
5465
+ max_blocks: The maximum number of CUDA thread blocks to use.
5466
+ Only has an effect for CUDA kernel launches.
5467
+ If negative or zero, the maximum hardware value will be used.
5353
5468
  block_dim: The number of threads per block.
5354
5469
  """
5355
5470
 
@@ -5370,7 +5485,7 @@ def launch(
5370
5485
  print(f"kernel: {kernel.key} dim: {dim} inputs: {inputs} outputs: {outputs} device: {device}")
5371
5486
 
5372
5487
  # construct launch bounds
5373
- bounds = warp.types.launch_bounds_t(dim)
5488
+ bounds = launch_bounds_t(dim)
5374
5489
 
5375
5490
  if bounds.size > 0:
5376
5491
  # first param is the number of threads
@@ -5427,6 +5542,17 @@ def launch(
5427
5542
  f"Failed to find backward kernel '{kernel.key}' from module '{kernel.module.name}' for device '{device}'"
5428
5543
  )
5429
5544
 
5545
+ if record_cmd:
5546
+ launch = Launch(
5547
+ kernel=kernel,
5548
+ hooks=hooks,
5549
+ params=params,
5550
+ params_addr=None,
5551
+ bounds=bounds,
5552
+ device=device,
5553
+ adjoint=adjoint,
5554
+ )
5555
+ return launch
5430
5556
  hooks.backward(*params)
5431
5557
 
5432
5558
  else:
@@ -5437,7 +5563,13 @@ def launch(
5437
5563
 
5438
5564
  if record_cmd:
5439
5565
  launch = Launch(
5440
- kernel=kernel, hooks=hooks, params=params, params_addr=None, bounds=bounds, device=device
5566
+ kernel=kernel,
5567
+ hooks=hooks,
5568
+ params=params,
5569
+ params_addr=None,
5570
+ bounds=bounds,
5571
+ device=device,
5572
+ adjoint=adjoint,
5441
5573
  )
5442
5574
  return launch
5443
5575
  else:
@@ -5464,16 +5596,30 @@ def launch(
5464
5596
  f"Failed to find backward kernel '{kernel.key}' from module '{kernel.module.name}' for device '{device}'"
5465
5597
  )
5466
5598
 
5467
- runtime.core.cuda_launch_kernel(
5468
- device.context,
5469
- hooks.backward,
5470
- bounds.size,
5471
- max_blocks,
5472
- block_dim,
5473
- hooks.backward_smem_bytes,
5474
- kernel_params,
5475
- stream.cuda_stream,
5476
- )
5599
+ if record_cmd:
5600
+ launch = Launch(
5601
+ kernel=kernel,
5602
+ hooks=hooks,
5603
+ params=params,
5604
+ params_addr=kernel_params,
5605
+ bounds=bounds,
5606
+ device=device,
5607
+ max_blocks=max_blocks,
5608
+ block_dim=block_dim,
5609
+ adjoint=adjoint,
5610
+ )
5611
+ return launch
5612
+ else:
5613
+ runtime.core.cuda_launch_kernel(
5614
+ device.context,
5615
+ hooks.backward,
5616
+ bounds.size,
5617
+ max_blocks,
5618
+ block_dim,
5619
+ hooks.backward_smem_bytes,
5620
+ kernel_params,
5621
+ stream.cuda_stream,
5622
+ )
5477
5623
 
5478
5624
  else:
5479
5625
  if hooks.forward is None:
@@ -5493,7 +5639,6 @@ def launch(
5493
5639
  block_dim=block_dim,
5494
5640
  )
5495
5641
  return launch
5496
-
5497
5642
  else:
5498
5643
  # launch
5499
5644
  runtime.core.cuda_launch_kernel(
@@ -6286,6 +6431,26 @@ def export_functions_rst(file): # pragma: no cover
6286
6431
  def export_stubs(file): # pragma: no cover
6287
6432
  """Generates stub file for auto-complete of builtin functions"""
6288
6433
 
6434
+ # Add copyright notice
6435
+ print(
6436
+ """# SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
6437
+ # SPDX-License-Identifier: Apache-2.0
6438
+ #
6439
+ # Licensed under the Apache License, Version 2.0 (the "License");
6440
+ # you may not use this file except in compliance with the License.
6441
+ # You may obtain a copy of the License at
6442
+ #
6443
+ # http://www.apache.org/licenses/LICENSE-2.0
6444
+ #
6445
+ # Unless required by applicable law or agreed to in writing, software
6446
+ # distributed under the License is distributed on an "AS IS" BASIS,
6447
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
6448
+ # See the License for the specific language governing permissions and
6449
+ # limitations under the License.
6450
+ """,
6451
+ file=file,
6452
+ )
6453
+
6289
6454
  print(
6290
6455
  "# Autogenerated file, do not edit, this file provides stubs for builtins autocomplete in VSCode, PyCharm, etc",
6291
6456
  file=file,
warp/dlpack.py CHANGED
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2023 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  # Python specification for DLpack:
9
17
  # https://dmlc.github.io/dlpack/latest/python_spec.html
warp/examples/__init__.py CHANGED
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2024 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  import os
9
17
 
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  import gc
9
17
  import statistics as stats
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  # include parent path
9
17
  import csv
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  import cupy as cp
9
17
  import cupyx as cpx
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  import jax.lax
9
17
  import jax.numpy as jnp
@@ -1,3 +1,18 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
1
16
  import math
2
17
 
3
18
  import cupy as cp
@@ -1,9 +1,17 @@
1
- # Copyright (c) 2022 NVIDIA CORPORATION. All rights reserved.
2
- # NVIDIA CORPORATION and its licensors retain all intellectual property
3
- # and proprietary rights in and to this software, related documentation
4
- # and any modifications thereto. Any use, reproduction, disclosure or
5
- # distribution of this software and related documentation without an express
6
- # license agreement from NVIDIA CORPORATION is strictly prohibited.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
7
15
 
8
16
  import numpy as np
9
17