warp-lang 0.9.0__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (315) hide show
  1. warp/__init__.py +15 -7
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +22 -443
  6. warp/build_dll.py +384 -0
  7. warp/builtins.py +998 -488
  8. warp/codegen.py +1307 -739
  9. warp/config.py +5 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +1291 -548
  12. warp/dlpack.py +31 -31
  13. warp/fabric.py +326 -0
  14. warp/fem/__init__.py +27 -0
  15. warp/fem/cache.py +389 -0
  16. warp/fem/dirichlet.py +181 -0
  17. warp/fem/domain.py +263 -0
  18. warp/fem/field/__init__.py +101 -0
  19. warp/fem/field/field.py +149 -0
  20. warp/fem/field/nodal_field.py +299 -0
  21. warp/fem/field/restriction.py +21 -0
  22. warp/fem/field/test.py +181 -0
  23. warp/fem/field/trial.py +183 -0
  24. warp/fem/geometry/__init__.py +19 -0
  25. warp/fem/geometry/closest_point.py +70 -0
  26. warp/fem/geometry/deformed_geometry.py +271 -0
  27. warp/fem/geometry/element.py +744 -0
  28. warp/fem/geometry/geometry.py +186 -0
  29. warp/fem/geometry/grid_2d.py +373 -0
  30. warp/fem/geometry/grid_3d.py +435 -0
  31. warp/fem/geometry/hexmesh.py +953 -0
  32. warp/fem/geometry/partition.py +376 -0
  33. warp/fem/geometry/quadmesh_2d.py +532 -0
  34. warp/fem/geometry/tetmesh.py +840 -0
  35. warp/fem/geometry/trimesh_2d.py +577 -0
  36. warp/fem/integrate.py +1616 -0
  37. warp/fem/operator.py +191 -0
  38. warp/fem/polynomial.py +213 -0
  39. warp/fem/quadrature/__init__.py +2 -0
  40. warp/fem/quadrature/pic_quadrature.py +245 -0
  41. warp/fem/quadrature/quadrature.py +294 -0
  42. warp/fem/space/__init__.py +292 -0
  43. warp/fem/space/basis_space.py +489 -0
  44. warp/fem/space/collocated_function_space.py +105 -0
  45. warp/fem/space/dof_mapper.py +236 -0
  46. warp/fem/space/function_space.py +145 -0
  47. warp/fem/space/grid_2d_function_space.py +267 -0
  48. warp/fem/space/grid_3d_function_space.py +306 -0
  49. warp/fem/space/hexmesh_function_space.py +352 -0
  50. warp/fem/space/partition.py +350 -0
  51. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  52. warp/fem/space/restriction.py +160 -0
  53. warp/fem/space/shape/__init__.py +15 -0
  54. warp/fem/space/shape/cube_shape_function.py +738 -0
  55. warp/fem/space/shape/shape_function.py +103 -0
  56. warp/fem/space/shape/square_shape_function.py +611 -0
  57. warp/fem/space/shape/tet_shape_function.py +567 -0
  58. warp/fem/space/shape/triangle_shape_function.py +429 -0
  59. warp/fem/space/tetmesh_function_space.py +292 -0
  60. warp/fem/space/topology.py +295 -0
  61. warp/fem/space/trimesh_2d_function_space.py +221 -0
  62. warp/fem/types.py +77 -0
  63. warp/fem/utils.py +495 -0
  64. warp/native/array.h +164 -55
  65. warp/native/builtin.h +150 -174
  66. warp/native/bvh.cpp +75 -328
  67. warp/native/bvh.cu +406 -23
  68. warp/native/bvh.h +37 -45
  69. warp/native/clang/clang.cpp +136 -24
  70. warp/native/crt.cpp +1 -76
  71. warp/native/crt.h +111 -104
  72. warp/native/cuda_crt.h +1049 -0
  73. warp/native/cuda_util.cpp +15 -3
  74. warp/native/cuda_util.h +3 -1
  75. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  76. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  77. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  78. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  79. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  80. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  133. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  134. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  135. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  136. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  137. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  138. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  139. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  140. warp/native/cutlass_gemm.cu +5 -3
  141. warp/native/exports.h +1240 -949
  142. warp/native/fabric.h +228 -0
  143. warp/native/hashgrid.cpp +4 -4
  144. warp/native/hashgrid.h +22 -2
  145. warp/native/initializer_array.h +2 -2
  146. warp/native/intersect.h +22 -7
  147. warp/native/intersect_adj.h +8 -8
  148. warp/native/intersect_tri.h +13 -16
  149. warp/native/marching.cu +157 -161
  150. warp/native/mat.h +119 -19
  151. warp/native/matnn.h +2 -2
  152. warp/native/mesh.cpp +108 -83
  153. warp/native/mesh.cu +243 -6
  154. warp/native/mesh.h +1547 -458
  155. warp/native/nanovdb/NanoVDB.h +1 -1
  156. warp/native/noise.h +272 -329
  157. warp/native/quat.h +51 -8
  158. warp/native/rand.h +45 -35
  159. warp/native/range.h +6 -2
  160. warp/native/reduce.cpp +157 -0
  161. warp/native/reduce.cu +348 -0
  162. warp/native/runlength_encode.cpp +62 -0
  163. warp/native/runlength_encode.cu +46 -0
  164. warp/native/scan.cu +11 -13
  165. warp/native/scan.h +1 -0
  166. warp/native/solid_angle.h +442 -0
  167. warp/native/sort.cpp +13 -0
  168. warp/native/sort.cu +9 -1
  169. warp/native/sparse.cpp +338 -0
  170. warp/native/sparse.cu +545 -0
  171. warp/native/spatial.h +2 -2
  172. warp/native/temp_buffer.h +30 -0
  173. warp/native/vec.h +126 -24
  174. warp/native/volume.h +120 -0
  175. warp/native/warp.cpp +658 -53
  176. warp/native/warp.cu +660 -68
  177. warp/native/warp.h +112 -12
  178. warp/optim/__init__.py +1 -0
  179. warp/optim/linear.py +922 -0
  180. warp/optim/sgd.py +92 -0
  181. warp/render/render_opengl.py +392 -152
  182. warp/render/render_usd.py +11 -11
  183. warp/sim/__init__.py +2 -2
  184. warp/sim/articulation.py +385 -185
  185. warp/sim/collide.py +21 -8
  186. warp/sim/import_mjcf.py +297 -106
  187. warp/sim/import_urdf.py +389 -210
  188. warp/sim/import_usd.py +198 -97
  189. warp/sim/inertia.py +17 -18
  190. warp/sim/integrator_euler.py +14 -8
  191. warp/sim/integrator_xpbd.py +161 -19
  192. warp/sim/model.py +795 -291
  193. warp/sim/optimizer.py +2 -6
  194. warp/sim/render.py +65 -3
  195. warp/sim/utils.py +3 -0
  196. warp/sparse.py +1227 -0
  197. warp/stubs.py +665 -223
  198. warp/tape.py +66 -15
  199. warp/tests/__main__.py +3 -6
  200. warp/tests/assets/curlnoise_golden.npy +0 -0
  201. warp/tests/assets/pnoise_golden.npy +0 -0
  202. warp/tests/assets/torus.usda +105 -105
  203. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  204. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  205. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  206. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  207. warp/tests/aux_test_unresolved_func.py +14 -0
  208. warp/tests/aux_test_unresolved_symbol.py +14 -0
  209. warp/tests/disabled_kinematics.py +239 -0
  210. warp/tests/run_coverage_serial.py +31 -0
  211. warp/tests/test_adam.py +103 -106
  212. warp/tests/test_arithmetic.py +128 -74
  213. warp/tests/test_array.py +1497 -211
  214. warp/tests/test_array_reduce.py +150 -0
  215. warp/tests/test_atomic.py +64 -28
  216. warp/tests/test_bool.py +99 -0
  217. warp/tests/test_builtins_resolution.py +1292 -0
  218. warp/tests/test_bvh.py +75 -43
  219. warp/tests/test_closest_point_edge_edge.py +54 -57
  220. warp/tests/test_codegen.py +233 -128
  221. warp/tests/test_compile_consts.py +28 -20
  222. warp/tests/test_conditional.py +108 -24
  223. warp/tests/test_copy.py +10 -12
  224. warp/tests/test_ctypes.py +112 -88
  225. warp/tests/test_dense.py +21 -14
  226. warp/tests/test_devices.py +98 -0
  227. warp/tests/test_dlpack.py +136 -108
  228. warp/tests/test_examples.py +277 -0
  229. warp/tests/test_fabricarray.py +955 -0
  230. warp/tests/test_fast_math.py +15 -11
  231. warp/tests/test_fem.py +1271 -0
  232. warp/tests/test_fp16.py +53 -19
  233. warp/tests/test_func.py +187 -74
  234. warp/tests/test_generics.py +194 -49
  235. warp/tests/test_grad.py +180 -116
  236. warp/tests/test_grad_customs.py +176 -0
  237. warp/tests/test_hash_grid.py +52 -37
  238. warp/tests/test_import.py +10 -23
  239. warp/tests/test_indexedarray.py +577 -24
  240. warp/tests/test_intersect.py +18 -9
  241. warp/tests/test_large.py +141 -0
  242. warp/tests/test_launch.py +251 -15
  243. warp/tests/test_lerp.py +64 -65
  244. warp/tests/test_linear_solvers.py +154 -0
  245. warp/tests/test_lvalue.py +493 -0
  246. warp/tests/test_marching_cubes.py +12 -13
  247. warp/tests/test_mat.py +508 -2778
  248. warp/tests/test_mat_lite.py +115 -0
  249. warp/tests/test_mat_scalar_ops.py +2889 -0
  250. warp/tests/test_math.py +103 -9
  251. warp/tests/test_matmul.py +305 -69
  252. warp/tests/test_matmul_lite.py +410 -0
  253. warp/tests/test_mesh.py +71 -14
  254. warp/tests/test_mesh_query_aabb.py +41 -25
  255. warp/tests/test_mesh_query_point.py +325 -34
  256. warp/tests/test_mesh_query_ray.py +39 -22
  257. warp/tests/test_mlp.py +30 -22
  258. warp/tests/test_model.py +92 -89
  259. warp/tests/test_modules_lite.py +39 -0
  260. warp/tests/test_multigpu.py +88 -114
  261. warp/tests/test_noise.py +12 -11
  262. warp/tests/test_operators.py +16 -20
  263. warp/tests/test_options.py +11 -11
  264. warp/tests/test_pinned.py +17 -18
  265. warp/tests/test_print.py +32 -11
  266. warp/tests/test_quat.py +275 -129
  267. warp/tests/test_rand.py +18 -16
  268. warp/tests/test_reload.py +38 -34
  269. warp/tests/test_rounding.py +50 -43
  270. warp/tests/test_runlength_encode.py +190 -0
  271. warp/tests/test_smoothstep.py +9 -11
  272. warp/tests/test_snippet.py +143 -0
  273. warp/tests/test_sparse.py +460 -0
  274. warp/tests/test_spatial.py +276 -243
  275. warp/tests/test_streams.py +110 -85
  276. warp/tests/test_struct.py +331 -85
  277. warp/tests/test_tape.py +39 -21
  278. warp/tests/test_torch.py +118 -89
  279. warp/tests/test_transient_module.py +12 -13
  280. warp/tests/test_types.py +614 -0
  281. warp/tests/test_utils.py +494 -0
  282. warp/tests/test_vec.py +354 -1987
  283. warp/tests/test_vec_lite.py +73 -0
  284. warp/tests/test_vec_scalar_ops.py +2099 -0
  285. warp/tests/test_volume.py +457 -293
  286. warp/tests/test_volume_write.py +124 -134
  287. warp/tests/unittest_serial.py +35 -0
  288. warp/tests/unittest_suites.py +341 -0
  289. warp/tests/unittest_utils.py +568 -0
  290. warp/tests/unused_test_misc.py +71 -0
  291. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  292. warp/thirdparty/appdirs.py +36 -45
  293. warp/thirdparty/unittest_parallel.py +549 -0
  294. warp/torch.py +72 -30
  295. warp/types.py +1744 -713
  296. warp/utils.py +360 -350
  297. warp_lang-0.11.0.dist-info/LICENSE.md +36 -0
  298. warp_lang-0.11.0.dist-info/METADATA +238 -0
  299. warp_lang-0.11.0.dist-info/RECORD +332 -0
  300. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  301. warp/bin/warp-clang.exp +0 -0
  302. warp/bin/warp-clang.lib +0 -0
  303. warp/bin/warp.exp +0 -0
  304. warp/bin/warp.lib +0 -0
  305. warp/tests/test_all.py +0 -215
  306. warp/tests/test_array_scan.py +0 -60
  307. warp/tests/test_base.py +0 -208
  308. warp/tests/test_unresolved_func.py +0 -7
  309. warp/tests/test_unresolved_symbol.py +0 -7
  310. warp_lang-0.9.0.dist-info/METADATA +0 -20
  311. warp_lang-0.9.0.dist-info/RECORD +0 -177
  312. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  313. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  314. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  315. {warp_lang-0.9.0.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
warp/dlpack.py CHANGED
@@ -108,9 +108,6 @@ def dtype_to_dlpack(wp_dtype) -> DLDataType:
108
108
  return (DLDataTypeCode.kDLFloat, 32, 1)
109
109
  elif wp_dtype == warp.float64:
110
110
  return (DLDataTypeCode.kDLFloat, 64, 1)
111
- elif wp_dtype in warp.types.vector_types:
112
- # treat vector/matrix arrays as regular nd-arrays with one dtype lane
113
- return (DLDataTypeCode.kDLFloat, 32, 1)
114
111
  else:
115
112
  raise RuntimeError(f"No conversion from Warp type {wp_dtype} to DLPack type")
116
113
 
@@ -196,7 +193,7 @@ def to_dlpack(wp_array: warp.array):
196
193
  # DLPack does not support structured arrays
197
194
  if isinstance(wp_array.dtype, warp.codegen.Struct):
198
195
  raise RuntimeError("Cannot convert structured Warp arrays to DLPack.")
199
-
196
+
200
197
  holder = _Holder(wp_array)
201
198
 
202
199
  # allocate DLManagedTensor
@@ -204,12 +201,12 @@ def to_dlpack(wp_array: warp.array):
204
201
  dl_managed_tensor = DLManagedTensor.from_address(ctypes.pythonapi.PyMem_RawMalloc(size))
205
202
 
206
203
  # handle vector types
207
- if wp_array.dtype in warp.types.vector_types:
204
+ if hasattr(wp_array.dtype, "_wp_scalar_type_"):
208
205
  # vector type, flatten the dimensions into one tuple
209
- target_dtype = warp.float32
206
+ target_dtype = wp_array.dtype._wp_scalar_type_
210
207
  target_ndim = wp_array.ndim + len(wp_array.dtype._shape_)
211
208
  target_shape = (*wp_array.shape, *wp_array.dtype._shape_)
212
- dtype_strides = warp.types.strides_from_shape(wp_array.dtype._shape_, warp.float32)
209
+ dtype_strides = warp.types.strides_from_shape(wp_array.dtype._shape_, wp_array.dtype._wp_scalar_type_)
213
210
  target_strides = (*wp_array.strides, *dtype_strides)
214
211
  else:
215
212
  # scalar type
@@ -255,7 +252,7 @@ def dtype_is_compatible(dl_dtype, wp_dtype):
255
252
  if dl_dtype.bits == 16:
256
253
  return wp_dtype == warp.float16
257
254
  elif dl_dtype.bits == 32:
258
- return wp_dtype == warp.float32 or wp_dtype in warp.types.vector_types
255
+ return wp_dtype == warp.float32
259
256
  elif dl_dtype.bits == 64:
260
257
  return wp_dtype == warp.float64
261
258
  elif dl_dtype.type_code.value == DLDataTypeCode.kDLInt or dl_dtype.type_code.value == DLDataTypeCode.kDLUInt:
@@ -320,30 +317,33 @@ def from_dlpack(pycapsule, dtype=None) -> warp.array:
320
317
  # automatically detect dtype
321
318
  dtype = dtype_from_dlpack(dlt.dtype)
322
319
 
323
- elif dtype_is_compatible(dlt.dtype, dtype):
324
- # handle vector types
325
- if dtype in warp.types.vector_types:
326
- dtype_shape = dtype._shape_
327
- dtype_dims = len(dtype._shape_)
328
- if dtype_dims > len(shape) or dtype_shape != shape[-dtype_dims:]:
329
- raise RuntimeError(
330
- f"Could not convert DLPack tensor with shape {shape} to Warp array with dtype={dtype}, ensure that source inner shape is {dtype_shape}"
331
- )
332
-
333
- if strides is not None:
334
- # ensure the inner strides are contiguous
335
- stride = 4
336
- for i in range(dtype_dims):
337
- if strides[-i - 1] != stride:
338
- raise RuntimeError(
339
- f"Could not convert DLPack tensor with shape {shape} to Warp array with dtype={dtype}, because the source inner strides are not contiguous"
340
- )
341
- stride *= dtype_shape[-i - 1]
342
- strides = tuple(strides[:-dtype_dims])
343
-
344
- shape = tuple(shape[:-dtype_dims])
320
+ elif hasattr(dtype, "_wp_scalar_type_"):
321
+ # handle vector/matrix types
345
322
 
346
- else:
323
+ if not dtype_is_compatible(dlt.dtype, dtype._wp_scalar_type_):
324
+ raise RuntimeError(f"Incompatible data types: {dlt.dtype} and {dtype}")
325
+
326
+ dtype_shape = dtype._shape_
327
+ dtype_dims = len(dtype._shape_)
328
+ if dtype_dims > len(shape) or dtype_shape != shape[-dtype_dims:]:
329
+ raise RuntimeError(
330
+ f"Could not convert DLPack tensor with shape {shape} to Warp array with dtype={dtype}, ensure that source inner shape is {dtype_shape}"
331
+ )
332
+
333
+ if strides is not None:
334
+ # ensure the inner strides are contiguous
335
+ stride = itemsize
336
+ for i in range(dtype_dims):
337
+ if strides[-i - 1] != stride:
338
+ raise RuntimeError(
339
+ f"Could not convert DLPack tensor with shape {shape} to Warp array with dtype={dtype}, because the source inner strides are not contiguous"
340
+ )
341
+ stride *= dtype_shape[-i - 1]
342
+ strides = tuple(strides[:-dtype_dims]) or (itemsize,)
343
+
344
+ shape = tuple(shape[:-dtype_dims]) or (1,)
345
+
346
+ elif not dtype_is_compatible(dlt.dtype, dtype):
347
347
  # incompatible dtype requested
348
348
  raise RuntimeError(f"Incompatible data types: {dlt.dtype} and {dtype}")
349
349
 
warp/fabric.py ADDED
@@ -0,0 +1,326 @@
1
+ import ctypes
2
+ import math
3
+ from typing import Any
4
+
5
+ import warp
6
+ from warp.types import *
7
+
8
+
9
+ class fabricbucket_t(ctypes.Structure):
10
+ _fields_ = [
11
+ ("index_start", ctypes.c_size_t),
12
+ ("index_end", ctypes.c_size_t),
13
+ ("ptr", ctypes.c_void_p),
14
+ ("lengths", ctypes.c_void_p),
15
+ ]
16
+
17
+ def __init__(self, index_start=0, index_end=0, ptr=None, lengths=None):
18
+ self.index_start = index_start
19
+ self.index_end = index_end
20
+ self.ptr = ctypes.c_void_p(ptr)
21
+ self.lengths = ctypes.c_void_p(lengths)
22
+
23
+
24
+ class fabricarray_t(ctypes.Structure):
25
+ _fields_ = [
26
+ ("buckets", ctypes.c_void_p), # array of fabricbucket_t on the correct device
27
+ ("nbuckets", ctypes.c_size_t),
28
+ ("size", ctypes.c_size_t),
29
+ ]
30
+
31
+ def __init__(self, buckets=None, nbuckets=0, size=0):
32
+ self.buckets = ctypes.c_void_p(buckets)
33
+ self.nbuckets = nbuckets
34
+ self.size = size
35
+
36
+
37
+ class indexedfabricarray_t(ctypes.Structure):
38
+ _fields_ = [
39
+ ("fa", fabricarray_t),
40
+ ("indices", ctypes.c_void_p),
41
+ ("size", ctypes.c_size_t),
42
+ ]
43
+
44
+ def __init__(self, fa=None, indices=None):
45
+ if fa is None:
46
+ self.fa = fabricarray_t()
47
+ else:
48
+ self.fa = fa.__ctype__()
49
+
50
+ if indices is None:
51
+ self.indices = ctypes.c_void_p(None)
52
+ self.size = 0
53
+ else:
54
+ self.indices = ctypes.c_void_p(indices.ptr)
55
+ self.size = indices.size
56
+
57
+
58
+ def fabric_to_warp_dtype(type_info, attrib_name):
59
+ if not type_info[0]:
60
+ raise RuntimeError(f"Attribute '{attrib_name}' cannot be used in Warp")
61
+
62
+ base_type_dict = {
63
+ "b": warp.bool, # boolean
64
+ "i1": warp.int8,
65
+ "i2": warp.int16,
66
+ "i4": warp.int32,
67
+ "i8": warp.int64,
68
+ "u1": warp.uint8,
69
+ "u2": warp.uint16,
70
+ "u4": warp.uint32,
71
+ "u8": warp.uint64,
72
+ "f2": warp.float16,
73
+ "f4": warp.float32,
74
+ "f8": warp.float64,
75
+ }
76
+
77
+ base_dtype = base_type_dict.get(type_info[1])
78
+ if base_dtype is None:
79
+ raise RuntimeError(f"Attribute '{attrib_name}' base data type '{type_info[1]}' is not supported in Warp")
80
+
81
+ elem_count = type_info[2]
82
+ role = type_info[4]
83
+
84
+ if role in ("text", "path"):
85
+ raise RuntimeError(f"Attribute '{attrib_name}' role '{role}' is not supported in Warp")
86
+
87
+ if elem_count > 1:
88
+ # vector or matrix type
89
+ if role == "quat" and elem_count == 4:
90
+ return quaternion(base_dtype)
91
+ elif role in ("matrix", "transform", "frame"):
92
+ # only square matrices are currently supported
93
+ mat_size = int(math.sqrt(elem_count))
94
+ assert mat_size * mat_size == elem_count
95
+ return matrix((mat_size, mat_size), base_dtype)
96
+ else:
97
+ return vector(elem_count, base_dtype)
98
+ else:
99
+ # scalar type
100
+ return base_dtype
101
+
102
+
103
+ class fabricarray(noncontiguous_array_base[T]):
104
+ # member attributes available during code-gen (e.g.: d = arr.shape[0])
105
+ # (initialized when needed)
106
+ _vars = None
107
+
108
+ def __init__(self, data=None, attrib=None, dtype=Any, ndim=None):
109
+ super().__init__(ARRAY_TYPE_FABRIC)
110
+
111
+ if data is not None:
112
+ from .context import runtime
113
+
114
+ # ensure the attribute name was also specified
115
+ if not isinstance(attrib, str):
116
+ raise ValueError(f"Invalid attribute name: {attrib}")
117
+
118
+ # get the fabric interface dictionary
119
+ if isinstance(data, dict):
120
+ iface = data
121
+ elif hasattr(data, "__fabric_arrays_interface__"):
122
+ iface = data.__fabric_arrays_interface__
123
+ else:
124
+ raise ValueError(
125
+ "Invalid data argument for fabricarray: expected dict or object with __fabric_arrays_interface__"
126
+ )
127
+
128
+ version = iface.get("version")
129
+ if version != 1:
130
+ raise ValueError(f"Unsupported Fabric interface version: {version}")
131
+
132
+ device = iface.get("device")
133
+ if not isinstance(device, str):
134
+ raise ValueError(f"Invalid Fabric interface device: {device}")
135
+
136
+ self.device = runtime.get_device(device)
137
+
138
+ attribs = iface.get("attribs")
139
+ if not isinstance(attribs, dict):
140
+ raise ValueError("Failed to get Fabric interface attributes")
141
+
142
+ # look up attribute info by name
143
+ attrib_info = attribs.get(attrib)
144
+ if not isinstance(attrib_info, dict):
145
+ raise ValueError(f"Failed to get attribute '{attrib}'")
146
+
147
+ type_info = attrib_info["type"]
148
+ assert len(type_info) == 5
149
+
150
+ self.dtype = fabric_to_warp_dtype(type_info, attrib)
151
+
152
+ self.access = attrib_info["access"]
153
+
154
+ pointers = attrib_info["pointers"]
155
+ counts = attrib_info["counts"]
156
+
157
+ if not (hasattr(pointers, "__len__") and hasattr(counts, "__len__") and len(pointers) == len(counts)):
158
+ raise RuntimeError("Attribute pointers and counts must be lists of the same size")
159
+
160
+ # check whether it's an array
161
+ array_depth = type_info[3]
162
+ if array_depth == 0:
163
+ self.ndim = 1
164
+ array_lengths = None
165
+ elif array_depth == 1:
166
+ self.ndim = 2
167
+ array_lengths = attrib_info["array_lengths"]
168
+ if not hasattr(array_lengths, "__len__") or len(array_lengths) != len(pointers):
169
+ raise RuntimeError(
170
+ "Attribute `array_lengths` must be a list of the same size as `pointers` and `counts`"
171
+ )
172
+ else:
173
+ raise ValueError(f"Invalid attribute array depth: {array_depth}")
174
+
175
+ num_buckets = len(pointers)
176
+ size = 0
177
+
178
+ buckets = (fabricbucket_t * num_buckets)()
179
+ for i in range(num_buckets):
180
+ buckets[i].index_start = size
181
+ buckets[i].index_end = size + counts[i]
182
+ buckets[i].ptr = pointers[i]
183
+ if array_lengths:
184
+ buckets[i].lengths = array_lengths[i]
185
+ size += counts[i]
186
+
187
+ if self.device.is_cuda:
188
+ # copy bucket info to device
189
+ with warp.ScopedStream(self.device.null_stream):
190
+ buckets_size = ctypes.sizeof(buckets)
191
+ buckets_ptr = self.device.allocator.alloc(buckets_size)
192
+ runtime.core.memcpy_h2d(self.device.context, buckets_ptr, ctypes.addressof(buckets), buckets_size)
193
+ else:
194
+ buckets_ptr = ctypes.addressof(buckets)
195
+
196
+ self.buckets = buckets
197
+ self.size = size
198
+ self.shape = (size,)
199
+
200
+ self.ctype = fabricarray_t(buckets_ptr, num_buckets, size)
201
+
202
+ else:
203
+ # empty array or type annotation
204
+ self.dtype = dtype
205
+ self.ndim = ndim or 1
206
+ self.device = None
207
+ self.access = None
208
+ self.buckets = None
209
+ self.size = 0
210
+ self.shape = (0,)
211
+ self.ctype = fabricarray_t()
212
+
213
+ def __del__(self):
214
+ # release the GPU copy of bucket info
215
+ if self.buckets is not None and self.device.is_cuda:
216
+ buckets_size = ctypes.sizeof(self.buckets)
217
+ with self.device.context_guard:
218
+ self.device.allocator.free(self.ctype.buckets, buckets_size)
219
+
220
+ def __ctype__(self):
221
+ return self.ctype
222
+
223
+ def __len__(self):
224
+ return self.size
225
+
226
+ def __str__(self):
227
+ if self.device is None:
228
+ # type annotation
229
+ return f"fabricarray{self.dtype}"
230
+ else:
231
+ return str(self.numpy())
232
+
233
+ def __getitem__(self, key):
234
+ if isinstance(key, array):
235
+ return indexedfabricarray(fa=self, indices=key)
236
+ else:
237
+ raise ValueError(f"Fabric arrays only support indexing using index arrays, got key of type {type(key)}")
238
+
239
+ @property
240
+ def vars(self):
241
+ # member attributes available during code-gen (e.g.: d = arr.shape[0])
242
+ # Note: we use a shared dict for all fabricarray instances
243
+ if fabricarray._vars is None:
244
+ fabricarray._vars = {"size": warp.codegen.Var("size", uint64)}
245
+ return fabricarray._vars
246
+
247
+ def fill_(self, value):
248
+ # TODO?
249
+ # filling Fabric arrays of arrays is not supported, because they are jagged arrays of arbitrary lengths
250
+ if self.ndim > 1:
251
+ raise RuntimeError("Filling Fabric arrays of arrays is not supported")
252
+
253
+ super().fill_(value)
254
+
255
+
256
+ # special case for fabric array of arrays
257
+ # equivalent to calling fabricarray(..., ndim=2)
258
+ def fabricarrayarray(**kwargs):
259
+ kwargs["ndim"] = 2
260
+ return fabricarray(**kwargs)
261
+
262
+
263
+ class indexedfabricarray(noncontiguous_array_base[T]):
264
+ # member attributes available during code-gen (e.g.: d = arr.shape[0])
265
+ # (initialized when needed)
266
+ _vars = None
267
+
268
+ def __init__(self, fa=None, indices=None, dtype=None, ndim=None):
269
+ super().__init__(ARRAY_TYPE_FABRIC_INDEXED)
270
+
271
+ if fa is not None:
272
+ check_index_array(indices, fa.device)
273
+ self.fa = fa
274
+ self.indices = indices
275
+ self.dtype = fa.dtype
276
+ self.ndim = fa.ndim
277
+ self.device = fa.device
278
+ self.size = indices.size
279
+ self.shape = (indices.size,)
280
+ self.ctype = indexedfabricarray_t(fa, indices)
281
+ else:
282
+ # allow empty indexedarrays in type annotations
283
+ self.fa = None
284
+ self.indices = None
285
+ self.dtype = dtype
286
+ self.ndim = ndim or 1
287
+ self.device = None
288
+ self.size = 0
289
+ self.shape = (0,)
290
+ self.ctype = indexedfabricarray_t()
291
+
292
+ def __ctype__(self):
293
+ return self.ctype
294
+
295
+ def __len__(self):
296
+ return self.size
297
+
298
+ def __str__(self):
299
+ if self.device is None:
300
+ # type annotation
301
+ return f"indexedfabricarray{self.dtype}"
302
+ else:
303
+ return str(self.numpy())
304
+
305
+ @property
306
+ def vars(self):
307
+ # member attributes available during code-gen (e.g.: d = arr.shape[0])
308
+ # Note: we use a shared dict for all indexedfabricarray instances
309
+ if indexedfabricarray._vars is None:
310
+ indexedfabricarray._vars = {"size": warp.codegen.Var("size", uint64)}
311
+ return indexedfabricarray._vars
312
+
313
+ def fill_(self, value):
314
+ # TODO?
315
+ # filling Fabric arrays of arrays is not supported, because they are jagged arrays of arbitrary lengths
316
+ if self.ndim > 1:
317
+ raise RuntimeError("Filling indexed Fabric arrays of arrays is not supported")
318
+
319
+ super().fill_(value)
320
+
321
+
322
+ # special case for indexed fabric array of arrays
323
+ # equivalent to calling fabricarray(..., ndim=2)
324
+ def indexedfabricarrayarray(**kwargs):
325
+ kwargs["ndim"] = 2
326
+ return indexedfabricarray(**kwargs)
warp/fem/__init__.py ADDED
@@ -0,0 +1,27 @@
1
+ from .geometry import Geometry, Grid2D, Trimesh2D, Quadmesh2D, Grid3D, Tetmesh, Hexmesh
2
+ from .geometry import GeometryPartition, LinearGeometryPartition, ExplicitGeometryPartition
3
+
4
+ from .space import FunctionSpace, make_polynomial_space, ElementBasis
5
+ from .space import BasisSpace, PointBasisSpace, make_polynomial_basis_space, make_collocated_function_space
6
+ from .space import DofMapper, SkewSymmetricTensorMapper, SymmetricTensorMapper
7
+ from .space import SpaceTopology, SpacePartition, SpaceRestriction, make_space_partition, make_space_restriction
8
+
9
+ from .domain import GeometryDomain, Cells, Sides, BoundarySides, FrontierSides
10
+ from .quadrature import Quadrature, RegularQuadrature, NodalQuadrature, ExplicitQuadrature, PicQuadrature
11
+ from .polynomial import Polynomial
12
+
13
+ from .field import FieldLike, DiscreteField, make_test, make_trial, make_restriction
14
+
15
+ from .integrate import integrate, interpolate
16
+
17
+ from .operator import integrand
18
+ from .operator import position, normal, lookup, measure, measure_ratio, deformation_gradient
19
+ from .operator import inner, grad, div, outer, grad_outer, div_outer
20
+ from .operator import degree, at_node
21
+ from .operator import D, curl, jump, average, grad_jump, grad_average
22
+
23
+ from .types import Sample, Field, Domain, Coords, ElementIndex
24
+
25
+ from .dirichlet import project_linear_system, normalize_dirichlet_projector
26
+
27
+ from .cache import TemporaryStore, set_default_temporary_store, borrow_temporary, borrow_temporary_like