warp-lang 1.0.0b2__py3-none-win_amd64.whl → 1.0.0b6__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (271) hide show
  1. docs/conf.py +17 -5
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/env/env_usd.py +4 -1
  6. examples/env/environment.py +8 -9
  7. examples/example_dem.py +34 -33
  8. examples/example_diffray.py +364 -337
  9. examples/example_fluid.py +32 -23
  10. examples/example_jacobian_ik.py +97 -93
  11. examples/example_marching_cubes.py +6 -16
  12. examples/example_mesh.py +6 -16
  13. examples/example_mesh_intersect.py +16 -14
  14. examples/example_nvdb.py +14 -16
  15. examples/example_raycast.py +14 -13
  16. examples/example_raymarch.py +16 -23
  17. examples/example_render_opengl.py +19 -10
  18. examples/example_sim_cartpole.py +82 -78
  19. examples/example_sim_cloth.py +45 -48
  20. examples/example_sim_fk_grad.py +51 -44
  21. examples/example_sim_fk_grad_torch.py +47 -40
  22. examples/example_sim_grad_bounce.py +108 -133
  23. examples/example_sim_grad_cloth.py +99 -113
  24. examples/example_sim_granular.py +5 -6
  25. examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
  26. examples/example_sim_neo_hookean.py +51 -55
  27. examples/example_sim_particle_chain.py +4 -4
  28. examples/example_sim_quadruped.py +126 -81
  29. examples/example_sim_rigid_chain.py +54 -61
  30. examples/example_sim_rigid_contact.py +66 -70
  31. examples/example_sim_rigid_fem.py +3 -3
  32. examples/example_sim_rigid_force.py +1 -1
  33. examples/example_sim_rigid_gyroscopic.py +3 -4
  34. examples/example_sim_rigid_kinematics.py +28 -39
  35. examples/example_sim_trajopt.py +112 -110
  36. examples/example_sph.py +9 -8
  37. examples/example_wave.py +7 -7
  38. examples/fem/bsr_utils.py +30 -17
  39. examples/fem/example_apic_fluid.py +85 -69
  40. examples/fem/example_convection_diffusion.py +97 -93
  41. examples/fem/example_convection_diffusion_dg.py +142 -149
  42. examples/fem/example_convection_diffusion_dg0.py +141 -136
  43. examples/fem/example_deformed_geometry.py +146 -0
  44. examples/fem/example_diffusion.py +115 -84
  45. examples/fem/example_diffusion_3d.py +116 -86
  46. examples/fem/example_diffusion_mgpu.py +102 -79
  47. examples/fem/example_mixed_elasticity.py +139 -100
  48. examples/fem/example_navier_stokes.py +175 -162
  49. examples/fem/example_stokes.py +143 -111
  50. examples/fem/example_stokes_transfer.py +186 -157
  51. examples/fem/mesh_utils.py +59 -97
  52. examples/fem/plot_utils.py +138 -17
  53. tools/ci/publishing/build_nodes_info.py +54 -0
  54. warp/__init__.py +4 -3
  55. warp/__init__.pyi +1 -0
  56. warp/bin/warp-clang.dll +0 -0
  57. warp/bin/warp.dll +0 -0
  58. warp/build.py +5 -3
  59. warp/build_dll.py +29 -9
  60. warp/builtins.py +836 -492
  61. warp/codegen.py +864 -553
  62. warp/config.py +3 -1
  63. warp/context.py +389 -172
  64. warp/fem/__init__.py +24 -6
  65. warp/fem/cache.py +318 -25
  66. warp/fem/dirichlet.py +7 -3
  67. warp/fem/domain.py +14 -0
  68. warp/fem/field/__init__.py +30 -38
  69. warp/fem/field/field.py +149 -0
  70. warp/fem/field/nodal_field.py +244 -138
  71. warp/fem/field/restriction.py +8 -6
  72. warp/fem/field/test.py +127 -59
  73. warp/fem/field/trial.py +117 -60
  74. warp/fem/geometry/__init__.py +5 -1
  75. warp/fem/geometry/deformed_geometry.py +271 -0
  76. warp/fem/geometry/element.py +24 -1
  77. warp/fem/geometry/geometry.py +86 -14
  78. warp/fem/geometry/grid_2d.py +112 -54
  79. warp/fem/geometry/grid_3d.py +134 -65
  80. warp/fem/geometry/hexmesh.py +953 -0
  81. warp/fem/geometry/partition.py +85 -33
  82. warp/fem/geometry/quadmesh_2d.py +532 -0
  83. warp/fem/geometry/tetmesh.py +451 -115
  84. warp/fem/geometry/trimesh_2d.py +197 -92
  85. warp/fem/integrate.py +534 -268
  86. warp/fem/operator.py +58 -31
  87. warp/fem/polynomial.py +11 -0
  88. warp/fem/quadrature/__init__.py +1 -1
  89. warp/fem/quadrature/pic_quadrature.py +150 -58
  90. warp/fem/quadrature/quadrature.py +209 -57
  91. warp/fem/space/__init__.py +230 -53
  92. warp/fem/space/basis_space.py +489 -0
  93. warp/fem/space/collocated_function_space.py +105 -0
  94. warp/fem/space/dof_mapper.py +49 -2
  95. warp/fem/space/function_space.py +90 -39
  96. warp/fem/space/grid_2d_function_space.py +149 -496
  97. warp/fem/space/grid_3d_function_space.py +173 -538
  98. warp/fem/space/hexmesh_function_space.py +352 -0
  99. warp/fem/space/partition.py +129 -76
  100. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  101. warp/fem/space/restriction.py +46 -34
  102. warp/fem/space/shape/__init__.py +15 -0
  103. warp/fem/space/shape/cube_shape_function.py +738 -0
  104. warp/fem/space/shape/shape_function.py +103 -0
  105. warp/fem/space/shape/square_shape_function.py +611 -0
  106. warp/fem/space/shape/tet_shape_function.py +567 -0
  107. warp/fem/space/shape/triangle_shape_function.py +429 -0
  108. warp/fem/space/tetmesh_function_space.py +132 -1039
  109. warp/fem/space/topology.py +295 -0
  110. warp/fem/space/trimesh_2d_function_space.py +104 -742
  111. warp/fem/types.py +13 -11
  112. warp/fem/utils.py +335 -60
  113. warp/native/array.h +120 -34
  114. warp/native/builtin.h +101 -72
  115. warp/native/bvh.cpp +73 -325
  116. warp/native/bvh.cu +406 -23
  117. warp/native/bvh.h +22 -40
  118. warp/native/clang/clang.cpp +1 -0
  119. warp/native/crt.h +2 -0
  120. warp/native/cuda_util.cpp +8 -3
  121. warp/native/cuda_util.h +1 -0
  122. warp/native/exports.h +1522 -1243
  123. warp/native/intersect.h +19 -4
  124. warp/native/intersect_adj.h +8 -8
  125. warp/native/mat.h +76 -17
  126. warp/native/mesh.cpp +33 -108
  127. warp/native/mesh.cu +114 -18
  128. warp/native/mesh.h +395 -40
  129. warp/native/noise.h +272 -329
  130. warp/native/quat.h +51 -8
  131. warp/native/rand.h +44 -34
  132. warp/native/reduce.cpp +1 -1
  133. warp/native/sparse.cpp +4 -4
  134. warp/native/sparse.cu +163 -155
  135. warp/native/spatial.h +2 -2
  136. warp/native/temp_buffer.h +18 -14
  137. warp/native/vec.h +103 -21
  138. warp/native/warp.cpp +2 -1
  139. warp/native/warp.cu +28 -3
  140. warp/native/warp.h +4 -3
  141. warp/render/render_opengl.py +261 -109
  142. warp/sim/__init__.py +1 -2
  143. warp/sim/articulation.py +385 -185
  144. warp/sim/import_mjcf.py +59 -48
  145. warp/sim/import_urdf.py +15 -15
  146. warp/sim/import_usd.py +174 -102
  147. warp/sim/inertia.py +17 -18
  148. warp/sim/integrator_xpbd.py +4 -3
  149. warp/sim/model.py +330 -250
  150. warp/sim/render.py +1 -1
  151. warp/sparse.py +625 -152
  152. warp/stubs.py +341 -309
  153. warp/tape.py +9 -6
  154. warp/tests/__main__.py +3 -6
  155. warp/tests/assets/curlnoise_golden.npy +0 -0
  156. warp/tests/assets/pnoise_golden.npy +0 -0
  157. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  158. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  159. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  160. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  161. warp/tests/aux_test_unresolved_func.py +14 -0
  162. warp/tests/aux_test_unresolved_symbol.py +14 -0
  163. warp/tests/disabled_kinematics.py +239 -0
  164. warp/tests/run_coverage_serial.py +31 -0
  165. warp/tests/test_adam.py +103 -106
  166. warp/tests/test_arithmetic.py +94 -74
  167. warp/tests/test_array.py +82 -101
  168. warp/tests/test_array_reduce.py +57 -23
  169. warp/tests/test_atomic.py +64 -28
  170. warp/tests/test_bool.py +22 -12
  171. warp/tests/test_builtins_resolution.py +1292 -0
  172. warp/tests/test_bvh.py +18 -18
  173. warp/tests/test_closest_point_edge_edge.py +54 -57
  174. warp/tests/test_codegen.py +165 -134
  175. warp/tests/test_compile_consts.py +28 -20
  176. warp/tests/test_conditional.py +108 -24
  177. warp/tests/test_copy.py +10 -12
  178. warp/tests/test_ctypes.py +112 -88
  179. warp/tests/test_dense.py +21 -14
  180. warp/tests/test_devices.py +98 -0
  181. warp/tests/test_dlpack.py +75 -75
  182. warp/tests/test_examples.py +237 -0
  183. warp/tests/test_fabricarray.py +22 -24
  184. warp/tests/test_fast_math.py +15 -11
  185. warp/tests/test_fem.py +1034 -124
  186. warp/tests/test_fp16.py +23 -16
  187. warp/tests/test_func.py +187 -86
  188. warp/tests/test_generics.py +194 -49
  189. warp/tests/test_grad.py +123 -181
  190. warp/tests/test_grad_customs.py +176 -0
  191. warp/tests/test_hash_grid.py +35 -34
  192. warp/tests/test_import.py +10 -23
  193. warp/tests/test_indexedarray.py +24 -25
  194. warp/tests/test_intersect.py +18 -9
  195. warp/tests/test_large.py +141 -0
  196. warp/tests/test_launch.py +14 -41
  197. warp/tests/test_lerp.py +64 -65
  198. warp/tests/test_lvalue.py +493 -0
  199. warp/tests/test_marching_cubes.py +12 -13
  200. warp/tests/test_mat.py +517 -2898
  201. warp/tests/test_mat_lite.py +115 -0
  202. warp/tests/test_mat_scalar_ops.py +2889 -0
  203. warp/tests/test_math.py +103 -9
  204. warp/tests/test_matmul.py +304 -69
  205. warp/tests/test_matmul_lite.py +410 -0
  206. warp/tests/test_mesh.py +60 -22
  207. warp/tests/test_mesh_query_aabb.py +21 -25
  208. warp/tests/test_mesh_query_point.py +111 -22
  209. warp/tests/test_mesh_query_ray.py +12 -24
  210. warp/tests/test_mlp.py +30 -22
  211. warp/tests/test_model.py +92 -89
  212. warp/tests/test_modules_lite.py +39 -0
  213. warp/tests/test_multigpu.py +88 -114
  214. warp/tests/test_noise.py +12 -11
  215. warp/tests/test_operators.py +16 -20
  216. warp/tests/test_options.py +11 -11
  217. warp/tests/test_pinned.py +17 -18
  218. warp/tests/test_print.py +32 -11
  219. warp/tests/test_quat.py +275 -129
  220. warp/tests/test_rand.py +18 -16
  221. warp/tests/test_reload.py +38 -34
  222. warp/tests/test_rounding.py +50 -43
  223. warp/tests/test_runlength_encode.py +168 -20
  224. warp/tests/test_smoothstep.py +9 -11
  225. warp/tests/test_snippet.py +143 -0
  226. warp/tests/test_sparse.py +261 -63
  227. warp/tests/test_spatial.py +276 -243
  228. warp/tests/test_streams.py +110 -85
  229. warp/tests/test_struct.py +268 -63
  230. warp/tests/test_tape.py +39 -21
  231. warp/tests/test_torch.py +90 -86
  232. warp/tests/test_transient_module.py +10 -12
  233. warp/tests/test_types.py +363 -0
  234. warp/tests/test_utils.py +451 -0
  235. warp/tests/test_vec.py +354 -2050
  236. warp/tests/test_vec_lite.py +73 -0
  237. warp/tests/test_vec_scalar_ops.py +2099 -0
  238. warp/tests/test_volume.py +418 -376
  239. warp/tests/test_volume_write.py +124 -134
  240. warp/tests/unittest_serial.py +35 -0
  241. warp/tests/unittest_suites.py +291 -0
  242. warp/tests/unittest_utils.py +342 -0
  243. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  244. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  245. warp/thirdparty/appdirs.py +36 -45
  246. warp/thirdparty/unittest_parallel.py +589 -0
  247. warp/types.py +622 -211
  248. warp/utils.py +54 -393
  249. warp_lang-1.0.0b6.dist-info/METADATA +238 -0
  250. warp_lang-1.0.0b6.dist-info/RECORD +409 -0
  251. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  252. examples/example_cache_management.py +0 -40
  253. examples/example_multigpu.py +0 -54
  254. examples/example_struct.py +0 -65
  255. examples/fem/example_stokes_transfer_3d.py +0 -210
  256. warp/bin/warp-clang.so +0 -0
  257. warp/bin/warp.so +0 -0
  258. warp/fem/field/discrete_field.py +0 -80
  259. warp/fem/space/nodal_function_space.py +0 -233
  260. warp/tests/test_all.py +0 -223
  261. warp/tests/test_array_scan.py +0 -60
  262. warp/tests/test_base.py +0 -208
  263. warp/tests/test_unresolved_func.py +0 -7
  264. warp/tests/test_unresolved_symbol.py +0 -7
  265. warp_lang-1.0.0b2.dist-info/METADATA +0 -26
  266. warp_lang-1.0.0b2.dist-info/RECORD +0 -380
  267. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  268. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  269. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  270. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  271. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/fem/operator.py CHANGED
@@ -51,31 +51,46 @@ def operator(resolver: Callable):
51
51
 
52
52
 
53
53
  @operator(resolver=lambda dmn: dmn.element_position)
54
- def position(domain: Domain, x: Sample):
55
- """Evaluates the world position of the sample point x"""
54
+ def position(domain: Domain, s: Sample):
55
+ """Evaluates the world position of the sample point `s`"""
56
56
  pass
57
57
 
58
58
 
59
59
  @operator(resolver=lambda dmn: dmn.eval_normal)
60
- def normal(domain: Domain, x: Sample):
61
- """Evaluates the element normal at the sample point x. Null for interior points."""
60
+ def normal(domain: Domain, s: Sample):
61
+ """Evaluates the element normal at the sample point `s`. Null for interior points."""
62
+ pass
63
+
64
+
65
+ @operator(resolver=lambda dmn: dmn.element_deformation_gradient)
66
+ def deformation_gradient(domain: Domain, s: Sample):
67
+ """Evaluates the gradient of the domain position with respect to the element reference space at the sample point `s`"""
62
68
  pass
63
69
 
64
70
 
65
71
  @operator(resolver=lambda dmn: dmn.element_lookup)
66
- def lookup(domain: Domain, x: Any):
67
- """Look-ups a sample point from a world position, projecting to the closest point on the domain"""
72
+ def lookup(domain: Domain, x: Any) -> Sample:
73
+ """Looks-up the sample point corresponding to a world position `x`, projecting to the closest point on the domain.
74
+
75
+ Arg:
76
+ x: world position of the point to look-up in the geometry
77
+ guess: (optional) :class:`Sample` initial guess, may help perform the query
78
+
79
+ Notes:
80
+ Currently this operator is only fully supported for :class:`Grid2D` and :class:`Grid3D` geometries.
81
+ For :class:`TriangleMesh2D` and :class:`Tetmesh` geometries, the operator requires providing `guess`.
82
+ """
68
83
  pass
69
84
 
70
85
 
71
86
  @operator(resolver=lambda dmn: dmn.element_measure)
72
- def measure(domain: Domain, sample: Sample):
73
- """Returns the measure (volume, area, or length) of an element"""
87
+ def measure(domain: Domain, s: Sample) -> float:
88
+ """Returns the measure (volume, area, or length) determinant of an element at a sample point `s`"""
74
89
  pass
75
90
 
76
91
 
77
92
  @operator(resolver=lambda dmn: dmn.element_measure_ratio)
78
- def measure_ratio(domain: Domain, sample: Sample):
93
+ def measure_ratio(domain: Domain, s: Sample) -> float:
79
94
  """Returns the maximum ratio between the measure of this element and that of higher-dimensional neighbours."""
80
95
  pass
81
96
 
@@ -85,26 +100,38 @@ def measure_ratio(domain: Domain, sample: Sample):
85
100
 
86
101
 
87
102
  @operator(resolver=lambda f: f.eval_inner)
88
- def inner(f: Field, x: Sample):
89
- """Evaluates the field at a sample point x. On oriented sides, use the inner element"""
103
+ def inner(f: Field, s: Sample):
104
+ """Evaluates the field at a sample point `s`. On oriented sides, uses the inner element"""
90
105
  pass
91
106
 
92
107
 
93
108
  @operator(resolver=lambda f: f.eval_grad_inner)
94
- def grad(f: Field, x: Sample):
95
- """Evaluates the field gradient at a sample point x. On oriented sides, use the inner element"""
109
+ def grad(f: Field, s: Sample):
110
+ """Evaluates the field gradient at a sample point `s`. On oriented sides, uses the inner element"""
111
+ pass
112
+
113
+
114
+ @operator(resolver=lambda f: f.eval_div_inner)
115
+ def div(f: Field, s: Sample):
116
+ """Evaluates the field divergence at a sample point `s`. On oriented sides, uses the inner element"""
96
117
  pass
97
118
 
98
119
 
99
120
  @operator(resolver=lambda f: f.eval_outer)
100
- def outer(f: Field, x: Sample):
101
- """Evaluates the field at a sample point x. On oriented sides, use the outer element. On interior points and on domain boundaries, this is equivalent to inner."""
121
+ def outer(f: Field, s: Sample):
122
+ """Evaluates the field at a sample point `s`. On oriented sides, uses the outer element. On interior points and on domain boundaries, this is equivalent to :func:`inner`."""
123
+ pass
124
+
125
+
126
+ @operator(resolver=lambda f: f.eval_grad_outer)
127
+ def grad_outer(f: Field, s: Sample):
128
+ """Evaluates the field gradient at a sample point `s`. On oriented sides, uses the outer element. On interior points and on domain boundaries, this is equivalent to :func:`grad`."""
102
129
  pass
103
130
 
104
131
 
105
132
  @operator(resolver=lambda f: f.eval_grad_outer)
106
- def grad_outer(f: Field, x: Sample):
107
- """Evaluates the field gradient at a sample point x. On oriented sides, use the outer element. On interior points and on domain boundaries, this is equivalent to grad."""
133
+ def div_outer(f: Field, s: Sample):
134
+ """Evaluates the field divergence at a sample point `s`. On oriented sides, uses the outer element. On interior points and on domain boundaries, this is equivalent to :func:`div`."""
108
135
  pass
109
136
 
110
137
 
@@ -124,39 +151,39 @@ def at_node(f: Field, s: Sample):
124
151
 
125
152
 
126
153
  @integrand
127
- def D(f: Field, x: Sample):
128
- """Symmetric part of the (inner) gradient of the field at x"""
129
- return utils.symmetric_part(grad(f, x))
154
+ def D(f: Field, s: Sample):
155
+ """Symmetric part of the (inner) gradient of the field at `s`"""
156
+ return utils.symmetric_part(grad(f, s))
130
157
 
131
158
 
132
159
  @integrand
133
- def div(f: Field, x: Sample):
134
- """(Inner) divergence of the field at x"""
135
- return wp.trace(grad(f, x))
160
+ def curl(f: Field, s: Sample):
161
+ """Skew part of the (inner) gradient of the field at `s`, as a vector such that ``wp.cross(curl(u), v) = skew(grad(u)) v``"""
162
+ return utils.skew_part(grad(f, s))
136
163
 
137
164
 
138
165
  @integrand
139
- def jump(f: Field, x: Sample):
166
+ def jump(f: Field, s: Sample):
140
167
  """Jump between inner and outer element values on an interior side. Zero for interior points or domain boundaries"""
141
- return inner(f, x) - outer(f, x)
168
+ return inner(f, s) - outer(f, s)
142
169
 
143
170
 
144
171
  @integrand
145
- def average(f: Field, x: Sample):
172
+ def average(f: Field, s: Sample):
146
173
  """Average between inner and outer element values"""
147
- return 0.5 * (inner(f, x) + outer(f, x))
174
+ return 0.5 * (inner(f, s) + outer(f, s))
148
175
 
149
176
 
150
177
  @integrand
151
- def grad_jump(f: Field, x: Sample):
178
+ def grad_jump(f: Field, s: Sample):
152
179
  """Jump between inner and outer element gradients on an interior side. Zero for interior points or domain boundaries"""
153
- return grad(f, x) - grad_outer(f, x)
180
+ return grad(f, s) - grad_outer(f, s)
154
181
 
155
182
 
156
183
  @integrand
157
- def grad_average(f: Field, x: Sample):
184
+ def grad_average(f: Field, s: Sample):
158
185
  """Average between inner and outer element gradients"""
159
- return 0.5 * (grad(f, x) + grad_outer(f, x))
186
+ return 0.5 * (grad(f, s) + grad_outer(f, s))
160
187
 
161
188
 
162
189
  # Set default call operators for argument types, so that field(s) = inner(field, s) and domain(s) = position(domain, s)
warp/fem/polynomial.py CHANGED
@@ -5,11 +5,22 @@ import numpy as np
5
5
 
6
6
 
7
7
  class Polynomial(Enum):
8
+ """Polynomial family defining interpolation nodes over an interval"""
9
+
8
10
  GAUSS_LEGENDRE = 0
11
+ """Gauss--Legendre 1D polynomial family (does not include endpoints)"""
12
+
9
13
  LOBATTO_GAUSS_LEGENDRE = 1
14
+ """Lobatto--Gauss--Legendre 1D polynomial family (includes endpoints)"""
15
+
10
16
  EQUISPACED_CLOSED = 2
17
+ """Closed 1D polynomial family with uniformly distributed nodes (includes endpoints)"""
18
+
11
19
  EQUISPACED_OPEN = 3
20
+ """Open 1D polynomial family with uniformly distributed nodes (does not include endpoints)"""
12
21
 
22
+ def __str__(self):
23
+ return self.name
13
24
 
14
25
  def is_closed(family: Polynomial):
15
26
  """Whether the polynomial roots include interval endpoints"""
@@ -1,2 +1,2 @@
1
- from .quadrature import Quadrature, RegularQuadrature
1
+ from .quadrature import Quadrature, RegularQuadrature, NodalQuadrature, ExplicitQuadrature
2
2
  from .pic_quadrature import PicQuadrature
@@ -1,8 +1,11 @@
1
+ from typing import Union, Tuple, Any, Optional
2
+
1
3
  import warp as wp
2
4
 
3
5
  from warp.fem.domain import GeometryDomain
4
- from warp.fem.types import ElementIndex, Coords
6
+ from warp.fem.types import ElementIndex, Coords, make_free_sample
5
7
  from warp.fem.utils import compress_node_indices
8
+ from warp.fem.cache import cached_arg_value, TemporaryStore, borrow_temporary, dynamic_kernel
6
9
 
7
10
  from .quadrature import Quadrature
8
11
 
@@ -11,23 +14,36 @@ wp.set_module_options({"enable_backward": False})
11
14
 
12
15
 
13
16
  class PicQuadrature(Quadrature):
14
- """Particle-based quadrature formula, using a global set of points irregularely spread out over geometry elements.
17
+ """Particle-based quadrature formula, using a global set of points unevenly spread out over geometry elements.
15
18
 
16
19
  Useful for Particle-In-Cell and derived methods.
20
+
21
+ Args:
22
+ domain: Undelying domain for the qaudrature
23
+ positions: Either an array containing the world positions of all particles, or a tuple of arrays containing
24
+ the cell indices and coordinates for each particle. Note that the former requires the underlying geometry to
25
+ define a global :meth:`Geometry.cell_lookup` method; currently this is only available for :class:`Grid2D` and :class:`Grid3D`.
26
+ measures: Array containing the measure (area/volume) of each particle, used to defined the integration weights.
27
+ If ``None``, defaults to the cell measure divided by the number of particles in the cell.
28
+ temporary_store: shared pool from which to allocate temporary arrays
17
29
  """
18
30
 
19
31
  def __init__(
20
32
  self,
21
33
  domain: GeometryDomain,
22
- positions: "wp.array()",
23
- measures: "wp.array(dtype=float)",
34
+ positions: Union[
35
+ "wp.array(dtype=wp.vecXd)",
36
+ Tuple[
37
+ "wp.array(dtype=ElementIndex)",
38
+ "wp.array(dtype=Coords)",
39
+ ],
40
+ ],
41
+ measures: Optional["wp.array(dtype=float)"] = None,
42
+ temporary_store: TemporaryStore = None,
24
43
  ):
25
44
  super().__init__(domain)
26
45
 
27
- self.positions = positions
28
- self.measures = measures
29
-
30
- self._bin_particles()
46
+ self._bin_particles(positions, measures, temporary_store)
31
47
 
32
48
  @property
33
49
  def name(self):
@@ -52,34 +68,39 @@ class PicQuadrature(Quadrature):
52
68
  particle_fraction: wp.array(dtype=float)
53
69
  particle_coords: wp.array(dtype=Coords)
54
70
 
71
+ @cached_arg_value
55
72
  def arg_value(self, device) -> Arg:
56
73
  arg = PicQuadrature.Arg()
57
- arg.cell_particle_offsets = self._cell_particle_offsets.to(device)
58
- arg.cell_particle_indices = self._cell_particle_indices.to(device)
74
+ arg.cell_particle_offsets = self._cell_particle_offsets.array.to(device)
75
+ arg.cell_particle_indices = self._cell_particle_indices.array.to(device)
59
76
  arg.particle_fraction = self._particle_fraction.to(device)
60
77
  arg.particle_coords = self._particle_coords.to(device)
61
78
  return arg
62
79
 
63
80
  def total_point_count(self):
64
- return self.positions.shape[0]
81
+ return self._particle_coords.shape[0]
82
+
83
+ def active_cell_count(self):
84
+ """Number of cells containing at least one particle"""
85
+ return self._cell_count
65
86
 
66
87
  @wp.func
67
- def point_count(arg: Arg, element_index: ElementIndex):
68
- return arg.cell_particle_offsets[element_index + 1] - arg.cell_particle_offsets[element_index]
88
+ def point_count(elt_arg: Any, qp_arg: Arg, element_index: ElementIndex):
89
+ return qp_arg.cell_particle_offsets[element_index + 1] - qp_arg.cell_particle_offsets[element_index]
69
90
 
70
91
  @wp.func
71
- def point_coords(arg: Arg, element_index: ElementIndex, index: int):
72
- particle_index = arg.cell_particle_indices[arg.cell_particle_offsets[element_index] + index]
73
- return arg.particle_coords[particle_index]
92
+ def point_coords(elt_arg: Any, qp_arg: Arg, element_index: ElementIndex, index: int):
93
+ particle_index = qp_arg.cell_particle_indices[qp_arg.cell_particle_offsets[element_index] + index]
94
+ return qp_arg.particle_coords[particle_index]
74
95
 
75
96
  @wp.func
76
- def point_weight(arg: Arg, element_index: ElementIndex, index: int):
77
- particle_index = arg.cell_particle_indices[arg.cell_particle_offsets[element_index] + index]
78
- return arg.particle_fraction[particle_index]
97
+ def point_weight(elt_arg: Any, qp_arg: Arg, element_index: ElementIndex, index: int):
98
+ particle_index = qp_arg.cell_particle_indices[qp_arg.cell_particle_offsets[element_index] + index]
99
+ return qp_arg.particle_fraction[particle_index]
79
100
 
80
101
  @wp.func
81
- def point_index(arg: Arg, element_index: ElementIndex, index: int):
82
- particle_index = arg.cell_particle_indices[arg.cell_particle_offsets[element_index] + index]
102
+ def point_index(elt_arg: Any, qp_arg: Arg, element_index: ElementIndex, index: int):
103
+ particle_index = qp_arg.cell_particle_indices[qp_arg.cell_particle_offsets[element_index] + index]
83
104
  return particle_index
84
105
 
85
106
  def fill_element_mask(self, mask: "wp.array(dtype=int)"):
@@ -93,7 +114,7 @@ class PicQuadrature(Quadrature):
93
114
  kernel=PicQuadrature._fill_mask_kernel,
94
115
  dim=self.domain.geometry_element_count(),
95
116
  device=mask.device,
96
- inputs=[self._cell_particle_offsets, mask],
117
+ inputs=[self._cell_particle_offsets.array, mask],
97
118
  )
98
119
 
99
120
  @wp.kernel
@@ -104,50 +125,121 @@ class PicQuadrature(Quadrature):
104
125
  i = wp.tid()
105
126
  element_mask[i] = wp.select(element_particle_offsets[i] == element_particle_offsets[i + 1], 1, 0)
106
127
 
107
- def _bin_particles(self):
108
- from warp.fem import cache
128
+ @wp.kernel
129
+ def _compute_uniform_fraction(
130
+ cell_index: wp.array(dtype=ElementIndex),
131
+ cell_particle_offsets: wp.array(dtype=int),
132
+ cell_fraction: wp.array(dtype=float),
133
+ ):
134
+ p = wp.tid()
109
135
 
110
- def bin_particles_fn(
111
- cell_arg_value: self.domain.ElementArg,
112
- positions: wp.array(dtype=self.positions.dtype),
113
- measures: wp.array(dtype=float),
114
- cell_index: wp.array(dtype=ElementIndex),
115
- cell_coords: wp.array(dtype=Coords),
116
- cell_fraction: wp.array(dtype=float),
117
- ):
118
- p = wp.tid()
119
- sample = self.domain.element_lookup(cell_arg_value, positions[p])
136
+ cell = cell_index[p]
137
+ cell_particle_count = cell_particle_offsets[cell + 1] - cell_particle_offsets[cell]
120
138
 
121
- cell_index[p] = sample.element_index
139
+ cell_fraction[p] = 1.0 / float(cell_particle_count)
122
140
 
123
- cell_coords[p] = sample.element_coords
124
- cell_fraction[p] = measures[p] / self.domain.element_measure(cell_arg_value, sample)
141
+ def _bin_particles(self, positions, measures, temporary_store: TemporaryStore):
142
+ if wp.types.is_array(positions):
143
+ # Initialize from positions
144
+ @dynamic_kernel(suffix=f"{self.domain.name}")
145
+ def bin_particles(
146
+ cell_arg_value: self.domain.ElementArg,
147
+ positions: wp.array(dtype=positions.dtype),
148
+ cell_index: wp.array(dtype=ElementIndex),
149
+ cell_coords: wp.array(dtype=Coords),
150
+ ):
151
+ p = wp.tid()
152
+ sample = self.domain.element_lookup(cell_arg_value, positions[p])
125
153
 
126
- bin_particles = cache.get_kernel(
127
- bin_particles_fn,
128
- suffix=f"{self.domain.name}",
129
- )
154
+ cell_index[p] = sample.element_index
155
+ cell_coords[p] = sample.element_coords
130
156
 
131
- device = self.positions.device
157
+ device = positions.device
132
158
 
133
- cell_index = wp.empty(shape=self.positions.shape, dtype=int, device=device)
134
- self._particle_coords = wp.empty(shape=self.positions.shape, dtype=Coords, device=device)
135
- self._particle_fraction = wp.empty(shape=self.positions.shape, dtype=float, device=device)
159
+ cell_index_temp = borrow_temporary(temporary_store, shape=positions.shape, dtype=int, device=device)
160
+ cell_index = cell_index_temp.array
136
161
 
137
- wp.launch(
138
- dim=self.positions.shape[0],
139
- kernel=bin_particles,
140
- inputs=[
141
- self.domain.element_arg_value(device),
142
- self.positions,
143
- self.measures,
144
- cell_index,
145
- self._particle_coords,
146
- self._particle_fraction,
147
- ],
148
- device=device,
149
- )
162
+ self._particle_coords_temp = borrow_temporary(
163
+ temporary_store, shape=positions.shape, dtype=Coords, device=device
164
+ )
165
+ self._particle_coords = self._particle_coords_temp.array
166
+
167
+ wp.launch(
168
+ dim=positions.shape[0],
169
+ kernel=bin_particles,
170
+ inputs=[
171
+ self.domain.element_arg_value(device),
172
+ positions,
173
+ cell_index,
174
+ self._particle_coords,
175
+ ],
176
+ device=device,
177
+ )
178
+
179
+ else:
180
+ cell_index, self._particle_coords = positions
181
+ if cell_index.shape != self._particle_coords.shape:
182
+ raise ValueError("Cell index and coordinates arrays must have the same shape")
183
+
184
+ cell_index_temp = None
185
+ self._particle_coords_temp = None
150
186
 
151
187
  self._cell_particle_offsets, self._cell_particle_indices, self._cell_count, _ = compress_node_indices(
152
188
  self.domain.geometry_element_count(), cell_index
153
189
  )
190
+
191
+ self._compute_fraction(cell_index, measures, temporary_store)
192
+
193
+ def _compute_fraction(self, cell_index, measures, temporary_store: TemporaryStore):
194
+ device = cell_index.device
195
+
196
+ self._particle_fraction_temp = borrow_temporary(
197
+ temporary_store, shape=cell_index.shape, dtype=float, device=device
198
+ )
199
+ self._particle_fraction = self._particle_fraction_temp.array
200
+
201
+ if measures is None:
202
+ # Split fraction uniformly over all particles in cell
203
+
204
+ wp.launch(
205
+ dim=cell_index.shape,
206
+ kernel=PicQuadrature._compute_uniform_fraction,
207
+ inputs=[
208
+ cell_index,
209
+ self._cell_particle_offsets.array,
210
+ self._particle_fraction,
211
+ ],
212
+ device=device,
213
+ )
214
+
215
+ else:
216
+ # Fraction from particle measure
217
+
218
+ if measures.shape != cell_index.shape:
219
+ raise ValueError("Measures should be an 1d array or length equal to particle count")
220
+
221
+ @dynamic_kernel(suffix=f"{self.domain.name}")
222
+ def compute_fraction(
223
+ cell_arg_value: self.domain.ElementArg,
224
+ measures: wp.array(dtype=float),
225
+ cell_index: wp.array(dtype=ElementIndex),
226
+ cell_coords: wp.array(dtype=Coords),
227
+ cell_fraction: wp.array(dtype=float),
228
+ ):
229
+ p = wp.tid()
230
+ sample = make_free_sample(cell_index[p], cell_coords[p])
231
+
232
+ cell_fraction[p] = measures[p] / self.domain.element_measure(cell_arg_value, sample)
233
+
234
+ wp.launch(
235
+ dim=measures.shape[0],
236
+ kernel=compute_fraction,
237
+ inputs=[
238
+ self.domain.element_arg_value(device),
239
+ measures,
240
+ cell_index,
241
+ self._particle_coords,
242
+ self._particle_fraction,
243
+ ],
244
+ device=device,
245
+ )