warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (300) hide show
  1. warp/__init__.py +10 -4
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +5 -3
  6. warp/build_dll.py +29 -9
  7. warp/builtins.py +868 -507
  8. warp/codegen.py +1074 -638
  9. warp/config.py +3 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +715 -222
  12. warp/fabric.py +326 -0
  13. warp/fem/__init__.py +27 -0
  14. warp/fem/cache.py +389 -0
  15. warp/fem/dirichlet.py +181 -0
  16. warp/fem/domain.py +263 -0
  17. warp/fem/field/__init__.py +101 -0
  18. warp/fem/field/field.py +149 -0
  19. warp/fem/field/nodal_field.py +299 -0
  20. warp/fem/field/restriction.py +21 -0
  21. warp/fem/field/test.py +181 -0
  22. warp/fem/field/trial.py +183 -0
  23. warp/fem/geometry/__init__.py +19 -0
  24. warp/fem/geometry/closest_point.py +70 -0
  25. warp/fem/geometry/deformed_geometry.py +271 -0
  26. warp/fem/geometry/element.py +744 -0
  27. warp/fem/geometry/geometry.py +186 -0
  28. warp/fem/geometry/grid_2d.py +373 -0
  29. warp/fem/geometry/grid_3d.py +435 -0
  30. warp/fem/geometry/hexmesh.py +953 -0
  31. warp/fem/geometry/partition.py +376 -0
  32. warp/fem/geometry/quadmesh_2d.py +532 -0
  33. warp/fem/geometry/tetmesh.py +840 -0
  34. warp/fem/geometry/trimesh_2d.py +577 -0
  35. warp/fem/integrate.py +1616 -0
  36. warp/fem/operator.py +191 -0
  37. warp/fem/polynomial.py +213 -0
  38. warp/fem/quadrature/__init__.py +2 -0
  39. warp/fem/quadrature/pic_quadrature.py +245 -0
  40. warp/fem/quadrature/quadrature.py +294 -0
  41. warp/fem/space/__init__.py +292 -0
  42. warp/fem/space/basis_space.py +489 -0
  43. warp/fem/space/collocated_function_space.py +105 -0
  44. warp/fem/space/dof_mapper.py +236 -0
  45. warp/fem/space/function_space.py +145 -0
  46. warp/fem/space/grid_2d_function_space.py +267 -0
  47. warp/fem/space/grid_3d_function_space.py +306 -0
  48. warp/fem/space/hexmesh_function_space.py +352 -0
  49. warp/fem/space/partition.py +350 -0
  50. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  51. warp/fem/space/restriction.py +160 -0
  52. warp/fem/space/shape/__init__.py +15 -0
  53. warp/fem/space/shape/cube_shape_function.py +738 -0
  54. warp/fem/space/shape/shape_function.py +103 -0
  55. warp/fem/space/shape/square_shape_function.py +611 -0
  56. warp/fem/space/shape/tet_shape_function.py +567 -0
  57. warp/fem/space/shape/triangle_shape_function.py +429 -0
  58. warp/fem/space/tetmesh_function_space.py +292 -0
  59. warp/fem/space/topology.py +295 -0
  60. warp/fem/space/trimesh_2d_function_space.py +221 -0
  61. warp/fem/types.py +77 -0
  62. warp/fem/utils.py +495 -0
  63. warp/native/array.h +147 -44
  64. warp/native/builtin.h +122 -149
  65. warp/native/bvh.cpp +73 -325
  66. warp/native/bvh.cu +406 -23
  67. warp/native/bvh.h +34 -43
  68. warp/native/clang/clang.cpp +13 -8
  69. warp/native/crt.h +2 -0
  70. warp/native/cuda_crt.h +5 -0
  71. warp/native/cuda_util.cpp +15 -3
  72. warp/native/cuda_util.h +3 -1
  73. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  74. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  75. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  76. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  77. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  78. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  79. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  80. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  133. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  134. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  135. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  136. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  137. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  138. warp/native/cutlass_gemm.cu +5 -3
  139. warp/native/exports.h +1240 -952
  140. warp/native/fabric.h +228 -0
  141. warp/native/hashgrid.cpp +4 -4
  142. warp/native/hashgrid.h +22 -2
  143. warp/native/intersect.h +22 -7
  144. warp/native/intersect_adj.h +8 -8
  145. warp/native/intersect_tri.h +1 -1
  146. warp/native/marching.cu +157 -161
  147. warp/native/mat.h +80 -19
  148. warp/native/matnn.h +2 -2
  149. warp/native/mesh.cpp +33 -108
  150. warp/native/mesh.cu +114 -23
  151. warp/native/mesh.h +446 -46
  152. warp/native/noise.h +272 -329
  153. warp/native/quat.h +51 -8
  154. warp/native/rand.h +45 -35
  155. warp/native/range.h +6 -2
  156. warp/native/reduce.cpp +1 -1
  157. warp/native/reduce.cu +10 -12
  158. warp/native/runlength_encode.cu +6 -10
  159. warp/native/scan.cu +8 -11
  160. warp/native/sparse.cpp +4 -4
  161. warp/native/sparse.cu +164 -154
  162. warp/native/spatial.h +2 -2
  163. warp/native/temp_buffer.h +14 -30
  164. warp/native/vec.h +107 -23
  165. warp/native/volume.h +120 -0
  166. warp/native/warp.cpp +560 -30
  167. warp/native/warp.cu +431 -44
  168. warp/native/warp.h +13 -4
  169. warp/optim/__init__.py +1 -0
  170. warp/optim/linear.py +922 -0
  171. warp/optim/sgd.py +92 -0
  172. warp/render/render_opengl.py +335 -119
  173. warp/render/render_usd.py +11 -11
  174. warp/sim/__init__.py +2 -2
  175. warp/sim/articulation.py +385 -185
  176. warp/sim/collide.py +8 -0
  177. warp/sim/import_mjcf.py +297 -106
  178. warp/sim/import_urdf.py +389 -210
  179. warp/sim/import_usd.py +198 -97
  180. warp/sim/inertia.py +17 -18
  181. warp/sim/integrator_euler.py +14 -8
  182. warp/sim/integrator_xpbd.py +158 -16
  183. warp/sim/model.py +795 -291
  184. warp/sim/render.py +3 -3
  185. warp/sim/utils.py +3 -0
  186. warp/sparse.py +640 -150
  187. warp/stubs.py +606 -267
  188. warp/tape.py +61 -10
  189. warp/tests/__main__.py +3 -6
  190. warp/tests/assets/curlnoise_golden.npy +0 -0
  191. warp/tests/assets/pnoise_golden.npy +0 -0
  192. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  193. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  194. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  195. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  196. warp/tests/aux_test_unresolved_func.py +14 -0
  197. warp/tests/aux_test_unresolved_symbol.py +14 -0
  198. warp/tests/disabled_kinematics.py +239 -0
  199. warp/tests/run_coverage_serial.py +31 -0
  200. warp/tests/test_adam.py +103 -106
  201. warp/tests/test_arithmetic.py +128 -74
  202. warp/tests/test_array.py +212 -97
  203. warp/tests/test_array_reduce.py +57 -23
  204. warp/tests/test_atomic.py +64 -28
  205. warp/tests/test_bool.py +99 -0
  206. warp/tests/test_builtins_resolution.py +1292 -0
  207. warp/tests/test_bvh.py +42 -18
  208. warp/tests/test_closest_point_edge_edge.py +54 -57
  209. warp/tests/test_codegen.py +208 -130
  210. warp/tests/test_compile_consts.py +28 -20
  211. warp/tests/test_conditional.py +108 -24
  212. warp/tests/test_copy.py +10 -12
  213. warp/tests/test_ctypes.py +112 -88
  214. warp/tests/test_dense.py +21 -14
  215. warp/tests/test_devices.py +98 -0
  216. warp/tests/test_dlpack.py +75 -75
  217. warp/tests/test_examples.py +277 -0
  218. warp/tests/test_fabricarray.py +955 -0
  219. warp/tests/test_fast_math.py +15 -11
  220. warp/tests/test_fem.py +1271 -0
  221. warp/tests/test_fp16.py +53 -19
  222. warp/tests/test_func.py +187 -86
  223. warp/tests/test_generics.py +194 -49
  224. warp/tests/test_grad.py +178 -109
  225. warp/tests/test_grad_customs.py +176 -0
  226. warp/tests/test_hash_grid.py +52 -37
  227. warp/tests/test_import.py +10 -23
  228. warp/tests/test_indexedarray.py +32 -31
  229. warp/tests/test_intersect.py +18 -9
  230. warp/tests/test_large.py +141 -0
  231. warp/tests/test_launch.py +14 -41
  232. warp/tests/test_lerp.py +64 -65
  233. warp/tests/test_linear_solvers.py +154 -0
  234. warp/tests/test_lvalue.py +493 -0
  235. warp/tests/test_marching_cubes.py +12 -13
  236. warp/tests/test_mat.py +517 -2898
  237. warp/tests/test_mat_lite.py +115 -0
  238. warp/tests/test_mat_scalar_ops.py +2889 -0
  239. warp/tests/test_math.py +103 -9
  240. warp/tests/test_matmul.py +305 -69
  241. warp/tests/test_matmul_lite.py +410 -0
  242. warp/tests/test_mesh.py +71 -14
  243. warp/tests/test_mesh_query_aabb.py +41 -25
  244. warp/tests/test_mesh_query_point.py +140 -22
  245. warp/tests/test_mesh_query_ray.py +39 -22
  246. warp/tests/test_mlp.py +30 -22
  247. warp/tests/test_model.py +92 -89
  248. warp/tests/test_modules_lite.py +39 -0
  249. warp/tests/test_multigpu.py +88 -114
  250. warp/tests/test_noise.py +12 -11
  251. warp/tests/test_operators.py +16 -20
  252. warp/tests/test_options.py +11 -11
  253. warp/tests/test_pinned.py +17 -18
  254. warp/tests/test_print.py +32 -11
  255. warp/tests/test_quat.py +275 -129
  256. warp/tests/test_rand.py +18 -16
  257. warp/tests/test_reload.py +38 -34
  258. warp/tests/test_rounding.py +50 -43
  259. warp/tests/test_runlength_encode.py +168 -20
  260. warp/tests/test_smoothstep.py +9 -11
  261. warp/tests/test_snippet.py +143 -0
  262. warp/tests/test_sparse.py +261 -63
  263. warp/tests/test_spatial.py +276 -243
  264. warp/tests/test_streams.py +110 -85
  265. warp/tests/test_struct.py +268 -63
  266. warp/tests/test_tape.py +39 -21
  267. warp/tests/test_torch.py +118 -89
  268. warp/tests/test_transient_module.py +12 -13
  269. warp/tests/test_types.py +614 -0
  270. warp/tests/test_utils.py +494 -0
  271. warp/tests/test_vec.py +354 -2050
  272. warp/tests/test_vec_lite.py +73 -0
  273. warp/tests/test_vec_scalar_ops.py +2099 -0
  274. warp/tests/test_volume.py +457 -293
  275. warp/tests/test_volume_write.py +124 -134
  276. warp/tests/unittest_serial.py +35 -0
  277. warp/tests/unittest_suites.py +341 -0
  278. warp/tests/unittest_utils.py +568 -0
  279. warp/tests/unused_test_misc.py +71 -0
  280. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  281. warp/thirdparty/appdirs.py +36 -45
  282. warp/thirdparty/unittest_parallel.py +549 -0
  283. warp/torch.py +9 -6
  284. warp/types.py +1089 -366
  285. warp/utils.py +93 -387
  286. warp_lang-0.11.0.dist-info/METADATA +238 -0
  287. warp_lang-0.11.0.dist-info/RECORD +332 -0
  288. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  289. warp/tests/test_all.py +0 -219
  290. warp/tests/test_array_scan.py +0 -60
  291. warp/tests/test_base.py +0 -208
  292. warp/tests/test_unresolved_func.py +0 -7
  293. warp/tests/test_unresolved_symbol.py +0 -7
  294. warp_lang-0.10.1.dist-info/METADATA +0 -21
  295. warp_lang-0.10.1.dist-info/RECORD +0 -188
  296. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  297. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  298. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  299. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
  300. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/native/array.h CHANGED
@@ -19,6 +19,12 @@ namespace wp
19
19
  printf(")\n"); \
20
20
  assert(0); \
21
21
 
22
+ #define FP_VERIFY_FWD(value) \
23
+ if (!isfinite(value)) { \
24
+ printf("%s:%d - %s(addr", __FILE__, __LINE__, __FUNCTION__); \
25
+ FP_ASSERT_FWD(value) \
26
+ } \
27
+
22
28
  #define FP_VERIFY_FWD_1(value) \
23
29
  if (!isfinite(value)) { \
24
30
  printf("%s:%d - %s(arr, %d) ", __FILE__, __LINE__, __FUNCTION__, i); \
@@ -43,6 +49,13 @@ namespace wp
43
49
  FP_ASSERT_FWD(value) \
44
50
  } \
45
51
 
52
+ #define FP_VERIFY_ADJ(value, adj_value) \
53
+ if (!isfinite(value) || !isfinite(adj_value)) \
54
+ { \
55
+ printf("%s:%d - %s(addr", __FILE__, __LINE__, __FUNCTION__); \
56
+ FP_ASSERT_ADJ(value, adj_value); \
57
+ } \
58
+
46
59
  #define FP_VERIFY_ADJ_1(value, adj_value) \
47
60
  if (!isfinite(value) || !isfinite(adj_value)) \
48
61
  { \
@@ -74,11 +87,13 @@ namespace wp
74
87
 
75
88
  #else
76
89
 
90
+ #define FP_VERIFY_FWD(value) {}
77
91
  #define FP_VERIFY_FWD_1(value) {}
78
92
  #define FP_VERIFY_FWD_2(value) {}
79
93
  #define FP_VERIFY_FWD_3(value) {}
80
94
  #define FP_VERIFY_FWD_4(value) {}
81
95
 
96
+ #define FP_VERIFY_ADJ(value, adj_value) {}
82
97
  #define FP_VERIFY_ADJ_1(value, adj_value) {}
83
98
  #define FP_VERIFY_ADJ_2(value, adj_value) {}
84
99
  #define FP_VERIFY_ADJ_3(value, adj_value) {}
@@ -88,14 +103,19 @@ namespace wp
88
103
 
89
104
  const int ARRAY_MAX_DIMS = 4; // must match constant in types.py
90
105
 
91
- const int ARRAY_TYPE_REGULAR = 0; // must match constant in types.py
92
- const int ARRAY_TYPE_INDEXED = 1; // must match constant in types.py
106
+ // must match constants in types.py
107
+ const int ARRAY_TYPE_REGULAR = 0;
108
+ const int ARRAY_TYPE_INDEXED = 1;
109
+ const int ARRAY_TYPE_FABRIC = 2;
110
+ const int ARRAY_TYPE_FABRIC_INDEXED = 3;
93
111
 
94
112
  struct shape_t
95
113
  {
96
114
  int dims[ARRAY_MAX_DIMS];
97
115
 
98
- CUDA_CALLABLE inline shape_t() : dims() {}
116
+ CUDA_CALLABLE inline shape_t()
117
+ : dims()
118
+ {}
99
119
 
100
120
  CUDA_CALLABLE inline int operator[](int i) const
101
121
  {
@@ -110,12 +130,12 @@ struct shape_t
110
130
  }
111
131
  };
112
132
 
113
- CUDA_CALLABLE inline int index(const shape_t& s, int i)
133
+ CUDA_CALLABLE inline int extract(const shape_t& s, int i)
114
134
  {
115
135
  return s.dims[i];
116
136
  }
117
137
 
118
- CUDA_CALLABLE inline void adj_index(const shape_t& s, int i, const shape_t& adj_s, int adj_i, int adj_ret) {}
138
+ CUDA_CALLABLE inline void adj_extract(const shape_t& s, int i, const shape_t& adj_s, int adj_i, int adj_ret) {}
119
139
 
120
140
  inline CUDA_CALLABLE void print(shape_t s)
121
141
  {
@@ -130,8 +150,13 @@ inline CUDA_CALLABLE void adj_print(shape_t s, shape_t& shape_t) {}
130
150
  template <typename T>
131
151
  struct array_t
132
152
  {
133
- CUDA_CALLABLE inline array_t() {}
134
- CUDA_CALLABLE inline array_t(int) {} // for backward a = 0 initialization syntax
153
+ CUDA_CALLABLE inline array_t()
154
+ : data(nullptr),
155
+ grad(nullptr),
156
+ shape(),
157
+ strides(),
158
+ ndim(0)
159
+ {}
135
160
 
136
161
  CUDA_CALLABLE array_t(T* data, int size, T* grad=nullptr) : data(data), grad(grad) {
137
162
  // constructor for 1d array
@@ -184,8 +209,8 @@ struct array_t
184
209
 
185
210
  CUDA_CALLABLE inline bool empty() const { return !data; }
186
211
 
187
- T* data{nullptr};
188
- T* grad{nullptr};
212
+ T* data;
213
+ T* grad;
189
214
  shape_t shape;
190
215
  int strides[ARRAY_MAX_DIMS];
191
216
  int ndim;
@@ -200,8 +225,11 @@ struct array_t
200
225
  template <typename T>
201
226
  struct indexedarray_t
202
227
  {
203
- CUDA_CALLABLE inline indexedarray_t() {}
204
- CUDA_CALLABLE inline indexedarray_t(int) {} // for backward a = 0 initialization syntax
228
+ CUDA_CALLABLE inline indexedarray_t()
229
+ : arr(),
230
+ indices(),
231
+ shape()
232
+ {}
205
233
 
206
234
  CUDA_CALLABLE inline bool empty() const { return !arr.data; }
207
235
 
@@ -667,43 +695,60 @@ template<template<typename> class A, typename T>
667
695
  inline CUDA_CALLABLE T atomic_max(const A<T>& buf, int i, int j, int k, int l, T value) { return atomic_max(&index(buf, i, j, k, l), value); }
668
696
 
669
697
  template<template<typename> class A, typename T>
670
- inline CUDA_CALLABLE T load(const A<T>& buf, int i) { return index(buf, i); }
698
+ inline CUDA_CALLABLE T* address(const A<T>& buf, int i) { return &index(buf, i); }
671
699
  template<template<typename> class A, typename T>
672
- inline CUDA_CALLABLE T load(const A<T>& buf, int i, int j) { return index(buf, i, j); }
700
+ inline CUDA_CALLABLE T* address(const A<T>& buf, int i, int j) { return &index(buf, i, j); }
673
701
  template<template<typename> class A, typename T>
674
- inline CUDA_CALLABLE T load(const A<T>& buf, int i, int j, int k) { return index(buf, i, j, k); }
702
+ inline CUDA_CALLABLE T* address(const A<T>& buf, int i, int j, int k) { return &index(buf, i, j, k); }
675
703
  template<template<typename> class A, typename T>
676
- inline CUDA_CALLABLE T load(const A<T>& buf, int i, int j, int k, int l) { return index(buf, i, j, k, l); }
704
+ inline CUDA_CALLABLE T* address(const A<T>& buf, int i, int j, int k, int l) { return &index(buf, i, j, k, l); }
677
705
 
678
706
  template<template<typename> class A, typename T>
679
- inline CUDA_CALLABLE void store(const A<T>& buf, int i, T value)
707
+ inline CUDA_CALLABLE void array_store(const A<T>& buf, int i, T value)
680
708
  {
681
709
  FP_VERIFY_FWD_1(value)
682
710
 
683
711
  index(buf, i) = value;
684
712
  }
685
713
  template<template<typename> class A, typename T>
686
- inline CUDA_CALLABLE void store(const A<T>& buf, int i, int j, T value)
714
+ inline CUDA_CALLABLE void array_store(const A<T>& buf, int i, int j, T value)
687
715
  {
688
716
  FP_VERIFY_FWD_2(value)
689
717
 
690
718
  index(buf, i, j) = value;
691
719
  }
692
720
  template<template<typename> class A, typename T>
693
- inline CUDA_CALLABLE void store(const A<T>& buf, int i, int j, int k, T value)
721
+ inline CUDA_CALLABLE void array_store(const A<T>& buf, int i, int j, int k, T value)
694
722
  {
695
723
  FP_VERIFY_FWD_3(value)
696
724
 
697
725
  index(buf, i, j, k) = value;
698
726
  }
699
727
  template<template<typename> class A, typename T>
700
- inline CUDA_CALLABLE void store(const A<T>& buf, int i, int j, int k, int l, T value)
728
+ inline CUDA_CALLABLE void array_store(const A<T>& buf, int i, int j, int k, int l, T value)
701
729
  {
702
730
  FP_VERIFY_FWD_4(value)
703
731
 
704
732
  index(buf, i, j, k, l) = value;
705
733
  }
706
734
 
735
+ template<typename T>
736
+ inline CUDA_CALLABLE void store(T* address, T value)
737
+ {
738
+ FP_VERIFY_FWD(value)
739
+
740
+ *address = value;
741
+ }
742
+
743
+ template<typename T>
744
+ inline CUDA_CALLABLE T load(T* address)
745
+ {
746
+ T value = *address;
747
+ FP_VERIFY_FWD(value)
748
+
749
+ return value;
750
+ }
751
+
707
752
  // select operator to check for array being null
708
753
  template <typename T1, typename T2>
709
754
  CUDA_CALLABLE inline T2 select(const array_t<T1>& arr, const T2& a, const T2& b) { return arr.data?b:a; }
@@ -737,34 +782,36 @@ CUDA_CALLABLE inline void adj_atomic_add(uint32* buf, uint32 value) { }
737
782
  CUDA_CALLABLE inline void adj_atomic_add(int64* buf, int64 value) { }
738
783
  CUDA_CALLABLE inline void adj_atomic_add(uint64* buf, uint64 value) { }
739
784
 
785
+ CUDA_CALLABLE inline void adj_atomic_add(bool* buf, bool value) { }
786
+
740
787
  // only generate gradients for T types
741
788
  template<typename T>
742
- inline CUDA_CALLABLE void adj_load(const array_t<T>& buf, int i, const array_t<T>& adj_buf, int& adj_i, const T& adj_output)
789
+ inline CUDA_CALLABLE void adj_address(const array_t<T>& buf, int i, const array_t<T>& adj_buf, int& adj_i, const T& adj_output)
743
790
  {
744
791
  if (buf.grad)
745
792
  adj_atomic_add(&index_grad(buf, i), adj_output);
746
793
  }
747
794
  template<typename T>
748
- inline CUDA_CALLABLE void adj_load(const array_t<T>& buf, int i, int j, const array_t<T>& adj_buf, int& adj_i, int& adj_j, const T& adj_output)
795
+ inline CUDA_CALLABLE void adj_address(const array_t<T>& buf, int i, int j, const array_t<T>& adj_buf, int& adj_i, int& adj_j, const T& adj_output)
749
796
  {
750
797
  if (buf.grad)
751
798
  adj_atomic_add(&index_grad(buf, i, j), adj_output);
752
799
  }
753
800
  template<typename T>
754
- inline CUDA_CALLABLE void adj_load(const array_t<T>& buf, int i, int j, int k, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, const T& adj_output)
801
+ inline CUDA_CALLABLE void adj_address(const array_t<T>& buf, int i, int j, int k, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, const T& adj_output)
755
802
  {
756
803
  if (buf.grad)
757
804
  adj_atomic_add(&index_grad(buf, i, j, k), adj_output);
758
805
  }
759
806
  template<typename T>
760
- inline CUDA_CALLABLE void adj_load(const array_t<T>& buf, int i, int j, int k, int l, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, const T& adj_output)
807
+ inline CUDA_CALLABLE void adj_address(const array_t<T>& buf, int i, int j, int k, int l, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, const T& adj_output)
761
808
  {
762
809
  if (buf.grad)
763
810
  adj_atomic_add(&index_grad(buf, i, j, k, l), adj_output);
764
811
  }
765
812
 
766
813
  template<typename T>
767
- inline CUDA_CALLABLE void adj_store(const array_t<T>& buf, int i, T value, const array_t<T>& adj_buf, int& adj_i, T& adj_value)
814
+ inline CUDA_CALLABLE void adj_array_store(const array_t<T>& buf, int i, T value, const array_t<T>& adj_buf, int& adj_i, T& adj_value)
768
815
  {
769
816
  if (buf.grad)
770
817
  adj_value += index_grad(buf, i);
@@ -772,7 +819,7 @@ inline CUDA_CALLABLE void adj_store(const array_t<T>& buf, int i, T value, const
772
819
  FP_VERIFY_ADJ_1(value, adj_value)
773
820
  }
774
821
  template<typename T>
775
- inline CUDA_CALLABLE void adj_store(const array_t<T>& buf, int i, int j, T value, const array_t<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value)
822
+ inline CUDA_CALLABLE void adj_array_store(const array_t<T>& buf, int i, int j, T value, const array_t<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value)
776
823
  {
777
824
  if (buf.grad)
778
825
  adj_value += index_grad(buf, i, j);
@@ -781,7 +828,7 @@ inline CUDA_CALLABLE void adj_store(const array_t<T>& buf, int i, int j, T value
781
828
 
782
829
  }
783
830
  template<typename T>
784
- inline CUDA_CALLABLE void adj_store(const array_t<T>& buf, int i, int j, int k, T value, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value)
831
+ inline CUDA_CALLABLE void adj_array_store(const array_t<T>& buf, int i, int j, int k, T value, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value)
785
832
  {
786
833
  if (buf.grad)
787
834
  adj_value += index_grad(buf, i, j, k);
@@ -789,7 +836,7 @@ inline CUDA_CALLABLE void adj_store(const array_t<T>& buf, int i, int j, int k,
789
836
  FP_VERIFY_ADJ_3(value, adj_value)
790
837
  }
791
838
  template<typename T>
792
- inline CUDA_CALLABLE void adj_store(const array_t<T>& buf, int i, int j, int k, int l, T value, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value)
839
+ inline CUDA_CALLABLE void adj_array_store(const array_t<T>& buf, int i, int j, int k, int l, T value, const array_t<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value)
793
840
  {
794
841
  if (buf.grad)
795
842
  adj_value += index_grad(buf, i, j, k, l);
@@ -797,6 +844,19 @@ inline CUDA_CALLABLE void adj_store(const array_t<T>& buf, int i, int j, int k,
797
844
  FP_VERIFY_ADJ_4(value, adj_value)
798
845
  }
799
846
 
847
+ template<typename T>
848
+ inline CUDA_CALLABLE void adj_store(const T* address, T value, const T& adj_address, T& adj_value)
849
+ {
850
+ // nop; generic store() operations are not differentiable, only array_store() is
851
+ FP_VERIFY_ADJ(value, adj_value)
852
+ }
853
+
854
+ template<typename T>
855
+ inline CUDA_CALLABLE void adj_load(const T* address, const T& adj_address, T& adj_value)
856
+ {
857
+ // nop; generic load() operations are not differentiable
858
+ }
859
+
800
860
  template<typename T>
801
861
  inline CUDA_CALLABLE void adj_atomic_add(const array_t<T>& buf, int i, T value, const array_t<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret)
802
862
  {
@@ -866,22 +926,22 @@ inline CUDA_CALLABLE void adj_atomic_sub(const array_t<T>& buf, int i, int j, in
866
926
 
867
927
  // generic array types that do not support gradient computation (indexedarray, etc.)
868
928
  template<template<typename> class A1, template<typename> class A2, typename T>
869
- inline CUDA_CALLABLE void adj_load(const A1<T>& buf, int i, const A2<T>& adj_buf, int& adj_i, const T& adj_output) {}
929
+ inline CUDA_CALLABLE void adj_address(const A1<T>& buf, int i, const A2<T>& adj_buf, int& adj_i, const T& adj_output) {}
870
930
  template<template<typename> class A1, template<typename> class A2, typename T>
871
- inline CUDA_CALLABLE void adj_load(const A1<T>& buf, int i, int j, const A2<T>& adj_buf, int& adj_i, int& adj_j, const T& adj_output) {}
931
+ inline CUDA_CALLABLE void adj_address(const A1<T>& buf, int i, int j, const A2<T>& adj_buf, int& adj_i, int& adj_j, const T& adj_output) {}
872
932
  template<template<typename> class A1, template<typename> class A2, typename T>
873
- inline CUDA_CALLABLE void adj_load(const A1<T>& buf, int i, int j, int k, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, const T& adj_output) {}
933
+ inline CUDA_CALLABLE void adj_address(const A1<T>& buf, int i, int j, int k, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, const T& adj_output) {}
874
934
  template<template<typename> class A1, template<typename> class A2, typename T>
875
- inline CUDA_CALLABLE void adj_load(const A1<T>& buf, int i, int j, int k, int l, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, const T& adj_output) {}
935
+ inline CUDA_CALLABLE void adj_address(const A1<T>& buf, int i, int j, int k, int l, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, const T& adj_output) {}
876
936
 
877
937
  template<template<typename> class A1, template<typename> class A2, typename T>
878
- inline CUDA_CALLABLE void adj_store(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value) {}
938
+ inline CUDA_CALLABLE void adj_array_store(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value) {}
879
939
  template<template<typename> class A1, template<typename> class A2, typename T>
880
- inline CUDA_CALLABLE void adj_store(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value) {}
940
+ inline CUDA_CALLABLE void adj_array_store(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value) {}
881
941
  template<template<typename> class A1, template<typename> class A2, typename T>
882
- inline CUDA_CALLABLE void adj_store(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value) {}
942
+ inline CUDA_CALLABLE void adj_array_store(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value) {}
883
943
  template<template<typename> class A1, template<typename> class A2, typename T>
884
- inline CUDA_CALLABLE void adj_store(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value) {}
944
+ inline CUDA_CALLABLE void adj_array_store(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value) {}
885
945
 
886
946
  template<template<typename> class A1, template<typename> class A2, typename T>
887
947
  inline CUDA_CALLABLE void adj_atomic_add(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) {}
@@ -901,22 +961,65 @@ inline CUDA_CALLABLE void adj_atomic_sub(const A1<T>& buf, int i, int j, int k,
901
961
  template<template<typename> class A1, template<typename> class A2, typename T>
902
962
  inline CUDA_CALLABLE void adj_atomic_sub(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) {}
903
963
 
964
+ // generic handler for scalar values
904
965
  template<template<typename> class A1, template<typename> class A2, typename T>
905
- inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) {}
966
+ inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) {
967
+ if (buf.grad)
968
+ adj_atomic_minmax(&index(buf, i), &index_grad(buf, i), value, adj_value);
969
+
970
+ FP_VERIFY_ADJ_1(value, adj_value)
971
+ }
906
972
  template<template<typename> class A1, template<typename> class A2, typename T>
907
- inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value, const T& adj_ret) {}
973
+ inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value, const T& adj_ret) {
974
+ if (buf.grad)
975
+ adj_atomic_minmax(&index(buf, i, j), &index_grad(buf, i, j), value, adj_value);
976
+
977
+ FP_VERIFY_ADJ_2(value, adj_value)
978
+ }
908
979
  template<template<typename> class A1, template<typename> class A2, typename T>
909
- inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value, const T& adj_ret) {}
980
+ inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value, const T& adj_ret) {
981
+ if (buf.grad)
982
+ adj_atomic_minmax(&index(buf, i, j, k), &index_grad(buf, i, j, k), value, adj_value);
983
+
984
+ FP_VERIFY_ADJ_3(value, adj_value)
985
+ }
910
986
  template<template<typename> class A1, template<typename> class A2, typename T>
911
- inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) {}
987
+ inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) {
988
+ if (buf.grad)
989
+ adj_atomic_minmax(&index(buf, i, j, k, l), &index_grad(buf, i, j, k, l), value, adj_value);
990
+
991
+ FP_VERIFY_ADJ_4(value, adj_value)
992
+ }
912
993
 
913
994
  template<template<typename> class A1, template<typename> class A2, typename T>
914
- inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) {}
995
+ inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) {
996
+ if (buf.grad)
997
+ adj_atomic_minmax(&index(buf, i), &index_grad(buf, i), value, adj_value);
998
+
999
+ FP_VERIFY_ADJ_1(value, adj_value)
1000
+ }
915
1001
  template<template<typename> class A1, template<typename> class A2, typename T>
916
- inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value, const T& adj_ret) {}
1002
+ inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value, const T& adj_ret) {
1003
+ if (buf.grad)
1004
+ adj_atomic_minmax(&index(buf, i, j), &index_grad(buf, i, j), value, adj_value);
1005
+
1006
+ FP_VERIFY_ADJ_2(value, adj_value)
1007
+ }
917
1008
  template<template<typename> class A1, template<typename> class A2, typename T>
918
- inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value, const T& adj_ret) {}
1009
+ inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value, const T& adj_ret) {
1010
+ if (buf.grad)
1011
+ adj_atomic_minmax(&index(buf, i, j, k), &index_grad(buf, i, j, k), value, adj_value);
1012
+
1013
+ FP_VERIFY_ADJ_3(value, adj_value)
1014
+ }
919
1015
  template<template<typename> class A1, template<typename> class A2, typename T>
920
- inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) {}
1016
+ inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) {
1017
+ if (buf.grad)
1018
+ adj_atomic_minmax(&index(buf, i, j, k, l), &index_grad(buf, i, j, k, l), value, adj_value);
1019
+
1020
+ FP_VERIFY_ADJ_4(value, adj_value)
1021
+ }
1022
+
1023
+ } // namespace wp
921
1024
 
922
- } // namespace wp
1025
+ #include "fabric.h"