warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (300) hide show
  1. warp/__init__.py +10 -4
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +5 -3
  6. warp/build_dll.py +29 -9
  7. warp/builtins.py +868 -507
  8. warp/codegen.py +1074 -638
  9. warp/config.py +3 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +715 -222
  12. warp/fabric.py +326 -0
  13. warp/fem/__init__.py +27 -0
  14. warp/fem/cache.py +389 -0
  15. warp/fem/dirichlet.py +181 -0
  16. warp/fem/domain.py +263 -0
  17. warp/fem/field/__init__.py +101 -0
  18. warp/fem/field/field.py +149 -0
  19. warp/fem/field/nodal_field.py +299 -0
  20. warp/fem/field/restriction.py +21 -0
  21. warp/fem/field/test.py +181 -0
  22. warp/fem/field/trial.py +183 -0
  23. warp/fem/geometry/__init__.py +19 -0
  24. warp/fem/geometry/closest_point.py +70 -0
  25. warp/fem/geometry/deformed_geometry.py +271 -0
  26. warp/fem/geometry/element.py +744 -0
  27. warp/fem/geometry/geometry.py +186 -0
  28. warp/fem/geometry/grid_2d.py +373 -0
  29. warp/fem/geometry/grid_3d.py +435 -0
  30. warp/fem/geometry/hexmesh.py +953 -0
  31. warp/fem/geometry/partition.py +376 -0
  32. warp/fem/geometry/quadmesh_2d.py +532 -0
  33. warp/fem/geometry/tetmesh.py +840 -0
  34. warp/fem/geometry/trimesh_2d.py +577 -0
  35. warp/fem/integrate.py +1616 -0
  36. warp/fem/operator.py +191 -0
  37. warp/fem/polynomial.py +213 -0
  38. warp/fem/quadrature/__init__.py +2 -0
  39. warp/fem/quadrature/pic_quadrature.py +245 -0
  40. warp/fem/quadrature/quadrature.py +294 -0
  41. warp/fem/space/__init__.py +292 -0
  42. warp/fem/space/basis_space.py +489 -0
  43. warp/fem/space/collocated_function_space.py +105 -0
  44. warp/fem/space/dof_mapper.py +236 -0
  45. warp/fem/space/function_space.py +145 -0
  46. warp/fem/space/grid_2d_function_space.py +267 -0
  47. warp/fem/space/grid_3d_function_space.py +306 -0
  48. warp/fem/space/hexmesh_function_space.py +352 -0
  49. warp/fem/space/partition.py +350 -0
  50. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  51. warp/fem/space/restriction.py +160 -0
  52. warp/fem/space/shape/__init__.py +15 -0
  53. warp/fem/space/shape/cube_shape_function.py +738 -0
  54. warp/fem/space/shape/shape_function.py +103 -0
  55. warp/fem/space/shape/square_shape_function.py +611 -0
  56. warp/fem/space/shape/tet_shape_function.py +567 -0
  57. warp/fem/space/shape/triangle_shape_function.py +429 -0
  58. warp/fem/space/tetmesh_function_space.py +292 -0
  59. warp/fem/space/topology.py +295 -0
  60. warp/fem/space/trimesh_2d_function_space.py +221 -0
  61. warp/fem/types.py +77 -0
  62. warp/fem/utils.py +495 -0
  63. warp/native/array.h +147 -44
  64. warp/native/builtin.h +122 -149
  65. warp/native/bvh.cpp +73 -325
  66. warp/native/bvh.cu +406 -23
  67. warp/native/bvh.h +34 -43
  68. warp/native/clang/clang.cpp +13 -8
  69. warp/native/crt.h +2 -0
  70. warp/native/cuda_crt.h +5 -0
  71. warp/native/cuda_util.cpp +15 -3
  72. warp/native/cuda_util.h +3 -1
  73. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  74. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  75. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  76. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  77. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  78. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  79. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  80. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  133. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  134. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  135. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  136. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  137. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  138. warp/native/cutlass_gemm.cu +5 -3
  139. warp/native/exports.h +1240 -952
  140. warp/native/fabric.h +228 -0
  141. warp/native/hashgrid.cpp +4 -4
  142. warp/native/hashgrid.h +22 -2
  143. warp/native/intersect.h +22 -7
  144. warp/native/intersect_adj.h +8 -8
  145. warp/native/intersect_tri.h +1 -1
  146. warp/native/marching.cu +157 -161
  147. warp/native/mat.h +80 -19
  148. warp/native/matnn.h +2 -2
  149. warp/native/mesh.cpp +33 -108
  150. warp/native/mesh.cu +114 -23
  151. warp/native/mesh.h +446 -46
  152. warp/native/noise.h +272 -329
  153. warp/native/quat.h +51 -8
  154. warp/native/rand.h +45 -35
  155. warp/native/range.h +6 -2
  156. warp/native/reduce.cpp +1 -1
  157. warp/native/reduce.cu +10 -12
  158. warp/native/runlength_encode.cu +6 -10
  159. warp/native/scan.cu +8 -11
  160. warp/native/sparse.cpp +4 -4
  161. warp/native/sparse.cu +164 -154
  162. warp/native/spatial.h +2 -2
  163. warp/native/temp_buffer.h +14 -30
  164. warp/native/vec.h +107 -23
  165. warp/native/volume.h +120 -0
  166. warp/native/warp.cpp +560 -30
  167. warp/native/warp.cu +431 -44
  168. warp/native/warp.h +13 -4
  169. warp/optim/__init__.py +1 -0
  170. warp/optim/linear.py +922 -0
  171. warp/optim/sgd.py +92 -0
  172. warp/render/render_opengl.py +335 -119
  173. warp/render/render_usd.py +11 -11
  174. warp/sim/__init__.py +2 -2
  175. warp/sim/articulation.py +385 -185
  176. warp/sim/collide.py +8 -0
  177. warp/sim/import_mjcf.py +297 -106
  178. warp/sim/import_urdf.py +389 -210
  179. warp/sim/import_usd.py +198 -97
  180. warp/sim/inertia.py +17 -18
  181. warp/sim/integrator_euler.py +14 -8
  182. warp/sim/integrator_xpbd.py +158 -16
  183. warp/sim/model.py +795 -291
  184. warp/sim/render.py +3 -3
  185. warp/sim/utils.py +3 -0
  186. warp/sparse.py +640 -150
  187. warp/stubs.py +606 -267
  188. warp/tape.py +61 -10
  189. warp/tests/__main__.py +3 -6
  190. warp/tests/assets/curlnoise_golden.npy +0 -0
  191. warp/tests/assets/pnoise_golden.npy +0 -0
  192. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  193. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  194. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  195. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  196. warp/tests/aux_test_unresolved_func.py +14 -0
  197. warp/tests/aux_test_unresolved_symbol.py +14 -0
  198. warp/tests/disabled_kinematics.py +239 -0
  199. warp/tests/run_coverage_serial.py +31 -0
  200. warp/tests/test_adam.py +103 -106
  201. warp/tests/test_arithmetic.py +128 -74
  202. warp/tests/test_array.py +212 -97
  203. warp/tests/test_array_reduce.py +57 -23
  204. warp/tests/test_atomic.py +64 -28
  205. warp/tests/test_bool.py +99 -0
  206. warp/tests/test_builtins_resolution.py +1292 -0
  207. warp/tests/test_bvh.py +42 -18
  208. warp/tests/test_closest_point_edge_edge.py +54 -57
  209. warp/tests/test_codegen.py +208 -130
  210. warp/tests/test_compile_consts.py +28 -20
  211. warp/tests/test_conditional.py +108 -24
  212. warp/tests/test_copy.py +10 -12
  213. warp/tests/test_ctypes.py +112 -88
  214. warp/tests/test_dense.py +21 -14
  215. warp/tests/test_devices.py +98 -0
  216. warp/tests/test_dlpack.py +75 -75
  217. warp/tests/test_examples.py +277 -0
  218. warp/tests/test_fabricarray.py +955 -0
  219. warp/tests/test_fast_math.py +15 -11
  220. warp/tests/test_fem.py +1271 -0
  221. warp/tests/test_fp16.py +53 -19
  222. warp/tests/test_func.py +187 -86
  223. warp/tests/test_generics.py +194 -49
  224. warp/tests/test_grad.py +178 -109
  225. warp/tests/test_grad_customs.py +176 -0
  226. warp/tests/test_hash_grid.py +52 -37
  227. warp/tests/test_import.py +10 -23
  228. warp/tests/test_indexedarray.py +32 -31
  229. warp/tests/test_intersect.py +18 -9
  230. warp/tests/test_large.py +141 -0
  231. warp/tests/test_launch.py +14 -41
  232. warp/tests/test_lerp.py +64 -65
  233. warp/tests/test_linear_solvers.py +154 -0
  234. warp/tests/test_lvalue.py +493 -0
  235. warp/tests/test_marching_cubes.py +12 -13
  236. warp/tests/test_mat.py +517 -2898
  237. warp/tests/test_mat_lite.py +115 -0
  238. warp/tests/test_mat_scalar_ops.py +2889 -0
  239. warp/tests/test_math.py +103 -9
  240. warp/tests/test_matmul.py +305 -69
  241. warp/tests/test_matmul_lite.py +410 -0
  242. warp/tests/test_mesh.py +71 -14
  243. warp/tests/test_mesh_query_aabb.py +41 -25
  244. warp/tests/test_mesh_query_point.py +140 -22
  245. warp/tests/test_mesh_query_ray.py +39 -22
  246. warp/tests/test_mlp.py +30 -22
  247. warp/tests/test_model.py +92 -89
  248. warp/tests/test_modules_lite.py +39 -0
  249. warp/tests/test_multigpu.py +88 -114
  250. warp/tests/test_noise.py +12 -11
  251. warp/tests/test_operators.py +16 -20
  252. warp/tests/test_options.py +11 -11
  253. warp/tests/test_pinned.py +17 -18
  254. warp/tests/test_print.py +32 -11
  255. warp/tests/test_quat.py +275 -129
  256. warp/tests/test_rand.py +18 -16
  257. warp/tests/test_reload.py +38 -34
  258. warp/tests/test_rounding.py +50 -43
  259. warp/tests/test_runlength_encode.py +168 -20
  260. warp/tests/test_smoothstep.py +9 -11
  261. warp/tests/test_snippet.py +143 -0
  262. warp/tests/test_sparse.py +261 -63
  263. warp/tests/test_spatial.py +276 -243
  264. warp/tests/test_streams.py +110 -85
  265. warp/tests/test_struct.py +268 -63
  266. warp/tests/test_tape.py +39 -21
  267. warp/tests/test_torch.py +118 -89
  268. warp/tests/test_transient_module.py +12 -13
  269. warp/tests/test_types.py +614 -0
  270. warp/tests/test_utils.py +494 -0
  271. warp/tests/test_vec.py +354 -2050
  272. warp/tests/test_vec_lite.py +73 -0
  273. warp/tests/test_vec_scalar_ops.py +2099 -0
  274. warp/tests/test_volume.py +457 -293
  275. warp/tests/test_volume_write.py +124 -134
  276. warp/tests/unittest_serial.py +35 -0
  277. warp/tests/unittest_suites.py +341 -0
  278. warp/tests/unittest_utils.py +568 -0
  279. warp/tests/unused_test_misc.py +71 -0
  280. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  281. warp/thirdparty/appdirs.py +36 -45
  282. warp/thirdparty/unittest_parallel.py +549 -0
  283. warp/torch.py +9 -6
  284. warp/types.py +1089 -366
  285. warp/utils.py +93 -387
  286. warp_lang-0.11.0.dist-info/METADATA +238 -0
  287. warp_lang-0.11.0.dist-info/RECORD +332 -0
  288. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  289. warp/tests/test_all.py +0 -219
  290. warp/tests/test_array_scan.py +0 -60
  291. warp/tests/test_base.py +0 -208
  292. warp/tests/test_unresolved_func.py +0 -7
  293. warp/tests/test_unresolved_symbol.py +0 -7
  294. warp_lang-0.10.1.dist-info/METADATA +0 -21
  295. warp_lang-0.10.1.dist-info/RECORD +0 -188
  296. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  297. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  298. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  299. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
  300. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/native/builtin.h CHANGED
@@ -46,7 +46,6 @@ __device__ void __debugbreak() {}
46
46
  namespace wp
47
47
  {
48
48
 
49
-
50
49
  // numeric types (used from generated kernels)
51
50
  typedef float float32;
52
51
  typedef double float64;
@@ -141,7 +140,7 @@ static_assert(sizeof(half) == 2, "Size of half / float16 type must be 2-bytes");
141
140
 
142
141
  typedef half float16;
143
142
 
144
- #if __CUDA_ARCH__
143
+ #if defined(__CUDA_ARCH__)
145
144
 
146
145
  CUDA_CALLABLE inline half float_to_half(float x)
147
146
  {
@@ -157,95 +156,38 @@ CUDA_CALLABLE inline float half_to_float(half x)
157
156
  return val;
158
157
  }
159
158
 
160
- #else
159
+ #elif defined(__clang__)
161
160
 
162
- // adapted from Fabien Giesen's post: https://gist.github.com/rygorous/2156668
161
+ // _Float16 is Clang's native half-precision floating-point type
163
162
  inline half float_to_half(float x)
164
163
  {
165
- union fp32
166
- {
167
- uint32 u;
168
- float f;
169
-
170
- struct
171
- {
172
- unsigned int mantissa : 23;
173
- unsigned int exponent : 8;
174
- unsigned int sign : 1;
175
- };
176
- };
177
-
178
- fp32 f;
179
- f.f = x;
180
-
181
- fp32 f32infty = { 255 << 23 };
182
- fp32 f16infty = { 31 << 23 };
183
- fp32 magic = { 15 << 23 };
184
- uint32 sign_mask = 0x80000000u;
185
- uint32 round_mask = ~0xfffu;
186
- half o;
187
-
188
- uint32 sign = f.u & sign_mask;
189
- f.u ^= sign;
190
-
191
- // NOTE all the integer compares in this function can be safely
192
- // compiled into signed compares since all operands are below
193
- // 0x80000000. Important if you want fast straight SSE2 code
194
- // (since there's no unsigned PCMPGTD).
195
-
196
- if (f.u >= f32infty.u) // Inf or NaN (all exponent bits set)
197
- o.u = (f.u > f32infty.u) ? 0x7e00 : 0x7c00; // NaN->qNaN and Inf->Inf
198
- else // (De)normalized number or zero
199
- {
200
- f.u &= round_mask;
201
- f.f *= magic.f;
202
- f.u -= round_mask;
203
- if (f.u > f16infty.u) f.u = f16infty.u; // Clamp to signed infinity if overflowed
204
164
 
205
- o.u = f.u >> 13; // Take the bits!
206
- }
207
-
208
- o.u |= sign >> 16;
209
- return o;
165
+ _Float16 f16 = static_cast<_Float16>(x);
166
+ return *reinterpret_cast<half*>(&f16);
210
167
  }
211
168
 
212
-
213
169
  inline float half_to_float(half h)
214
170
  {
215
- union fp32
216
- {
217
- uint32 u;
218
- float f;
171
+ _Float16 f16 = *reinterpret_cast<_Float16*>(&h);
172
+ return static_cast<float>(f16);
173
+ }
219
174
 
220
- struct
221
- {
222
- unsigned int mantissa : 23;
223
- unsigned int exponent : 8;
224
- unsigned int sign : 1;
225
- };
226
- };
227
-
228
- static const fp32 magic = { 113 << 23 };
229
- static const uint32 shifted_exp = 0x7c00 << 13; // exponent mask after shift
230
- fp32 o;
231
-
232
- o.u = (h.u & 0x7fff) << 13; // exponent/mantissa bits
233
- uint32 exp = shifted_exp & o.u; // just the exponent
234
- o.u += (127 - 15) << 23; // exponent adjust
235
-
236
- // handle exponent special cases
237
- if (exp == shifted_exp) // Inf/NaN?
238
- o.u += (128 - 16) << 23; // extra exp adjust
239
- else if (exp == 0) // Zero/Denormal?
240
- {
241
- o.u += 1 << 23; // extra exp adjust
242
- o.f -= magic.f; // renormalize
243
- }
175
+ #else // Native C++ for Warp builtins outside of kernels
244
176
 
245
- o.u |= (h.u & 0x8000) << 16; // sign bit
246
- return o.f;
177
+ extern "C" WP_API uint16_t float_to_half_bits(float x);
178
+ extern "C" WP_API float half_bits_to_float(uint16_t u);
179
+
180
+ inline half float_to_half(float x)
181
+ {
182
+ half h;
183
+ h.u = float_to_half_bits(x);
184
+ return h;
247
185
  }
248
186
 
187
+ inline float half_to_float(half h)
188
+ {
189
+ return half_bits_to_float(h.u);
190
+ }
249
191
 
250
192
  #endif
251
193
 
@@ -353,7 +295,7 @@ inline CUDA_CALLABLE T rshift(T a, T b) { return a>>b; } \
353
295
  inline CUDA_CALLABLE T invert(T x) { return ~x; } \
354
296
  inline CUDA_CALLABLE bool isfinite(T x) { return true; } \
355
297
  inline CUDA_CALLABLE void adj_mul(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
356
- inline CUDA_CALLABLE void adj_div(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
298
+ inline CUDA_CALLABLE void adj_div(T a, T b, T ret, T& adj_a, T& adj_b, T adj_ret) { } \
357
299
  inline CUDA_CALLABLE void adj_add(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
358
300
  inline CUDA_CALLABLE void adj_sub(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
359
301
  inline CUDA_CALLABLE void adj_mod(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
@@ -491,11 +433,6 @@ inline CUDA_CALLABLE void adj_clamp(T x, T a, T b, T& adj_x, T& adj_a, T& adj_b,
491
433
  else\
492
434
  adj_x += adj_ret;\
493
435
  }\
494
- inline CUDA_CALLABLE void adj_round(T x, T& adj_x, T adj_ret){ }\
495
- inline CUDA_CALLABLE void adj_rint(T x, T& adj_x, T adj_ret){ }\
496
- inline CUDA_CALLABLE void adj_trunc(T x, T& adj_x, T adj_ret){ }\
497
- inline CUDA_CALLABLE void adj_floor(T x, T& adj_x, T adj_ret){ }\
498
- inline CUDA_CALLABLE void adj_ceil(T x, T& adj_x, T adj_ret){ }\
499
436
  inline CUDA_CALLABLE T div(T a, T b)\
500
437
  {\
501
438
  DO_IF_FPCHECK(\
@@ -506,10 +443,10 @@ inline CUDA_CALLABLE T div(T a, T b)\
506
443
  })\
507
444
  return a/b;\
508
445
  }\
509
- inline CUDA_CALLABLE void adj_div(T a, T b, T& adj_a, T& adj_b, T adj_ret)\
446
+ inline CUDA_CALLABLE void adj_div(T a, T b, T ret, T& adj_a, T& adj_b, T adj_ret)\
510
447
  {\
511
448
  adj_a += adj_ret/b;\
512
- adj_b -= adj_ret*(a/b)/b;\
449
+ adj_b -= adj_ret*(ret)/b;\
513
450
  DO_IF_FPCHECK(\
514
451
  if (!isfinite(adj_a) || !isfinite(adj_b))\
515
452
  {\
@@ -848,6 +785,10 @@ inline CUDA_CALLABLE half sqrt(half x)
848
785
  return ::sqrtf(float(x));
849
786
  }
850
787
 
788
+ inline CUDA_CALLABLE float cbrt(float x) { return ::cbrtf(x); }
789
+ inline CUDA_CALLABLE double cbrt(double x) { return ::cbrt(x); }
790
+ inline CUDA_CALLABLE half cbrt(half x) { return ::cbrtf(float(x)); }
791
+
851
792
  inline CUDA_CALLABLE float tan(float x) { return ::tanf(x); }
852
793
  inline CUDA_CALLABLE float sinh(float x) { return ::sinhf(x);}
853
794
  inline CUDA_CALLABLE float cosh(float x) { return ::coshf(x);}
@@ -874,6 +815,21 @@ inline CUDA_CALLABLE float rint(float x) { return ::rintf(x); }
874
815
  inline CUDA_CALLABLE float trunc(float x) { return ::truncf(x); }
875
816
  inline CUDA_CALLABLE float floor(float x) { return ::floorf(x); }
876
817
  inline CUDA_CALLABLE float ceil(float x) { return ::ceilf(x); }
818
+ inline CUDA_CALLABLE float frac(float x) { return x - trunc(x); }
819
+
820
+ inline CUDA_CALLABLE double round(double x) { return ::round(x); }
821
+ inline CUDA_CALLABLE double rint(double x) { return ::rint(x); }
822
+ inline CUDA_CALLABLE double trunc(double x) { return ::trunc(x); }
823
+ inline CUDA_CALLABLE double floor(double x) { return ::floor(x); }
824
+ inline CUDA_CALLABLE double ceil(double x) { return ::ceil(x); }
825
+ inline CUDA_CALLABLE double frac(double x) { return x - trunc(x); }
826
+
827
+ inline CUDA_CALLABLE half round(half x) { return ::roundf(float(x)); }
828
+ inline CUDA_CALLABLE half rint(half x) { return ::rintf(float(x)); }
829
+ inline CUDA_CALLABLE half trunc(half x) { return ::truncf(float(x)); }
830
+ inline CUDA_CALLABLE half floor(half x) { return ::floorf(float(x)); }
831
+ inline CUDA_CALLABLE half ceil(half x) { return ::ceilf(float(x)); }
832
+ inline CUDA_CALLABLE half frac(half x) { return float(x) - trunc(float(x)); }
877
833
 
878
834
  #define DECLARE_ADJOINTS(T)\
879
835
  inline CUDA_CALLABLE void adj_log(T a, T& adj_a, T adj_ret)\
@@ -903,11 +859,11 @@ inline CUDA_CALLABLE void adj_log10(T a, T& adj_a, T adj_ret)\
903
859
  assert(0);\
904
860
  })\
905
861
  }\
906
- inline CUDA_CALLABLE void adj_exp(T a, T& adj_a, T adj_ret) { adj_a += exp(a)*adj_ret; }\
907
- inline CUDA_CALLABLE void adj_pow(T a, T b, T& adj_a, T& adj_b, T adj_ret)\
862
+ inline CUDA_CALLABLE void adj_exp(T a, T ret, T& adj_a, T adj_ret) { adj_a += ret*adj_ret; }\
863
+ inline CUDA_CALLABLE void adj_pow(T a, T b, T ret, T& adj_a, T& adj_b, T adj_ret)\
908
864
  { \
909
865
  adj_a += b*pow(a, b-T(1))*adj_ret;\
910
- adj_b += log(a)*pow(a, b)*adj_ret;\
866
+ adj_b += log(a)*ret*adj_ret;\
911
867
  DO_IF_FPCHECK(if (!isfinite(adj_a) || !isfinite(adj_b))\
912
868
  {\
913
869
  printf("%s:%d - adj_pow(%f, %f, %f, %f, %f)\n", __FILE__, __LINE__, float(a), float(b), float(adj_a), float(adj_b), float(adj_ret));\
@@ -1006,20 +962,28 @@ inline CUDA_CALLABLE void adj_cosh(T x, T& adj_x, T adj_ret)\
1006
962
  {\
1007
963
  adj_x += sinh(x)*adj_ret;\
1008
964
  }\
1009
- inline CUDA_CALLABLE void adj_tanh(T x, T& adj_x, T adj_ret)\
965
+ inline CUDA_CALLABLE void adj_tanh(T x, T ret, T& adj_x, T adj_ret)\
1010
966
  {\
1011
- T tanh_x = tanh(x);\
1012
- adj_x += (T(1) - tanh_x*tanh_x)*adj_ret;\
967
+ adj_x += (T(1) - ret*ret)*adj_ret;\
1013
968
  }\
1014
- inline CUDA_CALLABLE void adj_sqrt(T x, T& adj_x, T adj_ret)\
969
+ inline CUDA_CALLABLE void adj_sqrt(T x, T ret, T& adj_x, T adj_ret)\
1015
970
  {\
1016
- adj_x += T(0.5)*(T(1)/sqrt(x))*adj_ret;\
971
+ adj_x += T(0.5)*(T(1)/ret)*adj_ret;\
1017
972
  DO_IF_FPCHECK(if (!isfinite(adj_x))\
1018
973
  {\
1019
974
  printf("%s:%d - adj_sqrt(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret));\
1020
975
  assert(0);\
1021
976
  })\
1022
977
  }\
978
+ inline CUDA_CALLABLE void adj_cbrt(T x, T ret, T& adj_x, T adj_ret)\
979
+ {\
980
+ adj_x += (T(1)/T(3))*(T(1)/(ret*ret))*adj_ret;\
981
+ DO_IF_FPCHECK(if (!isfinite(adj_x))\
982
+ {\
983
+ printf("%s:%d - adj_cbrt(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret));\
984
+ assert(0);\
985
+ })\
986
+ }\
1023
987
  inline CUDA_CALLABLE void adj_degrees(T x, T& adj_x, T adj_ret)\
1024
988
  {\
1025
989
  adj_x += RAD_TO_DEG * adj_ret;\
@@ -1027,7 +991,13 @@ inline CUDA_CALLABLE void adj_degrees(T x, T& adj_x, T adj_ret)\
1027
991
  inline CUDA_CALLABLE void adj_radians(T x, T& adj_x, T adj_ret)\
1028
992
  {\
1029
993
  adj_x += DEG_TO_RAD * adj_ret;\
1030
- }
994
+ }\
995
+ inline CUDA_CALLABLE void adj_round(T x, T& adj_x, T adj_ret){ }\
996
+ inline CUDA_CALLABLE void adj_rint(T x, T& adj_x, T adj_ret){ }\
997
+ inline CUDA_CALLABLE void adj_trunc(T x, T& adj_x, T adj_ret){ }\
998
+ inline CUDA_CALLABLE void adj_floor(T x, T& adj_x, T adj_ret){ }\
999
+ inline CUDA_CALLABLE void adj_ceil(T x, T& adj_x, T adj_ret){ }\
1000
+ inline CUDA_CALLABLE void adj_frac(T x, T& adj_x, T adj_ret){ }
1031
1001
 
1032
1002
  DECLARE_ADJOINTS(float16)
1033
1003
  DECLARE_ADJOINTS(float32)
@@ -1051,17 +1021,31 @@ CUDA_CALLABLE inline void adj_select(const C& cond, const T& a, const T& b, C& a
1051
1021
  }
1052
1022
 
1053
1023
  template <typename T>
1054
- CUDA_CALLABLE inline void copy(T& dest, const T& src)
1024
+ CUDA_CALLABLE inline T copy(const T& src)
1025
+ {
1026
+ return src;
1027
+ }
1028
+
1029
+ template <typename T>
1030
+ CUDA_CALLABLE inline void adj_copy(const T& src, T& adj_src, T& adj_dest)
1031
+ {
1032
+ adj_src = adj_dest;
1033
+ adj_dest = T{};
1034
+ }
1035
+
1036
+ template <typename T>
1037
+ CUDA_CALLABLE inline void assign(T& dest, const T& src)
1055
1038
  {
1056
1039
  dest = src;
1057
1040
  }
1058
1041
 
1059
1042
  template <typename T>
1060
- CUDA_CALLABLE inline void adj_copy(T& dest, const T& src, T& adj_dest, T& adj_src)
1043
+ CUDA_CALLABLE inline void adj_assign(T& dest, const T& src, T& adj_dest, T& adj_src)
1061
1044
  {
1062
- // nop, this is non-differentiable operation since it violates SSA
1045
+ // this is generally a non-differentiable operation since it violates SSA,
1046
+ // except in read-modify-write statements which are reversible through backpropagation
1063
1047
  adj_src = adj_dest;
1064
- adj_dest = T(0);
1048
+ adj_dest = T{};
1065
1049
  }
1066
1050
 
1067
1051
 
@@ -1106,34 +1090,8 @@ struct launch_bounds_t
1106
1090
  size_t size; // total number of threads
1107
1091
  };
1108
1092
 
1109
- #ifdef __CUDACC__
1110
-
1111
- // store launch bounds in shared memory so
1112
- // we can access them from any user func
1113
- // this is to avoid having to explicitly
1114
- // set another piece of __constant__ memory
1115
- // from the host
1116
- __shared__ launch_bounds_t s_launchBounds;
1117
-
1118
- __device__ inline void set_launch_bounds(const launch_bounds_t& b)
1119
- {
1120
- if (threadIdx.x == 0)
1121
- s_launchBounds = b;
1122
-
1123
- __syncthreads();
1124
- }
1125
-
1126
- #else
1127
-
1128
- // for single-threaded CPU we store launch
1129
- // bounds in static memory to share globally
1130
- static launch_bounds_t s_launchBounds;
1093
+ #ifndef __CUDACC__
1131
1094
  static size_t s_threadIdx;
1132
-
1133
- inline void set_launch_bounds(const launch_bounds_t& b)
1134
- {
1135
- s_launchBounds = b;
1136
- }
1137
1095
  #endif
1138
1096
 
1139
1097
  inline CUDA_CALLABLE size_t grid_index()
@@ -1147,10 +1105,8 @@ inline CUDA_CALLABLE size_t grid_index()
1147
1105
  #endif
1148
1106
  }
1149
1107
 
1150
- inline CUDA_CALLABLE int tid()
1108
+ inline CUDA_CALLABLE int tid(size_t index)
1151
1109
  {
1152
- const size_t index = grid_index();
1153
-
1154
1110
  // For the 1-D tid() we need to warn the user if we're about to provide a truncated index
1155
1111
  // Only do this in _DEBUG when called from device to avoid excessive register allocation
1156
1112
  #if defined(_DEBUG) || !defined(__CUDA_ARCH__)
@@ -1161,23 +1117,19 @@ inline CUDA_CALLABLE int tid()
1161
1117
  return static_cast<int>(index);
1162
1118
  }
1163
1119
 
1164
- inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j)
1120
+ inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, size_t index, const launch_bounds_t& launch_bounds)
1165
1121
  {
1166
- const size_t index = grid_index();
1167
-
1168
- const size_t n = s_launchBounds.shape[1];
1122
+ const size_t n = launch_bounds.shape[1];
1169
1123
 
1170
1124
  // convert to work item
1171
1125
  i = index/n;
1172
1126
  j = index%n;
1173
1127
  }
1174
1128
 
1175
- inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k)
1129
+ inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, size_t index, const launch_bounds_t& launch_bounds)
1176
1130
  {
1177
- const size_t index = grid_index();
1178
-
1179
- const size_t n = s_launchBounds.shape[1];
1180
- const size_t o = s_launchBounds.shape[2];
1131
+ const size_t n = launch_bounds.shape[1];
1132
+ const size_t o = launch_bounds.shape[2];
1181
1133
 
1182
1134
  // convert to work item
1183
1135
  i = index/(n*o);
@@ -1185,13 +1137,11 @@ inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k)
1185
1137
  k = index%o;
1186
1138
  }
1187
1139
 
1188
- inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, int& l)
1140
+ inline CUDA_CALLABLE_DEVICE void tid(int& i, int& j, int& k, int& l, size_t index, const launch_bounds_t& launch_bounds)
1189
1141
  {
1190
- const size_t index = grid_index();
1191
-
1192
- const size_t n = s_launchBounds.shape[1];
1193
- const size_t o = s_launchBounds.shape[2];
1194
- const size_t p = s_launchBounds.shape[3];
1142
+ const size_t n = launch_bounds.shape[1];
1143
+ const size_t o = launch_bounds.shape[2];
1144
+ const size_t p = launch_bounds.shape[3];
1195
1145
 
1196
1146
  // convert to work item
1197
1147
  i = index/(n*o*p);
@@ -1321,9 +1271,36 @@ inline CUDA_CALLABLE int atomic_min(int* address, int val)
1321
1271
  #endif
1322
1272
  }
1323
1273
 
1274
+ // default behavior for adjoint of atomic min/max operation that accumulates gradients for all elements matching the min/max value
1275
+ template <typename T>
1276
+ CUDA_CALLABLE inline void adj_atomic_minmax(T *addr, T *adj_addr, const T &value, T &adj_value)
1277
+ {
1278
+ if (value == *addr)
1279
+ adj_value += *adj_addr;
1280
+ }
1281
+
1282
+ // for integral types we do not accumulate gradients
1283
+ CUDA_CALLABLE inline void adj_atomic_minmax(int8* buf, int8* adj_buf, const int8 &value, int8 &adj_value) { }
1284
+ CUDA_CALLABLE inline void adj_atomic_minmax(uint8* buf, uint8* adj_buf, const uint8 &value, uint8 &adj_value) { }
1285
+ CUDA_CALLABLE inline void adj_atomic_minmax(int16* buf, int16* adj_buf, const int16 &value, int16 &adj_value) { }
1286
+ CUDA_CALLABLE inline void adj_atomic_minmax(uint16* buf, uint16* adj_buf, const uint16 &value, uint16 &adj_value) { }
1287
+ CUDA_CALLABLE inline void adj_atomic_minmax(int32* buf, int32* adj_buf, const int32 &value, int32 &adj_value) { }
1288
+ CUDA_CALLABLE inline void adj_atomic_minmax(uint32* buf, uint32* adj_buf, const uint32 &value, uint32 &adj_value) { }
1289
+ CUDA_CALLABLE inline void adj_atomic_minmax(int64* buf, int64* adj_buf, const int64 &value, int64 &adj_value) { }
1290
+ CUDA_CALLABLE inline void adj_atomic_minmax(uint64* buf, uint64* adj_buf, const uint64 &value, uint64 &adj_value) { }
1291
+ CUDA_CALLABLE inline void adj_atomic_minmax(bool* buf, bool* adj_buf, const bool &value, bool &adj_value) { }
1292
+
1324
1293
 
1325
1294
  } // namespace wp
1326
1295
 
1296
+
1297
+ // bool and printf are defined outside of the wp namespace in crt.h, hence
1298
+ // their adjoint counterparts are also defined in the global namespace.
1299
+ template <typename T>
1300
+ CUDA_CALLABLE inline void adj_bool(T, T&, bool) {}
1301
+ inline CUDA_CALLABLE void adj_printf(const char* fmt, ...) {}
1302
+
1303
+
1327
1304
  #include "vec.h"
1328
1305
  #include "mat.h"
1329
1306
  #include "quat.h"
@@ -1488,10 +1465,6 @@ inline CUDA_CALLABLE void adj_print(transform_t<Type> t, transform_t<Type>& adj_
1488
1465
  inline CUDA_CALLABLE void adj_print(str t, str& adj_t) {}
1489
1466
 
1490
1467
 
1491
- // printf defined globally in crt.h
1492
- inline CUDA_CALLABLE void adj_printf(const char* fmt, ...) {}
1493
-
1494
-
1495
1468
  template <typename T>
1496
1469
  inline CUDA_CALLABLE void expect_eq(const T& actual, const T& expected)
1497
1470
  {