warp-lang 1.0.0b2__py3-none-win_amd64.whl → 1.0.0b6__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (271) hide show
  1. docs/conf.py +17 -5
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/env/env_usd.py +4 -1
  6. examples/env/environment.py +8 -9
  7. examples/example_dem.py +34 -33
  8. examples/example_diffray.py +364 -337
  9. examples/example_fluid.py +32 -23
  10. examples/example_jacobian_ik.py +97 -93
  11. examples/example_marching_cubes.py +6 -16
  12. examples/example_mesh.py +6 -16
  13. examples/example_mesh_intersect.py +16 -14
  14. examples/example_nvdb.py +14 -16
  15. examples/example_raycast.py +14 -13
  16. examples/example_raymarch.py +16 -23
  17. examples/example_render_opengl.py +19 -10
  18. examples/example_sim_cartpole.py +82 -78
  19. examples/example_sim_cloth.py +45 -48
  20. examples/example_sim_fk_grad.py +51 -44
  21. examples/example_sim_fk_grad_torch.py +47 -40
  22. examples/example_sim_grad_bounce.py +108 -133
  23. examples/example_sim_grad_cloth.py +99 -113
  24. examples/example_sim_granular.py +5 -6
  25. examples/{example_sim_sdf_shape.py → example_sim_granular_collision_sdf.py} +37 -26
  26. examples/example_sim_neo_hookean.py +51 -55
  27. examples/example_sim_particle_chain.py +4 -4
  28. examples/example_sim_quadruped.py +126 -81
  29. examples/example_sim_rigid_chain.py +54 -61
  30. examples/example_sim_rigid_contact.py +66 -70
  31. examples/example_sim_rigid_fem.py +3 -3
  32. examples/example_sim_rigid_force.py +1 -1
  33. examples/example_sim_rigid_gyroscopic.py +3 -4
  34. examples/example_sim_rigid_kinematics.py +28 -39
  35. examples/example_sim_trajopt.py +112 -110
  36. examples/example_sph.py +9 -8
  37. examples/example_wave.py +7 -7
  38. examples/fem/bsr_utils.py +30 -17
  39. examples/fem/example_apic_fluid.py +85 -69
  40. examples/fem/example_convection_diffusion.py +97 -93
  41. examples/fem/example_convection_diffusion_dg.py +142 -149
  42. examples/fem/example_convection_diffusion_dg0.py +141 -136
  43. examples/fem/example_deformed_geometry.py +146 -0
  44. examples/fem/example_diffusion.py +115 -84
  45. examples/fem/example_diffusion_3d.py +116 -86
  46. examples/fem/example_diffusion_mgpu.py +102 -79
  47. examples/fem/example_mixed_elasticity.py +139 -100
  48. examples/fem/example_navier_stokes.py +175 -162
  49. examples/fem/example_stokes.py +143 -111
  50. examples/fem/example_stokes_transfer.py +186 -157
  51. examples/fem/mesh_utils.py +59 -97
  52. examples/fem/plot_utils.py +138 -17
  53. tools/ci/publishing/build_nodes_info.py +54 -0
  54. warp/__init__.py +4 -3
  55. warp/__init__.pyi +1 -0
  56. warp/bin/warp-clang.dll +0 -0
  57. warp/bin/warp.dll +0 -0
  58. warp/build.py +5 -3
  59. warp/build_dll.py +29 -9
  60. warp/builtins.py +836 -492
  61. warp/codegen.py +864 -553
  62. warp/config.py +3 -1
  63. warp/context.py +389 -172
  64. warp/fem/__init__.py +24 -6
  65. warp/fem/cache.py +318 -25
  66. warp/fem/dirichlet.py +7 -3
  67. warp/fem/domain.py +14 -0
  68. warp/fem/field/__init__.py +30 -38
  69. warp/fem/field/field.py +149 -0
  70. warp/fem/field/nodal_field.py +244 -138
  71. warp/fem/field/restriction.py +8 -6
  72. warp/fem/field/test.py +127 -59
  73. warp/fem/field/trial.py +117 -60
  74. warp/fem/geometry/__init__.py +5 -1
  75. warp/fem/geometry/deformed_geometry.py +271 -0
  76. warp/fem/geometry/element.py +24 -1
  77. warp/fem/geometry/geometry.py +86 -14
  78. warp/fem/geometry/grid_2d.py +112 -54
  79. warp/fem/geometry/grid_3d.py +134 -65
  80. warp/fem/geometry/hexmesh.py +953 -0
  81. warp/fem/geometry/partition.py +85 -33
  82. warp/fem/geometry/quadmesh_2d.py +532 -0
  83. warp/fem/geometry/tetmesh.py +451 -115
  84. warp/fem/geometry/trimesh_2d.py +197 -92
  85. warp/fem/integrate.py +534 -268
  86. warp/fem/operator.py +58 -31
  87. warp/fem/polynomial.py +11 -0
  88. warp/fem/quadrature/__init__.py +1 -1
  89. warp/fem/quadrature/pic_quadrature.py +150 -58
  90. warp/fem/quadrature/quadrature.py +209 -57
  91. warp/fem/space/__init__.py +230 -53
  92. warp/fem/space/basis_space.py +489 -0
  93. warp/fem/space/collocated_function_space.py +105 -0
  94. warp/fem/space/dof_mapper.py +49 -2
  95. warp/fem/space/function_space.py +90 -39
  96. warp/fem/space/grid_2d_function_space.py +149 -496
  97. warp/fem/space/grid_3d_function_space.py +173 -538
  98. warp/fem/space/hexmesh_function_space.py +352 -0
  99. warp/fem/space/partition.py +129 -76
  100. warp/fem/space/quadmesh_2d_function_space.py +369 -0
  101. warp/fem/space/restriction.py +46 -34
  102. warp/fem/space/shape/__init__.py +15 -0
  103. warp/fem/space/shape/cube_shape_function.py +738 -0
  104. warp/fem/space/shape/shape_function.py +103 -0
  105. warp/fem/space/shape/square_shape_function.py +611 -0
  106. warp/fem/space/shape/tet_shape_function.py +567 -0
  107. warp/fem/space/shape/triangle_shape_function.py +429 -0
  108. warp/fem/space/tetmesh_function_space.py +132 -1039
  109. warp/fem/space/topology.py +295 -0
  110. warp/fem/space/trimesh_2d_function_space.py +104 -742
  111. warp/fem/types.py +13 -11
  112. warp/fem/utils.py +335 -60
  113. warp/native/array.h +120 -34
  114. warp/native/builtin.h +101 -72
  115. warp/native/bvh.cpp +73 -325
  116. warp/native/bvh.cu +406 -23
  117. warp/native/bvh.h +22 -40
  118. warp/native/clang/clang.cpp +1 -0
  119. warp/native/crt.h +2 -0
  120. warp/native/cuda_util.cpp +8 -3
  121. warp/native/cuda_util.h +1 -0
  122. warp/native/exports.h +1522 -1243
  123. warp/native/intersect.h +19 -4
  124. warp/native/intersect_adj.h +8 -8
  125. warp/native/mat.h +76 -17
  126. warp/native/mesh.cpp +33 -108
  127. warp/native/mesh.cu +114 -18
  128. warp/native/mesh.h +395 -40
  129. warp/native/noise.h +272 -329
  130. warp/native/quat.h +51 -8
  131. warp/native/rand.h +44 -34
  132. warp/native/reduce.cpp +1 -1
  133. warp/native/sparse.cpp +4 -4
  134. warp/native/sparse.cu +163 -155
  135. warp/native/spatial.h +2 -2
  136. warp/native/temp_buffer.h +18 -14
  137. warp/native/vec.h +103 -21
  138. warp/native/warp.cpp +2 -1
  139. warp/native/warp.cu +28 -3
  140. warp/native/warp.h +4 -3
  141. warp/render/render_opengl.py +261 -109
  142. warp/sim/__init__.py +1 -2
  143. warp/sim/articulation.py +385 -185
  144. warp/sim/import_mjcf.py +59 -48
  145. warp/sim/import_urdf.py +15 -15
  146. warp/sim/import_usd.py +174 -102
  147. warp/sim/inertia.py +17 -18
  148. warp/sim/integrator_xpbd.py +4 -3
  149. warp/sim/model.py +330 -250
  150. warp/sim/render.py +1 -1
  151. warp/sparse.py +625 -152
  152. warp/stubs.py +341 -309
  153. warp/tape.py +9 -6
  154. warp/tests/__main__.py +3 -6
  155. warp/tests/assets/curlnoise_golden.npy +0 -0
  156. warp/tests/assets/pnoise_golden.npy +0 -0
  157. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  158. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  159. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  160. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  161. warp/tests/aux_test_unresolved_func.py +14 -0
  162. warp/tests/aux_test_unresolved_symbol.py +14 -0
  163. warp/tests/disabled_kinematics.py +239 -0
  164. warp/tests/run_coverage_serial.py +31 -0
  165. warp/tests/test_adam.py +103 -106
  166. warp/tests/test_arithmetic.py +94 -74
  167. warp/tests/test_array.py +82 -101
  168. warp/tests/test_array_reduce.py +57 -23
  169. warp/tests/test_atomic.py +64 -28
  170. warp/tests/test_bool.py +22 -12
  171. warp/tests/test_builtins_resolution.py +1292 -0
  172. warp/tests/test_bvh.py +18 -18
  173. warp/tests/test_closest_point_edge_edge.py +54 -57
  174. warp/tests/test_codegen.py +165 -134
  175. warp/tests/test_compile_consts.py +28 -20
  176. warp/tests/test_conditional.py +108 -24
  177. warp/tests/test_copy.py +10 -12
  178. warp/tests/test_ctypes.py +112 -88
  179. warp/tests/test_dense.py +21 -14
  180. warp/tests/test_devices.py +98 -0
  181. warp/tests/test_dlpack.py +75 -75
  182. warp/tests/test_examples.py +237 -0
  183. warp/tests/test_fabricarray.py +22 -24
  184. warp/tests/test_fast_math.py +15 -11
  185. warp/tests/test_fem.py +1034 -124
  186. warp/tests/test_fp16.py +23 -16
  187. warp/tests/test_func.py +187 -86
  188. warp/tests/test_generics.py +194 -49
  189. warp/tests/test_grad.py +123 -181
  190. warp/tests/test_grad_customs.py +176 -0
  191. warp/tests/test_hash_grid.py +35 -34
  192. warp/tests/test_import.py +10 -23
  193. warp/tests/test_indexedarray.py +24 -25
  194. warp/tests/test_intersect.py +18 -9
  195. warp/tests/test_large.py +141 -0
  196. warp/tests/test_launch.py +14 -41
  197. warp/tests/test_lerp.py +64 -65
  198. warp/tests/test_lvalue.py +493 -0
  199. warp/tests/test_marching_cubes.py +12 -13
  200. warp/tests/test_mat.py +517 -2898
  201. warp/tests/test_mat_lite.py +115 -0
  202. warp/tests/test_mat_scalar_ops.py +2889 -0
  203. warp/tests/test_math.py +103 -9
  204. warp/tests/test_matmul.py +304 -69
  205. warp/tests/test_matmul_lite.py +410 -0
  206. warp/tests/test_mesh.py +60 -22
  207. warp/tests/test_mesh_query_aabb.py +21 -25
  208. warp/tests/test_mesh_query_point.py +111 -22
  209. warp/tests/test_mesh_query_ray.py +12 -24
  210. warp/tests/test_mlp.py +30 -22
  211. warp/tests/test_model.py +92 -89
  212. warp/tests/test_modules_lite.py +39 -0
  213. warp/tests/test_multigpu.py +88 -114
  214. warp/tests/test_noise.py +12 -11
  215. warp/tests/test_operators.py +16 -20
  216. warp/tests/test_options.py +11 -11
  217. warp/tests/test_pinned.py +17 -18
  218. warp/tests/test_print.py +32 -11
  219. warp/tests/test_quat.py +275 -129
  220. warp/tests/test_rand.py +18 -16
  221. warp/tests/test_reload.py +38 -34
  222. warp/tests/test_rounding.py +50 -43
  223. warp/tests/test_runlength_encode.py +168 -20
  224. warp/tests/test_smoothstep.py +9 -11
  225. warp/tests/test_snippet.py +143 -0
  226. warp/tests/test_sparse.py +261 -63
  227. warp/tests/test_spatial.py +276 -243
  228. warp/tests/test_streams.py +110 -85
  229. warp/tests/test_struct.py +268 -63
  230. warp/tests/test_tape.py +39 -21
  231. warp/tests/test_torch.py +90 -86
  232. warp/tests/test_transient_module.py +10 -12
  233. warp/tests/test_types.py +363 -0
  234. warp/tests/test_utils.py +451 -0
  235. warp/tests/test_vec.py +354 -2050
  236. warp/tests/test_vec_lite.py +73 -0
  237. warp/tests/test_vec_scalar_ops.py +2099 -0
  238. warp/tests/test_volume.py +418 -376
  239. warp/tests/test_volume_write.py +124 -134
  240. warp/tests/unittest_serial.py +35 -0
  241. warp/tests/unittest_suites.py +291 -0
  242. warp/tests/unittest_utils.py +342 -0
  243. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  244. warp/tests/{test_debug.py → walkthough_debug.py} +3 -17
  245. warp/thirdparty/appdirs.py +36 -45
  246. warp/thirdparty/unittest_parallel.py +589 -0
  247. warp/types.py +622 -211
  248. warp/utils.py +54 -393
  249. warp_lang-1.0.0b6.dist-info/METADATA +238 -0
  250. warp_lang-1.0.0b6.dist-info/RECORD +409 -0
  251. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  252. examples/example_cache_management.py +0 -40
  253. examples/example_multigpu.py +0 -54
  254. examples/example_struct.py +0 -65
  255. examples/fem/example_stokes_transfer_3d.py +0 -210
  256. warp/bin/warp-clang.so +0 -0
  257. warp/bin/warp.so +0 -0
  258. warp/fem/field/discrete_field.py +0 -80
  259. warp/fem/space/nodal_function_space.py +0 -233
  260. warp/tests/test_all.py +0 -223
  261. warp/tests/test_array_scan.py +0 -60
  262. warp/tests/test_base.py +0 -208
  263. warp/tests/test_unresolved_func.py +0 -7
  264. warp/tests/test_unresolved_symbol.py +0 -7
  265. warp_lang-1.0.0b2.dist-info/METADATA +0 -26
  266. warp_lang-1.0.0b2.dist-info/RECORD +0 -380
  267. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  268. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  269. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  270. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  271. {warp_lang-1.0.0b2.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
warp/fem/__init__.py CHANGED
@@ -1,9 +1,27 @@
1
+ from .geometry import Geometry, Grid2D, Trimesh2D, Quadmesh2D, Grid3D, Tetmesh, Hexmesh
2
+ from .geometry import GeometryPartition, LinearGeometryPartition, ExplicitGeometryPartition
1
3
 
4
+ from .space import FunctionSpace, make_polynomial_space, ElementBasis
5
+ from .space import BasisSpace, PointBasisSpace, make_polynomial_basis_space, make_collocated_function_space
6
+ from .space import DofMapper, SkewSymmetricTensorMapper, SymmetricTensorMapper
7
+ from .space import SpaceTopology, SpacePartition, SpaceRestriction, make_space_partition, make_space_restriction
2
8
 
3
-
9
+ from .domain import GeometryDomain, Cells, Sides, BoundarySides, FrontierSides
10
+ from .quadrature import Quadrature, RegularQuadrature, NodalQuadrature, ExplicitQuadrature, PicQuadrature
11
+ from .polynomial import Polynomial
4
12
 
5
- # Shared DOFs : vertices or edges (sides in 3D)
6
- # DG: Interior DOFs only
7
- # Unique identifier -> vertex or edge, side
8
- # Element -> DOF id
9
-
13
+ from .field import FieldLike, DiscreteField, make_test, make_trial, make_restriction
14
+
15
+ from .integrate import integrate, interpolate
16
+
17
+ from .operator import integrand
18
+ from .operator import position, normal, lookup, measure, measure_ratio, deformation_gradient
19
+ from .operator import inner, grad, div, outer, grad_outer, div_outer
20
+ from .operator import degree, at_node
21
+ from .operator import D, curl, jump, average, grad_jump, grad_average
22
+
23
+ from .types import Sample, Field, Domain, Coords, ElementIndex
24
+
25
+ from .dirichlet import project_linear_system, normalize_dirichlet_projector
26
+
27
+ from .cache import TemporaryStore, set_default_temporary_store, borrow_temporary, borrow_temporary_like
warp/fem/cache.py CHANGED
@@ -1,22 +1,26 @@
1
- from typing import Callable, Optional
1
+ from typing import Callable, Optional, Union, Tuple, Dict, Any
2
+ from copy import copy
3
+ import bisect
4
+ import re
2
5
 
3
- import warp as wp
4
6
 
5
- from warp.fem.operator import Integrand
7
+ import warp as wp
6
8
 
7
- import re
8
9
 
9
10
  _kernel_cache = dict()
10
11
  _struct_cache = dict()
11
12
  _func_cache = dict()
12
13
 
13
-
14
14
  _key_re = re.compile("[^0-9a-zA-Z_]+")
15
15
 
16
16
 
17
- def get_func(func, suffix=""):
18
- key = f"{func.__name__}_{suffix}"
19
- key = _key_re.sub("", key)
17
+ def _make_key(obj, suffix: str, use_qualified_name):
18
+ base_name = f"{obj.__module__}.{obj.__qualname__}" if use_qualified_name else obj.__name__
19
+ return _key_re.sub("", f"{base_name}_{suffix}")
20
+
21
+
22
+ def get_func(func, suffix: str, use_qualified_name: bool = False):
23
+ key = _make_key(func, suffix, use_qualified_name)
20
24
 
21
25
  if key not in _func_cache:
22
26
  _func_cache[key] = wp.Function(
@@ -31,23 +35,49 @@ def get_func(func, suffix=""):
31
35
  return _func_cache[key]
32
36
 
33
37
 
34
- def get_kernel(func, suffix=""):
35
- module = wp.get_module(func.__module__)
36
- key = func.__name__ + "_" + suffix
37
- key = _key_re.sub("", key)
38
+ def dynamic_func(suffix: str, use_qualified_name=False):
39
+ def wrap_func(func: Callable):
40
+ return get_func(func, suffix=suffix, use_qualified_name=use_qualified_name)
41
+
42
+ return wrap_func
43
+
44
+
45
+ def get_kernel(
46
+ func,
47
+ suffix: str,
48
+ use_qualified_name: bool = False,
49
+ kernel_options: Dict[str, Any] = {},
50
+ ):
51
+ key = _make_key(func, suffix, use_qualified_name)
38
52
 
39
53
  if key not in _kernel_cache:
54
+ # Avoid creating too long file names -- can lead to issues on Windows
55
+ # We could hash the key, but prefer to keep it human-readable
56
+ module_name = f"{func.__module__}.dyn.{key}"
57
+ module_name = module_name[:128] if len(module_name) > 128 else module_name
58
+ module = wp.get_module(module_name)
59
+ module.options = copy(wp.get_module(func.__module__).options)
60
+ module.options.update(kernel_options)
40
61
  _kernel_cache[key] = wp.Kernel(func=func, key=key, module=module)
41
62
  return _kernel_cache[key]
42
63
 
43
64
 
44
- def get_struct(Fields):
45
- module = wp.get_module(Fields.__module__)
46
- key = _key_re.sub("", Fields.__qualname__)
65
+ def dynamic_kernel(suffix: str, use_qualified_name=False, kernel_options: Dict[str, Any] = {}):
66
+ def wrap_kernel(func: Callable):
67
+ return get_kernel(func, suffix=suffix, use_qualified_name=use_qualified_name, kernel_options=kernel_options)
68
+
69
+ return wrap_kernel
70
+
71
+
72
+ def get_struct(struct: type, suffix: str, use_qualified_name: bool = False):
73
+ key = _make_key(struct, suffix, use_qualified_name)
74
+ # used in codegen
75
+ struct.__qualname__ = key
47
76
 
48
77
  if key not in _struct_cache:
78
+ module = wp.get_module(struct.__module__)
49
79
  _struct_cache[key] = wp.codegen.Struct(
50
- cls=Fields,
80
+ cls=struct,
51
81
  key=key,
52
82
  module=module,
53
83
  )
@@ -55,18 +85,25 @@ def get_struct(Fields):
55
85
  return _struct_cache[key]
56
86
 
57
87
 
88
+ def dynamic_struct(suffix: str, use_qualified_name=False):
89
+ def wrap_struct(struct: type):
90
+ return get_struct(struct, suffix=suffix, use_qualified_name=use_qualified_name)
91
+
92
+ return wrap_struct
93
+
94
+
58
95
  def get_integrand_function(
59
- integrand: Integrand,
96
+ integrand: "warp.fem.operator.Integrand",
60
97
  suffix: str,
98
+ func=None,
61
99
  annotations=None,
62
100
  code_transformers=[],
63
101
  ):
64
- key = integrand.name + suffix
65
- key = _key_re.sub("", key)
102
+ key = _make_key(integrand.func, suffix, use_qualified_name=True)
66
103
 
67
104
  if key not in _func_cache:
68
105
  _func_cache[key] = wp.Function(
69
- func=integrand.func,
106
+ func=integrand.func if func is None else func,
70
107
  key=key,
71
108
  namespace="",
72
109
  module=integrand.module,
@@ -78,19 +115,275 @@ def get_integrand_function(
78
115
 
79
116
 
80
117
  def get_integrand_kernel(
81
- integrand: Integrand,
118
+ integrand: "warp.fem.operator.Integrand",
82
119
  suffix: str,
83
120
  kernel_fn: Optional[Callable] = None,
121
+ kernel_options: Dict[str, Any] = {},
84
122
  code_transformers=[],
85
123
  ):
86
- module = wp.get_module(f"{integrand.module.name}.{integrand.name}")
87
- module.options = integrand.module.options
88
- key = integrand.name + "_" + suffix
89
- key = _key_re.sub("", key)
124
+ key = _make_key(integrand.func, suffix, use_qualified_name=True)
90
125
 
91
126
  if key not in _kernel_cache:
92
127
  if kernel_fn is None:
93
128
  return None
94
129
 
130
+ module = wp.get_module(f"{integrand.module.name}.{integrand.name}")
131
+ module.options = copy(integrand.module.options)
132
+ module.options.update(kernel_options)
133
+
95
134
  _kernel_cache[key] = wp.Kernel(func=kernel_fn, key=key, module=module, code_transformers=code_transformers)
96
135
  return _kernel_cache[key]
136
+
137
+
138
+ def cached_arg_value(func: Callable):
139
+ """Decorator to be applied to member methods assembling Arg structs, so that the result gets
140
+ automatically cached for the lifetime of the parent object
141
+ """
142
+
143
+ cache_attr = f"_{func.__name__}_cache"
144
+
145
+ def get_arg(obj, device):
146
+ if not hasattr(obj, cache_attr):
147
+ setattr(obj, cache_attr, {})
148
+
149
+ cache = getattr(obj, cache_attr, {})
150
+
151
+ device = wp.get_device(device)
152
+ if device.ordinal not in cache:
153
+ cache[device.ordinal] = func(obj, device)
154
+
155
+ return cache[device.ordinal]
156
+
157
+ return get_arg
158
+
159
+
160
+ _cached_vec_types = {}
161
+ _cached_mat_types = {}
162
+
163
+
164
+ def cached_vec_type(length, dtype):
165
+ key = (length, dtype)
166
+ if key not in _cached_vec_types:
167
+ _cached_vec_types[key] = wp.vec(length=length, dtype=dtype)
168
+
169
+ return _cached_vec_types[key]
170
+
171
+
172
+ def cached_mat_type(shape, dtype):
173
+ key = (*shape, dtype)
174
+ if key not in _cached_mat_types:
175
+ _cached_mat_types[key] = wp.mat(shape=shape, dtype=dtype)
176
+
177
+ return _cached_mat_types[key]
178
+
179
+
180
+ class Temporary:
181
+ """Handle over a temporary array from a :class:`TemporaryStore`.
182
+
183
+ The array will be automatically returned to the temporary pool for re-use upon destruction of this object, unless
184
+ the temporary is explicitly detached from the pool using :meth:`detach`.
185
+ The temporary may also be explicitly returned to the pool before destruction using :meth:`release`.
186
+ """
187
+
188
+ def __init__(self, array: wp.array, pool: Optional["TemporaryStore.Pool"] = None, shape=None, dtype=None):
189
+ self._raw_array = array
190
+ self._array_view = array
191
+ self._pool = pool
192
+
193
+ if shape is not None or dtype is not None:
194
+ self._view_as(shape=shape, dtype=dtype)
195
+
196
+ def detach(self) -> wp.array:
197
+ """Detaches the temporary so it is never returned to the pool"""
198
+ if self._pool is not None:
199
+ self._pool.detach(self._raw_array)
200
+
201
+ self._pool = None
202
+ return self._array_view
203
+
204
+ def release(self):
205
+ """Returns the temporary array to the pool"""
206
+ if self._pool is not None:
207
+ self._pool.redeem(self._raw_array)
208
+
209
+ self._pool = None
210
+
211
+ @property
212
+ def array(self) -> wp.array:
213
+ """View of the array with desired shape and data type."""
214
+ return self._array_view
215
+
216
+ def _view_as(self, shape, dtype) -> "Temporary":
217
+ def _view_reshaped_truncated(array):
218
+ return wp.types.array(
219
+ ptr=array.ptr,
220
+ dtype=dtype,
221
+ shape=shape,
222
+ device=array.device,
223
+ pinned=array.pinned,
224
+ capacity=array.capacity,
225
+ copy=False,
226
+ owner=False,
227
+ grad=None if array.grad is None else _view_reshaped_truncated(array.grad),
228
+ )
229
+
230
+ self._array_view = _view_reshaped_truncated(self._raw_array)
231
+ return self
232
+
233
+ def __del__(self):
234
+ self.release()
235
+
236
+
237
+ class TemporaryStore:
238
+ """
239
+ Shared pool of temporary arrays that will be persisted and reused across invocations of ``warp.fem`` functions.
240
+
241
+ A :class:`TemporaryStore` instance may either be passed explicitly to ``warp.fem`` functions that accept such an argument, for instance :func:`.integrate.integrate`,
242
+ or can be set globally as the default store using :func:`set_default_temporary_store`.
243
+
244
+ By default, there is no default temporary store, so that temporary allocations are not persisted.
245
+ """
246
+
247
+ _default_store: "TemporaryStore" = None
248
+
249
+ class Pool:
250
+ def __init__(self, dtype, device, pinned: bool):
251
+ self.dtype = dtype
252
+ self.device = device
253
+ self.pinned = pinned
254
+
255
+ self._pool = [] # Currently available arrays for borrowing, ordered by size
256
+ self._pool_sizes = [] # Sizes of available arrays for borrowing, ascending
257
+ self._allocs = {} # All allocated arrays, including borrowed ones
258
+
259
+ def borrow(self, shape, dtype, requires_grad: bool):
260
+ size = 1
261
+ if isinstance(shape, int):
262
+ shape = (shape,)
263
+ for d in shape:
264
+ size *= d
265
+
266
+ index = bisect.bisect_left(
267
+ a=self._pool_sizes,
268
+ x=size,
269
+ )
270
+ if index < len(self._pool):
271
+ # Big enough array found, remove from pool
272
+ array = self._pool.pop(index)
273
+ self._pool_sizes.pop(index)
274
+ if requires_grad and array.grad is None:
275
+ array.requires_grad = True
276
+ return Temporary(pool=self, array=array, shape=shape, dtype=dtype)
277
+
278
+ # No big enough array found, allocate new one
279
+ if len(self._pool) > 0:
280
+ grow_factor = 1.5
281
+ size = max(int(self._pool_sizes[-1] * grow_factor), size)
282
+
283
+ array = wp.empty(
284
+ shape=(size,), dtype=self.dtype, pinned=self.pinned, device=self.device, requires_grad=requires_grad
285
+ )
286
+ self._allocs[array.ptr] = array
287
+ return Temporary(pool=self, array=array, shape=shape, dtype=dtype)
288
+
289
+ def redeem(self, array):
290
+ # Insert back array into available pool
291
+ index = bisect.bisect_left(
292
+ a=self._pool_sizes,
293
+ x=array.size,
294
+ )
295
+ self._pool.insert(index, array)
296
+ self._pool_sizes.insert(index, array.size)
297
+
298
+ def detach(self, array):
299
+ del self._allocs[array.ptr]
300
+
301
+ def __init__(self):
302
+ self.clear()
303
+
304
+ def clear(self):
305
+ self._temporaries = {}
306
+
307
+ def borrow(self, shape, dtype, pinned: bool = False, device=None, requires_grad: bool = False) -> Temporary:
308
+ dtype = wp.types.type_to_warp(dtype)
309
+ device = wp.get_device(device)
310
+
311
+ type_length = wp.types.type_length(dtype)
312
+ key = (dtype._type_, type_length, pinned, device.ordinal)
313
+
314
+ pool = self._temporaries.get(key, None)
315
+ if pool is None:
316
+ value_type = (
317
+ cached_vec_type(length=type_length, dtype=wp.types.type_scalar_type(dtype))
318
+ if type_length > 1
319
+ else dtype
320
+ )
321
+ pool = TemporaryStore.Pool(value_type, device, pinned=pinned)
322
+ self._temporaries[key] = pool
323
+
324
+ return pool.borrow(dtype=dtype, shape=shape, requires_grad=requires_grad)
325
+
326
+
327
+ def set_default_temporary_store(temporary_store: Optional[TemporaryStore]):
328
+ """Globally sets the default :class:`TemporaryStore` instance to use for temporary allocations in ``warp.fem`` functions.
329
+
330
+ If the default temporary store is set to ``None``, temporary allocations are not persisted unless a :class:`TemporaryStore` is provided at a per-function granularity.
331
+ """
332
+
333
+ TemporaryStore._default_store = temporary_store
334
+
335
+
336
+ def borrow_temporary(
337
+ temporary_store: Optional[TemporaryStore],
338
+ shape: Union[int, Tuple[int]],
339
+ dtype: type,
340
+ pinned: bool = False,
341
+ requires_grad: bool = False,
342
+ device=None,
343
+ ) -> Temporary:
344
+ """
345
+ Borrows and returns a temporary array with specified attributes from a shared pool.
346
+
347
+ If an array with sufficient capacity and matching desired attributes is already available in the pool, it will be returned.
348
+ Otherwise, a new allocation wil be performed.
349
+
350
+ Args:
351
+ temporary_store: the shared pool to borrow the temporary from. If `temporary_store` is ``None``, the global default temporary store, if set, will be used.
352
+ shape: desired dimensions for the temporary array
353
+ dtype: desired data type for the temporary array
354
+ pinned: whether a pinned allocation is desired
355
+ device: device on which the momory should be allocated; if ``None``, the current device will be used.
356
+ """
357
+
358
+ if temporary_store is None:
359
+ temporary_store = TemporaryStore._default_store
360
+
361
+ if temporary_store is None:
362
+ return Temporary(
363
+ array=wp.empty(shape=shape, dtype=dtype, pinned=pinned, device=device, requires_grad=requires_grad)
364
+ )
365
+
366
+ return temporary_store.borrow(shape=shape, dtype=dtype, device=device, pinned=pinned, requires_grad=requires_grad)
367
+
368
+
369
+ def borrow_temporary_like(
370
+ array: Union[wp.array, Temporary],
371
+ temporary_store: Optional[TemporaryStore],
372
+ ) -> Temporary:
373
+ """
374
+ Borrows and returns a temporary array with the same attributes as another array or temporary.
375
+
376
+ Args:
377
+ array: Warp or temporary array to read the desired attributes from
378
+ temporary_store: the shared pool to borrow the temporary from. If `temporary_store` is ``None``, the global default temporary store, if set, will be used.
379
+ """
380
+ if isinstance(array, Temporary):
381
+ array = array.array
382
+ return borrow_temporary(
383
+ temporary_store=temporary_store,
384
+ shape=array.shape,
385
+ dtype=array.dtype,
386
+ pinned=array.pinned,
387
+ device=array.device,
388
+ requires_grad=array.requires_grad,
389
+ )
warp/fem/dirichlet.py CHANGED
@@ -62,7 +62,7 @@ def normalize_dirichlet_projector(projector_matrix: BsrMatrix, fixed_value: Opti
62
62
 
63
63
 
64
64
  def project_system_rhs(
65
- system_matrix: BsrMatrix, system_rhs: wp.array, projector_matrix: BsrMatrix, fixed_value: wp.array
65
+ system_matrix: BsrMatrix, system_rhs: wp.array, projector_matrix: BsrMatrix, fixed_value: Optional[wp.array] = None
66
66
  ):
67
67
  """Projects the right-hand-side of a linear system to enforce Dirichlet boundary conditions
68
68
 
@@ -72,7 +72,11 @@ def project_system_rhs(
72
72
  rhs_tmp = wp.empty_like(system_rhs)
73
73
  rhs_tmp.assign(system_rhs)
74
74
 
75
- bsr_mv(A=projector_matrix, x=fixed_value, y=system_rhs, alpha=1.0, beta=0.0)
75
+ if fixed_value is None:
76
+ system_rhs.zero_()
77
+ else:
78
+ bsr_mv(A=projector_matrix, x=fixed_value, y=system_rhs, alpha=1.0, beta=0.0)
79
+
76
80
  bsr_mv(A=system_matrix, x=system_rhs, y=rhs_tmp, alpha=-1.0, beta=1.0)
77
81
 
78
82
  # here rhs_tmp = system_rhs - system_matrix * projector * fixed_value
@@ -99,7 +103,7 @@ def project_linear_system(
99
103
  system_matrix: BsrMatrix,
100
104
  system_rhs: wp.array,
101
105
  projector_matrix: BsrMatrix,
102
- fixed_value: wp.array,
106
+ fixed_value: Optional[wp.array] = None,
103
107
  normalize_projector=True,
104
108
  ):
105
109
  """
warp/fem/domain.py CHANGED
@@ -47,6 +47,7 @@ class GeometryDomain:
47
47
  """Kind of elements that this domain contains (cells or sides)"""
48
48
  raise NotImplementedError
49
49
 
50
+ @property
50
51
  def dimension(self) -> int:
51
52
  """Dimension of the elements of the domain"""
52
53
  raise NotImplementedError
@@ -89,6 +90,9 @@ class GeometryDomain:
89
90
  element_position: wp.Function
90
91
  """Device function returning the element position at a sample point"""
91
92
 
93
+ element_deformation_gradient: wp.Function
94
+ """Device function returning the gradient of the position with respect to the element's reference space"""
95
+
92
96
  element_normal: wp.Function
93
97
  """Device function returning the element normal at a sample point"""
94
98
 
@@ -106,6 +110,7 @@ class Cells(GeometryDomain):
106
110
  def element_kind(self) -> GeometryDomain.ElementKind:
107
111
  return GeometryDomain.ElementKind.CELL
108
112
 
113
+ @property
109
114
  def dimension(self) -> int:
110
115
  return self.geometry.dimension
111
116
 
@@ -140,6 +145,10 @@ class Cells(GeometryDomain):
140
145
  def element_position(self) -> wp.Function:
141
146
  return self.geometry.cell_position
142
147
 
148
+ @property
149
+ def element_deformation_gradient(self) -> wp.Function:
150
+ return self.geometry.cell_deformation_gradient
151
+
143
152
  @property
144
153
  def element_measure(self) -> wp.Function:
145
154
  return self.geometry.cell_measure
@@ -168,6 +177,7 @@ class Sides(GeometryDomain):
168
177
  def element_kind(self) -> GeometryDomain.ElementKind:
169
178
  return GeometryDomain.ElementKind.SIDE
170
179
 
180
+ @property
171
181
  def dimension(self) -> int:
172
182
  return self.geometry.dimension - 1
173
183
 
@@ -202,6 +212,10 @@ class Sides(GeometryDomain):
202
212
  def element_position(self) -> wp.Function:
203
213
  return self.geometry.side_position
204
214
 
215
+ @property
216
+ def element_deformation_gradient(self) -> wp.Function:
217
+ return self.geometry.side_deformation_gradient
218
+
205
219
  @property
206
220
  def element_measure(self) -> wp.Function:
207
221
  return self.geometry.side_measure
@@ -3,15 +3,13 @@ from typing import Union, Optional
3
3
  from warp.fem.domain import GeometryDomain, Cells
4
4
  from warp.fem.space import FunctionSpace, SpaceRestriction, SpacePartition, make_space_partition, make_space_restriction
5
5
 
6
- from .discrete_field import DiscreteField
6
+ from .field import DiscreteField, SpaceField, FieldLike
7
7
  from .restriction import FieldRestriction
8
8
  from .test import TestField
9
9
  from .trial import TrialField
10
10
 
11
11
  from .nodal_field import NodalField
12
12
 
13
- FieldLike = Union[DiscreteField, FieldRestriction, TestField, TrialField]
14
-
15
13
 
16
14
  def make_restriction(
17
15
  field: DiscreteField,
@@ -33,77 +31,71 @@ def make_restriction(
33
31
  """
34
32
 
35
33
  if space_restriction is None:
36
- space_restriction = make_space_restriction(
37
- space=field.space, space_partition=field.space_partition, domain=domain, device=device
38
- )
34
+ space_restriction = make_space_restriction(space_partition=field.space_partition, domain=domain, device=device)
39
35
 
40
36
  return FieldRestriction(field=field, space_restriction=space_restriction)
41
37
 
42
38
 
43
39
  def make_test(
44
- space: Union[FunctionSpace, SpaceRestriction] = None,
45
- space_partition: SpacePartition = None,
46
- domain: GeometryDomain = None,
40
+ space: FunctionSpace,
41
+ space_restriction: Optional[SpaceRestriction] = None,
42
+ space_partition: Optional[SpacePartition] = None,
43
+ domain: Optional[GeometryDomain] = None,
47
44
  device=None,
48
45
  ) -> TestField:
49
46
  """
50
47
  Constructs a test field over a function space or its restriction
51
48
 
52
49
  Args:
53
- space: the function space or function space restriction
54
- space_partition: if ``space`` is a whole function space, the optional subset of node indices to consider
55
- domain: if ``space`` is a whole function space, the optional subset of elements to consider
50
+ space: the function space
51
+ space_restriction: restriction of the space topology to a domain
52
+ space_partition: if `space_restriction` is ``None``, the optional subset of node indices to consider
53
+ domain: if `space_restriction` is ``None``, optional subset of elements to consider
56
54
  device: Warp device on which to perform and store computations
57
55
 
58
56
  Returns:
59
57
  the test field
60
58
  """
61
59
 
62
- if not isinstance(space, SpaceRestriction):
63
- if space is None:
64
- space = space_partition.space
65
-
66
- if domain is None:
67
- domain = Cells(geometry=space.geometry)
68
-
69
- if space_partition is None:
70
- space_partition = make_space_partition(space, domain.geometry_partition)
71
-
72
- space = make_space_restriction(space=space, space_partition=space_partition, domain=domain, device=device)
60
+ if space_restriction is None:
61
+ space_restriction = make_space_restriction(
62
+ space_topology=space.topology, space_partition=space_partition, domain=domain, device=device
63
+ )
73
64
 
74
- return TestField(space)
65
+ return TestField(space_restriction=space_restriction, space=space)
75
66
 
76
67
 
77
68
  def make_trial(
78
- space: Union[FunctionSpace, SpaceRestriction] = None,
79
- space_partition: SpacePartition = None,
80
- domain: GeometryDomain = None,
69
+ space: FunctionSpace,
70
+ space_restriction: Optional[SpaceRestriction] = None,
71
+ space_partition: Optional[SpacePartition] = None,
72
+ domain: Optional[GeometryDomain] = None,
81
73
  ) -> TrialField:
82
74
  """
83
75
  Constructs a trial field over a function space or partition
84
76
 
85
77
  Args:
86
78
  space: the function space or function space restriction
87
- space_partition: if ``space`` is a whole function space, the optional subset of node indices to consider
88
- domain: if ``space`` is a whole function space, the optional subset of elements to consider
79
+ space_restriction: restriction of the space topology to a domain
80
+ space_partition: if `space_restriction` is ``None``, the optional subset of node indices to consider
81
+ domain: if `space_restriction` is ``None``, optional subset of elements to consider
89
82
  device: Warp device on which to perform and store computations
90
83
 
91
84
  Returns:
92
85
  the trial field
93
86
  """
94
87
 
95
- if isinstance(space, SpaceRestriction):
88
+ if space_restriction is not None:
96
89
  domain = space.domain
97
- space = space.space
98
90
  space_partition = space.space_partition
99
91
 
100
- if space is None:
101
- space = space_partition.space
102
-
103
- if domain is None:
104
- domain = Cells(geometry=space.geometry)
105
-
106
92
  if space_partition is None:
107
- space_partition = make_space_partition(space, domain.geometry_partition)
93
+ if domain is None:
94
+ domain = Cells(geometry=space.geometry)
95
+ space_partition = make_space_partition(
96
+ space_topology=space.topology, geometry_partition=domain.geometry_partition
97
+ )
98
+ elif domain is None:
99
+ domain = Cells(geometry=space_partition.geo_partition)
108
100
 
109
101
  return TrialField(space, space_partition, domain)