warp-lang 0.10.1__py3-none-win_amd64.whl → 0.11.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (300) hide show
  1. warp/__init__.py +10 -4
  2. warp/__init__.pyi +1 -0
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +5 -3
  6. warp/build_dll.py +29 -9
  7. warp/builtins.py +868 -507
  8. warp/codegen.py +1074 -638
  9. warp/config.py +3 -3
  10. warp/constants.py +6 -0
  11. warp/context.py +715 -222
  12. warp/fabric.py +326 -0
  13. warp/fem/__init__.py +27 -0
  14. warp/fem/cache.py +389 -0
  15. warp/fem/dirichlet.py +181 -0
  16. warp/fem/domain.py +263 -0
  17. warp/fem/field/__init__.py +101 -0
  18. warp/fem/field/field.py +149 -0
  19. warp/fem/field/nodal_field.py +299 -0
  20. warp/fem/field/restriction.py +21 -0
  21. warp/fem/field/test.py +181 -0
  22. warp/fem/field/trial.py +183 -0
  23. warp/fem/geometry/__init__.py +19 -0
  24. warp/fem/geometry/closest_point.py +70 -0
  25. warp/fem/geometry/deformed_geometry.py +271 -0
  26. warp/fem/geometry/element.py +744 -0
  27. warp/fem/geometry/geometry.py +186 -0
  28. warp/fem/geometry/grid_2d.py +373 -0
  29. warp/fem/geometry/grid_3d.py +435 -0
  30. warp/fem/geometry/hexmesh.py +953 -0
  31. warp/fem/geometry/partition.py +376 -0
  32. warp/fem/geometry/quadmesh_2d.py +532 -0
  33. warp/fem/geometry/tetmesh.py +840 -0
  34. warp/fem/geometry/trimesh_2d.py +577 -0
  35. warp/fem/integrate.py +1616 -0
  36. warp/fem/operator.py +191 -0
  37. warp/fem/polynomial.py +213 -0
  38. warp/fem/quadrature/__init__.py +2 -0
  39. warp/fem/quadrature/pic_quadrature.py +245 -0
  40. warp/fem/quadrature/quadrature.py +294 -0
  41. warp/fem/space/__init__.py +292 -0
  42. warp/fem/space/basis_space.py +489 -0
  43. warp/fem/space/collocated_function_space.py +105 -0
  44. warp/fem/space/dof_mapper.py +236 -0
  45. warp/fem/space/function_space.py +145 -0
  46. warp/fem/space/grid_2d_function_space.py +267 -0
  47. warp/fem/space/grid_3d_function_space.py +306 -0
  48. warp/fem/space/hexmesh_function_space.py +352 -0
  49. warp/fem/space/partition.py +350 -0
  50. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  51. warp/fem/space/restriction.py +160 -0
  52. warp/fem/space/shape/__init__.py +15 -0
  53. warp/fem/space/shape/cube_shape_function.py +738 -0
  54. warp/fem/space/shape/shape_function.py +103 -0
  55. warp/fem/space/shape/square_shape_function.py +611 -0
  56. warp/fem/space/shape/tet_shape_function.py +567 -0
  57. warp/fem/space/shape/triangle_shape_function.py +429 -0
  58. warp/fem/space/tetmesh_function_space.py +292 -0
  59. warp/fem/space/topology.py +295 -0
  60. warp/fem/space/trimesh_2d_function_space.py +221 -0
  61. warp/fem/types.py +77 -0
  62. warp/fem/utils.py +495 -0
  63. warp/native/array.h +147 -44
  64. warp/native/builtin.h +122 -149
  65. warp/native/bvh.cpp +73 -325
  66. warp/native/bvh.cu +406 -23
  67. warp/native/bvh.h +34 -43
  68. warp/native/clang/clang.cpp +13 -8
  69. warp/native/crt.h +2 -0
  70. warp/native/cuda_crt.h +5 -0
  71. warp/native/cuda_util.cpp +15 -3
  72. warp/native/cuda_util.h +3 -1
  73. warp/native/cutlass/tools/library/scripts/conv2d_operation.py +463 -0
  74. warp/native/cutlass/tools/library/scripts/conv3d_operation.py +321 -0
  75. warp/native/cutlass/tools/library/scripts/gemm_operation.py +988 -0
  76. warp/native/cutlass/tools/library/scripts/generator.py +4625 -0
  77. warp/native/cutlass/tools/library/scripts/library.py +799 -0
  78. warp/native/cutlass/tools/library/scripts/manifest.py +402 -0
  79. warp/native/cutlass/tools/library/scripts/pycutlass/docs/source/conf.py +96 -0
  80. warp/native/cutlass/tools/library/scripts/pycutlass/profile/conv/conv2d_f16_sm80.py +106 -0
  81. warp/native/cutlass/tools/library/scripts/pycutlass/profile/gemm/gemm_f32_sm80.py +91 -0
  82. warp/native/cutlass/tools/library/scripts/pycutlass/setup.py +80 -0
  83. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/__init__.py +48 -0
  84. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/arguments.py +118 -0
  85. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/c_types.py +241 -0
  86. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/compiler.py +432 -0
  87. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/conv2d_operation.py +631 -0
  88. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/epilogue.py +1026 -0
  89. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/frontend.py +104 -0
  90. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/gemm_operation.py +1276 -0
  91. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/library.py +744 -0
  92. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/memory_manager.py +74 -0
  93. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/operation.py +110 -0
  94. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/parser.py +619 -0
  95. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/reduction_operation.py +398 -0
  96. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/tensor_ref.py +70 -0
  97. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/__init__.py +4 -0
  98. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/conv2d_testbed.py +646 -0
  99. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_grouped_testbed.py +235 -0
  100. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/gemm_testbed.py +557 -0
  101. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/test/profiler.py +70 -0
  102. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/type_hint.py +39 -0
  103. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/__init__.py +1 -0
  104. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/device.py +76 -0
  105. warp/native/cutlass/tools/library/scripts/pycutlass/src/pycutlass/utils/reference_model.py +255 -0
  106. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/__init__.py +0 -0
  107. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +201 -0
  108. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +177 -0
  109. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +98 -0
  110. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_dgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +95 -0
  111. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +163 -0
  112. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.py +187 -0
  113. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +309 -0
  114. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +54 -0
  115. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  116. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_fprop_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  117. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +253 -0
  118. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.py +97 -0
  119. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.py +242 -0
  120. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_f32nhwc_f32nhwc_f32nhwc_simt_f32_sm80.py +96 -0
  121. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/conv2d_wgrad_implicit_gemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_sm80.py +107 -0
  122. warp/native/cutlass/tools/library/scripts/pycutlass/test/conv/run_all_tests.py +10 -0
  123. warp/native/cutlass/tools/library/scripts/pycutlass/test/frontend/test_frontend.py +146 -0
  124. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/__init__.py +0 -0
  125. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_bf16_sm80.py +96 -0
  126. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f16_sm80.py +447 -0
  127. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f32_sm80.py +146 -0
  128. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_f64_sm80.py +102 -0
  129. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_grouped_sm80.py +203 -0
  130. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/gemm_s8_sm80.py +229 -0
  131. warp/native/cutlass/tools/library/scripts/pycutlass/test/gemm/run_all_tests.py +9 -0
  132. warp/native/cutlass/tools/library/scripts/pycutlass/test/unit/test_sm80.py +453 -0
  133. warp/native/cutlass/tools/library/scripts/rank_2k_operation.py +398 -0
  134. warp/native/cutlass/tools/library/scripts/rank_k_operation.py +387 -0
  135. warp/native/cutlass/tools/library/scripts/rt.py +796 -0
  136. warp/native/cutlass/tools/library/scripts/symm_operation.py +400 -0
  137. warp/native/cutlass/tools/library/scripts/trmm_operation.py +407 -0
  138. warp/native/cutlass_gemm.cu +5 -3
  139. warp/native/exports.h +1240 -952
  140. warp/native/fabric.h +228 -0
  141. warp/native/hashgrid.cpp +4 -4
  142. warp/native/hashgrid.h +22 -2
  143. warp/native/intersect.h +22 -7
  144. warp/native/intersect_adj.h +8 -8
  145. warp/native/intersect_tri.h +1 -1
  146. warp/native/marching.cu +157 -161
  147. warp/native/mat.h +80 -19
  148. warp/native/matnn.h +2 -2
  149. warp/native/mesh.cpp +33 -108
  150. warp/native/mesh.cu +114 -23
  151. warp/native/mesh.h +446 -46
  152. warp/native/noise.h +272 -329
  153. warp/native/quat.h +51 -8
  154. warp/native/rand.h +45 -35
  155. warp/native/range.h +6 -2
  156. warp/native/reduce.cpp +1 -1
  157. warp/native/reduce.cu +10 -12
  158. warp/native/runlength_encode.cu +6 -10
  159. warp/native/scan.cu +8 -11
  160. warp/native/sparse.cpp +4 -4
  161. warp/native/sparse.cu +164 -154
  162. warp/native/spatial.h +2 -2
  163. warp/native/temp_buffer.h +14 -30
  164. warp/native/vec.h +107 -23
  165. warp/native/volume.h +120 -0
  166. warp/native/warp.cpp +560 -30
  167. warp/native/warp.cu +431 -44
  168. warp/native/warp.h +13 -4
  169. warp/optim/__init__.py +1 -0
  170. warp/optim/linear.py +922 -0
  171. warp/optim/sgd.py +92 -0
  172. warp/render/render_opengl.py +335 -119
  173. warp/render/render_usd.py +11 -11
  174. warp/sim/__init__.py +2 -2
  175. warp/sim/articulation.py +385 -185
  176. warp/sim/collide.py +8 -0
  177. warp/sim/import_mjcf.py +297 -106
  178. warp/sim/import_urdf.py +389 -210
  179. warp/sim/import_usd.py +198 -97
  180. warp/sim/inertia.py +17 -18
  181. warp/sim/integrator_euler.py +14 -8
  182. warp/sim/integrator_xpbd.py +158 -16
  183. warp/sim/model.py +795 -291
  184. warp/sim/render.py +3 -3
  185. warp/sim/utils.py +3 -0
  186. warp/sparse.py +640 -150
  187. warp/stubs.py +606 -267
  188. warp/tape.py +61 -10
  189. warp/tests/__main__.py +3 -6
  190. warp/tests/assets/curlnoise_golden.npy +0 -0
  191. warp/tests/assets/pnoise_golden.npy +0 -0
  192. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  193. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  194. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  195. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  196. warp/tests/aux_test_unresolved_func.py +14 -0
  197. warp/tests/aux_test_unresolved_symbol.py +14 -0
  198. warp/tests/disabled_kinematics.py +239 -0
  199. warp/tests/run_coverage_serial.py +31 -0
  200. warp/tests/test_adam.py +103 -106
  201. warp/tests/test_arithmetic.py +128 -74
  202. warp/tests/test_array.py +212 -97
  203. warp/tests/test_array_reduce.py +57 -23
  204. warp/tests/test_atomic.py +64 -28
  205. warp/tests/test_bool.py +99 -0
  206. warp/tests/test_builtins_resolution.py +1292 -0
  207. warp/tests/test_bvh.py +42 -18
  208. warp/tests/test_closest_point_edge_edge.py +54 -57
  209. warp/tests/test_codegen.py +208 -130
  210. warp/tests/test_compile_consts.py +28 -20
  211. warp/tests/test_conditional.py +108 -24
  212. warp/tests/test_copy.py +10 -12
  213. warp/tests/test_ctypes.py +112 -88
  214. warp/tests/test_dense.py +21 -14
  215. warp/tests/test_devices.py +98 -0
  216. warp/tests/test_dlpack.py +75 -75
  217. warp/tests/test_examples.py +277 -0
  218. warp/tests/test_fabricarray.py +955 -0
  219. warp/tests/test_fast_math.py +15 -11
  220. warp/tests/test_fem.py +1271 -0
  221. warp/tests/test_fp16.py +53 -19
  222. warp/tests/test_func.py +187 -86
  223. warp/tests/test_generics.py +194 -49
  224. warp/tests/test_grad.py +178 -109
  225. warp/tests/test_grad_customs.py +176 -0
  226. warp/tests/test_hash_grid.py +52 -37
  227. warp/tests/test_import.py +10 -23
  228. warp/tests/test_indexedarray.py +32 -31
  229. warp/tests/test_intersect.py +18 -9
  230. warp/tests/test_large.py +141 -0
  231. warp/tests/test_launch.py +14 -41
  232. warp/tests/test_lerp.py +64 -65
  233. warp/tests/test_linear_solvers.py +154 -0
  234. warp/tests/test_lvalue.py +493 -0
  235. warp/tests/test_marching_cubes.py +12 -13
  236. warp/tests/test_mat.py +517 -2898
  237. warp/tests/test_mat_lite.py +115 -0
  238. warp/tests/test_mat_scalar_ops.py +2889 -0
  239. warp/tests/test_math.py +103 -9
  240. warp/tests/test_matmul.py +305 -69
  241. warp/tests/test_matmul_lite.py +410 -0
  242. warp/tests/test_mesh.py +71 -14
  243. warp/tests/test_mesh_query_aabb.py +41 -25
  244. warp/tests/test_mesh_query_point.py +140 -22
  245. warp/tests/test_mesh_query_ray.py +39 -22
  246. warp/tests/test_mlp.py +30 -22
  247. warp/tests/test_model.py +92 -89
  248. warp/tests/test_modules_lite.py +39 -0
  249. warp/tests/test_multigpu.py +88 -114
  250. warp/tests/test_noise.py +12 -11
  251. warp/tests/test_operators.py +16 -20
  252. warp/tests/test_options.py +11 -11
  253. warp/tests/test_pinned.py +17 -18
  254. warp/tests/test_print.py +32 -11
  255. warp/tests/test_quat.py +275 -129
  256. warp/tests/test_rand.py +18 -16
  257. warp/tests/test_reload.py +38 -34
  258. warp/tests/test_rounding.py +50 -43
  259. warp/tests/test_runlength_encode.py +168 -20
  260. warp/tests/test_smoothstep.py +9 -11
  261. warp/tests/test_snippet.py +143 -0
  262. warp/tests/test_sparse.py +261 -63
  263. warp/tests/test_spatial.py +276 -243
  264. warp/tests/test_streams.py +110 -85
  265. warp/tests/test_struct.py +268 -63
  266. warp/tests/test_tape.py +39 -21
  267. warp/tests/test_torch.py +118 -89
  268. warp/tests/test_transient_module.py +12 -13
  269. warp/tests/test_types.py +614 -0
  270. warp/tests/test_utils.py +494 -0
  271. warp/tests/test_vec.py +354 -2050
  272. warp/tests/test_vec_lite.py +73 -0
  273. warp/tests/test_vec_scalar_ops.py +2099 -0
  274. warp/tests/test_volume.py +457 -293
  275. warp/tests/test_volume_write.py +124 -134
  276. warp/tests/unittest_serial.py +35 -0
  277. warp/tests/unittest_suites.py +341 -0
  278. warp/tests/unittest_utils.py +568 -0
  279. warp/tests/unused_test_misc.py +71 -0
  280. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  281. warp/thirdparty/appdirs.py +36 -45
  282. warp/thirdparty/unittest_parallel.py +549 -0
  283. warp/torch.py +9 -6
  284. warp/types.py +1089 -366
  285. warp/utils.py +93 -387
  286. warp_lang-0.11.0.dist-info/METADATA +238 -0
  287. warp_lang-0.11.0.dist-info/RECORD +332 -0
  288. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/WHEEL +1 -1
  289. warp/tests/test_all.py +0 -219
  290. warp/tests/test_array_scan.py +0 -60
  291. warp/tests/test_base.py +0 -208
  292. warp/tests/test_unresolved_func.py +0 -7
  293. warp/tests/test_unresolved_symbol.py +0 -7
  294. warp_lang-0.10.1.dist-info/METADATA +0 -21
  295. warp_lang-0.10.1.dist-info/RECORD +0 -188
  296. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  297. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  298. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  299. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/LICENSE.md +0 -0
  300. {warp_lang-0.10.1.dist-info → warp_lang-0.11.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,435 @@
1
+ from typing import Any
2
+ import warp as wp
3
+
4
+ from warp.fem.types import ElementIndex, Coords, Sample, make_free_sample, OUTSIDE
5
+ from warp.fem.cache import cached_arg_value
6
+
7
+ from .geometry import Geometry
8
+ from .element import Square, Cube
9
+
10
+
11
+ @wp.struct
12
+ class Grid3DCellArg:
13
+ res: wp.vec3i
14
+ cell_size: wp.vec3
15
+ origin: wp.vec3
16
+
17
+
18
+ _mat32 = wp.mat(shape=(3, 2), dtype=float)
19
+
20
+
21
+ class Grid3D(Geometry):
22
+ """Three-dimensional regular grid geometry"""
23
+
24
+ dimension = 3
25
+
26
+ Permutation = wp.types.matrix(shape=(3, 3), dtype=int)
27
+ LOC_TO_WORLD = wp.constant(Permutation(0, 1, 2, 1, 2, 0, 2, 0, 1))
28
+ WORLD_TO_LOC = wp.constant(Permutation(0, 1, 2, 2, 0, 1, 1, 2, 0))
29
+
30
+ def __init__(self, res: wp.vec3i, bounds_lo: wp.vec3 = wp.vec3(0.0), bounds_hi: wp.vec3 = wp.vec3(1.0)):
31
+ """Constructs a dense 3D grid
32
+
33
+ Args:
34
+ res: Resolution of the grid along each dimension
35
+ bounds_lo: Position of the lower bound of the axis-aligned grid
36
+ bounds_up: Position of the upper bound of the axis-aligned grid
37
+ """
38
+
39
+ self.bounds_lo = bounds_lo
40
+ self.bounds_hi = bounds_hi
41
+
42
+ self._res = res
43
+
44
+ @property
45
+ def extents(self) -> wp.vec3:
46
+ # Avoid using native sub due to higher over of calling builtins from Python
47
+ return wp.vec3(
48
+ self.bounds_hi[0] - self.bounds_lo[0],
49
+ self.bounds_hi[1] - self.bounds_lo[1],
50
+ self.bounds_hi[2] - self.bounds_lo[2],
51
+ )
52
+
53
+ @property
54
+ def cell_size(self) -> wp.vec3:
55
+ ex = self.extents
56
+ return wp.vec3(
57
+ ex[0] / self.res[0],
58
+ ex[1] / self.res[1],
59
+ ex[2] / self.res[2],
60
+ )
61
+
62
+ def cell_count(self):
63
+ return self.res[0] * self.res[1] * self.res[2]
64
+
65
+ def vertex_count(self):
66
+ return (self.res[0] + 1) * (self.res[1] + 1) * (self.res[2] + 1)
67
+
68
+ def side_count(self):
69
+ return (
70
+ (self.res[0] + 1) * (self.res[1]) * (self.res[2])
71
+ + (self.res[0]) * (self.res[1] + 1) * (self.res[2])
72
+ + (self.res[0]) * (self.res[1]) * (self.res[2] + 1)
73
+ )
74
+
75
+ def edge_count(self):
76
+ return (
77
+ (self.res[0] + 1) * (self.res[1] + 1) * (self.res[2])
78
+ + (self.res[0]) * (self.res[1] + 1) * (self.res[2] + 1)
79
+ + (self.res[0] + 1) * (self.res[1]) * (self.res[2] + 1)
80
+ )
81
+
82
+ def boundary_side_count(self):
83
+ return 2 * (self.res[1]) * (self.res[2]) + (self.res[0]) * 2 * (self.res[2]) + (self.res[0]) * (self.res[1]) * 2
84
+
85
+ def reference_cell(self) -> Cube:
86
+ return Cube()
87
+
88
+ def reference_side(self) -> Square:
89
+ return Square()
90
+
91
+ @property
92
+ def res(self):
93
+ return self._res
94
+
95
+ @property
96
+ def origin(self):
97
+ return self.bounds_lo
98
+
99
+ @property
100
+ def strides(self):
101
+ return wp.vec3i(self.res[1] * self.res[2], self.res[2], 1)
102
+
103
+ # Utility device functions
104
+
105
+ CellArg = Grid3DCellArg
106
+ Cell = wp.vec3i
107
+
108
+ @wp.func
109
+ def _to_3d_index(strides: wp.vec2i, index: int):
110
+ x = index // strides[0]
111
+ y = (index - strides[0] * x) // strides[1]
112
+ z = index - strides[0] * x - strides[1] * y
113
+ return wp.vec3i(x, y, z)
114
+
115
+ @wp.func
116
+ def _from_3d_index(strides: wp.vec2i, index: wp.vec3i):
117
+ return strides[0] * index[0] + strides[1] * index[1] + index[2]
118
+
119
+ @wp.func
120
+ def cell_index(res: wp.vec3i, cell: Cell):
121
+ strides = wp.vec2i(res[1] * res[2], res[2])
122
+ return Grid3D._from_3d_index(strides, cell)
123
+
124
+ @wp.func
125
+ def get_cell(res: wp.vec3i, cell_index: ElementIndex):
126
+ strides = wp.vec2i(res[1] * res[2], res[2])
127
+ return Grid3D._to_3d_index(strides, cell_index)
128
+
129
+ @wp.struct
130
+ class Side:
131
+ axis: int # normal
132
+ origin: wp.vec3i # index of vertex at corner (0,0,0)
133
+
134
+ @wp.struct
135
+ class SideArg:
136
+ cell_count: int
137
+ axis_offsets: wp.vec3i
138
+ cell_arg: Grid3DCellArg
139
+
140
+ SideIndexArg = SideArg
141
+
142
+ @wp.func
143
+ def _world_to_local(axis: int, vec: Any):
144
+ return type(vec)(
145
+ vec[Grid3D.LOC_TO_WORLD[axis, 0]],
146
+ vec[Grid3D.LOC_TO_WORLD[axis, 1]],
147
+ vec[Grid3D.LOC_TO_WORLD[axis, 2]],
148
+ )
149
+
150
+ @wp.func
151
+ def _local_to_world(axis: int, vec: Any):
152
+ return type(vec)(
153
+ vec[Grid3D.WORLD_TO_LOC[axis, 0]],
154
+ vec[Grid3D.WORLD_TO_LOC[axis, 1]],
155
+ vec[Grid3D.WORLD_TO_LOC[axis, 2]],
156
+ )
157
+
158
+ @wp.func
159
+ def side_index(arg: SideArg, side: Side):
160
+ alt_axis = Grid3D.LOC_TO_WORLD[side.axis, 0]
161
+ if side.origin[0] == arg.cell_arg.res[alt_axis]:
162
+ # Upper-boundary side
163
+ longitude = side.origin[1]
164
+ latitude = side.origin[2]
165
+
166
+ latitude_res = arg.cell_arg.res[Grid3D.LOC_TO_WORLD[side.axis, 2]]
167
+ lat_long = latitude_res * longitude + latitude
168
+
169
+ return 3 * arg.cell_count + arg.axis_offsets[side.axis] + lat_long
170
+
171
+ cell_index = Grid3D.cell_index(arg.cell_arg.res, Grid3D._local_to_world(side.axis, side.origin))
172
+ return side.axis * arg.cell_count + cell_index
173
+
174
+ @wp.func
175
+ def get_side(arg: SideArg, side_index: ElementIndex):
176
+ if side_index < 3 * arg.cell_count:
177
+ axis = side_index // arg.cell_count
178
+ cell_index = side_index - axis * arg.cell_count
179
+ origin = Grid3D._world_to_local(axis, Grid3D.get_cell(arg.cell_arg.res, cell_index))
180
+ return Grid3D.Side(axis, origin)
181
+
182
+ axis_side_index = side_index - 3 * arg.cell_count
183
+ if axis_side_index < arg.axis_offsets[1]:
184
+ axis = 0
185
+ elif axis_side_index < arg.axis_offsets[2]:
186
+ axis = 1
187
+ else:
188
+ axis = 2
189
+
190
+ altitude = arg.cell_arg.res[Grid3D.LOC_TO_WORLD[axis, 0]]
191
+
192
+ lat_long = axis_side_index - arg.axis_offsets[axis]
193
+ latitude_res = arg.cell_arg.res[Grid3D.LOC_TO_WORLD[axis, 2]]
194
+
195
+ longitude = lat_long // latitude_res
196
+ latitude = lat_long - longitude * latitude_res
197
+
198
+ origin_loc = wp.vec3i(altitude, longitude, latitude)
199
+
200
+ return Grid3D.Side(axis, origin_loc)
201
+
202
+ # Geometry device interface
203
+
204
+ @cached_arg_value
205
+ def cell_arg_value(self, device) -> CellArg:
206
+ args = self.CellArg()
207
+ args.res = self.res
208
+ args.origin = self.bounds_lo
209
+ args.cell_size = self.cell_size
210
+ return args
211
+
212
+ @wp.func
213
+ def cell_position(args: CellArg, s: Sample):
214
+ cell = Grid3D.get_cell(args.res, s.element_index)
215
+ return (
216
+ wp.vec3(
217
+ (float(cell[0]) + s.element_coords[0]) * args.cell_size[0],
218
+ (float(cell[1]) + s.element_coords[1]) * args.cell_size[1],
219
+ (float(cell[2]) + s.element_coords[2]) * args.cell_size[2],
220
+ )
221
+ + args.origin
222
+ )
223
+
224
+ @wp.func
225
+ def cell_deformation_gradient(args: CellArg, s: Sample):
226
+ return wp.diag(args.cell_size)
227
+
228
+ @wp.func
229
+ def cell_inverse_deformation_gradient(args: CellArg, s: Sample):
230
+ return wp.diag(wp.cw_div(wp.vec3(1.0), args.cell_size))
231
+
232
+ @wp.func
233
+ def cell_lookup(args: CellArg, pos: wp.vec3):
234
+ loc_pos = wp.cw_div(pos - args.origin, args.cell_size)
235
+ x = wp.clamp(loc_pos[0], 0.0, float(args.res[0]))
236
+ y = wp.clamp(loc_pos[1], 0.0, float(args.res[1]))
237
+ z = wp.clamp(loc_pos[2], 0.0, float(args.res[2]))
238
+
239
+ x_cell = wp.min(wp.floor(x), float(args.res[0]) - 1.0)
240
+ y_cell = wp.min(wp.floor(y), float(args.res[1]) - 1.0)
241
+ z_cell = wp.min(wp.floor(z), float(args.res[2]) - 1.0)
242
+
243
+ coords = Coords(x - x_cell, y - y_cell, z - z_cell)
244
+ cell_index = Grid3D.cell_index(args.res, Grid3D.Cell(int(x_cell), int(y_cell), int(z_cell)))
245
+
246
+ return make_free_sample(cell_index, coords)
247
+
248
+ @wp.func
249
+ def cell_lookup(args: CellArg, pos: wp.vec3, guess: Sample):
250
+ return Grid3D.cell_lookup(args, pos)
251
+
252
+ @wp.func
253
+ def cell_measure(args: CellArg, s: Sample):
254
+ return args.cell_size[0] * args.cell_size[1] * args.cell_size[2]
255
+
256
+ @wp.func
257
+ def cell_normal(args: CellArg, s: Sample):
258
+ return wp.vec3(0.0)
259
+
260
+ @wp.func
261
+ def cell_transform_reference_gradient(args: CellArg, cell_index: ElementIndex, coords: Coords, ref_grad: wp.vec3):
262
+ return wp.cw_div(ref_grad, args.cell_size)
263
+
264
+ @cached_arg_value
265
+ def side_arg_value(self, device) -> SideArg:
266
+ args = self.SideArg()
267
+
268
+ axis_dims = wp.vec3i(
269
+ self.res[1] * self.res[2],
270
+ self.res[2] * self.res[0],
271
+ self.res[0] * self.res[1],
272
+ )
273
+ args.axis_offsets = wp.vec3i(
274
+ 0,
275
+ axis_dims[0],
276
+ axis_dims[0] + axis_dims[1],
277
+ )
278
+ args.cell_count = self.cell_count()
279
+ args.cell_arg = self.cell_arg_value(device)
280
+ return args
281
+
282
+ def side_index_arg_value(self, device) -> SideIndexArg:
283
+ return self.side_arg_value(device)
284
+
285
+ @wp.func
286
+ def boundary_side_index(args: SideArg, boundary_side_index: int):
287
+ """Boundary side to side index"""
288
+
289
+ axis_side_index = boundary_side_index // 2
290
+ border = boundary_side_index - 2 * axis_side_index
291
+
292
+ if axis_side_index < args.axis_offsets[1]:
293
+ axis = 0
294
+ elif axis_side_index < args.axis_offsets[2]:
295
+ axis = 1
296
+ else:
297
+ axis = 2
298
+
299
+ lat_long = axis_side_index - args.axis_offsets[axis]
300
+ latitude_res = args.cell_arg.res[Grid3D.LOC_TO_WORLD[axis, 2]]
301
+
302
+ longitude = lat_long // latitude_res
303
+ latitude = lat_long - longitude * latitude_res
304
+
305
+ altitude = border * args.cell_arg.res[axis]
306
+
307
+ side = Grid3D.Side(axis, wp.vec3i(altitude, longitude, latitude))
308
+ return Grid3D.side_index(args, side)
309
+
310
+ @wp.func
311
+ def side_position(args: SideArg, s: Sample):
312
+ side = Grid3D.get_side(args, s.element_index)
313
+
314
+ coord0 = wp.select(side.origin[0] == 0, s.element_coords[0], 1.0 - s.element_coords[0])
315
+
316
+ local_pos = wp.vec3(
317
+ float(side.origin[0]),
318
+ float(side.origin[1]) + coord0,
319
+ float(side.origin[2]) + s.element_coords[1],
320
+ )
321
+
322
+ pos = args.cell_arg.origin + wp.cw_mul(Grid3D._local_to_world(side.axis, local_pos), args.cell_arg.cell_size)
323
+
324
+ return pos
325
+
326
+ @wp.func
327
+ def side_deformation_gradient(args: SideArg, s: Sample):
328
+ side = Grid3D.get_side(args, s.element_index)
329
+
330
+ sign = wp.select(side.origin[0] == 0, 1.0, -1.0)
331
+
332
+ return _mat32(
333
+ wp.cw_mul(Grid3D._local_to_world(side.axis, wp.vec3(0.0, sign, 0.0)), args.cell_arg.cell_size),
334
+ wp.cw_mul(Grid3D._local_to_world(side.axis, wp.vec3(0.0, 0.0, 1.0)), args.cell_arg.cell_size),
335
+ )
336
+
337
+ @wp.func
338
+ def side_inner_inverse_deformation_gradient(args: SideArg, s: Sample):
339
+ return Grid3D.cell_inverse_deformation_gradient(args.cell_arg, s)
340
+
341
+ @wp.func
342
+ def side_outer_inverse_deformation_gradient(args: SideArg, s: Sample):
343
+ return Grid3D.cell_inverse_deformation_gradient(args.cell_arg, s)
344
+
345
+ @wp.func
346
+ def side_measure(args: SideArg, s: Sample):
347
+ side = Grid3D.get_side(args, s.element_index)
348
+ long_axis = Grid3D.LOC_TO_WORLD[side.axis, 1]
349
+ lat_axis = Grid3D.LOC_TO_WORLD[side.axis, 2]
350
+ return args.cell_arg.cell_size[long_axis] * args.cell_arg.cell_size[lat_axis]
351
+
352
+ @wp.func
353
+ def side_measure_ratio(args: SideArg, s: Sample):
354
+ side = Grid3D.get_side(args, s.element_index)
355
+ alt_axis = Grid3D.LOC_TO_WORLD[side.axis, 0]
356
+ return 1.0 / args.cell_arg.cell_size[alt_axis]
357
+
358
+ @wp.func
359
+ def side_normal(args: SideArg, s: Sample):
360
+ side = Grid3D.get_side(args, s.element_index)
361
+
362
+ sign = wp.select(side.origin[0] == 0, 1.0, -1.0)
363
+
364
+ local_n = wp.vec3(sign, 0.0, 0.0)
365
+ return Grid3D._local_to_world(side.axis, local_n)
366
+
367
+ @wp.func
368
+ def side_inner_cell_index(arg: SideArg, side_index: ElementIndex):
369
+ side = Grid3D.get_side(arg, side_index)
370
+
371
+ inner_alt = wp.select(side.origin[0] == 0, side.origin[0] - 1, 0)
372
+
373
+ inner_origin = wp.vec3i(inner_alt, side.origin[1], side.origin[2])
374
+
375
+ cell = Grid3D._local_to_world(side.axis, inner_origin)
376
+ return Grid3D.cell_index(arg.cell_arg.res, cell)
377
+
378
+ @wp.func
379
+ def side_outer_cell_index(arg: SideArg, side_index: ElementIndex):
380
+ side = Grid3D.get_side(arg, side_index)
381
+
382
+ alt_axis = Grid3D.LOC_TO_WORLD[side.axis, 0]
383
+
384
+ outer_alt = wp.select(
385
+ side.origin[0] == arg.cell_arg.res[alt_axis], side.origin[0], arg.cell_arg.res[alt_axis] - 1
386
+ )
387
+
388
+ outer_origin = wp.vec3i(outer_alt, side.origin[1], side.origin[2])
389
+
390
+ cell = Grid3D._local_to_world(side.axis, outer_origin)
391
+ return Grid3D.cell_index(arg.cell_arg.res, cell)
392
+
393
+ @wp.func
394
+ def side_inner_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
395
+ side = Grid3D.get_side(args, side_index)
396
+
397
+ inner_alt = wp.select(side.origin[0] == 0, 1.0, 0.0)
398
+
399
+ side_coord0 = wp.select(side.origin[0] == 0, side_coords[0], 1.0 - side_coords[0])
400
+
401
+ return Grid3D._local_to_world(side.axis, wp.vec3(inner_alt, side_coord0, side_coords[1]))
402
+
403
+ @wp.func
404
+ def side_outer_cell_coords(args: SideArg, side_index: ElementIndex, side_coords: Coords):
405
+ side = Grid3D.get_side(args, side_index)
406
+
407
+ alt_axis = Grid3D.LOC_TO_WORLD[side.axis, 0]
408
+ outer_alt = wp.select(side.origin[0] == args.cell_arg.res[alt_axis], 0.0, 1.0)
409
+
410
+ side_coord0 = wp.select(side.origin[0] == 0, side_coords[0], 1.0 - side_coords[0])
411
+
412
+ return Grid3D._local_to_world(side.axis, wp.vec3(outer_alt, side_coord0, side_coords[1]))
413
+
414
+ @wp.func
415
+ def side_from_cell_coords(
416
+ args: SideArg,
417
+ side_index: ElementIndex,
418
+ element_index: ElementIndex,
419
+ element_coords: Coords,
420
+ ):
421
+ side = Grid3D.get_side(args, side_index)
422
+ cell = Grid3D.get_cell(args.cell_arg.res, element_index)
423
+
424
+ if float(side.origin[0] - cell[side.axis]) == element_coords[side.axis]:
425
+ long_axis = Grid3D.LOC_TO_WORLD[side.axis, 1]
426
+ lat_axis = Grid3D.LOC_TO_WORLD[side.axis, 2]
427
+ long_coord = element_coords[long_axis]
428
+ long_coord = wp.select(side.origin[0] == 0, long_coord, 1.0 - long_coord)
429
+ return Coords(long_coord, element_coords[lat_axis], 0.0)
430
+
431
+ return Coords(OUTSIDE)
432
+
433
+ @wp.func
434
+ def side_to_cell_arg(side_arg: SideArg):
435
+ return side_arg.cell_arg