warp-lang 1.9.0__py3-none-macosx_10_13_universal2.whl → 1.9.1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

@@ -15,11 +15,13 @@
15
15
 
16
16
  import os
17
17
  import unittest
18
+ from functools import partial
18
19
  from typing import Any
19
20
 
20
21
  import numpy as np
21
22
 
22
23
  import warp as wp
24
+ from warp.jax import get_jax_device
23
25
  from warp.tests.unittest_utils import *
24
26
 
25
27
 
@@ -132,15 +134,19 @@ def test_device_conversion(test, device):
132
134
  test.assertEqual(warp_device, device)
133
135
 
134
136
 
135
- @unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
136
- def test_jax_kernel_basic(test, device):
137
+ def test_jax_kernel_basic(test, device, use_ffi=False):
137
138
  import jax.numpy as jp
138
139
 
139
- from warp.jax_experimental import jax_kernel
140
+ if use_ffi:
141
+ from warp.jax_experimental.ffi import jax_kernel
140
142
 
141
- n = 64
143
+ jax_triple = jax_kernel(triple_kernel)
144
+ else:
145
+ from warp.jax_experimental.custom_call import jax_kernel
146
+
147
+ jax_triple = jax_kernel(triple_kernel, quiet=True) # suppress deprecation warnings
142
148
 
143
- jax_triple = jax_kernel(triple_kernel)
149
+ n = 64
144
150
 
145
151
  @jax.jit
146
152
  def f():
@@ -157,11 +163,17 @@ def test_jax_kernel_basic(test, device):
157
163
  assert_np_equal(result, expected)
158
164
 
159
165
 
160
- @unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
161
- def test_jax_kernel_scalar(test, device):
166
+ def test_jax_kernel_scalar(test, device, use_ffi=False):
162
167
  import jax.numpy as jp
163
168
 
164
- from warp.jax_experimental import jax_kernel
169
+ if use_ffi:
170
+ from warp.jax_experimental.ffi import jax_kernel
171
+
172
+ kwargs = {}
173
+ else:
174
+ from warp.jax_experimental.custom_call import jax_kernel
175
+
176
+ kwargs = {"quiet": True}
165
177
 
166
178
  n = 64
167
179
 
@@ -173,7 +185,7 @@ def test_jax_kernel_scalar(test, device):
173
185
  # get the concrete overload
174
186
  kernel_instance = triple_kernel_scalar.add_overload([wp.array(dtype=T), wp.array(dtype=T)])
175
187
 
176
- jax_triple = jax_kernel(kernel_instance)
188
+ jax_triple = jax_kernel(kernel_instance, **kwargs)
177
189
 
178
190
  @jax.jit
179
191
  def f(jax_triple=jax_triple, jp_dtype=jp_dtype):
@@ -190,11 +202,17 @@ def test_jax_kernel_scalar(test, device):
190
202
  assert_np_equal(result, expected)
191
203
 
192
204
 
193
- @unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
194
- def test_jax_kernel_vecmat(test, device):
205
+ def test_jax_kernel_vecmat(test, device, use_ffi=False):
195
206
  import jax.numpy as jp
196
207
 
197
- from warp.jax_experimental import jax_kernel
208
+ if use_ffi:
209
+ from warp.jax_experimental.ffi import jax_kernel
210
+
211
+ kwargs = {}
212
+ else:
213
+ from warp.jax_experimental.custom_call import jax_kernel
214
+
215
+ kwargs = {"quiet": True}
198
216
 
199
217
  for T in [*vector_types, *matrix_types]:
200
218
  jp_dtype = wp.dtype_to_jax(T._wp_scalar_type_)
@@ -208,7 +226,7 @@ def test_jax_kernel_vecmat(test, device):
208
226
  # get the concrete overload
209
227
  kernel_instance = triple_kernel_vecmat.add_overload([wp.array(dtype=T), wp.array(dtype=T)])
210
228
 
211
- jax_triple = jax_kernel(kernel_instance)
229
+ jax_triple = jax_kernel(kernel_instance, **kwargs)
212
230
 
213
231
  @jax.jit
214
232
  def f(jax_triple=jax_triple, jp_dtype=jp_dtype, scalar_len=scalar_len, scalar_shape=scalar_shape):
@@ -225,15 +243,19 @@ def test_jax_kernel_vecmat(test, device):
225
243
  assert_np_equal(result, expected)
226
244
 
227
245
 
228
- @unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
229
- def test_jax_kernel_multiarg(test, device):
246
+ def test_jax_kernel_multiarg(test, device, use_ffi=False):
230
247
  import jax.numpy as jp
231
248
 
232
- from warp.jax_experimental import jax_kernel
249
+ if use_ffi:
250
+ from warp.jax_experimental.ffi import jax_kernel
233
251
 
234
- n = 64
252
+ jax_multiarg = jax_kernel(multiarg_kernel, num_outputs=2)
253
+ else:
254
+ from warp.jax_experimental.custom_call import jax_kernel
235
255
 
236
- jax_multiarg = jax_kernel(multiarg_kernel)
256
+ jax_multiarg = jax_kernel(multiarg_kernel, quiet=True)
257
+
258
+ n = 64
237
259
 
238
260
  @jax.jit
239
261
  def f():
@@ -254,11 +276,17 @@ def test_jax_kernel_multiarg(test, device):
254
276
  assert_np_equal(result_y, expected_y)
255
277
 
256
278
 
257
- @unittest.skipUnless(_jax_version() >= (0, 4, 25), "Jax version too old")
258
- def test_jax_kernel_launch_dims(test, device):
279
+ def test_jax_kernel_launch_dims(test, device, use_ffi=False):
259
280
  import jax.numpy as jp
260
281
 
261
- from warp.jax_experimental import jax_kernel
282
+ if use_ffi:
283
+ from warp.jax_experimental.ffi import jax_kernel
284
+
285
+ kwargs = {}
286
+ else:
287
+ from warp.jax_experimental.custom_call import jax_kernel
288
+
289
+ kwargs = {"quiet": True}
262
290
 
263
291
  n = 64
264
292
  m = 32
@@ -270,7 +298,7 @@ def test_jax_kernel_launch_dims(test, device):
270
298
  y[tid] = x[tid] + 1.0
271
299
 
272
300
  jax_add_one = jax_kernel(
273
- add_one_kernel, launch_dims=(n - 2,)
301
+ add_one_kernel, launch_dims=(n - 2,), **kwargs
274
302
  ) # Intentionally not the same as the first dimension of the input
275
303
 
276
304
  @jax.jit
@@ -285,7 +313,7 @@ def test_jax_kernel_launch_dims(test, device):
285
313
  y[i, j] = x[i, j] + 1.0
286
314
 
287
315
  jax_add_one_2d = jax_kernel(
288
- add_one_2d_kernel, launch_dims=(n - 2, m - 2)
316
+ add_one_2d_kernel, launch_dims=(n - 2, m - 2), **kwargs
289
317
  ) # Intentionally not the same as the first dimension of the input
290
318
 
291
319
  @jax.jit
@@ -308,6 +336,462 @@ def test_jax_kernel_launch_dims(test, device):
308
336
  assert_np_equal(result_2d, expected_2d)
309
337
 
310
338
 
339
+ # =========================================================================================================
340
+ # JAX FFI
341
+ # =========================================================================================================
342
+
343
+
344
+ @wp.kernel
345
+ def add_kernel(a: wp.array(dtype=int), b: wp.array(dtype=int), output: wp.array(dtype=int)):
346
+ tid = wp.tid()
347
+ output[tid] = a[tid] + b[tid]
348
+
349
+
350
+ @wp.kernel
351
+ def sincos_kernel(angle: wp.array(dtype=float), sin_out: wp.array(dtype=float), cos_out: wp.array(dtype=float)):
352
+ tid = wp.tid()
353
+ sin_out[tid] = wp.sin(angle[tid])
354
+ cos_out[tid] = wp.cos(angle[tid])
355
+
356
+
357
+ @wp.kernel
358
+ def diagonal_kernel(output: wp.array(dtype=wp.mat33)):
359
+ tid = wp.tid()
360
+ d = float(tid + 1)
361
+ output[tid] = wp.mat33(d, 0.0, 0.0, 0.0, d * 2.0, 0.0, 0.0, 0.0, d * 3.0)
362
+
363
+
364
+ @wp.kernel
365
+ def scale_kernel(a: wp.array(dtype=float), s: float, output: wp.array(dtype=float)):
366
+ tid = wp.tid()
367
+ output[tid] = a[tid] * s
368
+
369
+
370
+ @wp.kernel
371
+ def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtype=wp.vec2)):
372
+ tid = wp.tid()
373
+ output[tid] = a[tid] * s
374
+
375
+
376
+ @wp.kernel
377
+ def accum_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float)):
378
+ tid = wp.tid()
379
+ b[tid] += a[tid]
380
+
381
+
382
+ @wp.kernel
383
+ def matmul_kernel(
384
+ a: wp.array2d(dtype=float), # NxK
385
+ b: wp.array2d(dtype=float), # KxM
386
+ c: wp.array2d(dtype=float), # NxM
387
+ ):
388
+ # launch dims should be (N, M)
389
+ i, j = wp.tid()
390
+ N = a.shape[0]
391
+ K = a.shape[1]
392
+ M = b.shape[1]
393
+ if i < N and j < M:
394
+ s = wp.float32(0)
395
+ for k in range(K):
396
+ s += a[i, k] * b[k, j]
397
+ c[i, j] = s
398
+
399
+
400
+ @wp.kernel
401
+ def in_out_kernel(
402
+ a: wp.array(dtype=float), # input only
403
+ b: wp.array(dtype=float), # input and output
404
+ c: wp.array(dtype=float), # output only
405
+ ):
406
+ tid = wp.tid()
407
+ b[tid] += a[tid]
408
+ c[tid] = 2.0 * a[tid]
409
+
410
+
411
+ # The Python function to call.
412
+ # Note the argument annotations, just like Warp kernels.
413
+ def scale_func(
414
+ # inputs
415
+ a: wp.array(dtype=float),
416
+ b: wp.array(dtype=wp.vec2),
417
+ s: float,
418
+ # outputs
419
+ c: wp.array(dtype=float),
420
+ d: wp.array(dtype=wp.vec2),
421
+ ):
422
+ wp.launch(scale_kernel, dim=a.shape, inputs=[a, s], outputs=[c])
423
+ wp.launch(scale_vec_kernel, dim=b.shape, inputs=[b, s], outputs=[d])
424
+
425
+
426
+ def in_out_func(
427
+ a: wp.array(dtype=float), # input only
428
+ b: wp.array(dtype=float), # input and output
429
+ c: wp.array(dtype=float), # output only
430
+ ):
431
+ wp.launch(scale_kernel, dim=a.size, inputs=[a, 2.0], outputs=[c])
432
+ wp.launch(accum_kernel, dim=a.size, inputs=[a, b]) # modifies `b`
433
+
434
+
435
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
436
+ def test_ffi_jax_kernel_add(test, device):
437
+ # two inputs and one output
438
+ import jax.numpy as jp
439
+
440
+ from warp.jax_experimental.ffi import jax_kernel
441
+
442
+ jax_add = jax_kernel(add_kernel)
443
+
444
+ @jax.jit
445
+ def f():
446
+ n = 10
447
+ a = jp.arange(n, dtype=jp.int32)
448
+ b = jp.ones(n, dtype=jp.int32)
449
+ return jax_add(a, b)
450
+
451
+ with jax.default_device(wp.device_to_jax(device)):
452
+ (y,) = f()
453
+
454
+ result = np.asarray(y)
455
+ expected = np.arange(1, 11, dtype=np.int32)
456
+
457
+ assert_np_equal(result, expected)
458
+
459
+
460
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
461
+ def test_ffi_jax_kernel_sincos(test, device):
462
+ # one input and two outputs
463
+ import jax.numpy as jp
464
+
465
+ from warp.jax_experimental.ffi import jax_kernel
466
+
467
+ jax_sincos = jax_kernel(sincos_kernel, num_outputs=2)
468
+ n = 32
469
+
470
+ @jax.jit
471
+ def f():
472
+ a = jp.linspace(0, 2 * jp.pi, n, dtype=jp.float32)
473
+ return jax_sincos(a)
474
+
475
+ with jax.default_device(wp.device_to_jax(device)):
476
+ s, c = f()
477
+
478
+ result_s = np.asarray(s)
479
+ result_c = np.asarray(c)
480
+
481
+ a = np.linspace(0, 2 * np.pi, n, dtype=np.float32)
482
+ expected_s = np.sin(a)
483
+ expected_c = np.cos(a)
484
+
485
+ assert_np_equal(result_s, expected_s, tol=1e-4)
486
+ assert_np_equal(result_c, expected_c, tol=1e-4)
487
+
488
+
489
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
490
+ def test_ffi_jax_kernel_diagonal(test, device):
491
+ # no inputs and one output
492
+ from warp.jax_experimental.ffi import jax_kernel
493
+
494
+ jax_diagonal = jax_kernel(diagonal_kernel)
495
+
496
+ @jax.jit
497
+ def f():
498
+ # launch dimensions determine output size
499
+ return jax_diagonal(launch_dims=4)
500
+
501
+ with jax.default_device(wp.device_to_jax(device)):
502
+ (d,) = f()
503
+
504
+ result = np.asarray(d)
505
+ expected = np.array(
506
+ [
507
+ [[1.0, 0.0, 0.0], [0.0, 2.0, 0.0], [0.0, 0.0, 3.0]],
508
+ [[2.0, 0.0, 0.0], [0.0, 4.0, 0.0], [0.0, 0.0, 6.0]],
509
+ [[3.0, 0.0, 0.0], [0.0, 6.0, 0.0], [0.0, 0.0, 9.0]],
510
+ [[4.0, 0.0, 0.0], [0.0, 8.0, 0.0], [0.0, 0.0, 12.0]],
511
+ ],
512
+ dtype=np.float32,
513
+ )
514
+
515
+ assert_np_equal(result, expected)
516
+
517
+
518
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
519
+ def test_ffi_jax_kernel_in_out(test, device):
520
+ # in-out args
521
+ import jax.numpy as jp
522
+
523
+ from warp.jax_experimental.ffi import jax_kernel
524
+
525
+ jax_func = jax_kernel(in_out_kernel, num_outputs=2, in_out_argnames=["b"])
526
+
527
+ f = jax.jit(jax_func)
528
+
529
+ with jax.default_device(wp.device_to_jax(device)):
530
+ a = jp.ones(10, dtype=jp.float32)
531
+ b = jp.arange(10, dtype=jp.float32)
532
+ b, c = f(a, b)
533
+
534
+ assert_np_equal(b, np.arange(1, 11, dtype=np.float32))
535
+ assert_np_equal(c, np.full(10, 2, dtype=np.float32))
536
+
537
+
538
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
539
+ def test_ffi_jax_kernel_scale_vec_constant(test, device):
540
+ # multiply vectors by scalar (constant)
541
+ import jax.numpy as jp
542
+
543
+ from warp.jax_experimental.ffi import jax_kernel
544
+
545
+ jax_scale_vec = jax_kernel(scale_vec_kernel)
546
+
547
+ @jax.jit
548
+ def f():
549
+ a = jp.arange(10, dtype=jp.float32).reshape((5, 2)) # array of vec2
550
+ s = 2.0
551
+ return jax_scale_vec(a, s)
552
+
553
+ with jax.default_device(wp.device_to_jax(device)):
554
+ (b,) = f()
555
+
556
+ expected = 2 * np.arange(10, dtype=np.float32).reshape((5, 2))
557
+
558
+ assert_np_equal(b, expected)
559
+
560
+
561
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
562
+ def test_ffi_jax_kernel_scale_vec_static(test, device):
563
+ # multiply vectors by scalar (static arg)
564
+ import jax.numpy as jp
565
+
566
+ from warp.jax_experimental.ffi import jax_kernel
567
+
568
+ jax_scale_vec = jax_kernel(scale_vec_kernel)
569
+
570
+ # NOTE: scalar arguments must be static compile-time constants
571
+ @partial(jax.jit, static_argnames=["s"])
572
+ def f(a, s):
573
+ return jax_scale_vec(a, s)
574
+
575
+ a = jp.arange(10, dtype=jp.float32).reshape((5, 2)) # array of vec2
576
+ s = 3.0
577
+
578
+ with jax.default_device(wp.device_to_jax(device)):
579
+ (b,) = f(a, s)
580
+
581
+ expected = 3 * np.arange(10, dtype=np.float32).reshape((5, 2))
582
+
583
+ assert_np_equal(b, expected)
584
+
585
+
586
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
587
+ def test_ffi_jax_kernel_launch_dims_default(test, device):
588
+ # specify default launch dims
589
+ import jax.numpy as jp
590
+
591
+ from warp.jax_experimental.ffi import jax_kernel
592
+
593
+ N, M, K = 3, 4, 2
594
+
595
+ jax_matmul = jax_kernel(matmul_kernel, launch_dims=(N, M))
596
+
597
+ @jax.jit
598
+ def f():
599
+ a = jp.full((N, K), 2, dtype=jp.float32)
600
+ b = jp.full((K, M), 3, dtype=jp.float32)
601
+
602
+ # use default launch dims
603
+ return jax_matmul(a, b)
604
+
605
+ with jax.default_device(wp.device_to_jax(device)):
606
+ (result,) = f()
607
+
608
+ expected = np.full((3, 4), 12, dtype=np.float32)
609
+
610
+ test.assertEqual(result.shape, expected.shape)
611
+ assert_np_equal(result, expected)
612
+
613
+
614
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
615
+ def test_ffi_jax_kernel_launch_dims_custom(test, device):
616
+ # specify custom launch dims per call
617
+ import jax.numpy as jp
618
+
619
+ from warp.jax_experimental.ffi import jax_kernel
620
+
621
+ jax_matmul = jax_kernel(matmul_kernel)
622
+
623
+ @jax.jit
624
+ def f():
625
+ N1, M1, K1 = 3, 4, 2
626
+ a1 = jp.full((N1, K1), 2, dtype=jp.float32)
627
+ b1 = jp.full((K1, M1), 3, dtype=jp.float32)
628
+
629
+ # use custom launch dims
630
+ result1 = jax_matmul(a1, b1, launch_dims=(N1, M1))
631
+
632
+ N2, M2, K2 = 4, 3, 2
633
+ a2 = jp.full((N2, K2), 2, dtype=jp.float32)
634
+ b2 = jp.full((K2, M2), 3, dtype=jp.float32)
635
+
636
+ # use different custom launch dims
637
+ result2 = jax_matmul(a2, b2, launch_dims=(N2, M2))
638
+
639
+ return result1[0], result2[0]
640
+
641
+ with jax.default_device(wp.device_to_jax(device)):
642
+ result1, result2 = f()
643
+
644
+ expected1 = np.full((3, 4), 12, dtype=np.float32)
645
+ expected2 = np.full((4, 3), 12, dtype=np.float32)
646
+
647
+ test.assertEqual(result1.shape, expected1.shape)
648
+ test.assertEqual(result2.shape, expected2.shape)
649
+ assert_np_equal(result1, expected1)
650
+ assert_np_equal(result2, expected2)
651
+
652
+
653
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
654
+ def test_ffi_jax_callable_scale_constant(test, device):
655
+ # scale two arrays using a constant
656
+ import jax.numpy as jp
657
+
658
+ from warp.jax_experimental.ffi import jax_callable
659
+
660
+ jax_func = jax_callable(scale_func, num_outputs=2)
661
+
662
+ @jax.jit
663
+ def f():
664
+ # inputs
665
+ a = jp.arange(10, dtype=jp.float32)
666
+ b = jp.arange(10, dtype=jp.float32).reshape((5, 2)) # wp.vec2
667
+ s = 2.0
668
+
669
+ # output shapes
670
+ output_dims = {"c": a.shape, "d": b.shape}
671
+
672
+ c, d = jax_func(a, b, s, output_dims=output_dims)
673
+
674
+ return c, d
675
+
676
+ with jax.default_device(wp.device_to_jax(device)):
677
+ result1, result2 = f()
678
+
679
+ expected1 = 2 * np.arange(10, dtype=np.float32)
680
+ expected2 = 2 * np.arange(10, dtype=np.float32).reshape((5, 2))
681
+
682
+ assert_np_equal(result1, expected1)
683
+ assert_np_equal(result2, expected2)
684
+
685
+
686
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
687
+ def test_ffi_jax_callable_scale_static(test, device):
688
+ # scale two arrays using a static arg
689
+ import jax.numpy as jp
690
+
691
+ from warp.jax_experimental.ffi import jax_callable
692
+
693
+ jax_func = jax_callable(scale_func, num_outputs=2)
694
+
695
+ # NOTE: scalar arguments must be static compile-time constants
696
+ @partial(jax.jit, static_argnames=["s"])
697
+ def f(a, b, s):
698
+ # output shapes
699
+ output_dims = {"c": a.shape, "d": b.shape}
700
+
701
+ c, d = jax_func(a, b, s, output_dims=output_dims)
702
+
703
+ return c, d
704
+
705
+ with jax.default_device(wp.device_to_jax(device)):
706
+ # inputs
707
+ a = jp.arange(10, dtype=jp.float32)
708
+ b = jp.arange(10, dtype=jp.float32).reshape((5, 2)) # wp.vec2
709
+ s = 3.0
710
+ result1, result2 = f(a, b, s)
711
+
712
+ expected1 = 3 * np.arange(10, dtype=np.float32)
713
+ expected2 = 3 * np.arange(10, dtype=np.float32).reshape((5, 2))
714
+
715
+ assert_np_equal(result1, expected1)
716
+ assert_np_equal(result2, expected2)
717
+
718
+
719
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
720
+ def test_ffi_jax_callable_in_out(test, device):
721
+ # in-out arguments
722
+ import jax.numpy as jp
723
+
724
+ from warp.jax_experimental.ffi import jax_callable
725
+
726
+ jax_func = jax_callable(in_out_func, num_outputs=2, in_out_argnames=["b"])
727
+
728
+ f = jax.jit(jax_func)
729
+
730
+ with jax.default_device(wp.device_to_jax(device)):
731
+ a = jp.ones(10, dtype=jp.float32)
732
+ b = jp.arange(10, dtype=jp.float32)
733
+ b, c = f(a, b)
734
+
735
+ assert_np_equal(b, np.arange(1, 11, dtype=np.float32))
736
+ assert_np_equal(c, np.full(10, 2, dtype=np.float32))
737
+
738
+
739
+ @unittest.skipUnless(_jax_version() >= (0, 5, 0), "Jax version too old")
740
+ def test_ffi_callback(test, device):
741
+ # in-out arguments
742
+ import jax.numpy as jp
743
+
744
+ from warp.jax_experimental.ffi import register_ffi_callback
745
+
746
+ # the Python function to call
747
+ def warp_func(inputs, outputs, attrs, ctx):
748
+ # input arrays
749
+ a = inputs[0]
750
+ b = inputs[1]
751
+
752
+ # scalar attributes
753
+ s = attrs["scale"]
754
+
755
+ # output arrays
756
+ c = outputs[0]
757
+ d = outputs[1]
758
+
759
+ device = wp.device_from_jax(get_jax_device())
760
+ stream = wp.Stream(device, cuda_stream=ctx.stream)
761
+
762
+ with wp.ScopedStream(stream):
763
+ # launch with arrays of scalars
764
+ wp.launch(scale_kernel, dim=a.shape, inputs=[a, s], outputs=[c])
765
+
766
+ # launch with arrays of vec2
767
+ # NOTE: the input shapes are from JAX arrays, we need to strip the inner dimension for vec2 arrays
768
+ wp.launch(scale_vec_kernel, dim=b.shape[0], inputs=[b, s], outputs=[d])
769
+
770
+ # register callback
771
+ register_ffi_callback("warp_func", warp_func)
772
+
773
+ n = 10
774
+
775
+ with jax.default_device(wp.device_to_jax(device)):
776
+ # inputs
777
+ a = jp.arange(n, dtype=jp.float32)
778
+ b = jp.arange(n, dtype=jp.float32).reshape((n // 2, 2)) # array of wp.vec2
779
+ s = 2.0
780
+
781
+ # set up call
782
+ out_types = [
783
+ jax.ShapeDtypeStruct(a.shape, jp.float32),
784
+ jax.ShapeDtypeStruct(b.shape, jp.float32), # array of wp.vec2
785
+ ]
786
+ call = jax.ffi.ffi_call("warp_func", out_types)
787
+
788
+ # call it
789
+ c, d = call(a, b, scale=s)
790
+
791
+ assert_np_equal(c, 2 * np.arange(10, dtype=np.float32))
792
+ assert_np_equal(d, 2 * np.arange(10, dtype=np.float32).reshape((5, 2)))
793
+
794
+
311
795
  class TestJax(unittest.TestCase):
312
796
  pass
313
797
 
@@ -346,20 +830,116 @@ try:
346
830
  add_function_test(TestJax, "test_device_conversion", test_device_conversion, devices=jax_compatible_devices)
347
831
 
348
832
  if jax_compatible_cuda_devices:
349
- add_function_test(TestJax, "test_jax_kernel_basic", test_jax_kernel_basic, devices=jax_compatible_cuda_devices)
833
+ # tests for both custom_call and ffi variants of jax_kernel(), selected by installed JAX version
834
+ if jax.__version_info__ < (0, 4, 25):
835
+ # no interop supported
836
+ ffi_opts = []
837
+ elif jax.__version_info__ < (0, 5, 0):
838
+ # only custom_call supported
839
+ ffi_opts = [False]
840
+ elif jax.__version_info__ < (0, 8, 0):
841
+ # both custom_call and ffi supported
842
+ ffi_opts = [False, True]
843
+ else:
844
+ # only ffi supported
845
+ ffi_opts = [True]
846
+
847
+ for use_ffi in ffi_opts:
848
+ suffix = "ffi" if use_ffi else "cc"
849
+ add_function_test(
850
+ TestJax,
851
+ f"test_jax_kernel_basic_{suffix}",
852
+ test_jax_kernel_basic,
853
+ devices=jax_compatible_cuda_devices,
854
+ use_ffi=use_ffi,
855
+ )
856
+ add_function_test(
857
+ TestJax,
858
+ f"test_jax_kernel_scalar_{suffix}",
859
+ test_jax_kernel_scalar,
860
+ devices=jax_compatible_cuda_devices,
861
+ use_ffi=use_ffi,
862
+ )
863
+ add_function_test(
864
+ TestJax,
865
+ f"test_jax_kernel_vecmat_{suffix}",
866
+ test_jax_kernel_vecmat,
867
+ devices=jax_compatible_cuda_devices,
868
+ use_ffi=use_ffi,
869
+ )
870
+ add_function_test(
871
+ TestJax,
872
+ f"test_jax_kernel_multiarg_{suffix}",
873
+ test_jax_kernel_multiarg,
874
+ devices=jax_compatible_cuda_devices,
875
+ use_ffi=use_ffi,
876
+ )
877
+ add_function_test(
878
+ TestJax,
879
+ f"test_jax_kernel_launch_dims_{suffix}",
880
+ test_jax_kernel_launch_dims,
881
+ devices=jax_compatible_cuda_devices,
882
+ use_ffi=use_ffi,
883
+ )
884
+
885
+ # ffi.jax_kernel() tests
886
+ add_function_test(
887
+ TestJax, "test_ffi_jax_kernel_add", test_ffi_jax_kernel_add, devices=jax_compatible_cuda_devices
888
+ )
350
889
  add_function_test(
351
- TestJax, "test_jax_kernel_scalar", test_jax_kernel_scalar, devices=jax_compatible_cuda_devices
890
+ TestJax, "test_ffi_jax_kernel_sincos", test_ffi_jax_kernel_sincos, devices=jax_compatible_cuda_devices
352
891
  )
353
892
  add_function_test(
354
- TestJax, "test_jax_kernel_vecmat", test_jax_kernel_vecmat, devices=jax_compatible_cuda_devices
893
+ TestJax, "test_ffi_jax_kernel_diagonal", test_ffi_jax_kernel_diagonal, devices=jax_compatible_cuda_devices
355
894
  )
356
895
  add_function_test(
357
- TestJax, "test_jax_kernel_multiarg", test_jax_kernel_multiarg, devices=jax_compatible_cuda_devices
896
+ TestJax, "test_ffi_jax_kernel_in_out", test_ffi_jax_kernel_in_out, devices=jax_compatible_cuda_devices
897
+ )
898
+ add_function_test(
899
+ TestJax,
900
+ "test_ffi_jax_kernel_scale_vec_constant",
901
+ test_ffi_jax_kernel_scale_vec_constant,
902
+ devices=jax_compatible_cuda_devices,
903
+ )
904
+ add_function_test(
905
+ TestJax,
906
+ "test_ffi_jax_kernel_scale_vec_static",
907
+ test_ffi_jax_kernel_scale_vec_static,
908
+ devices=jax_compatible_cuda_devices,
909
+ )
910
+ add_function_test(
911
+ TestJax,
912
+ "test_ffi_jax_kernel_launch_dims_default",
913
+ test_ffi_jax_kernel_launch_dims_default,
914
+ devices=jax_compatible_cuda_devices,
915
+ )
916
+ add_function_test(
917
+ TestJax,
918
+ "test_ffi_jax_kernel_launch_dims_custom",
919
+ test_ffi_jax_kernel_launch_dims_custom,
920
+ devices=jax_compatible_cuda_devices,
358
921
  )
359
922
 
923
+ # ffi.jax_callable() tests
924
+ add_function_test(
925
+ TestJax,
926
+ "test_ffi_jax_callable_scale_constant",
927
+ test_ffi_jax_callable_scale_constant,
928
+ devices=jax_compatible_cuda_devices,
929
+ )
360
930
  add_function_test(
361
- TestJax, "test_jax_kernel_launch_dims", test_jax_kernel_launch_dims, devices=jax_compatible_cuda_devices
931
+ TestJax,
932
+ "test_ffi_jax_callable_scale_static",
933
+ test_ffi_jax_callable_scale_static,
934
+ devices=jax_compatible_cuda_devices,
362
935
  )
936
+ add_function_test(
937
+ TestJax, "test_ffi_jax_callable_in_out", test_ffi_jax_callable_in_out, devices=jax_compatible_cuda_devices
938
+ )
939
+
940
+ # ffi callback tests
941
+ add_function_test(TestJax, "test_ffi_callback", test_ffi_callback, devices=jax_compatible_cuda_devices)
942
+
363
943
 
364
944
  except Exception as e:
365
945
  print(f"Skipping Jax tests due to exception: {e}")