warp-lang 1.8.0__py3-none-win_amd64.whl → 1.9.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (153) hide show
  1. warp/__init__.py +282 -103
  2. warp/__init__.pyi +482 -110
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +93 -30
  6. warp/build_dll.py +48 -63
  7. warp/builtins.py +955 -137
  8. warp/codegen.py +327 -209
  9. warp/config.py +1 -1
  10. warp/context.py +1363 -800
  11. warp/examples/core/example_marching_cubes.py +1 -0
  12. warp/examples/core/example_render_opengl.py +100 -3
  13. warp/examples/fem/example_apic_fluid.py +98 -52
  14. warp/examples/fem/example_convection_diffusion_dg.py +25 -4
  15. warp/examples/fem/example_diffusion_mgpu.py +8 -3
  16. warp/examples/fem/utils.py +68 -22
  17. warp/examples/interop/example_jax_callable.py +34 -4
  18. warp/examples/interop/example_jax_kernel.py +27 -1
  19. warp/fabric.py +1 -1
  20. warp/fem/cache.py +27 -19
  21. warp/fem/domain.py +2 -2
  22. warp/fem/field/nodal_field.py +2 -2
  23. warp/fem/field/virtual.py +266 -166
  24. warp/fem/geometry/geometry.py +5 -5
  25. warp/fem/integrate.py +200 -91
  26. warp/fem/space/restriction.py +4 -0
  27. warp/fem/space/shape/tet_shape_function.py +3 -10
  28. warp/jax_experimental/custom_call.py +1 -1
  29. warp/jax_experimental/ffi.py +203 -54
  30. warp/marching_cubes.py +708 -0
  31. warp/native/array.h +103 -8
  32. warp/native/builtin.h +90 -9
  33. warp/native/bvh.cpp +64 -28
  34. warp/native/bvh.cu +58 -58
  35. warp/native/bvh.h +2 -2
  36. warp/native/clang/clang.cpp +7 -7
  37. warp/native/coloring.cpp +13 -3
  38. warp/native/crt.cpp +2 -2
  39. warp/native/crt.h +3 -5
  40. warp/native/cuda_util.cpp +42 -11
  41. warp/native/cuda_util.h +10 -4
  42. warp/native/exports.h +1842 -1908
  43. warp/native/fabric.h +2 -1
  44. warp/native/hashgrid.cpp +37 -37
  45. warp/native/hashgrid.cu +2 -2
  46. warp/native/initializer_array.h +1 -1
  47. warp/native/intersect.h +4 -4
  48. warp/native/mat.h +1913 -119
  49. warp/native/mathdx.cpp +43 -43
  50. warp/native/mesh.cpp +24 -24
  51. warp/native/mesh.cu +26 -26
  52. warp/native/mesh.h +5 -3
  53. warp/native/nanovdb/GridHandle.h +179 -12
  54. warp/native/nanovdb/HostBuffer.h +8 -7
  55. warp/native/nanovdb/NanoVDB.h +517 -895
  56. warp/native/nanovdb/NodeManager.h +323 -0
  57. warp/native/nanovdb/PNanoVDB.h +2 -2
  58. warp/native/quat.h +337 -16
  59. warp/native/rand.h +7 -7
  60. warp/native/range.h +7 -1
  61. warp/native/reduce.cpp +10 -10
  62. warp/native/reduce.cu +13 -14
  63. warp/native/runlength_encode.cpp +2 -2
  64. warp/native/runlength_encode.cu +5 -5
  65. warp/native/scan.cpp +3 -3
  66. warp/native/scan.cu +4 -4
  67. warp/native/sort.cpp +10 -10
  68. warp/native/sort.cu +22 -22
  69. warp/native/sparse.cpp +8 -8
  70. warp/native/sparse.cu +14 -14
  71. warp/native/spatial.h +366 -17
  72. warp/native/svd.h +23 -8
  73. warp/native/temp_buffer.h +2 -2
  74. warp/native/tile.h +303 -70
  75. warp/native/tile_radix_sort.h +5 -1
  76. warp/native/tile_reduce.h +16 -25
  77. warp/native/tuple.h +2 -2
  78. warp/native/vec.h +385 -18
  79. warp/native/volume.cpp +54 -54
  80. warp/native/volume.cu +1 -1
  81. warp/native/volume.h +2 -1
  82. warp/native/volume_builder.cu +30 -37
  83. warp/native/warp.cpp +150 -149
  84. warp/native/warp.cu +337 -193
  85. warp/native/warp.h +227 -226
  86. warp/optim/linear.py +736 -271
  87. warp/render/imgui_manager.py +289 -0
  88. warp/render/render_opengl.py +137 -57
  89. warp/render/render_usd.py +0 -1
  90. warp/sim/collide.py +1 -2
  91. warp/sim/graph_coloring.py +2 -2
  92. warp/sim/integrator_vbd.py +10 -2
  93. warp/sparse.py +559 -176
  94. warp/tape.py +2 -0
  95. warp/tests/aux_test_module_aot.py +7 -0
  96. warp/tests/cuda/test_async.py +3 -3
  97. warp/tests/cuda/test_conditional_captures.py +101 -0
  98. warp/tests/geometry/test_marching_cubes.py +233 -12
  99. warp/tests/sim/test_cloth.py +89 -6
  100. warp/tests/sim/test_coloring.py +82 -7
  101. warp/tests/test_array.py +56 -5
  102. warp/tests/test_assert.py +53 -0
  103. warp/tests/test_atomic_cas.py +127 -114
  104. warp/tests/test_codegen.py +3 -2
  105. warp/tests/test_context.py +8 -15
  106. warp/tests/test_enum.py +136 -0
  107. warp/tests/test_examples.py +2 -2
  108. warp/tests/test_fem.py +45 -2
  109. warp/tests/test_fixedarray.py +229 -0
  110. warp/tests/test_func.py +18 -15
  111. warp/tests/test_future_annotations.py +7 -5
  112. warp/tests/test_linear_solvers.py +30 -0
  113. warp/tests/test_map.py +1 -1
  114. warp/tests/test_mat.py +1540 -378
  115. warp/tests/test_mat_assign_copy.py +178 -0
  116. warp/tests/test_mat_constructors.py +574 -0
  117. warp/tests/test_module_aot.py +287 -0
  118. warp/tests/test_print.py +69 -0
  119. warp/tests/test_quat.py +162 -34
  120. warp/tests/test_quat_assign_copy.py +145 -0
  121. warp/tests/test_reload.py +2 -1
  122. warp/tests/test_sparse.py +103 -0
  123. warp/tests/test_spatial.py +140 -34
  124. warp/tests/test_spatial_assign_copy.py +160 -0
  125. warp/tests/test_static.py +48 -0
  126. warp/tests/test_struct.py +43 -3
  127. warp/tests/test_tape.py +38 -0
  128. warp/tests/test_types.py +0 -20
  129. warp/tests/test_vec.py +216 -441
  130. warp/tests/test_vec_assign_copy.py +143 -0
  131. warp/tests/test_vec_constructors.py +325 -0
  132. warp/tests/tile/test_tile.py +206 -152
  133. warp/tests/tile/test_tile_cholesky.py +605 -0
  134. warp/tests/tile/test_tile_load.py +169 -0
  135. warp/tests/tile/test_tile_mathdx.py +2 -558
  136. warp/tests/tile/test_tile_matmul.py +179 -0
  137. warp/tests/tile/test_tile_mlp.py +1 -1
  138. warp/tests/tile/test_tile_reduce.py +100 -11
  139. warp/tests/tile/test_tile_shared_memory.py +16 -16
  140. warp/tests/tile/test_tile_sort.py +59 -55
  141. warp/tests/unittest_suites.py +16 -0
  142. warp/tests/walkthrough_debug.py +1 -1
  143. warp/thirdparty/unittest_parallel.py +108 -9
  144. warp/types.py +554 -264
  145. warp/utils.py +68 -86
  146. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/METADATA +28 -65
  147. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/RECORD +150 -138
  148. warp/native/marching.cpp +0 -19
  149. warp/native/marching.cu +0 -514
  150. warp/native/marching.h +0 -19
  151. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/WHEEL +0 -0
  152. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/licenses/LICENSE.md +0 -0
  153. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/top_level.txt +0 -0
@@ -42,7 +42,7 @@ def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtyp
42
42
 
43
43
  # The Python function to call.
44
44
  # Note the argument annotations, just like Warp kernels.
45
- def example_func(
45
+ def scale_func(
46
46
  # inputs
47
47
  a: wp.array(dtype=float),
48
48
  b: wp.array(dtype=wp.vec2),
@@ -55,8 +55,23 @@ def example_func(
55
55
  wp.launch(scale_vec_kernel, dim=b.shape, inputs=[b, s], outputs=[d])
56
56
 
57
57
 
58
+ @wp.kernel
59
+ def accum_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float)):
60
+ tid = wp.tid()
61
+ b[tid] += a[tid]
62
+
63
+
64
+ def in_out_func(
65
+ a: wp.array(dtype=float), # input only
66
+ b: wp.array(dtype=float), # input and output
67
+ c: wp.array(dtype=float), # output only
68
+ ):
69
+ wp.launch(scale_kernel, dim=a.size, inputs=[a, 2.0], outputs=[c])
70
+ wp.launch(accum_kernel, dim=a.size, inputs=[a, b]) # modifies `b`
71
+
72
+
58
73
  def example1():
59
- jax_func = jax_callable(example_func, num_outputs=2, vmap_method="broadcast_all")
74
+ jax_func = jax_callable(scale_func, num_outputs=2)
60
75
 
61
76
  @jax.jit
62
77
  def f():
@@ -78,7 +93,7 @@ def example1():
78
93
 
79
94
 
80
95
  def example2():
81
- jax_func = jax_callable(example_func, num_outputs=2, vmap_method="broadcast_all")
96
+ jax_func = jax_callable(scale_func, num_outputs=2)
82
97
 
83
98
  # NOTE: scalar arguments must be static compile-time constants
84
99
  @partial(jax.jit, static_argnames=["s"])
@@ -100,11 +115,26 @@ def example2():
100
115
  print(r2)
101
116
 
102
117
 
118
+ def example3():
119
+ # Using input-output arguments
120
+
121
+ jax_func = jax_callable(in_out_func, num_outputs=2, in_out_argnames=["b"])
122
+
123
+ f = jax.jit(jax_func)
124
+
125
+ a = jnp.ones(10, dtype=jnp.float32)
126
+ b = jnp.arange(10, dtype=jnp.float32)
127
+
128
+ b, c = f(a, b)
129
+ print(b)
130
+ print(c)
131
+
132
+
103
133
  def main():
104
134
  wp.init()
105
135
  wp.load_module(device=wp.get_device())
106
136
 
107
- examples = [example1, example2]
137
+ examples = [example1, example2, example3]
108
138
 
109
139
  for example in examples:
110
140
  print("\n===========================================================================")
@@ -72,6 +72,17 @@ def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtyp
72
72
  output[tid] = a[tid] * s
73
73
 
74
74
 
75
+ @wp.kernel
76
+ def in_out_kernel(
77
+ a: wp.array(dtype=float), # input only
78
+ b: wp.array(dtype=float), # input and output
79
+ c: wp.array(dtype=float), # output only
80
+ ):
81
+ tid = wp.tid()
82
+ b[tid] += a[tid]
83
+ c[tid] = 2.0 * a[tid]
84
+
85
+
75
86
  def example1():
76
87
  # two inputs and one output
77
88
  jax_add = jax_kernel(add_kernel)
@@ -189,11 +200,26 @@ def example7():
189
200
  print(f())
190
201
 
191
202
 
203
+ def example8():
204
+ # Using input-output arguments
205
+
206
+ jax_func = jax_kernel(in_out_kernel, num_outputs=2, in_out_argnames=["b"])
207
+
208
+ f = jax.jit(jax_func)
209
+
210
+ a = jnp.ones(10, dtype=jnp.float32)
211
+ b = jnp.arange(10, dtype=jnp.float32)
212
+
213
+ b, c = f(a, b)
214
+ print(b)
215
+ print(c)
216
+
217
+
192
218
  def main():
193
219
  wp.init()
194
220
  wp.load_module(device=wp.get_device())
195
221
 
196
- examples = [example1, example2, example3, example4, example5, example6, example7]
222
+ examples = [example1, example2, example3, example4, example5, example6, example7, example8]
197
223
 
198
224
  for example in examples:
199
225
  print("\n===========================================================================")
warp/fabric.py CHANGED
@@ -211,7 +211,7 @@ class fabricarray(noncontiguous_array_base[T]):
211
211
  allocator = self.device.get_allocator()
212
212
  buckets_ptr = allocator.alloc(buckets_size)
213
213
  cuda_stream = self.device.stream.cuda_stream
214
- runtime.core.memcpy_h2d(
214
+ runtime.core.wp_memcpy_h2d(
215
215
  self.device.context, buckets_ptr, ctypes.addressof(buckets), buckets_size, cuda_stream
216
216
  )
217
217
  self.deleter = allocator.deleter
warp/fem/cache.py CHANGED
@@ -34,7 +34,7 @@ _key_re = re.compile("[^0-9a-zA-Z_]+")
34
34
 
35
35
  def _make_key(obj, suffix: str, options: Optional[Dict[str, Any]] = None):
36
36
  # human-readable part
37
- key = _key_re.sub("", f"{obj.__name__}_{suffix}")
37
+ suffix = str(suffix)
38
38
 
39
39
  sorted_opts = sorted(options.items()) if options is not None else ()
40
40
  opts_str = "".join(
@@ -49,7 +49,7 @@ def _make_key(obj, suffix: str, options: Optional[Dict[str, Any]] = None):
49
49
  uid = hashlib.blake2b(bytes(opts_str, encoding="utf-8"), digest_size=4).hexdigest()
50
50
 
51
51
  # avoid long keys, issues on win
52
- key = f"{key[:64]}_{uid}"
52
+ key = f"{obj.__name__}_{suffix[:32]}_{uid}"
53
53
 
54
54
  return key
55
55
 
@@ -62,7 +62,10 @@ def _arg_type_name(arg_type):
62
62
  return wp.types.get_type_code(wp.types.type_to_warp(arg_type))
63
63
 
64
64
 
65
- def _make_cache_key(func, key, argspec=None):
65
+ def _make_cache_key(func, key, argspec=None, allow_overloads: bool = True):
66
+ if not allow_overloads:
67
+ return key
68
+
66
69
  if argspec is None:
67
70
  annotations = get_annotations(func)
68
71
  else:
@@ -80,6 +83,7 @@ def _register_function(
80
83
  ):
81
84
  # wp.Function will override existing func for a given key...
82
85
  # manually add back our overloads
86
+ key = _key_re.sub("", key)
83
87
  existing = module.functions.get(key)
84
88
  new_fn = wp.Function(
85
89
  func=func,
@@ -95,9 +99,9 @@ def _register_function(
95
99
  return module.functions[key]
96
100
 
97
101
 
98
- def get_func(func, suffix: str, code_transformers=None):
102
+ def get_func(func, suffix: str, code_transformers=None, allow_overloads=False):
99
103
  key = _make_key(func, suffix)
100
- cache_key = _make_cache_key(func, key)
104
+ cache_key = _make_cache_key(func, key, allow_overloads=allow_overloads)
101
105
 
102
106
  if cache_key not in _func_cache:
103
107
  module = wp.get_module(func.__module__)
@@ -111,9 +115,9 @@ def get_func(func, suffix: str, code_transformers=None):
111
115
  return _func_cache[cache_key]
112
116
 
113
117
 
114
- def dynamic_func(suffix: str, code_transformers=None):
118
+ def dynamic_func(suffix: str, code_transformers=None, allow_overloads=False):
115
119
  def wrap_func(func: Callable):
116
- return get_func(func, suffix=suffix, code_transformers=code_transformers)
120
+ return get_func(func, suffix=suffix, code_transformers=code_transformers, allow_overloads=allow_overloads)
117
121
 
118
122
  return wrap_func
119
123
 
@@ -122,46 +126,49 @@ def get_kernel(
122
126
  func,
123
127
  suffix: str,
124
128
  kernel_options: Optional[Dict[str, Any]] = None,
129
+ allow_overloads=False,
125
130
  ):
126
131
  if kernel_options is None:
127
132
  kernel_options = {}
128
133
 
129
134
  key = _make_key(func, suffix, kernel_options)
130
- cache_key = _make_cache_key(func, key)
135
+ cache_key = _make_cache_key(func, key, allow_overloads=allow_overloads)
131
136
 
132
137
  if cache_key not in _kernel_cache:
133
- module_name = f"{func.__module__}.dyn.{key}"
138
+ kernel_key = _key_re.sub("", key)
139
+ module_name = f"{func.__module__}.dyn.{kernel_key}"
134
140
  module = wp.get_module(module_name)
135
141
  module.options = dict(wp.get_module(func.__module__).options)
136
142
  module.options.update(kernel_options)
137
- _kernel_cache[cache_key] = wp.Kernel(func=func, key=key, module=module, options=kernel_options)
143
+ _kernel_cache[cache_key] = wp.Kernel(func=func, key=kernel_key, module=module, options=kernel_options)
138
144
  return _kernel_cache[cache_key]
139
145
 
140
146
 
141
- def dynamic_kernel(suffix: str, kernel_options: Optional[Dict[str, Any]] = None):
147
+ def dynamic_kernel(suffix: str, kernel_options: Optional[Dict[str, Any]] = None, allow_overloads=False):
142
148
  if kernel_options is None:
143
149
  kernel_options = {}
144
150
 
145
151
  def wrap_kernel(func: Callable):
146
- return get_kernel(func, suffix=suffix, kernel_options=kernel_options)
152
+ return get_kernel(func, suffix=suffix, kernel_options=kernel_options, allow_overloads=allow_overloads)
147
153
 
148
154
  return wrap_kernel
149
155
 
150
156
 
151
157
  def get_struct(struct: type, suffix: str):
152
158
  key = _make_key(struct, suffix)
153
- # used in codegen
154
- struct.__qualname__ = key
159
+ cache_key = key
155
160
 
156
- if key not in _struct_cache:
161
+ if cache_key not in _struct_cache:
162
+ # used in codegen
163
+ struct.__qualname__ = _key_re.sub("", key)
157
164
  module = wp.get_module(struct.__module__)
158
- _struct_cache[key] = wp.codegen.Struct(
159
- key=key,
165
+ _struct_cache[cache_key] = wp.codegen.Struct(
166
+ key=struct.__qualname__,
160
167
  cls=struct,
161
168
  module=module,
162
169
  )
163
170
 
164
- return _struct_cache[key]
171
+ return _struct_cache[cache_key]
165
172
 
166
173
 
167
174
  def dynamic_struct(suffix: str):
@@ -293,12 +300,13 @@ def get_integrand_kernel(
293
300
  options.update(kernel_options)
294
301
 
295
302
  kernel_key = _make_key(integrand.func, suffix, options=options)
296
- cache_key = _make_cache_key(integrand, kernel_key, integrand.argspec)
303
+ cache_key = _make_cache_key(integrand, kernel_key, integrand.argspec, allow_overloads=True)
297
304
 
298
305
  if cache_key not in _kernel_cache:
299
306
  if kernel_fn is None:
300
307
  return None
301
308
 
309
+ kernel_key = _key_re.sub("", kernel_key)
302
310
  module = wp.get_module(f"{integrand.module.name}.{kernel_key}")
303
311
  module.options = options
304
312
  _kernel_cache[cache_key] = wp.Kernel(
warp/fem/domain.py CHANGED
@@ -237,11 +237,11 @@ class Cells(GeometryDomain):
237
237
  filter_target = True
238
238
  pos_type = cache.cached_vec_type(self.geometry.dimension, dtype=float)
239
239
 
240
- @cache.dynamic_func(suffix=self.name)
240
+ @cache.dynamic_func(suffix=self.name, allow_overloads=True)
241
241
  def cell_partition_lookup(args: self.DomainArg, pos: pos_type, max_dist: float):
242
242
  return filtered_cell_lookup(args.geo, pos, max_dist, args.index, filter_target)
243
243
 
244
- @cache.dynamic_func(suffix=self.name)
244
+ @cache.dynamic_func(suffix=self.name, allow_overloads=True)
245
245
  def cell_partition_lookup(args: self.DomainArg, pos: pos_type):
246
246
  max_dist = 0.0
247
247
  return filtered_cell_lookup(args.geo, pos, max_dist, args.index, filter_target)
@@ -232,7 +232,7 @@ class NodalFieldBase(DiscreteField):
232
232
  @cache.dynamic_func(suffix=self.name)
233
233
  def eval_grad_outer_world_space(args: self.ElementEvalArg, s: Sample):
234
234
  grad_transform = self.space.element_outer_reference_gradient_transform(args.elt_arg, s)
235
- return eval_grad_outer_ref_space(args, s, grad_transform)
235
+ return eval_grad_outer(args, s, grad_transform)
236
236
 
237
237
  return eval_grad_outer_world_space
238
238
  else:
@@ -240,7 +240,7 @@ class NodalFieldBase(DiscreteField):
240
240
  @cache.dynamic_func(suffix=self.name)
241
241
  def eval_grad_outer_ref_space(args: self.ElementEvalArg, s: Sample):
242
242
  grad_transform = 1.0
243
- return eval_grad_outer_ref_space(args, s, grad_transform)
243
+ return eval_grad_outer(args, s, grad_transform)
244
244
 
245
245
  return eval_grad_outer_ref_space
246
246