warp-lang 1.7.2rc1__py3-none-macosx_10_13_universal2.whl → 1.8.1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (192) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +241 -252
  6. warp/build_dll.py +130 -26
  7. warp/builtins.py +1907 -384
  8. warp/codegen.py +272 -104
  9. warp/config.py +12 -1
  10. warp/constants.py +1 -1
  11. warp/context.py +770 -238
  12. warp/dlpack.py +1 -1
  13. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  14. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  15. warp/examples/core/example_sample_mesh.py +1 -1
  16. warp/examples/core/example_spin_lock.py +93 -0
  17. warp/examples/core/example_work_queue.py +118 -0
  18. warp/examples/fem/example_adaptive_grid.py +5 -5
  19. warp/examples/fem/example_apic_fluid.py +1 -1
  20. warp/examples/fem/example_burgers.py +1 -1
  21. warp/examples/fem/example_convection_diffusion.py +9 -6
  22. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  23. warp/examples/fem/example_deformed_geometry.py +1 -1
  24. warp/examples/fem/example_diffusion.py +2 -2
  25. warp/examples/fem/example_diffusion_3d.py +1 -1
  26. warp/examples/fem/example_distortion_energy.py +1 -1
  27. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  28. warp/examples/fem/example_magnetostatics.py +5 -3
  29. warp/examples/fem/example_mixed_elasticity.py +5 -3
  30. warp/examples/fem/example_navier_stokes.py +11 -9
  31. warp/examples/fem/example_nonconforming_contact.py +5 -3
  32. warp/examples/fem/example_streamlines.py +8 -3
  33. warp/examples/fem/utils.py +9 -8
  34. warp/examples/interop/example_jax_callable.py +34 -4
  35. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  36. warp/examples/interop/example_jax_kernel.py +27 -1
  37. warp/examples/optim/example_drone.py +1 -1
  38. warp/examples/sim/example_cloth.py +1 -1
  39. warp/examples/sim/example_cloth_self_contact.py +48 -54
  40. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  41. warp/examples/tile/example_tile_cholesky.py +2 -1
  42. warp/examples/tile/example_tile_convolution.py +1 -1
  43. warp/examples/tile/example_tile_filtering.py +1 -1
  44. warp/examples/tile/example_tile_matmul.py +1 -1
  45. warp/examples/tile/example_tile_mlp.py +2 -0
  46. warp/fabric.py +7 -7
  47. warp/fem/__init__.py +5 -0
  48. warp/fem/adaptivity.py +1 -1
  49. warp/fem/cache.py +152 -63
  50. warp/fem/dirichlet.py +2 -2
  51. warp/fem/domain.py +136 -6
  52. warp/fem/field/field.py +141 -99
  53. warp/fem/field/nodal_field.py +85 -39
  54. warp/fem/field/virtual.py +99 -52
  55. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  56. warp/fem/geometry/closest_point.py +13 -0
  57. warp/fem/geometry/deformed_geometry.py +102 -40
  58. warp/fem/geometry/element.py +56 -2
  59. warp/fem/geometry/geometry.py +323 -22
  60. warp/fem/geometry/grid_2d.py +157 -62
  61. warp/fem/geometry/grid_3d.py +116 -20
  62. warp/fem/geometry/hexmesh.py +86 -20
  63. warp/fem/geometry/nanogrid.py +166 -86
  64. warp/fem/geometry/partition.py +59 -25
  65. warp/fem/geometry/quadmesh.py +86 -135
  66. warp/fem/geometry/tetmesh.py +47 -119
  67. warp/fem/geometry/trimesh.py +77 -270
  68. warp/fem/integrate.py +181 -95
  69. warp/fem/linalg.py +25 -58
  70. warp/fem/operator.py +124 -27
  71. warp/fem/quadrature/pic_quadrature.py +36 -14
  72. warp/fem/quadrature/quadrature.py +40 -16
  73. warp/fem/space/__init__.py +1 -1
  74. warp/fem/space/basis_function_space.py +66 -46
  75. warp/fem/space/basis_space.py +17 -4
  76. warp/fem/space/dof_mapper.py +1 -1
  77. warp/fem/space/function_space.py +2 -2
  78. warp/fem/space/grid_2d_function_space.py +4 -1
  79. warp/fem/space/hexmesh_function_space.py +4 -2
  80. warp/fem/space/nanogrid_function_space.py +3 -1
  81. warp/fem/space/partition.py +11 -2
  82. warp/fem/space/quadmesh_function_space.py +4 -1
  83. warp/fem/space/restriction.py +5 -2
  84. warp/fem/space/shape/__init__.py +10 -8
  85. warp/fem/space/tetmesh_function_space.py +4 -1
  86. warp/fem/space/topology.py +52 -21
  87. warp/fem/space/trimesh_function_space.py +4 -1
  88. warp/fem/utils.py +53 -8
  89. warp/jax.py +1 -2
  90. warp/jax_experimental/ffi.py +210 -67
  91. warp/jax_experimental/xla_ffi.py +37 -24
  92. warp/math.py +171 -1
  93. warp/native/array.h +103 -4
  94. warp/native/builtin.h +182 -35
  95. warp/native/coloring.cpp +6 -2
  96. warp/native/cuda_util.cpp +1 -1
  97. warp/native/exports.h +118 -63
  98. warp/native/intersect.h +5 -5
  99. warp/native/mat.h +8 -13
  100. warp/native/mathdx.cpp +11 -5
  101. warp/native/matnn.h +1 -123
  102. warp/native/mesh.h +1 -1
  103. warp/native/quat.h +34 -6
  104. warp/native/rand.h +7 -7
  105. warp/native/sparse.cpp +121 -258
  106. warp/native/sparse.cu +181 -274
  107. warp/native/spatial.h +305 -17
  108. warp/native/svd.h +23 -8
  109. warp/native/tile.h +603 -73
  110. warp/native/tile_radix_sort.h +1112 -0
  111. warp/native/tile_reduce.h +239 -13
  112. warp/native/tile_scan.h +240 -0
  113. warp/native/tuple.h +189 -0
  114. warp/native/vec.h +10 -20
  115. warp/native/warp.cpp +36 -4
  116. warp/native/warp.cu +588 -52
  117. warp/native/warp.h +47 -74
  118. warp/optim/linear.py +5 -1
  119. warp/paddle.py +7 -8
  120. warp/py.typed +0 -0
  121. warp/render/render_opengl.py +110 -80
  122. warp/render/render_usd.py +124 -62
  123. warp/sim/__init__.py +9 -0
  124. warp/sim/collide.py +253 -80
  125. warp/sim/graph_coloring.py +8 -1
  126. warp/sim/import_mjcf.py +4 -3
  127. warp/sim/import_usd.py +11 -7
  128. warp/sim/integrator.py +5 -2
  129. warp/sim/integrator_euler.py +1 -1
  130. warp/sim/integrator_featherstone.py +1 -1
  131. warp/sim/integrator_vbd.py +761 -322
  132. warp/sim/integrator_xpbd.py +1 -1
  133. warp/sim/model.py +265 -260
  134. warp/sim/utils.py +10 -7
  135. warp/sparse.py +303 -166
  136. warp/tape.py +54 -51
  137. warp/tests/cuda/test_conditional_captures.py +1046 -0
  138. warp/tests/cuda/test_streams.py +1 -1
  139. warp/tests/geometry/test_volume.py +2 -2
  140. warp/tests/interop/test_dlpack.py +9 -9
  141. warp/tests/interop/test_jax.py +0 -1
  142. warp/tests/run_coverage_serial.py +1 -1
  143. warp/tests/sim/disabled_kinematics.py +2 -2
  144. warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
  145. warp/tests/sim/test_collision.py +159 -51
  146. warp/tests/sim/test_coloring.py +91 -2
  147. warp/tests/test_array.py +254 -2
  148. warp/tests/test_array_reduce.py +2 -2
  149. warp/tests/test_assert.py +53 -0
  150. warp/tests/test_atomic_cas.py +312 -0
  151. warp/tests/test_codegen.py +142 -19
  152. warp/tests/test_conditional.py +47 -1
  153. warp/tests/test_ctypes.py +0 -20
  154. warp/tests/test_devices.py +8 -0
  155. warp/tests/test_fabricarray.py +4 -2
  156. warp/tests/test_fem.py +58 -25
  157. warp/tests/test_func.py +42 -1
  158. warp/tests/test_grad.py +1 -1
  159. warp/tests/test_lerp.py +1 -3
  160. warp/tests/test_map.py +481 -0
  161. warp/tests/test_mat.py +23 -24
  162. warp/tests/test_quat.py +28 -15
  163. warp/tests/test_rounding.py +10 -38
  164. warp/tests/test_runlength_encode.py +7 -7
  165. warp/tests/test_smoothstep.py +1 -1
  166. warp/tests/test_sparse.py +83 -2
  167. warp/tests/test_spatial.py +507 -1
  168. warp/tests/test_static.py +48 -0
  169. warp/tests/test_struct.py +2 -2
  170. warp/tests/test_tape.py +38 -0
  171. warp/tests/test_tuple.py +265 -0
  172. warp/tests/test_types.py +2 -2
  173. warp/tests/test_utils.py +24 -18
  174. warp/tests/test_vec.py +38 -408
  175. warp/tests/test_vec_constructors.py +325 -0
  176. warp/tests/tile/test_tile.py +438 -131
  177. warp/tests/tile/test_tile_mathdx.py +518 -14
  178. warp/tests/tile/test_tile_matmul.py +179 -0
  179. warp/tests/tile/test_tile_reduce.py +307 -5
  180. warp/tests/tile/test_tile_shared_memory.py +136 -7
  181. warp/tests/tile/test_tile_sort.py +121 -0
  182. warp/tests/unittest_suites.py +14 -6
  183. warp/types.py +462 -308
  184. warp/utils.py +647 -86
  185. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
  186. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +189 -175
  187. warp/stubs.py +0 -3381
  188. warp/tests/sim/test_xpbd.py +0 -399
  189. warp/tests/test_mlp.py +0 -282
  190. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
  191. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
  192. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
warp/fem/cache.py CHANGED
@@ -15,13 +15,15 @@
15
15
 
16
16
  import ast
17
17
  import bisect
18
+ import hashlib
18
19
  import re
19
20
  import weakref
20
- from copy import copy
21
21
  from typing import Any, Callable, Dict, Optional, Tuple, Union
22
22
 
23
23
  import warp as wp
24
+ from warp.codegen import get_annotations
24
25
  from warp.fem.operator import Integrand
26
+ from warp.fem.types import Domain, Field
25
27
 
26
28
  _kernel_cache = {}
27
29
  _struct_cache = {}
@@ -30,31 +32,88 @@ _func_cache = {}
30
32
  _key_re = re.compile("[^0-9a-zA-Z_]+")
31
33
 
32
34
 
33
- def _make_key(obj, suffix: str, use_qualified_name):
34
- base_name = f"{obj.__module__}.{obj.__qualname__}" if use_qualified_name else obj.__name__
35
- return _key_re.sub("", f"{base_name}_{suffix}")
35
+ def _make_key(obj, suffix: str, options: Optional[Dict[str, Any]] = None):
36
+ # human-readable part
37
+ key = _key_re.sub("", f"{obj.__name__}_{suffix}")
36
38
 
39
+ sorted_opts = sorted(options.items()) if options is not None else ()
40
+ opts_str = "".join(
41
+ (
42
+ obj.__module__,
43
+ obj.__qualname__,
44
+ suffix,
45
+ *(opt[0] for opt in sorted_opts),
46
+ *(str(opt[1]) for opt in sorted_opts),
47
+ )
48
+ )
49
+ uid = hashlib.blake2b(bytes(opts_str, encoding="utf-8"), digest_size=4).hexdigest()
37
50
 
38
- def get_func(func, suffix: str, use_qualified_name: bool = False, code_transformers=None):
39
- key = _make_key(func, suffix, use_qualified_name)
51
+ # avoid long keys, issues on win
52
+ key = f"{key[:64]}_{uid}"
40
53
 
41
- if key not in _func_cache:
42
- _func_cache[key] = wp.Function(
43
- func=func,
44
- key=key,
45
- namespace="",
46
- module=wp.get_module(
47
- func.__module__,
48
- ),
54
+ return key
55
+
56
+
57
+ def _arg_type_name(arg_type):
58
+ if isinstance(arg_type, str):
59
+ return arg_type
60
+ if arg_type in (Field, Domain):
61
+ return ""
62
+ return wp.types.get_type_code(wp.types.type_to_warp(arg_type))
63
+
64
+
65
+ def _make_cache_key(func, key, argspec=None):
66
+ if argspec is None:
67
+ annotations = get_annotations(func)
68
+ else:
69
+ annotations = argspec.annotations
70
+
71
+ sig_key = (key, tuple((k, _arg_type_name(v)) for k, v in annotations.items()))
72
+ return sig_key
73
+
74
+
75
+ def _register_function(
76
+ func,
77
+ key,
78
+ module,
79
+ **kwargs,
80
+ ):
81
+ # wp.Function will override existing func for a given key...
82
+ # manually add back our overloads
83
+ existing = module.functions.get(key)
84
+ new_fn = wp.Function(
85
+ func=func,
86
+ key=key,
87
+ namespace="",
88
+ module=module,
89
+ **kwargs,
90
+ )
91
+
92
+ if existing:
93
+ existing.add_overload(new_fn)
94
+ module.functions[key] = existing
95
+ return module.functions[key]
96
+
97
+
98
+ def get_func(func, suffix: str, code_transformers=None):
99
+ key = _make_key(func, suffix)
100
+ cache_key = _make_cache_key(func, key)
101
+
102
+ if cache_key not in _func_cache:
103
+ module = wp.get_module(func.__module__)
104
+ _func_cache[cache_key] = _register_function(
105
+ func,
106
+ key,
107
+ module,
49
108
  code_transformers=code_transformers,
50
109
  )
51
110
 
52
- return _func_cache[key]
111
+ return _func_cache[cache_key]
53
112
 
54
113
 
55
- def dynamic_func(suffix: str, use_qualified_name=False, code_transformers=None):
114
+ def dynamic_func(suffix: str, code_transformers=None):
56
115
  def wrap_func(func: Callable):
57
- return get_func(func, suffix=suffix, use_qualified_name=use_qualified_name, code_transformers=code_transformers)
116
+ return get_func(func, suffix=suffix, code_transformers=code_transformers)
58
117
 
59
118
  return wrap_func
60
119
 
@@ -62,38 +121,35 @@ def dynamic_func(suffix: str, use_qualified_name=False, code_transformers=None):
62
121
  def get_kernel(
63
122
  func,
64
123
  suffix: str,
65
- use_qualified_name: bool = False,
66
- kernel_options: Dict[str, Any] = None,
124
+ kernel_options: Optional[Dict[str, Any]] = None,
67
125
  ):
68
126
  if kernel_options is None:
69
127
  kernel_options = {}
70
128
 
71
- key = _make_key(func, suffix, use_qualified_name)
129
+ key = _make_key(func, suffix, kernel_options)
130
+ cache_key = _make_cache_key(func, key)
72
131
 
73
- if key not in _kernel_cache:
74
- # Avoid creating too long file names -- can lead to issues on Windows
75
- # We could hash the key, but prefer to keep it human-readable
132
+ if cache_key not in _kernel_cache:
76
133
  module_name = f"{func.__module__}.dyn.{key}"
77
- module_name = module_name[:128] if len(module_name) > 128 else module_name
78
134
  module = wp.get_module(module_name)
79
- module.options = copy(wp.get_module(func.__module__).options)
135
+ module.options = dict(wp.get_module(func.__module__).options)
80
136
  module.options.update(kernel_options)
81
- _kernel_cache[key] = wp.Kernel(func=func, key=key, module=module)
82
- return _kernel_cache[key]
137
+ _kernel_cache[cache_key] = wp.Kernel(func=func, key=key, module=module, options=kernel_options)
138
+ return _kernel_cache[cache_key]
83
139
 
84
140
 
85
- def dynamic_kernel(suffix: str, use_qualified_name=False, kernel_options: Dict[str, Any] = None):
141
+ def dynamic_kernel(suffix: str, kernel_options: Optional[Dict[str, Any]] = None):
86
142
  if kernel_options is None:
87
143
  kernel_options = {}
88
144
 
89
145
  def wrap_kernel(func: Callable):
90
- return get_kernel(func, suffix=suffix, use_qualified_name=use_qualified_name, kernel_options=kernel_options)
146
+ return get_kernel(func, suffix=suffix, kernel_options=kernel_options)
91
147
 
92
148
  return wrap_kernel
93
149
 
94
150
 
95
- def get_struct(struct: type, suffix: str, use_qualified_name: bool = False):
96
- key = _make_key(struct, suffix, use_qualified_name)
151
+ def get_struct(struct: type, suffix: str):
152
+ key = _make_key(struct, suffix)
97
153
  # used in codegen
98
154
  struct.__qualname__ = key
99
155
 
@@ -108,9 +164,9 @@ def get_struct(struct: type, suffix: str, use_qualified_name: bool = False):
108
164
  return _struct_cache[key]
109
165
 
110
166
 
111
- def dynamic_struct(suffix: str, use_qualified_name=False):
167
+ def dynamic_struct(suffix: str):
112
168
  def wrap_struct(struct: type):
113
- return get_struct(struct, suffix=suffix, use_qualified_name=use_qualified_name)
169
+ return get_struct(struct, suffix=suffix)
114
170
 
115
171
  return wrap_struct
116
172
 
@@ -125,35 +181,36 @@ def get_argument_struct(arg_types: Dict[str, type]):
125
181
  setattr(Args, name, None)
126
182
  annotations[name] = arg_type
127
183
 
128
- def arg_type_name(arg_type):
129
- return wp.types.get_type_code(wp.types.type_to_warp(arg_type))
130
-
131
184
  try:
132
185
  Args.__annotations__ = annotations
133
186
  except AttributeError:
134
187
  Args.__dict__.__annotations__ = annotations
135
188
 
136
- suffix = "_".join([f"{name}_{arg_type_name(arg_type)}" for name, arg_type in annotations.items()])
189
+ suffix = "_".join([f"{name}_{_arg_type_name(arg_type)}" for name, arg_type in annotations.items()])
137
190
 
138
191
  return get_struct(Args, suffix=suffix)
139
192
 
140
193
 
141
- def populate_argument_struct(Args: wp.codegen.Struct, values: Dict[str, Any], func_name: str):
194
+ def populate_argument_struct(
195
+ Args: wp.codegen.Struct, values: Dict[str, Any], func_name: str, value_struct_values: Optional = None
196
+ ):
142
197
  if values is None:
143
198
  values = {}
144
199
 
145
- value_struct_values = Args()
146
- for k, v in values.items():
147
- try:
200
+ if value_struct_values is None:
201
+ value_struct_values = Args()
202
+
203
+ try:
204
+ for k, v in values.items():
148
205
  setattr(value_struct_values, k, v)
149
- except Exception as err:
150
- if k not in Args.vars:
151
- raise ValueError(
152
- f"Passed value argument '{k}' does not match any of the function '{func_name}' parameters"
153
- ) from err
206
+ except Exception as err:
207
+ if k not in Args.vars:
154
208
  raise ValueError(
155
- f"Passed value argument '{k}' of type '{wp.types.type_repr(v)}' is incompatible with the function '{func_name}' parameter of type '{wp.types.type_repr(Args.vars[k].type)}'"
209
+ f"Passed value argument '{k}' does not match any of the function '{func_name}' parameters"
156
210
  ) from err
211
+ raise ValueError(
212
+ f"Passed value argument '{k}' of type '{wp.types.type_repr(v)}' is incompatible with the function '{func_name}' parameter of type '{wp.types.type_repr(Args.vars[k].type)}'"
213
+ ) from err
157
214
 
158
215
  missing_values = Args.vars.keys() - values.keys()
159
216
  if missing_values:
@@ -208,26 +265,26 @@ def get_integrand_function(
208
265
  annotations=None,
209
266
  code_transformers=None,
210
267
  ):
211
- key = _make_key(integrand.func, suffix, use_qualified_name=True)
268
+ key = _make_key(integrand.func, suffix)
269
+ cache_key = _make_cache_key(integrand.func, key, integrand.argspec)
212
270
 
213
- if key not in _func_cache:
214
- _func_cache[key] = wp.Function(
271
+ if cache_key not in _func_cache:
272
+ _func_cache[cache_key] = _register_function(
215
273
  func=integrand.func if func is None else func,
216
274
  key=key,
217
- namespace="",
218
275
  module=integrand.module,
219
276
  overloaded_annotations=annotations,
220
277
  code_transformers=code_transformers,
221
278
  )
222
279
 
223
- return _func_cache[key]
280
+ return _func_cache[cache_key]
224
281
 
225
282
 
226
283
  def get_integrand_kernel(
227
284
  integrand: Integrand,
228
285
  suffix: str,
229
286
  kernel_fn: Optional[Callable] = None,
230
- kernel_options: Dict[str, Any] = None,
287
+ kernel_options: Optional[Dict[str, Any]] = None,
231
288
  code_transformers=None,
232
289
  ):
233
290
  options = integrand.module.options.copy()
@@ -235,15 +292,15 @@ def get_integrand_kernel(
235
292
  if kernel_options is not None:
236
293
  options.update(kernel_options)
237
294
 
238
- kernel_key = _make_key(integrand.func, suffix, use_qualified_name=True)
239
- opts_key = "".join([f"{k}:{v}" for k, v in sorted(options.items())])
240
- cache_key = kernel_key + opts_key
295
+ kernel_key = _make_key(integrand.func, suffix, options=options)
296
+ cache_key = _make_cache_key(integrand, kernel_key, integrand.argspec)
241
297
 
242
298
  if cache_key not in _kernel_cache:
243
299
  if kernel_fn is None:
244
300
  return None
245
301
 
246
- module = wp.get_module(f"{integrand.module.name}.{integrand.name}")
302
+ module = wp.get_module(f"{integrand.module.name}.{kernel_key}")
303
+ module.options = options
247
304
  _kernel_cache[cache_key] = wp.Kernel(
248
305
  func=kernel_fn, key=kernel_key, module=module, code_transformers=code_transformers, options=options
249
306
  )
@@ -272,6 +329,40 @@ def cached_arg_value(func: Callable):
272
329
  return get_arg
273
330
 
274
331
 
332
+ def setup_dynamic_attributes(
333
+ obj,
334
+ cls: Optional[type] = None,
335
+ constructors: Optional[Dict[str, Callable]] = None,
336
+ key: Optional[str] = None,
337
+ ):
338
+ if cls is None:
339
+ cls = type(obj)
340
+
341
+ if key is None:
342
+ key = obj.name
343
+
344
+ if constructors is None:
345
+ constructors = cls._dynamic_attribute_constructors
346
+
347
+ key = (key, frozenset(constructors.keys()))
348
+
349
+ if not hasattr(cls, "_cached_dynamic_attrs"):
350
+ cls._cached_dynamic_attrs = {}
351
+
352
+ attrs = cls._cached_dynamic_attrs.get(key)
353
+ if attrs is None:
354
+ attrs = {}
355
+ # create attributes one-by-one, as some may depend on previous ones
356
+ for k, v in constructors.items():
357
+ attr = v(obj)
358
+ attrs[k] = attr
359
+ setattr(obj, k, attr)
360
+ cls._cached_dynamic_attrs[key] = attrs
361
+ else:
362
+ for k, v in attrs.items():
363
+ setattr(obj, k, v)
364
+
365
+
275
366
  _cached_vec_types = {}
276
367
  _cached_mat_types = {}
277
368
 
@@ -301,7 +392,7 @@ class Temporary:
301
392
  """
302
393
 
303
394
  def __new__(cls, *args, **kwargs):
304
- instance = super(Temporary, cls).__new__(cls)
395
+ instance = super().__new__(cls)
305
396
  instance._pool = None
306
397
  return instance
307
398
 
@@ -447,15 +538,13 @@ class TemporaryStore:
447
538
  dtype = wp.types.type_to_warp(dtype)
448
539
  device = wp.get_device(device)
449
540
 
450
- type_length = wp.types.type_length(dtype)
451
- key = (dtype._type_, type_length, pinned, device.ordinal)
541
+ type_size = wp.types.type_size(dtype)
542
+ key = (dtype._type_, type_size, pinned, device.ordinal)
452
543
 
453
544
  pool = self._temporaries.get(key, None)
454
545
  if pool is None:
455
546
  value_type = (
456
- cached_vec_type(length=type_length, dtype=wp.types.type_scalar_type(dtype))
457
- if type_length > 1
458
- else dtype
547
+ cached_vec_type(length=type_size, dtype=wp.types.type_scalar_type(dtype)) if type_size > 1 else dtype
459
548
  )
460
549
  pool = TemporaryStore.Pool(value_type, device, pinned=pinned)
461
550
  self._temporaries[key] = pool
warp/fem/dirichlet.py CHANGED
@@ -18,7 +18,7 @@ from typing import Any, Optional
18
18
  import warp as wp
19
19
  from warp.fem.linalg import array_axpy, symmetric_eigenvalues_qr
20
20
  from warp.sparse import BsrMatrix, bsr_assign, bsr_axpy, bsr_copy, bsr_mm, bsr_mv
21
- from warp.types import type_is_matrix, type_length
21
+ from warp.types import type_is_matrix, type_size
22
22
 
23
23
 
24
24
  def normalize_dirichlet_projector(projector_matrix: BsrMatrix, fixed_value: Optional[wp.array] = None):
@@ -53,7 +53,7 @@ def normalize_dirichlet_projector(projector_matrix: BsrMatrix, fixed_value: Opti
53
53
  if fixed_value.shape[0] != projector_matrix.nrow:
54
54
  raise ValueError("Fixed value array must be of length equal to the number of rows of blocks")
55
55
 
56
- if type_length(fixed_value.dtype) == 1:
56
+ if type_size(fixed_value.dtype) == 1:
57
57
  # array of scalars, convert to 1d array of vectors
58
58
  fixed_value = wp.array(
59
59
  data=None,
warp/fem/domain.py CHANGED
@@ -13,6 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ from functools import cached_property
16
17
  from typing import Any, Optional, Set, Union
17
18
 
18
19
  import warp as wp
@@ -27,7 +28,7 @@ from warp.fem.geometry import (
27
28
  WholeGeometryPartition,
28
29
  )
29
30
  from warp.fem.operator import Operator
30
- from warp.fem.types import ElementKind
31
+ from warp.fem.types import NULL_ELEMENT_INDEX, ElementKind
31
32
 
32
33
  GeometryOrPartition = Union[Geometry, GeometryPartition]
33
34
 
@@ -89,6 +90,9 @@ class GeometryDomain:
89
90
  element_index: wp.Function
90
91
  """Device function for retrieving an ElementIndex from a linearized index"""
91
92
 
93
+ element_partition_index: wp.Function
94
+ """Device function for retrieving linearized index in the domain's partition from an ElementIndex"""
95
+
92
96
  ElementArg: warp.codegen.Struct
93
97
  """Structure containing arguments to be passed to device functions computing element geometry"""
94
98
 
@@ -107,13 +111,34 @@ class GeometryDomain:
107
111
  element_normal: wp.Function
108
112
  """Device function returning the element normal at a sample point"""
109
113
 
114
+ element_closest_point: wp.Function
115
+ """Device function returning the coordinates of the closest point in a given element to a world position"""
116
+
117
+ element_coordinates: wp.Function
118
+ """Device function returning the coordinates corresponding to a world position in a given element reference system"""
119
+
110
120
  element_lookup: wp.Function
111
- """Device function returning the sample point corresponding to a world position"""
121
+ """Device function returning the sample point in the domain's geometry corresponding to a world position"""
122
+
123
+ element_partition_lookup: wp.Function
124
+ """Device function returning the sample point in the domain's geometry partition corresponding to a world position"""
112
125
 
113
126
  def notify_operator_usage(self, ops: Set[Operator]):
114
127
  """Makes the Domain aware that the operators `ops` will be applied"""
115
128
  pass
116
129
 
130
+ @cached_property
131
+ def DomainArg(self):
132
+ return self._make_domain_arg()
133
+
134
+ def _make_domain_arg(self):
135
+ @cache.dynamic_struct(suffix=self.name)
136
+ class DomainArg:
137
+ geo: self.ElementArg
138
+ index: self.ElementIndexArg
139
+
140
+ return DomainArg
141
+
117
142
 
118
143
  class Cells(GeometryDomain):
119
144
  """A Domain containing all cells of the geometry or geometry partition"""
@@ -145,13 +170,23 @@ class Cells(GeometryDomain):
145
170
  def element_index_arg_value(self, device: warp.context.Devicelike) -> warp.codegen.StructInstance:
146
171
  return self.geometry_partition.cell_arg_value(device)
147
172
 
173
+ def fill_element_index_arg(self, arg: ElementIndexArg, device: warp.context.Devicelike):
174
+ self.geometry_partition.fill_cell_arg(arg, device)
175
+
148
176
  @property
149
177
  def element_index(self) -> wp.Function:
150
178
  return self.geometry_partition.cell_index
151
179
 
180
+ @property
181
+ def element_partition_index(self) -> wp.Function:
182
+ return self.geometry_partition.partition_cell_index
183
+
152
184
  def element_arg_value(self, device: warp.context.Devicelike) -> warp.codegen.StructInstance:
153
185
  return self.geometry.cell_arg_value(device)
154
186
 
187
+ def fill_element_arg(self, arg: "ElementArg", device: warp.context.Devicelike):
188
+ self.geometry.fill_cell_arg(arg, device)
189
+
155
190
  @property
156
191
  def ElementArg(self) -> warp.codegen.Struct:
157
192
  return self.geometry.CellArg
@@ -176,10 +211,46 @@ class Cells(GeometryDomain):
176
211
  def element_normal(self) -> wp.Function:
177
212
  return self.geometry.cell_normal
178
213
 
214
+ @property
215
+ def element_closest_point(self) -> wp.Function:
216
+ return self.geometry.cell_closest_point
217
+
218
+ @property
219
+ def element_coordinates(self) -> wp.Function:
220
+ return self.geometry.cell_coordinates
221
+
179
222
  @property
180
223
  def element_lookup(self) -> wp.Function:
181
224
  return self.geometry.cell_lookup
182
225
 
226
+ @property
227
+ def element_partition_lookup(self) -> wp.Function:
228
+ pos_type = cache.cached_vec_type(self.geometry.dimension, dtype=float)
229
+
230
+ @cache.dynamic_func(suffix=self.geometry_partition.name)
231
+ def is_in_partition(args: self.ElementIndexArg, cell_index: int):
232
+ return self.geometry_partition.partition_cell_index(args, cell_index) != NULL_ELEMENT_INDEX
233
+
234
+ filtered_cell_lookup = self.geometry.make_filtered_cell_lookup(filter_func=is_in_partition)
235
+
236
+ # overloads
237
+ filter_target = True
238
+ pos_type = cache.cached_vec_type(self.geometry.dimension, dtype=float)
239
+
240
+ @cache.dynamic_func(suffix=self.name)
241
+ def cell_partition_lookup(args: self.DomainArg, pos: pos_type, max_dist: float):
242
+ return filtered_cell_lookup(args.geo, pos, max_dist, args.index, filter_target)
243
+
244
+ @cache.dynamic_func(suffix=self.name)
245
+ def cell_partition_lookup(args: self.DomainArg, pos: pos_type):
246
+ max_dist = 0.0
247
+ return filtered_cell_lookup(args.geo, pos, max_dist, args.index, filter_target)
248
+
249
+ return cell_partition_lookup
250
+
251
+ def supports_lookup(self, device):
252
+ return self.geometry.supports_cell_lookup(device)
253
+
183
254
  @property
184
255
  def domain_cell_arg(self) -> wp.Function:
185
256
  return Cells._identity_fn
@@ -200,6 +271,11 @@ class Sides(GeometryDomain):
200
271
  super().__init__(geometry)
201
272
 
202
273
  self.element_lookup = None
274
+ self.element_partition_lookup = None
275
+ self.element_filtered_lookup = None
276
+
277
+ def supports_lookup(self, device):
278
+ return False
203
279
 
204
280
  @property
205
281
  def element_kind(self) -> ElementKind:
@@ -225,6 +301,9 @@ class Sides(GeometryDomain):
225
301
  def element_index_arg_value(self, device: warp.context.Devicelike) -> warp.codegen.StructInstance:
226
302
  return self.geometry_partition.side_arg_value(device)
227
303
 
304
+ def fill_element_index_arg(self, arg: "ElementIndexArg", device: warp.context.Devicelike):
305
+ self.geometry_partition.fill_side_arg(arg, device)
306
+
228
307
  @property
229
308
  def element_index(self) -> wp.Function:
230
309
  return self.geometry_partition.side_index
@@ -236,6 +315,9 @@ class Sides(GeometryDomain):
236
315
  def element_arg_value(self, device: warp.context.Devicelike) -> warp.codegen.StructInstance:
237
316
  return self.geometry.side_arg_value(device)
238
317
 
318
+ def fill_element_arg(self, arg: "ElementArg", device: warp.context.Devicelike):
319
+ self.geometry.fill_side_arg(arg, device)
320
+
239
321
  @property
240
322
  def element_position(self) -> wp.Function:
241
323
  return self.geometry.side_position
@@ -256,6 +338,14 @@ class Sides(GeometryDomain):
256
338
  def element_normal(self) -> wp.Function:
257
339
  return self.geometry.side_normal
258
340
 
341
+ @property
342
+ def element_closest_point(self) -> wp.Function:
343
+ return self.geometry.side_closest_point
344
+
345
+ @property
346
+ def element_coordinates(self) -> wp.Function:
347
+ return self.geometry.side_coordinates
348
+
259
349
  @property
260
350
  def element_inner_cell_index(self) -> wp.Function:
261
351
  return self.geometry.side_inner_cell_index
@@ -276,9 +366,18 @@ class Sides(GeometryDomain):
276
366
  def cell_to_element_coords(self) -> wp.Function:
277
367
  return self.geometry.side_from_cell_coords
278
368
 
279
- @property
369
+ @cached_property
280
370
  def domain_cell_arg(self) -> wp.Function:
281
- return self.geometry.side_to_cell_arg
371
+ CellDomainArg = self.cell_domain().DomainArg
372
+
373
+ @cache.dynamic_func(suffix=self.name)
374
+ def domain_cell_arg(x: self.DomainArg):
375
+ return CellDomainArg(
376
+ self.geometry.side_to_cell_arg(x.geo),
377
+ self.geometry_partition.side_to_cell_arg(x.index),
378
+ )
379
+
380
+ return domain_cell_arg
282
381
 
283
382
  def cell_domain(self):
284
383
  return Cells(self.geometry_partition)
@@ -359,11 +458,13 @@ class Subdomain(GeometryDomain):
359
458
  self.geometry_element_count = self._domain.geometry_element_count
360
459
  self.reference_element = self._domain.reference_element
361
460
  self.element_arg_value = self._domain.element_arg_value
461
+ self.fill_element_arg = self._domain.fill_element_arg
362
462
  self.element_measure = self._domain.element_measure
363
463
  self.element_measure_ratio = self._domain.element_measure_ratio
364
464
  self.element_position = self._domain.element_position
365
465
  self.element_deformation_gradient = self._domain.element_deformation_gradient
366
466
  self.element_lookup = self._domain.element_lookup
467
+ self.element_partition_lookup = self._domain.element_partition_lookup
367
468
  self.element_normal = self._domain.element_normal
368
469
 
369
470
  @property
@@ -399,13 +500,42 @@ class Subdomain(GeometryDomain):
399
500
  @cache.cached_arg_value
400
501
  def element_index_arg_value(self, device: warp.context.Devicelike):
401
502
  arg = self.ElementIndexArg()
402
- arg.domain_arg = self._domain.element_index_arg_value(device)
403
- arg.element_indices = self._element_indices.to(device)
503
+ self.fill_element_index_arg(arg, device)
404
504
  return arg
405
505
 
506
+ def fill_element_index_arg(self, arg: "GeometryDomain.ElementIndexArg", device: warp.context.Devicelike):
507
+ self._domain.fill_element_index_arg(arg.domain_arg, device)
508
+ arg.element_indices = self._element_indices.to(device)
509
+
406
510
  def _make_element_index(self) -> wp.Function:
407
511
  @cache.dynamic_func(suffix=self.name)
408
512
  def element_index(arg: self.ElementIndexArg, index: int):
409
513
  return self._domain.element_index(arg.domain_arg, arg.element_indices[index])
410
514
 
411
515
  return element_index
516
+
517
+ def _make_element_partition_index(self) -> wp.Function:
518
+ @cache.dynamic_func(suffix=self.name)
519
+ def element_partition_index(arg: self.ElementIndexArg, element_index: int):
520
+ return self._domain.element_partition_index(arg.domain_arg, element_index)
521
+
522
+ return element_partition_index
523
+
524
+ def supports_lookup(self, device):
525
+ return self._domain.supports_lokup(device)
526
+
527
+ def cell_domain(self):
528
+ return self._domain.cell_domain()
529
+
530
+ @cached_property
531
+ def domain_cell_arg(self) -> wp.Function:
532
+ CellDomainArg = self.cell_domain().DomainArg
533
+
534
+ @cache.dynamic_func(suffix=self.name)
535
+ def domain_cell_arg(x: self.DomainArg):
536
+ return CellDomainArg(
537
+ self.geometry.side_to_cell_arg(x.geo),
538
+ self.geometry_partition.side_to_cell_arg(x.index.domain_arg),
539
+ )
540
+
541
+ return domain_cell_arg