warp-lang 1.7.2__py3-none-macosx_10_13_universal2.whl → 1.8.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (181) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/libwarp-clang.dylib +0 -0
  5. warp/bin/libwarp.dylib +0 -0
  6. warp/build.py +241 -252
  7. warp/build_dll.py +125 -26
  8. warp/builtins.py +1907 -384
  9. warp/codegen.py +257 -101
  10. warp/config.py +12 -1
  11. warp/constants.py +1 -1
  12. warp/context.py +657 -223
  13. warp/dlpack.py +1 -1
  14. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  15. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  16. warp/examples/core/example_sample_mesh.py +1 -1
  17. warp/examples/core/example_spin_lock.py +93 -0
  18. warp/examples/core/example_work_queue.py +118 -0
  19. warp/examples/fem/example_adaptive_grid.py +5 -5
  20. warp/examples/fem/example_apic_fluid.py +1 -1
  21. warp/examples/fem/example_burgers.py +1 -1
  22. warp/examples/fem/example_convection_diffusion.py +9 -6
  23. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  24. warp/examples/fem/example_deformed_geometry.py +1 -1
  25. warp/examples/fem/example_diffusion.py +2 -2
  26. warp/examples/fem/example_diffusion_3d.py +1 -1
  27. warp/examples/fem/example_distortion_energy.py +1 -1
  28. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  29. warp/examples/fem/example_magnetostatics.py +5 -3
  30. warp/examples/fem/example_mixed_elasticity.py +5 -3
  31. warp/examples/fem/example_navier_stokes.py +11 -9
  32. warp/examples/fem/example_nonconforming_contact.py +5 -3
  33. warp/examples/fem/example_streamlines.py +8 -3
  34. warp/examples/fem/utils.py +9 -8
  35. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  36. warp/examples/optim/example_drone.py +1 -1
  37. warp/examples/sim/example_cloth.py +1 -1
  38. warp/examples/sim/example_cloth_self_contact.py +48 -54
  39. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  40. warp/examples/tile/example_tile_cholesky.py +2 -1
  41. warp/examples/tile/example_tile_convolution.py +1 -1
  42. warp/examples/tile/example_tile_filtering.py +1 -1
  43. warp/examples/tile/example_tile_matmul.py +1 -1
  44. warp/examples/tile/example_tile_mlp.py +2 -0
  45. warp/fabric.py +7 -7
  46. warp/fem/__init__.py +5 -0
  47. warp/fem/adaptivity.py +1 -1
  48. warp/fem/cache.py +152 -63
  49. warp/fem/dirichlet.py +2 -2
  50. warp/fem/domain.py +136 -6
  51. warp/fem/field/field.py +141 -99
  52. warp/fem/field/nodal_field.py +85 -39
  53. warp/fem/field/virtual.py +97 -52
  54. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  55. warp/fem/geometry/closest_point.py +13 -0
  56. warp/fem/geometry/deformed_geometry.py +102 -40
  57. warp/fem/geometry/element.py +56 -2
  58. warp/fem/geometry/geometry.py +323 -22
  59. warp/fem/geometry/grid_2d.py +157 -62
  60. warp/fem/geometry/grid_3d.py +116 -20
  61. warp/fem/geometry/hexmesh.py +86 -20
  62. warp/fem/geometry/nanogrid.py +166 -86
  63. warp/fem/geometry/partition.py +59 -25
  64. warp/fem/geometry/quadmesh.py +86 -135
  65. warp/fem/geometry/tetmesh.py +47 -119
  66. warp/fem/geometry/trimesh.py +77 -270
  67. warp/fem/integrate.py +107 -52
  68. warp/fem/linalg.py +25 -58
  69. warp/fem/operator.py +124 -27
  70. warp/fem/quadrature/pic_quadrature.py +36 -14
  71. warp/fem/quadrature/quadrature.py +40 -16
  72. warp/fem/space/__init__.py +1 -1
  73. warp/fem/space/basis_function_space.py +66 -46
  74. warp/fem/space/basis_space.py +17 -4
  75. warp/fem/space/dof_mapper.py +1 -1
  76. warp/fem/space/function_space.py +2 -2
  77. warp/fem/space/grid_2d_function_space.py +4 -1
  78. warp/fem/space/hexmesh_function_space.py +4 -2
  79. warp/fem/space/nanogrid_function_space.py +3 -1
  80. warp/fem/space/partition.py +11 -2
  81. warp/fem/space/quadmesh_function_space.py +4 -1
  82. warp/fem/space/restriction.py +5 -2
  83. warp/fem/space/shape/__init__.py +10 -8
  84. warp/fem/space/tetmesh_function_space.py +4 -1
  85. warp/fem/space/topology.py +52 -21
  86. warp/fem/space/trimesh_function_space.py +4 -1
  87. warp/fem/utils.py +53 -8
  88. warp/jax.py +1 -2
  89. warp/jax_experimental/ffi.py +12 -17
  90. warp/jax_experimental/xla_ffi.py +37 -24
  91. warp/math.py +171 -1
  92. warp/native/array.h +99 -0
  93. warp/native/builtin.h +174 -31
  94. warp/native/coloring.cpp +1 -1
  95. warp/native/exports.h +118 -63
  96. warp/native/intersect.h +3 -3
  97. warp/native/mat.h +5 -10
  98. warp/native/mathdx.cpp +11 -5
  99. warp/native/matnn.h +1 -123
  100. warp/native/quat.h +28 -4
  101. warp/native/sparse.cpp +121 -258
  102. warp/native/sparse.cu +181 -274
  103. warp/native/spatial.h +305 -17
  104. warp/native/tile.h +583 -72
  105. warp/native/tile_radix_sort.h +1108 -0
  106. warp/native/tile_reduce.h +237 -2
  107. warp/native/tile_scan.h +240 -0
  108. warp/native/tuple.h +189 -0
  109. warp/native/vec.h +6 -16
  110. warp/native/warp.cpp +36 -4
  111. warp/native/warp.cu +574 -51
  112. warp/native/warp.h +47 -74
  113. warp/optim/linear.py +5 -1
  114. warp/paddle.py +7 -8
  115. warp/py.typed +0 -0
  116. warp/render/render_opengl.py +58 -29
  117. warp/render/render_usd.py +124 -61
  118. warp/sim/__init__.py +9 -0
  119. warp/sim/collide.py +252 -78
  120. warp/sim/graph_coloring.py +8 -1
  121. warp/sim/import_mjcf.py +4 -3
  122. warp/sim/import_usd.py +11 -7
  123. warp/sim/integrator.py +5 -2
  124. warp/sim/integrator_euler.py +1 -1
  125. warp/sim/integrator_featherstone.py +1 -1
  126. warp/sim/integrator_vbd.py +751 -320
  127. warp/sim/integrator_xpbd.py +1 -1
  128. warp/sim/model.py +265 -260
  129. warp/sim/utils.py +10 -7
  130. warp/sparse.py +303 -166
  131. warp/tape.py +52 -51
  132. warp/tests/cuda/test_conditional_captures.py +1046 -0
  133. warp/tests/cuda/test_streams.py +1 -1
  134. warp/tests/geometry/test_volume.py +2 -2
  135. warp/tests/interop/test_dlpack.py +9 -9
  136. warp/tests/interop/test_jax.py +0 -1
  137. warp/tests/run_coverage_serial.py +1 -1
  138. warp/tests/sim/disabled_kinematics.py +2 -2
  139. warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
  140. warp/tests/sim/test_collision.py +159 -51
  141. warp/tests/sim/test_coloring.py +15 -1
  142. warp/tests/test_array.py +254 -2
  143. warp/tests/test_array_reduce.py +2 -2
  144. warp/tests/test_atomic_cas.py +299 -0
  145. warp/tests/test_codegen.py +142 -19
  146. warp/tests/test_conditional.py +47 -1
  147. warp/tests/test_ctypes.py +0 -20
  148. warp/tests/test_devices.py +8 -0
  149. warp/tests/test_fabricarray.py +4 -2
  150. warp/tests/test_fem.py +58 -25
  151. warp/tests/test_func.py +42 -1
  152. warp/tests/test_grad.py +1 -1
  153. warp/tests/test_lerp.py +1 -3
  154. warp/tests/test_map.py +481 -0
  155. warp/tests/test_mat.py +1 -24
  156. warp/tests/test_quat.py +6 -15
  157. warp/tests/test_rounding.py +10 -38
  158. warp/tests/test_runlength_encode.py +7 -7
  159. warp/tests/test_smoothstep.py +1 -1
  160. warp/tests/test_sparse.py +51 -2
  161. warp/tests/test_spatial.py +507 -1
  162. warp/tests/test_struct.py +2 -2
  163. warp/tests/test_tuple.py +265 -0
  164. warp/tests/test_types.py +2 -2
  165. warp/tests/test_utils.py +24 -18
  166. warp/tests/tile/test_tile.py +420 -1
  167. warp/tests/tile/test_tile_mathdx.py +518 -14
  168. warp/tests/tile/test_tile_reduce.py +213 -0
  169. warp/tests/tile/test_tile_shared_memory.py +130 -1
  170. warp/tests/tile/test_tile_sort.py +117 -0
  171. warp/tests/unittest_suites.py +4 -6
  172. warp/types.py +462 -308
  173. warp/utils.py +647 -86
  174. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
  175. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/RECORD +178 -166
  176. warp/stubs.py +0 -3381
  177. warp/tests/sim/test_xpbd.py +0 -399
  178. warp/tests/test_mlp.py +0 -282
  179. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
  180. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
  181. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
warp/fabric.py CHANGED
@@ -22,12 +22,12 @@ from warp.types import *
22
22
 
23
23
 
24
24
  class fabricbucket_t(ctypes.Structure):
25
- _fields_ = [
25
+ _fields_ = (
26
26
  ("index_start", ctypes.c_size_t),
27
27
  ("index_end", ctypes.c_size_t),
28
28
  ("ptr", ctypes.c_void_p),
29
29
  ("lengths", ctypes.c_void_p),
30
- ]
30
+ )
31
31
 
32
32
  def __init__(self, index_start=0, index_end=0, ptr=None, lengths=None):
33
33
  self.index_start = index_start
@@ -37,11 +37,11 @@ class fabricbucket_t(ctypes.Structure):
37
37
 
38
38
 
39
39
  class fabricarray_t(ctypes.Structure):
40
- _fields_ = [
40
+ _fields_ = (
41
41
  ("buckets", ctypes.c_void_p), # array of fabricbucket_t on the correct device
42
42
  ("nbuckets", ctypes.c_size_t),
43
43
  ("size", ctypes.c_size_t),
44
- ]
44
+ )
45
45
 
46
46
  def __init__(self, buckets=None, nbuckets=0, size=0):
47
47
  self.buckets = ctypes.c_void_p(buckets)
@@ -50,11 +50,11 @@ class fabricarray_t(ctypes.Structure):
50
50
 
51
51
 
52
52
  class indexedfabricarray_t(ctypes.Structure):
53
- _fields_ = [
53
+ _fields_ = (
54
54
  ("fa", fabricarray_t),
55
55
  ("indices", ctypes.c_void_p),
56
56
  ("size", ctypes.c_size_t),
57
- ]
57
+ )
58
58
 
59
59
  def __init__(self, fa=None, indices=None):
60
60
  if fa is None:
@@ -121,7 +121,7 @@ class fabricarray(noncontiguous_array_base[T]):
121
121
  _vars = None
122
122
 
123
123
  def __new__(cls, *args, **kwargs):
124
- instance = super(fabricarray, cls).__new__(cls)
124
+ instance = super().__new__(cls)
125
125
  instance.deleter = None
126
126
  return instance
127
127
 
warp/fem/__init__.py CHANGED
@@ -55,6 +55,8 @@ from .operator import (
55
55
  degree,
56
56
  div,
57
57
  div_outer,
58
+ element_closest_point,
59
+ element_coordinates,
58
60
  grad,
59
61
  grad_average,
60
62
  grad_jump,
@@ -65,8 +67,11 @@ from .operator import (
65
67
  lookup,
66
68
  measure,
67
69
  measure_ratio,
70
+ node_count,
71
+ node_index,
68
72
  normal,
69
73
  outer,
74
+ partition_lookup,
70
75
  position,
71
76
  to_cell_side,
72
77
  to_inner_cell,
warp/fem/adaptivity.py CHANGED
@@ -50,7 +50,7 @@ def adaptive_nanogrid_from_hierarchy(
50
50
  # Concatenate voxels for each grid
51
51
  voxel_counts = [grid.get_voxel_count() for grid in grids]
52
52
 
53
- voxel_offsets = np.cumsum(np.array([0] + voxel_counts))
53
+ voxel_offsets = np.cumsum(np.array([0, *voxel_counts]))
54
54
  merged_ijks = cache.borrow_temporary(temporary_store, dtype=wp.vec3i, shape=int(voxel_offsets[-1]), device=device)
55
55
  for l in range(level_count):
56
56
  voxel_count = voxel_counts[l]
warp/fem/cache.py CHANGED
@@ -15,13 +15,15 @@
15
15
 
16
16
  import ast
17
17
  import bisect
18
+ import hashlib
18
19
  import re
19
20
  import weakref
20
- from copy import copy
21
21
  from typing import Any, Callable, Dict, Optional, Tuple, Union
22
22
 
23
23
  import warp as wp
24
+ from warp.codegen import get_annotations
24
25
  from warp.fem.operator import Integrand
26
+ from warp.fem.types import Domain, Field
25
27
 
26
28
  _kernel_cache = {}
27
29
  _struct_cache = {}
@@ -30,31 +32,88 @@ _func_cache = {}
30
32
  _key_re = re.compile("[^0-9a-zA-Z_]+")
31
33
 
32
34
 
33
- def _make_key(obj, suffix: str, use_qualified_name):
34
- base_name = f"{obj.__module__}.{obj.__qualname__}" if use_qualified_name else obj.__name__
35
- return _key_re.sub("", f"{base_name}_{suffix}")
35
+ def _make_key(obj, suffix: str, options: Optional[Dict[str, Any]] = None):
36
+ # human-readable part
37
+ key = _key_re.sub("", f"{obj.__name__}_{suffix}")
36
38
 
39
+ sorted_opts = sorted(options.items()) if options is not None else ()
40
+ opts_str = "".join(
41
+ (
42
+ obj.__module__,
43
+ obj.__qualname__,
44
+ suffix,
45
+ *(opt[0] for opt in sorted_opts),
46
+ *(str(opt[1]) for opt in sorted_opts),
47
+ )
48
+ )
49
+ uid = hashlib.blake2b(bytes(opts_str, encoding="utf-8"), digest_size=4).hexdigest()
37
50
 
38
- def get_func(func, suffix: str, use_qualified_name: bool = False, code_transformers=None):
39
- key = _make_key(func, suffix, use_qualified_name)
51
+ # avoid long keys, issues on win
52
+ key = f"{key[:64]}_{uid}"
40
53
 
41
- if key not in _func_cache:
42
- _func_cache[key] = wp.Function(
43
- func=func,
44
- key=key,
45
- namespace="",
46
- module=wp.get_module(
47
- func.__module__,
48
- ),
54
+ return key
55
+
56
+
57
+ def _arg_type_name(arg_type):
58
+ if isinstance(arg_type, str):
59
+ return arg_type
60
+ if arg_type in (Field, Domain):
61
+ return ""
62
+ return wp.types.get_type_code(wp.types.type_to_warp(arg_type))
63
+
64
+
65
+ def _make_cache_key(func, key, argspec=None):
66
+ if argspec is None:
67
+ annotations = get_annotations(func)
68
+ else:
69
+ annotations = argspec.annotations
70
+
71
+ sig_key = (key, tuple((k, _arg_type_name(v)) for k, v in annotations.items()))
72
+ return sig_key
73
+
74
+
75
+ def _register_function(
76
+ func,
77
+ key,
78
+ module,
79
+ **kwargs,
80
+ ):
81
+ # wp.Function will override existing func for a given key...
82
+ # manually add back our overloads
83
+ existing = module.functions.get(key)
84
+ new_fn = wp.Function(
85
+ func=func,
86
+ key=key,
87
+ namespace="",
88
+ module=module,
89
+ **kwargs,
90
+ )
91
+
92
+ if existing:
93
+ existing.add_overload(new_fn)
94
+ module.functions[key] = existing
95
+ return module.functions[key]
96
+
97
+
98
+ def get_func(func, suffix: str, code_transformers=None):
99
+ key = _make_key(func, suffix)
100
+ cache_key = _make_cache_key(func, key)
101
+
102
+ if cache_key not in _func_cache:
103
+ module = wp.get_module(func.__module__)
104
+ _func_cache[cache_key] = _register_function(
105
+ func,
106
+ key,
107
+ module,
49
108
  code_transformers=code_transformers,
50
109
  )
51
110
 
52
- return _func_cache[key]
111
+ return _func_cache[cache_key]
53
112
 
54
113
 
55
- def dynamic_func(suffix: str, use_qualified_name=False, code_transformers=None):
114
+ def dynamic_func(suffix: str, code_transformers=None):
56
115
  def wrap_func(func: Callable):
57
- return get_func(func, suffix=suffix, use_qualified_name=use_qualified_name, code_transformers=code_transformers)
116
+ return get_func(func, suffix=suffix, code_transformers=code_transformers)
58
117
 
59
118
  return wrap_func
60
119
 
@@ -62,38 +121,35 @@ def dynamic_func(suffix: str, use_qualified_name=False, code_transformers=None):
62
121
  def get_kernel(
63
122
  func,
64
123
  suffix: str,
65
- use_qualified_name: bool = False,
66
- kernel_options: Dict[str, Any] = None,
124
+ kernel_options: Optional[Dict[str, Any]] = None,
67
125
  ):
68
126
  if kernel_options is None:
69
127
  kernel_options = {}
70
128
 
71
- key = _make_key(func, suffix, use_qualified_name)
129
+ key = _make_key(func, suffix, kernel_options)
130
+ cache_key = _make_cache_key(func, key)
72
131
 
73
- if key not in _kernel_cache:
74
- # Avoid creating too long file names -- can lead to issues on Windows
75
- # We could hash the key, but prefer to keep it human-readable
132
+ if cache_key not in _kernel_cache:
76
133
  module_name = f"{func.__module__}.dyn.{key}"
77
- module_name = module_name[:128] if len(module_name) > 128 else module_name
78
134
  module = wp.get_module(module_name)
79
- module.options = copy(wp.get_module(func.__module__).options)
135
+ module.options = dict(wp.get_module(func.__module__).options)
80
136
  module.options.update(kernel_options)
81
- _kernel_cache[key] = wp.Kernel(func=func, key=key, module=module)
82
- return _kernel_cache[key]
137
+ _kernel_cache[cache_key] = wp.Kernel(func=func, key=key, module=module, options=kernel_options)
138
+ return _kernel_cache[cache_key]
83
139
 
84
140
 
85
- def dynamic_kernel(suffix: str, use_qualified_name=False, kernel_options: Dict[str, Any] = None):
141
+ def dynamic_kernel(suffix: str, kernel_options: Optional[Dict[str, Any]] = None):
86
142
  if kernel_options is None:
87
143
  kernel_options = {}
88
144
 
89
145
  def wrap_kernel(func: Callable):
90
- return get_kernel(func, suffix=suffix, use_qualified_name=use_qualified_name, kernel_options=kernel_options)
146
+ return get_kernel(func, suffix=suffix, kernel_options=kernel_options)
91
147
 
92
148
  return wrap_kernel
93
149
 
94
150
 
95
- def get_struct(struct: type, suffix: str, use_qualified_name: bool = False):
96
- key = _make_key(struct, suffix, use_qualified_name)
151
+ def get_struct(struct: type, suffix: str):
152
+ key = _make_key(struct, suffix)
97
153
  # used in codegen
98
154
  struct.__qualname__ = key
99
155
 
@@ -108,9 +164,9 @@ def get_struct(struct: type, suffix: str, use_qualified_name: bool = False):
108
164
  return _struct_cache[key]
109
165
 
110
166
 
111
- def dynamic_struct(suffix: str, use_qualified_name=False):
167
+ def dynamic_struct(suffix: str):
112
168
  def wrap_struct(struct: type):
113
- return get_struct(struct, suffix=suffix, use_qualified_name=use_qualified_name)
169
+ return get_struct(struct, suffix=suffix)
114
170
 
115
171
  return wrap_struct
116
172
 
@@ -125,35 +181,36 @@ def get_argument_struct(arg_types: Dict[str, type]):
125
181
  setattr(Args, name, None)
126
182
  annotations[name] = arg_type
127
183
 
128
- def arg_type_name(arg_type):
129
- return wp.types.get_type_code(wp.types.type_to_warp(arg_type))
130
-
131
184
  try:
132
185
  Args.__annotations__ = annotations
133
186
  except AttributeError:
134
187
  Args.__dict__.__annotations__ = annotations
135
188
 
136
- suffix = "_".join([f"{name}_{arg_type_name(arg_type)}" for name, arg_type in annotations.items()])
189
+ suffix = "_".join([f"{name}_{_arg_type_name(arg_type)}" for name, arg_type in annotations.items()])
137
190
 
138
191
  return get_struct(Args, suffix=suffix)
139
192
 
140
193
 
141
- def populate_argument_struct(Args: wp.codegen.Struct, values: Dict[str, Any], func_name: str):
194
+ def populate_argument_struct(
195
+ Args: wp.codegen.Struct, values: Dict[str, Any], func_name: str, value_struct_values: Optional = None
196
+ ):
142
197
  if values is None:
143
198
  values = {}
144
199
 
145
- value_struct_values = Args()
146
- for k, v in values.items():
147
- try:
200
+ if value_struct_values is None:
201
+ value_struct_values = Args()
202
+
203
+ try:
204
+ for k, v in values.items():
148
205
  setattr(value_struct_values, k, v)
149
- except Exception as err:
150
- if k not in Args.vars:
151
- raise ValueError(
152
- f"Passed value argument '{k}' does not match any of the function '{func_name}' parameters"
153
- ) from err
206
+ except Exception as err:
207
+ if k not in Args.vars:
154
208
  raise ValueError(
155
- f"Passed value argument '{k}' of type '{wp.types.type_repr(v)}' is incompatible with the function '{func_name}' parameter of type '{wp.types.type_repr(Args.vars[k].type)}'"
209
+ f"Passed value argument '{k}' does not match any of the function '{func_name}' parameters"
156
210
  ) from err
211
+ raise ValueError(
212
+ f"Passed value argument '{k}' of type '{wp.types.type_repr(v)}' is incompatible with the function '{func_name}' parameter of type '{wp.types.type_repr(Args.vars[k].type)}'"
213
+ ) from err
157
214
 
158
215
  missing_values = Args.vars.keys() - values.keys()
159
216
  if missing_values:
@@ -208,26 +265,26 @@ def get_integrand_function(
208
265
  annotations=None,
209
266
  code_transformers=None,
210
267
  ):
211
- key = _make_key(integrand.func, suffix, use_qualified_name=True)
268
+ key = _make_key(integrand.func, suffix)
269
+ cache_key = _make_cache_key(integrand.func, key, integrand.argspec)
212
270
 
213
- if key not in _func_cache:
214
- _func_cache[key] = wp.Function(
271
+ if cache_key not in _func_cache:
272
+ _func_cache[cache_key] = _register_function(
215
273
  func=integrand.func if func is None else func,
216
274
  key=key,
217
- namespace="",
218
275
  module=integrand.module,
219
276
  overloaded_annotations=annotations,
220
277
  code_transformers=code_transformers,
221
278
  )
222
279
 
223
- return _func_cache[key]
280
+ return _func_cache[cache_key]
224
281
 
225
282
 
226
283
  def get_integrand_kernel(
227
284
  integrand: Integrand,
228
285
  suffix: str,
229
286
  kernel_fn: Optional[Callable] = None,
230
- kernel_options: Dict[str, Any] = None,
287
+ kernel_options: Optional[Dict[str, Any]] = None,
231
288
  code_transformers=None,
232
289
  ):
233
290
  options = integrand.module.options.copy()
@@ -235,15 +292,15 @@ def get_integrand_kernel(
235
292
  if kernel_options is not None:
236
293
  options.update(kernel_options)
237
294
 
238
- kernel_key = _make_key(integrand.func, suffix, use_qualified_name=True)
239
- opts_key = "".join([f"{k}:{v}" for k, v in sorted(options.items())])
240
- cache_key = kernel_key + opts_key
295
+ kernel_key = _make_key(integrand.func, suffix, options=options)
296
+ cache_key = _make_cache_key(integrand, kernel_key, integrand.argspec)
241
297
 
242
298
  if cache_key not in _kernel_cache:
243
299
  if kernel_fn is None:
244
300
  return None
245
301
 
246
- module = wp.get_module(f"{integrand.module.name}.{integrand.name}")
302
+ module = wp.get_module(f"{integrand.module.name}.{kernel_key}")
303
+ module.options = options
247
304
  _kernel_cache[cache_key] = wp.Kernel(
248
305
  func=kernel_fn, key=kernel_key, module=module, code_transformers=code_transformers, options=options
249
306
  )
@@ -272,6 +329,40 @@ def cached_arg_value(func: Callable):
272
329
  return get_arg
273
330
 
274
331
 
332
+ def setup_dynamic_attributes(
333
+ obj,
334
+ cls: Optional[type] = None,
335
+ constructors: Optional[Dict[str, Callable]] = None,
336
+ key: Optional[str] = None,
337
+ ):
338
+ if cls is None:
339
+ cls = type(obj)
340
+
341
+ if key is None:
342
+ key = obj.name
343
+
344
+ if constructors is None:
345
+ constructors = cls._dynamic_attribute_constructors
346
+
347
+ key = (key, frozenset(constructors.keys()))
348
+
349
+ if not hasattr(cls, "_cached_dynamic_attrs"):
350
+ cls._cached_dynamic_attrs = {}
351
+
352
+ attrs = cls._cached_dynamic_attrs.get(key)
353
+ if attrs is None:
354
+ attrs = {}
355
+ # create attributes one-by-one, as some may depend on previous ones
356
+ for k, v in constructors.items():
357
+ attr = v(obj)
358
+ attrs[k] = attr
359
+ setattr(obj, k, attr)
360
+ cls._cached_dynamic_attrs[key] = attrs
361
+ else:
362
+ for k, v in attrs.items():
363
+ setattr(obj, k, v)
364
+
365
+
275
366
  _cached_vec_types = {}
276
367
  _cached_mat_types = {}
277
368
 
@@ -301,7 +392,7 @@ class Temporary:
301
392
  """
302
393
 
303
394
  def __new__(cls, *args, **kwargs):
304
- instance = super(Temporary, cls).__new__(cls)
395
+ instance = super().__new__(cls)
305
396
  instance._pool = None
306
397
  return instance
307
398
 
@@ -447,15 +538,13 @@ class TemporaryStore:
447
538
  dtype = wp.types.type_to_warp(dtype)
448
539
  device = wp.get_device(device)
449
540
 
450
- type_length = wp.types.type_length(dtype)
451
- key = (dtype._type_, type_length, pinned, device.ordinal)
541
+ type_size = wp.types.type_size(dtype)
542
+ key = (dtype._type_, type_size, pinned, device.ordinal)
452
543
 
453
544
  pool = self._temporaries.get(key, None)
454
545
  if pool is None:
455
546
  value_type = (
456
- cached_vec_type(length=type_length, dtype=wp.types.type_scalar_type(dtype))
457
- if type_length > 1
458
- else dtype
547
+ cached_vec_type(length=type_size, dtype=wp.types.type_scalar_type(dtype)) if type_size > 1 else dtype
459
548
  )
460
549
  pool = TemporaryStore.Pool(value_type, device, pinned=pinned)
461
550
  self._temporaries[key] = pool
warp/fem/dirichlet.py CHANGED
@@ -18,7 +18,7 @@ from typing import Any, Optional
18
18
  import warp as wp
19
19
  from warp.fem.linalg import array_axpy, symmetric_eigenvalues_qr
20
20
  from warp.sparse import BsrMatrix, bsr_assign, bsr_axpy, bsr_copy, bsr_mm, bsr_mv
21
- from warp.types import type_is_matrix, type_length
21
+ from warp.types import type_is_matrix, type_size
22
22
 
23
23
 
24
24
  def normalize_dirichlet_projector(projector_matrix: BsrMatrix, fixed_value: Optional[wp.array] = None):
@@ -53,7 +53,7 @@ def normalize_dirichlet_projector(projector_matrix: BsrMatrix, fixed_value: Opti
53
53
  if fixed_value.shape[0] != projector_matrix.nrow:
54
54
  raise ValueError("Fixed value array must be of length equal to the number of rows of blocks")
55
55
 
56
- if type_length(fixed_value.dtype) == 1:
56
+ if type_size(fixed_value.dtype) == 1:
57
57
  # array of scalars, convert to 1d array of vectors
58
58
  fixed_value = wp.array(
59
59
  data=None,