warp-lang 1.9.0__py3-none-manylinux_2_34_aarch64.whl → 1.9.1__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.pyi +1420 -2
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build_dll.py +322 -72
- warp/builtins.py +289 -23
- warp/codegen.py +5 -0
- warp/config.py +1 -1
- warp/context.py +243 -32
- warp/examples/interop/example_jax_kernel.py +2 -1
- warp/jax_experimental/custom_call.py +24 -1
- warp/jax_experimental/ffi.py +20 -0
- warp/jax_experimental/xla_ffi.py +16 -7
- warp/native/builtin.h +4 -4
- warp/native/sort.cu +22 -13
- warp/native/sort.h +2 -0
- warp/native/tile.h +188 -13
- warp/native/vec.h +0 -53
- warp/native/warp.cpp +3 -3
- warp/native/warp.cu +60 -30
- warp/native/warp.h +3 -3
- warp/render/render_opengl.py +14 -12
- warp/render/render_usd.py +1 -0
- warp/tests/geometry/test_hash_grid.py +38 -0
- warp/tests/interop/test_jax.py +608 -28
- warp/tests/test_array.py +2 -0
- warp/tests/test_codegen.py +1 -1
- warp/tests/test_fem.py +4 -4
- warp/tests/test_map.py +14 -0
- warp/tests/test_tuple.py +96 -0
- warp/tests/test_types.py +61 -0
- warp/tests/tile/test_tile.py +61 -0
- warp/types.py +17 -3
- {warp_lang-1.9.0.dist-info → warp_lang-1.9.1.dist-info}/METADATA +5 -8
- {warp_lang-1.9.0.dist-info → warp_lang-1.9.1.dist-info}/RECORD +37 -37
- {warp_lang-1.9.0.dist-info → warp_lang-1.9.1.dist-info}/WHEEL +0 -0
- {warp_lang-1.9.0.dist-info → warp_lang-1.9.1.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.9.0.dist-info → warp_lang-1.9.1.dist-info}/top_level.txt +0 -0
warp/tests/interop/test_jax.py
CHANGED
|
@@ -15,11 +15,13 @@
|
|
|
15
15
|
|
|
16
16
|
import os
|
|
17
17
|
import unittest
|
|
18
|
+
from functools import partial
|
|
18
19
|
from typing import Any
|
|
19
20
|
|
|
20
21
|
import numpy as np
|
|
21
22
|
|
|
22
23
|
import warp as wp
|
|
24
|
+
from warp.jax import get_jax_device
|
|
23
25
|
from warp.tests.unittest_utils import *
|
|
24
26
|
|
|
25
27
|
|
|
@@ -132,15 +134,19 @@ def test_device_conversion(test, device):
|
|
|
132
134
|
test.assertEqual(warp_device, device)
|
|
133
135
|
|
|
134
136
|
|
|
135
|
-
|
|
136
|
-
def test_jax_kernel_basic(test, device):
|
|
137
|
+
def test_jax_kernel_basic(test, device, use_ffi=False):
|
|
137
138
|
import jax.numpy as jp
|
|
138
139
|
|
|
139
|
-
|
|
140
|
+
if use_ffi:
|
|
141
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
140
142
|
|
|
141
|
-
|
|
143
|
+
jax_triple = jax_kernel(triple_kernel)
|
|
144
|
+
else:
|
|
145
|
+
from warp.jax_experimental.custom_call import jax_kernel
|
|
146
|
+
|
|
147
|
+
jax_triple = jax_kernel(triple_kernel, quiet=True) # suppress deprecation warnings
|
|
142
148
|
|
|
143
|
-
|
|
149
|
+
n = 64
|
|
144
150
|
|
|
145
151
|
@jax.jit
|
|
146
152
|
def f():
|
|
@@ -157,11 +163,17 @@ def test_jax_kernel_basic(test, device):
|
|
|
157
163
|
assert_np_equal(result, expected)
|
|
158
164
|
|
|
159
165
|
|
|
160
|
-
|
|
161
|
-
def test_jax_kernel_scalar(test, device):
|
|
166
|
+
def test_jax_kernel_scalar(test, device, use_ffi=False):
|
|
162
167
|
import jax.numpy as jp
|
|
163
168
|
|
|
164
|
-
|
|
169
|
+
if use_ffi:
|
|
170
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
171
|
+
|
|
172
|
+
kwargs = {}
|
|
173
|
+
else:
|
|
174
|
+
from warp.jax_experimental.custom_call import jax_kernel
|
|
175
|
+
|
|
176
|
+
kwargs = {"quiet": True}
|
|
165
177
|
|
|
166
178
|
n = 64
|
|
167
179
|
|
|
@@ -173,7 +185,7 @@ def test_jax_kernel_scalar(test, device):
|
|
|
173
185
|
# get the concrete overload
|
|
174
186
|
kernel_instance = triple_kernel_scalar.add_overload([wp.array(dtype=T), wp.array(dtype=T)])
|
|
175
187
|
|
|
176
|
-
jax_triple = jax_kernel(kernel_instance)
|
|
188
|
+
jax_triple = jax_kernel(kernel_instance, **kwargs)
|
|
177
189
|
|
|
178
190
|
@jax.jit
|
|
179
191
|
def f(jax_triple=jax_triple, jp_dtype=jp_dtype):
|
|
@@ -190,11 +202,17 @@ def test_jax_kernel_scalar(test, device):
|
|
|
190
202
|
assert_np_equal(result, expected)
|
|
191
203
|
|
|
192
204
|
|
|
193
|
-
|
|
194
|
-
def test_jax_kernel_vecmat(test, device):
|
|
205
|
+
def test_jax_kernel_vecmat(test, device, use_ffi=False):
|
|
195
206
|
import jax.numpy as jp
|
|
196
207
|
|
|
197
|
-
|
|
208
|
+
if use_ffi:
|
|
209
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
210
|
+
|
|
211
|
+
kwargs = {}
|
|
212
|
+
else:
|
|
213
|
+
from warp.jax_experimental.custom_call import jax_kernel
|
|
214
|
+
|
|
215
|
+
kwargs = {"quiet": True}
|
|
198
216
|
|
|
199
217
|
for T in [*vector_types, *matrix_types]:
|
|
200
218
|
jp_dtype = wp.dtype_to_jax(T._wp_scalar_type_)
|
|
@@ -208,7 +226,7 @@ def test_jax_kernel_vecmat(test, device):
|
|
|
208
226
|
# get the concrete overload
|
|
209
227
|
kernel_instance = triple_kernel_vecmat.add_overload([wp.array(dtype=T), wp.array(dtype=T)])
|
|
210
228
|
|
|
211
|
-
jax_triple = jax_kernel(kernel_instance)
|
|
229
|
+
jax_triple = jax_kernel(kernel_instance, **kwargs)
|
|
212
230
|
|
|
213
231
|
@jax.jit
|
|
214
232
|
def f(jax_triple=jax_triple, jp_dtype=jp_dtype, scalar_len=scalar_len, scalar_shape=scalar_shape):
|
|
@@ -225,15 +243,19 @@ def test_jax_kernel_vecmat(test, device):
|
|
|
225
243
|
assert_np_equal(result, expected)
|
|
226
244
|
|
|
227
245
|
|
|
228
|
-
|
|
229
|
-
def test_jax_kernel_multiarg(test, device):
|
|
246
|
+
def test_jax_kernel_multiarg(test, device, use_ffi=False):
|
|
230
247
|
import jax.numpy as jp
|
|
231
248
|
|
|
232
|
-
|
|
249
|
+
if use_ffi:
|
|
250
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
233
251
|
|
|
234
|
-
|
|
252
|
+
jax_multiarg = jax_kernel(multiarg_kernel, num_outputs=2)
|
|
253
|
+
else:
|
|
254
|
+
from warp.jax_experimental.custom_call import jax_kernel
|
|
235
255
|
|
|
236
|
-
|
|
256
|
+
jax_multiarg = jax_kernel(multiarg_kernel, quiet=True)
|
|
257
|
+
|
|
258
|
+
n = 64
|
|
237
259
|
|
|
238
260
|
@jax.jit
|
|
239
261
|
def f():
|
|
@@ -254,11 +276,17 @@ def test_jax_kernel_multiarg(test, device):
|
|
|
254
276
|
assert_np_equal(result_y, expected_y)
|
|
255
277
|
|
|
256
278
|
|
|
257
|
-
|
|
258
|
-
def test_jax_kernel_launch_dims(test, device):
|
|
279
|
+
def test_jax_kernel_launch_dims(test, device, use_ffi=False):
|
|
259
280
|
import jax.numpy as jp
|
|
260
281
|
|
|
261
|
-
|
|
282
|
+
if use_ffi:
|
|
283
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
284
|
+
|
|
285
|
+
kwargs = {}
|
|
286
|
+
else:
|
|
287
|
+
from warp.jax_experimental.custom_call import jax_kernel
|
|
288
|
+
|
|
289
|
+
kwargs = {"quiet": True}
|
|
262
290
|
|
|
263
291
|
n = 64
|
|
264
292
|
m = 32
|
|
@@ -270,7 +298,7 @@ def test_jax_kernel_launch_dims(test, device):
|
|
|
270
298
|
y[tid] = x[tid] + 1.0
|
|
271
299
|
|
|
272
300
|
jax_add_one = jax_kernel(
|
|
273
|
-
add_one_kernel, launch_dims=(n - 2,)
|
|
301
|
+
add_one_kernel, launch_dims=(n - 2,), **kwargs
|
|
274
302
|
) # Intentionally not the same as the first dimension of the input
|
|
275
303
|
|
|
276
304
|
@jax.jit
|
|
@@ -285,7 +313,7 @@ def test_jax_kernel_launch_dims(test, device):
|
|
|
285
313
|
y[i, j] = x[i, j] + 1.0
|
|
286
314
|
|
|
287
315
|
jax_add_one_2d = jax_kernel(
|
|
288
|
-
add_one_2d_kernel, launch_dims=(n - 2, m - 2)
|
|
316
|
+
add_one_2d_kernel, launch_dims=(n - 2, m - 2), **kwargs
|
|
289
317
|
) # Intentionally not the same as the first dimension of the input
|
|
290
318
|
|
|
291
319
|
@jax.jit
|
|
@@ -308,6 +336,462 @@ def test_jax_kernel_launch_dims(test, device):
|
|
|
308
336
|
assert_np_equal(result_2d, expected_2d)
|
|
309
337
|
|
|
310
338
|
|
|
339
|
+
# =========================================================================================================
|
|
340
|
+
# JAX FFI
|
|
341
|
+
# =========================================================================================================
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
@wp.kernel
|
|
345
|
+
def add_kernel(a: wp.array(dtype=int), b: wp.array(dtype=int), output: wp.array(dtype=int)):
|
|
346
|
+
tid = wp.tid()
|
|
347
|
+
output[tid] = a[tid] + b[tid]
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
@wp.kernel
|
|
351
|
+
def sincos_kernel(angle: wp.array(dtype=float), sin_out: wp.array(dtype=float), cos_out: wp.array(dtype=float)):
|
|
352
|
+
tid = wp.tid()
|
|
353
|
+
sin_out[tid] = wp.sin(angle[tid])
|
|
354
|
+
cos_out[tid] = wp.cos(angle[tid])
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
@wp.kernel
|
|
358
|
+
def diagonal_kernel(output: wp.array(dtype=wp.mat33)):
|
|
359
|
+
tid = wp.tid()
|
|
360
|
+
d = float(tid + 1)
|
|
361
|
+
output[tid] = wp.mat33(d, 0.0, 0.0, 0.0, d * 2.0, 0.0, 0.0, 0.0, d * 3.0)
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
@wp.kernel
|
|
365
|
+
def scale_kernel(a: wp.array(dtype=float), s: float, output: wp.array(dtype=float)):
|
|
366
|
+
tid = wp.tid()
|
|
367
|
+
output[tid] = a[tid] * s
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
@wp.kernel
|
|
371
|
+
def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtype=wp.vec2)):
|
|
372
|
+
tid = wp.tid()
|
|
373
|
+
output[tid] = a[tid] * s
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
@wp.kernel
|
|
377
|
+
def accum_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float)):
|
|
378
|
+
tid = wp.tid()
|
|
379
|
+
b[tid] += a[tid]
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
@wp.kernel
|
|
383
|
+
def matmul_kernel(
|
|
384
|
+
a: wp.array2d(dtype=float), # NxK
|
|
385
|
+
b: wp.array2d(dtype=float), # KxM
|
|
386
|
+
c: wp.array2d(dtype=float), # NxM
|
|
387
|
+
):
|
|
388
|
+
# launch dims should be (N, M)
|
|
389
|
+
i, j = wp.tid()
|
|
390
|
+
N = a.shape[0]
|
|
391
|
+
K = a.shape[1]
|
|
392
|
+
M = b.shape[1]
|
|
393
|
+
if i < N and j < M:
|
|
394
|
+
s = wp.float32(0)
|
|
395
|
+
for k in range(K):
|
|
396
|
+
s += a[i, k] * b[k, j]
|
|
397
|
+
c[i, j] = s
|
|
398
|
+
|
|
399
|
+
|
|
400
|
+
@wp.kernel
|
|
401
|
+
def in_out_kernel(
|
|
402
|
+
a: wp.array(dtype=float), # input only
|
|
403
|
+
b: wp.array(dtype=float), # input and output
|
|
404
|
+
c: wp.array(dtype=float), # output only
|
|
405
|
+
):
|
|
406
|
+
tid = wp.tid()
|
|
407
|
+
b[tid] += a[tid]
|
|
408
|
+
c[tid] = 2.0 * a[tid]
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
# The Python function to call.
|
|
412
|
+
# Note the argument annotations, just like Warp kernels.
|
|
413
|
+
def scale_func(
|
|
414
|
+
# inputs
|
|
415
|
+
a: wp.array(dtype=float),
|
|
416
|
+
b: wp.array(dtype=wp.vec2),
|
|
417
|
+
s: float,
|
|
418
|
+
# outputs
|
|
419
|
+
c: wp.array(dtype=float),
|
|
420
|
+
d: wp.array(dtype=wp.vec2),
|
|
421
|
+
):
|
|
422
|
+
wp.launch(scale_kernel, dim=a.shape, inputs=[a, s], outputs=[c])
|
|
423
|
+
wp.launch(scale_vec_kernel, dim=b.shape, inputs=[b, s], outputs=[d])
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def in_out_func(
|
|
427
|
+
a: wp.array(dtype=float), # input only
|
|
428
|
+
b: wp.array(dtype=float), # input and output
|
|
429
|
+
c: wp.array(dtype=float), # output only
|
|
430
|
+
):
|
|
431
|
+
wp.launch(scale_kernel, dim=a.size, inputs=[a, 2.0], outputs=[c])
|
|
432
|
+
wp.launch(accum_kernel, dim=a.size, inputs=[a, b]) # modifies `b`
|
|
433
|
+
|
|
434
|
+
|
|
435
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
436
|
+
def test_ffi_jax_kernel_add(test, device):
|
|
437
|
+
# two inputs and one output
|
|
438
|
+
import jax.numpy as jp
|
|
439
|
+
|
|
440
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
441
|
+
|
|
442
|
+
jax_add = jax_kernel(add_kernel)
|
|
443
|
+
|
|
444
|
+
@jax.jit
|
|
445
|
+
def f():
|
|
446
|
+
n = 10
|
|
447
|
+
a = jp.arange(n, dtype=jp.int32)
|
|
448
|
+
b = jp.ones(n, dtype=jp.int32)
|
|
449
|
+
return jax_add(a, b)
|
|
450
|
+
|
|
451
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
452
|
+
(y,) = f()
|
|
453
|
+
|
|
454
|
+
result = np.asarray(y)
|
|
455
|
+
expected = np.arange(1, 11, dtype=np.int32)
|
|
456
|
+
|
|
457
|
+
assert_np_equal(result, expected)
|
|
458
|
+
|
|
459
|
+
|
|
460
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
461
|
+
def test_ffi_jax_kernel_sincos(test, device):
|
|
462
|
+
# one input and two outputs
|
|
463
|
+
import jax.numpy as jp
|
|
464
|
+
|
|
465
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
466
|
+
|
|
467
|
+
jax_sincos = jax_kernel(sincos_kernel, num_outputs=2)
|
|
468
|
+
n = 32
|
|
469
|
+
|
|
470
|
+
@jax.jit
|
|
471
|
+
def f():
|
|
472
|
+
a = jp.linspace(0, 2 * jp.pi, n, dtype=jp.float32)
|
|
473
|
+
return jax_sincos(a)
|
|
474
|
+
|
|
475
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
476
|
+
s, c = f()
|
|
477
|
+
|
|
478
|
+
result_s = np.asarray(s)
|
|
479
|
+
result_c = np.asarray(c)
|
|
480
|
+
|
|
481
|
+
a = np.linspace(0, 2 * np.pi, n, dtype=np.float32)
|
|
482
|
+
expected_s = np.sin(a)
|
|
483
|
+
expected_c = np.cos(a)
|
|
484
|
+
|
|
485
|
+
assert_np_equal(result_s, expected_s, tol=1e-4)
|
|
486
|
+
assert_np_equal(result_c, expected_c, tol=1e-4)
|
|
487
|
+
|
|
488
|
+
|
|
489
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
490
|
+
def test_ffi_jax_kernel_diagonal(test, device):
|
|
491
|
+
# no inputs and one output
|
|
492
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
493
|
+
|
|
494
|
+
jax_diagonal = jax_kernel(diagonal_kernel)
|
|
495
|
+
|
|
496
|
+
@jax.jit
|
|
497
|
+
def f():
|
|
498
|
+
# launch dimensions determine output size
|
|
499
|
+
return jax_diagonal(launch_dims=4)
|
|
500
|
+
|
|
501
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
502
|
+
(d,) = f()
|
|
503
|
+
|
|
504
|
+
result = np.asarray(d)
|
|
505
|
+
expected = np.array(
|
|
506
|
+
[
|
|
507
|
+
[[1.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 3.0]],
|
|
508
|
+
[[2.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 6.0]],
|
|
509
|
+
[[3.0, 0.0, 0.0], [0.0, 6.0, 0.0], [0.0, 0.0, 9.0]],
|
|
510
|
+
[[4.0, 0.0, 0.0], [0.0, 8.0, 0.0], [0.0, 0.0, 12.0]],
|
|
511
|
+
],
|
|
512
|
+
dtype=np.float32,
|
|
513
|
+
)
|
|
514
|
+
|
|
515
|
+
assert_np_equal(result, expected)
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
519
|
+
def test_ffi_jax_kernel_in_out(test, device):
|
|
520
|
+
# in-out args
|
|
521
|
+
import jax.numpy as jp
|
|
522
|
+
|
|
523
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
524
|
+
|
|
525
|
+
jax_func = jax_kernel(in_out_kernel, num_outputs=2, in_out_argnames=["b"])
|
|
526
|
+
|
|
527
|
+
f = jax.jit(jax_func)
|
|
528
|
+
|
|
529
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
530
|
+
a = jp.ones(10, dtype=jp.float32)
|
|
531
|
+
b = jp.arange(10, dtype=jp.float32)
|
|
532
|
+
b, c = f(a, b)
|
|
533
|
+
|
|
534
|
+
assert_np_equal(b, np.arange(1, 11, dtype=np.float32))
|
|
535
|
+
assert_np_equal(c, np.full(10, 2, dtype=np.float32))
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
539
|
+
def test_ffi_jax_kernel_scale_vec_constant(test, device):
|
|
540
|
+
# multiply vectors by scalar (constant)
|
|
541
|
+
import jax.numpy as jp
|
|
542
|
+
|
|
543
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
544
|
+
|
|
545
|
+
jax_scale_vec = jax_kernel(scale_vec_kernel)
|
|
546
|
+
|
|
547
|
+
@jax.jit
|
|
548
|
+
def f():
|
|
549
|
+
a = jp.arange(10, dtype=jp.float32).reshape((5, 2)) # array of vec2
|
|
550
|
+
s = 2.0
|
|
551
|
+
return jax_scale_vec(a, s)
|
|
552
|
+
|
|
553
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
554
|
+
(b,) = f()
|
|
555
|
+
|
|
556
|
+
expected = 2 * np.arange(10, dtype=np.float32).reshape((5, 2))
|
|
557
|
+
|
|
558
|
+
assert_np_equal(b, expected)
|
|
559
|
+
|
|
560
|
+
|
|
561
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
562
|
+
def test_ffi_jax_kernel_scale_vec_static(test, device):
|
|
563
|
+
# multiply vectors by scalar (static arg)
|
|
564
|
+
import jax.numpy as jp
|
|
565
|
+
|
|
566
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
567
|
+
|
|
568
|
+
jax_scale_vec = jax_kernel(scale_vec_kernel)
|
|
569
|
+
|
|
570
|
+
# NOTE: scalar arguments must be static compile-time constants
|
|
571
|
+
@partial(jax.jit, static_argnames=["s"])
|
|
572
|
+
def f(a, s):
|
|
573
|
+
return jax_scale_vec(a, s)
|
|
574
|
+
|
|
575
|
+
a = jp.arange(10, dtype=jp.float32).reshape((5, 2)) # array of vec2
|
|
576
|
+
s = 3.0
|
|
577
|
+
|
|
578
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
579
|
+
(b,) = f(a, s)
|
|
580
|
+
|
|
581
|
+
expected = 3 * np.arange(10, dtype=np.float32).reshape((5, 2))
|
|
582
|
+
|
|
583
|
+
assert_np_equal(b, expected)
|
|
584
|
+
|
|
585
|
+
|
|
586
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
587
|
+
def test_ffi_jax_kernel_launch_dims_default(test, device):
|
|
588
|
+
# specify default launch dims
|
|
589
|
+
import jax.numpy as jp
|
|
590
|
+
|
|
591
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
592
|
+
|
|
593
|
+
N, M, K = 3, 4, 2
|
|
594
|
+
|
|
595
|
+
jax_matmul = jax_kernel(matmul_kernel, launch_dims=(N, M))
|
|
596
|
+
|
|
597
|
+
@jax.jit
|
|
598
|
+
def f():
|
|
599
|
+
a = jp.full((N, K), 2, dtype=jp.float32)
|
|
600
|
+
b = jp.full((K, M), 3, dtype=jp.float32)
|
|
601
|
+
|
|
602
|
+
# use default launch dims
|
|
603
|
+
return jax_matmul(a, b)
|
|
604
|
+
|
|
605
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
606
|
+
(result,) = f()
|
|
607
|
+
|
|
608
|
+
expected = np.full((3, 4), 12, dtype=np.float32)
|
|
609
|
+
|
|
610
|
+
test.assertEqual(result.shape, expected.shape)
|
|
611
|
+
assert_np_equal(result, expected)
|
|
612
|
+
|
|
613
|
+
|
|
614
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
615
|
+
def test_ffi_jax_kernel_launch_dims_custom(test, device):
|
|
616
|
+
# specify custom launch dims per call
|
|
617
|
+
import jax.numpy as jp
|
|
618
|
+
|
|
619
|
+
from warp.jax_experimental.ffi import jax_kernel
|
|
620
|
+
|
|
621
|
+
jax_matmul = jax_kernel(matmul_kernel)
|
|
622
|
+
|
|
623
|
+
@jax.jit
|
|
624
|
+
def f():
|
|
625
|
+
N1, M1, K1 = 3, 4, 2
|
|
626
|
+
a1 = jp.full((N1, K1), 2, dtype=jp.float32)
|
|
627
|
+
b1 = jp.full((K1, M1), 3, dtype=jp.float32)
|
|
628
|
+
|
|
629
|
+
# use custom launch dims
|
|
630
|
+
result1 = jax_matmul(a1, b1, launch_dims=(N1, M1))
|
|
631
|
+
|
|
632
|
+
N2, M2, K2 = 4, 3, 2
|
|
633
|
+
a2 = jp.full((N2, K2), 2, dtype=jp.float32)
|
|
634
|
+
b2 = jp.full((K2, M2), 3, dtype=jp.float32)
|
|
635
|
+
|
|
636
|
+
# use different custom launch dims
|
|
637
|
+
result2 = jax_matmul(a2, b2, launch_dims=(N2, M2))
|
|
638
|
+
|
|
639
|
+
return result1[0], result2[0]
|
|
640
|
+
|
|
641
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
642
|
+
result1, result2 = f()
|
|
643
|
+
|
|
644
|
+
expected1 = np.full((3, 4), 12, dtype=np.float32)
|
|
645
|
+
expected2 = np.full((4, 3), 12, dtype=np.float32)
|
|
646
|
+
|
|
647
|
+
test.assertEqual(result1.shape, expected1.shape)
|
|
648
|
+
test.assertEqual(result2.shape, expected2.shape)
|
|
649
|
+
assert_np_equal(result1, expected1)
|
|
650
|
+
assert_np_equal(result2, expected2)
|
|
651
|
+
|
|
652
|
+
|
|
653
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
654
|
+
def test_ffi_jax_callable_scale_constant(test, device):
|
|
655
|
+
# scale two arrays using a constant
|
|
656
|
+
import jax.numpy as jp
|
|
657
|
+
|
|
658
|
+
from warp.jax_experimental.ffi import jax_callable
|
|
659
|
+
|
|
660
|
+
jax_func = jax_callable(scale_func, num_outputs=2)
|
|
661
|
+
|
|
662
|
+
@jax.jit
|
|
663
|
+
def f():
|
|
664
|
+
# inputs
|
|
665
|
+
a = jp.arange(10, dtype=jp.float32)
|
|
666
|
+
b = jp.arange(10, dtype=jp.float32).reshape((5, 2)) # wp.vec2
|
|
667
|
+
s = 2.0
|
|
668
|
+
|
|
669
|
+
# output shapes
|
|
670
|
+
output_dims = {"c": a.shape, "d": b.shape}
|
|
671
|
+
|
|
672
|
+
c, d = jax_func(a, b, s, output_dims=output_dims)
|
|
673
|
+
|
|
674
|
+
return c, d
|
|
675
|
+
|
|
676
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
677
|
+
result1, result2 = f()
|
|
678
|
+
|
|
679
|
+
expected1 = 2 * np.arange(10, dtype=np.float32)
|
|
680
|
+
expected2 = 2 * np.arange(10, dtype=np.float32).reshape((5, 2))
|
|
681
|
+
|
|
682
|
+
assert_np_equal(result1, expected1)
|
|
683
|
+
assert_np_equal(result2, expected2)
|
|
684
|
+
|
|
685
|
+
|
|
686
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
687
|
+
def test_ffi_jax_callable_scale_static(test, device):
|
|
688
|
+
# scale two arrays using a static arg
|
|
689
|
+
import jax.numpy as jp
|
|
690
|
+
|
|
691
|
+
from warp.jax_experimental.ffi import jax_callable
|
|
692
|
+
|
|
693
|
+
jax_func = jax_callable(scale_func, num_outputs=2)
|
|
694
|
+
|
|
695
|
+
# NOTE: scalar arguments must be static compile-time constants
|
|
696
|
+
@partial(jax.jit, static_argnames=["s"])
|
|
697
|
+
def f(a, b, s):
|
|
698
|
+
# output shapes
|
|
699
|
+
output_dims = {"c": a.shape, "d": b.shape}
|
|
700
|
+
|
|
701
|
+
c, d = jax_func(a, b, s, output_dims=output_dims)
|
|
702
|
+
|
|
703
|
+
return c, d
|
|
704
|
+
|
|
705
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
706
|
+
# inputs
|
|
707
|
+
a = jp.arange(10, dtype=jp.float32)
|
|
708
|
+
b = jp.arange(10, dtype=jp.float32).reshape((5, 2)) # wp.vec2
|
|
709
|
+
s = 3.0
|
|
710
|
+
result1, result2 = f(a, b, s)
|
|
711
|
+
|
|
712
|
+
expected1 = 3 * np.arange(10, dtype=np.float32)
|
|
713
|
+
expected2 = 3 * np.arange(10, dtype=np.float32).reshape((5, 2))
|
|
714
|
+
|
|
715
|
+
assert_np_equal(result1, expected1)
|
|
716
|
+
assert_np_equal(result2, expected2)
|
|
717
|
+
|
|
718
|
+
|
|
719
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
720
|
+
def test_ffi_jax_callable_in_out(test, device):
|
|
721
|
+
# in-out arguments
|
|
722
|
+
import jax.numpy as jp
|
|
723
|
+
|
|
724
|
+
from warp.jax_experimental.ffi import jax_callable
|
|
725
|
+
|
|
726
|
+
jax_func = jax_callable(in_out_func, num_outputs=2, in_out_argnames=["b"])
|
|
727
|
+
|
|
728
|
+
f = jax.jit(jax_func)
|
|
729
|
+
|
|
730
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
731
|
+
a = jp.ones(10, dtype=jp.float32)
|
|
732
|
+
b = jp.arange(10, dtype=jp.float32)
|
|
733
|
+
b, c = f(a, b)
|
|
734
|
+
|
|
735
|
+
assert_np_equal(b, np.arange(1, 11, dtype=np.float32))
|
|
736
|
+
assert_np_equal(c, np.full(10, 2, dtype=np.float32))
|
|
737
|
+
|
|
738
|
+
|
|
739
|
+
@unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
|
|
740
|
+
def test_ffi_callback(test, device):
|
|
741
|
+
# in-out arguments
|
|
742
|
+
import jax.numpy as jp
|
|
743
|
+
|
|
744
|
+
from warp.jax_experimental.ffi import register_ffi_callback
|
|
745
|
+
|
|
746
|
+
# the Python function to call
|
|
747
|
+
def warp_func(inputs, outputs, attrs, ctx):
|
|
748
|
+
# input arrays
|
|
749
|
+
a = inputs[0]
|
|
750
|
+
b = inputs[1]
|
|
751
|
+
|
|
752
|
+
# scalar attributes
|
|
753
|
+
s = attrs["scale"]
|
|
754
|
+
|
|
755
|
+
# output arrays
|
|
756
|
+
c = outputs[0]
|
|
757
|
+
d = outputs[1]
|
|
758
|
+
|
|
759
|
+
device = wp.device_from_jax(get_jax_device())
|
|
760
|
+
stream = wp.Stream(device, cuda_stream=ctx.stream)
|
|
761
|
+
|
|
762
|
+
with wp.ScopedStream(stream):
|
|
763
|
+
# launch with arrays of scalars
|
|
764
|
+
wp.launch(scale_kernel, dim=a.shape, inputs=[a, s], outputs=[c])
|
|
765
|
+
|
|
766
|
+
# launch with arrays of vec2
|
|
767
|
+
# NOTE: the input shapes are from JAX arrays, we need to strip the inner dimension for vec2 arrays
|
|
768
|
+
wp.launch(scale_vec_kernel, dim=b.shape[0], inputs=[b, s], outputs=[d])
|
|
769
|
+
|
|
770
|
+
# register callback
|
|
771
|
+
register_ffi_callback("warp_func", warp_func)
|
|
772
|
+
|
|
773
|
+
n = 10
|
|
774
|
+
|
|
775
|
+
with jax.default_device(wp.device_to_jax(device)):
|
|
776
|
+
# inputs
|
|
777
|
+
a = jp.arange(n, dtype=jp.float32)
|
|
778
|
+
b = jp.arange(n, dtype=jp.float32).reshape((n // 2, 2)) # array of wp.vec2
|
|
779
|
+
s = 2.0
|
|
780
|
+
|
|
781
|
+
# set up call
|
|
782
|
+
out_types = [
|
|
783
|
+
jax.ShapeDtypeStruct(a.shape, jp.float32),
|
|
784
|
+
jax.ShapeDtypeStruct(b.shape, jp.float32), # array of wp.vec2
|
|
785
|
+
]
|
|
786
|
+
call = jax.ffi.ffi_call("warp_func", out_types)
|
|
787
|
+
|
|
788
|
+
# call it
|
|
789
|
+
c, d = call(a, b, scale=s)
|
|
790
|
+
|
|
791
|
+
assert_np_equal(c, 2 * np.arange(10, dtype=np.float32))
|
|
792
|
+
assert_np_equal(d, 2 * np.arange(10, dtype=np.float32).reshape((5, 2)))
|
|
793
|
+
|
|
794
|
+
|
|
311
795
|
class TestJax(unittest.TestCase):
|
|
312
796
|
pass
|
|
313
797
|
|
|
@@ -346,20 +830,116 @@ try:
|
|
|
346
830
|
add_function_test(TestJax, "test_device_conversion", test_device_conversion, devices=jax_compatible_devices)
|
|
347
831
|
|
|
348
832
|
if jax_compatible_cuda_devices:
|
|
349
|
-
|
|
833
|
+
# tests for both custom_call and ffi variants of jax_kernel(), selected by installed JAX version
|
|
834
|
+
if jax.__version_info__ < (0, 4, 25):
|
|
835
|
+
# no interop supported
|
|
836
|
+
ffi_opts = []
|
|
837
|
+
elif jax.__version_info__ < (0, 5, 0):
|
|
838
|
+
# only custom_call supported
|
|
839
|
+
ffi_opts = [False]
|
|
840
|
+
elif jax.__version_info__ < (0, 8, 0):
|
|
841
|
+
# both custom_call and ffi supported
|
|
842
|
+
ffi_opts = [False, True]
|
|
843
|
+
else:
|
|
844
|
+
# only ffi supported
|
|
845
|
+
ffi_opts = [True]
|
|
846
|
+
|
|
847
|
+
for use_ffi in ffi_opts:
|
|
848
|
+
suffix = "ffi" if use_ffi else "cc"
|
|
849
|
+
add_function_test(
|
|
850
|
+
TestJax,
|
|
851
|
+
f"test_jax_kernel_basic_{suffix}",
|
|
852
|
+
test_jax_kernel_basic,
|
|
853
|
+
devices=jax_compatible_cuda_devices,
|
|
854
|
+
use_ffi=use_ffi,
|
|
855
|
+
)
|
|
856
|
+
add_function_test(
|
|
857
|
+
TestJax,
|
|
858
|
+
f"test_jax_kernel_scalar_{suffix}",
|
|
859
|
+
test_jax_kernel_scalar,
|
|
860
|
+
devices=jax_compatible_cuda_devices,
|
|
861
|
+
use_ffi=use_ffi,
|
|
862
|
+
)
|
|
863
|
+
add_function_test(
|
|
864
|
+
TestJax,
|
|
865
|
+
f"test_jax_kernel_vecmat_{suffix}",
|
|
866
|
+
test_jax_kernel_vecmat,
|
|
867
|
+
devices=jax_compatible_cuda_devices,
|
|
868
|
+
use_ffi=use_ffi,
|
|
869
|
+
)
|
|
870
|
+
add_function_test(
|
|
871
|
+
TestJax,
|
|
872
|
+
f"test_jax_kernel_multiarg_{suffix}",
|
|
873
|
+
test_jax_kernel_multiarg,
|
|
874
|
+
devices=jax_compatible_cuda_devices,
|
|
875
|
+
use_ffi=use_ffi,
|
|
876
|
+
)
|
|
877
|
+
add_function_test(
|
|
878
|
+
TestJax,
|
|
879
|
+
f"test_jax_kernel_launch_dims_{suffix}",
|
|
880
|
+
test_jax_kernel_launch_dims,
|
|
881
|
+
devices=jax_compatible_cuda_devices,
|
|
882
|
+
use_ffi=use_ffi,
|
|
883
|
+
)
|
|
884
|
+
|
|
885
|
+
# ffi.jax_kernel() tests
|
|
886
|
+
add_function_test(
|
|
887
|
+
TestJax, "test_ffi_jax_kernel_add", test_ffi_jax_kernel_add, devices=jax_compatible_cuda_devices
|
|
888
|
+
)
|
|
350
889
|
add_function_test(
|
|
351
|
-
TestJax, "
|
|
890
|
+
TestJax, "test_ffi_jax_kernel_sincos", test_ffi_jax_kernel_sincos, devices=jax_compatible_cuda_devices
|
|
352
891
|
)
|
|
353
892
|
add_function_test(
|
|
354
|
-
TestJax, "
|
|
893
|
+
TestJax, "test_ffi_jax_kernel_diagonal", test_ffi_jax_kernel_diagonal, devices=jax_compatible_cuda_devices
|
|
355
894
|
)
|
|
356
895
|
add_function_test(
|
|
357
|
-
TestJax, "
|
|
896
|
+
TestJax, "test_ffi_jax_kernel_in_out", test_ffi_jax_kernel_in_out, devices=jax_compatible_cuda_devices
|
|
897
|
+
)
|
|
898
|
+
add_function_test(
|
|
899
|
+
TestJax,
|
|
900
|
+
"test_ffi_jax_kernel_scale_vec_constant",
|
|
901
|
+
test_ffi_jax_kernel_scale_vec_constant,
|
|
902
|
+
devices=jax_compatible_cuda_devices,
|
|
903
|
+
)
|
|
904
|
+
add_function_test(
|
|
905
|
+
TestJax,
|
|
906
|
+
"test_ffi_jax_kernel_scale_vec_static",
|
|
907
|
+
test_ffi_jax_kernel_scale_vec_static,
|
|
908
|
+
devices=jax_compatible_cuda_devices,
|
|
909
|
+
)
|
|
910
|
+
add_function_test(
|
|
911
|
+
TestJax,
|
|
912
|
+
"test_ffi_jax_kernel_launch_dims_default",
|
|
913
|
+
test_ffi_jax_kernel_launch_dims_default,
|
|
914
|
+
devices=jax_compatible_cuda_devices,
|
|
915
|
+
)
|
|
916
|
+
add_function_test(
|
|
917
|
+
TestJax,
|
|
918
|
+
"test_ffi_jax_kernel_launch_dims_custom",
|
|
919
|
+
test_ffi_jax_kernel_launch_dims_custom,
|
|
920
|
+
devices=jax_compatible_cuda_devices,
|
|
358
921
|
)
|
|
359
922
|
|
|
923
|
+
# ffi.jax_callable() tests
|
|
924
|
+
add_function_test(
|
|
925
|
+
TestJax,
|
|
926
|
+
"test_ffi_jax_callable_scale_constant",
|
|
927
|
+
test_ffi_jax_callable_scale_constant,
|
|
928
|
+
devices=jax_compatible_cuda_devices,
|
|
929
|
+
)
|
|
360
930
|
add_function_test(
|
|
361
|
-
TestJax,
|
|
931
|
+
TestJax,
|
|
932
|
+
"test_ffi_jax_callable_scale_static",
|
|
933
|
+
test_ffi_jax_callable_scale_static,
|
|
934
|
+
devices=jax_compatible_cuda_devices,
|
|
362
935
|
)
|
|
936
|
+
add_function_test(
|
|
937
|
+
TestJax, "test_ffi_jax_callable_in_out", test_ffi_jax_callable_in_out, devices=jax_compatible_cuda_devices
|
|
938
|
+
)
|
|
939
|
+
|
|
940
|
+
# ffi callback tests
|
|
941
|
+
add_function_test(TestJax, "test_ffi_callback", test_ffi_callback, devices=jax_compatible_cuda_devices)
|
|
942
|
+
|
|
363
943
|
|
|
364
944
|
except Exception as e:
|
|
365
945
|
print(f"Skipping Jax tests due to exception: {e}")
|