warp-lang 1.8.1__py3-none-win_amd64.whl → 1.9.1__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (141) hide show
  1. warp/__init__.py +282 -103
  2. warp/__init__.pyi +1904 -114
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +93 -30
  6. warp/build_dll.py +331 -101
  7. warp/builtins.py +1244 -160
  8. warp/codegen.py +317 -206
  9. warp/config.py +1 -1
  10. warp/context.py +1465 -789
  11. warp/examples/core/example_marching_cubes.py +1 -0
  12. warp/examples/core/example_render_opengl.py +100 -3
  13. warp/examples/fem/example_apic_fluid.py +98 -52
  14. warp/examples/fem/example_convection_diffusion_dg.py +25 -4
  15. warp/examples/fem/example_diffusion_mgpu.py +8 -3
  16. warp/examples/fem/utils.py +68 -22
  17. warp/examples/interop/example_jax_kernel.py +2 -1
  18. warp/fabric.py +1 -1
  19. warp/fem/cache.py +27 -19
  20. warp/fem/domain.py +2 -2
  21. warp/fem/field/nodal_field.py +2 -2
  22. warp/fem/field/virtual.py +264 -166
  23. warp/fem/geometry/geometry.py +5 -5
  24. warp/fem/integrate.py +129 -51
  25. warp/fem/space/restriction.py +4 -0
  26. warp/fem/space/shape/tet_shape_function.py +3 -10
  27. warp/jax_experimental/custom_call.py +25 -2
  28. warp/jax_experimental/ffi.py +22 -1
  29. warp/jax_experimental/xla_ffi.py +16 -7
  30. warp/marching_cubes.py +708 -0
  31. warp/native/array.h +99 -4
  32. warp/native/builtin.h +86 -9
  33. warp/native/bvh.cpp +64 -28
  34. warp/native/bvh.cu +58 -58
  35. warp/native/bvh.h +2 -2
  36. warp/native/clang/clang.cpp +7 -7
  37. warp/native/coloring.cpp +8 -2
  38. warp/native/crt.cpp +2 -2
  39. warp/native/crt.h +3 -5
  40. warp/native/cuda_util.cpp +41 -10
  41. warp/native/cuda_util.h +10 -4
  42. warp/native/exports.h +1842 -1908
  43. warp/native/fabric.h +2 -1
  44. warp/native/hashgrid.cpp +37 -37
  45. warp/native/hashgrid.cu +2 -2
  46. warp/native/initializer_array.h +1 -1
  47. warp/native/intersect.h +2 -2
  48. warp/native/mat.h +1910 -116
  49. warp/native/mathdx.cpp +43 -43
  50. warp/native/mesh.cpp +24 -24
  51. warp/native/mesh.cu +26 -26
  52. warp/native/mesh.h +4 -2
  53. warp/native/nanovdb/GridHandle.h +179 -12
  54. warp/native/nanovdb/HostBuffer.h +8 -7
  55. warp/native/nanovdb/NanoVDB.h +517 -895
  56. warp/native/nanovdb/NodeManager.h +323 -0
  57. warp/native/nanovdb/PNanoVDB.h +2 -2
  58. warp/native/quat.h +331 -14
  59. warp/native/range.h +7 -1
  60. warp/native/reduce.cpp +10 -10
  61. warp/native/reduce.cu +13 -14
  62. warp/native/runlength_encode.cpp +2 -2
  63. warp/native/runlength_encode.cu +5 -5
  64. warp/native/scan.cpp +3 -3
  65. warp/native/scan.cu +4 -4
  66. warp/native/sort.cpp +10 -10
  67. warp/native/sort.cu +40 -31
  68. warp/native/sort.h +2 -0
  69. warp/native/sparse.cpp +8 -8
  70. warp/native/sparse.cu +13 -13
  71. warp/native/spatial.h +366 -17
  72. warp/native/temp_buffer.h +2 -2
  73. warp/native/tile.h +471 -82
  74. warp/native/vec.h +328 -14
  75. warp/native/volume.cpp +54 -54
  76. warp/native/volume.cu +1 -1
  77. warp/native/volume.h +2 -1
  78. warp/native/volume_builder.cu +30 -37
  79. warp/native/warp.cpp +150 -149
  80. warp/native/warp.cu +377 -216
  81. warp/native/warp.h +227 -226
  82. warp/optim/linear.py +736 -271
  83. warp/render/imgui_manager.py +289 -0
  84. warp/render/render_opengl.py +99 -18
  85. warp/render/render_usd.py +1 -0
  86. warp/sim/graph_coloring.py +2 -2
  87. warp/sparse.py +558 -175
  88. warp/tests/aux_test_module_aot.py +7 -0
  89. warp/tests/cuda/test_async.py +3 -3
  90. warp/tests/cuda/test_conditional_captures.py +101 -0
  91. warp/tests/geometry/test_hash_grid.py +38 -0
  92. warp/tests/geometry/test_marching_cubes.py +233 -12
  93. warp/tests/interop/test_jax.py +608 -28
  94. warp/tests/sim/test_coloring.py +6 -6
  95. warp/tests/test_array.py +58 -5
  96. warp/tests/test_codegen.py +4 -3
  97. warp/tests/test_context.py +8 -15
  98. warp/tests/test_enum.py +136 -0
  99. warp/tests/test_examples.py +2 -2
  100. warp/tests/test_fem.py +49 -6
  101. warp/tests/test_fixedarray.py +229 -0
  102. warp/tests/test_func.py +18 -15
  103. warp/tests/test_future_annotations.py +7 -5
  104. warp/tests/test_linear_solvers.py +30 -0
  105. warp/tests/test_map.py +15 -1
  106. warp/tests/test_mat.py +1518 -378
  107. warp/tests/test_mat_assign_copy.py +178 -0
  108. warp/tests/test_mat_constructors.py +574 -0
  109. warp/tests/test_module_aot.py +287 -0
  110. warp/tests/test_print.py +69 -0
  111. warp/tests/test_quat.py +140 -34
  112. warp/tests/test_quat_assign_copy.py +145 -0
  113. warp/tests/test_reload.py +2 -1
  114. warp/tests/test_sparse.py +71 -0
  115. warp/tests/test_spatial.py +140 -34
  116. warp/tests/test_spatial_assign_copy.py +160 -0
  117. warp/tests/test_struct.py +43 -3
  118. warp/tests/test_tuple.py +96 -0
  119. warp/tests/test_types.py +61 -20
  120. warp/tests/test_vec.py +179 -34
  121. warp/tests/test_vec_assign_copy.py +143 -0
  122. warp/tests/tile/test_tile.py +245 -18
  123. warp/tests/tile/test_tile_cholesky.py +605 -0
  124. warp/tests/tile/test_tile_load.py +169 -0
  125. warp/tests/tile/test_tile_mathdx.py +2 -558
  126. warp/tests/tile/test_tile_matmul.py +1 -1
  127. warp/tests/tile/test_tile_mlp.py +1 -1
  128. warp/tests/tile/test_tile_shared_memory.py +5 -5
  129. warp/tests/unittest_suites.py +6 -0
  130. warp/tests/walkthrough_debug.py +1 -1
  131. warp/thirdparty/unittest_parallel.py +108 -9
  132. warp/types.py +571 -267
  133. warp/utils.py +68 -86
  134. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/METADATA +29 -69
  135. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/RECORD +138 -128
  136. warp/native/marching.cpp +0 -19
  137. warp/native/marching.cu +0 -514
  138. warp/native/marching.h +0 -19
  139. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/WHEEL +0 -0
  140. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/licenses/LICENSE.md +0 -0
  141. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/top_level.txt +0 -0
warp/native/quat.h CHANGED
@@ -459,14 +459,19 @@ inline CUDA_CALLABLE quat_t<Type> quat_from_matrix(const mat_t<Rows,Cols,Type>&
459
459
  template<typename Type>
460
460
  inline CUDA_CALLABLE Type extract(const quat_t<Type>& a, int idx)
461
461
  {
462
- #if FP_CHECK
463
- if (idx < 0 || idx > 3)
462
+ #ifndef NDEBUG
463
+ if (idx < -4 || idx >= 4)
464
464
  {
465
465
  printf("quat_t index %d out of bounds at %s %d", idx, __FILE__, __LINE__);
466
466
  assert(0);
467
467
  }
468
468
  #endif
469
469
 
470
+ if (idx < 0)
471
+ {
472
+ idx += 4;
473
+ }
474
+
470
475
  /*
471
476
  * Because quat data is not stored in an array, we index the quaternion by checking all possible idx values.
472
477
  * (&a.x)[idx] would be the preferred access strategy, but this results in undefined behavior in the clang compiler
@@ -478,17 +483,48 @@ inline CUDA_CALLABLE Type extract(const quat_t<Type>& a, int idx)
478
483
  else {return a.w;}
479
484
  }
480
485
 
486
+ template<unsigned SliceLength, typename Type>
487
+ inline CUDA_CALLABLE vec_t<SliceLength, Type> extract(const quat_t<Type> & a, slice_t slice)
488
+ {
489
+ vec_t<SliceLength, Type> ret;
490
+
491
+ assert(slice.start >= 0 && slice.start <= 4);
492
+ assert(slice.stop >= -1 && slice.stop <= 4);
493
+ assert(slice.step != 0 && slice.step < 0 ? slice.start >= slice.stop : slice.start <= slice.stop);
494
+ assert(slice_get_length(slice) == SliceLength);
495
+
496
+ bool is_reversed = slice.step < 0;
497
+
498
+ int idx = 0;
499
+ for (
500
+ int i = slice.start;
501
+ is_reversed ? (i > slice.stop) : (i < slice.stop);
502
+ i += slice.step
503
+ )
504
+ {
505
+ ret[idx] = a[i];
506
+ ++idx;
507
+ }
508
+
509
+ return ret;
510
+ }
511
+
481
512
  template<typename Type>
482
513
  inline CUDA_CALLABLE Type* index(quat_t<Type>& q, int idx)
483
514
  {
484
515
  #ifndef NDEBUG
485
- if (idx < 0 || idx > 3)
516
+ if (idx < -4 || idx >= 4)
486
517
  {
487
518
  printf("quat index %d out of bounds at %s %d\n", idx, __FILE__, __LINE__);
488
519
  assert(0);
489
520
  }
490
521
  #endif
491
522
 
523
+ if (idx < 0)
524
+ {
525
+ idx += 4;
526
+ }
527
+
492
528
  return &q[idx];
493
529
  }
494
530
 
@@ -496,13 +532,18 @@ template<typename Type>
496
532
  inline CUDA_CALLABLE Type* indexref(quat_t<Type>* q, int idx)
497
533
  {
498
534
  #ifndef NDEBUG
499
- if (idx < 0 || idx > 3)
535
+ if (idx < -4 || idx >= 4)
500
536
  {
501
537
  printf("quat store %d out of bounds at %s %d\n", idx, __FILE__, __LINE__);
502
538
  assert(0);
503
539
  }
504
540
  #endif
505
541
 
542
+ if (idx < 0)
543
+ {
544
+ idx += 4;
545
+ }
546
+
506
547
  return &((*q)[idx]);
507
548
  }
508
549
 
@@ -526,120 +567,328 @@ template<typename Type>
526
567
  inline CUDA_CALLABLE void add_inplace(quat_t<Type>& q, int idx, Type value)
527
568
  {
528
569
  #ifndef NDEBUG
529
- if (idx < 0 || idx > 3)
570
+ if (idx < -4 || idx >= 4)
530
571
  {
531
572
  printf("quat index %d out of bounds at %s %d\n", idx, __FILE__, __LINE__);
532
573
  assert(0);
533
574
  }
534
575
  #endif
535
576
 
577
+ if (idx < 0)
578
+ {
579
+ idx += 4;
580
+ }
581
+
536
582
  q[idx] += value;
537
583
  }
538
584
 
539
585
 
586
+ template<unsigned SliceLength, typename Type>
587
+ inline CUDA_CALLABLE void add_inplace(quat_t<Type>& q, slice_t slice, const vec_t<SliceLength, Type> &a)
588
+ {
589
+ assert(slice.start >= 0 && slice.start <= 4);
590
+ assert(slice.stop >= -1 && slice.stop <= 4);
591
+ assert(slice.step != 0 && slice.step < 0 ? slice.start >= slice.stop : slice.start <= slice.stop);
592
+ assert(slice_get_length(slice) == SliceLength);
593
+
594
+ bool is_reversed = slice.step < 0;
595
+
596
+ int ii = 0;
597
+ for (
598
+ int i = slice.start;
599
+ is_reversed ? (i > slice.stop) : (i < slice.stop);
600
+ i += slice.step
601
+ )
602
+ {
603
+ q[i] += a[ii];
604
+ ++ii;
605
+ }
606
+
607
+ assert(ii == SliceLength);
608
+ }
609
+
610
+
540
611
  template<typename Type>
541
612
  inline CUDA_CALLABLE void adj_add_inplace(quat_t<Type>& q, int idx, Type value,
542
613
  quat_t<Type>& adj_q, int adj_idx, Type& adj_value)
543
614
  {
544
615
  #ifndef NDEBUG
545
- if (idx < 0 || idx > 3)
616
+ if (idx < -4 || idx >= 4)
546
617
  {
547
618
  printf("quat index %d out of bounds at %s %d\n", idx, __FILE__, __LINE__);
548
619
  assert(0);
549
620
  }
550
621
  #endif
551
622
 
623
+ if (idx < 0)
624
+ {
625
+ idx += 4;
626
+ }
627
+
552
628
  adj_value += adj_q[idx];
553
629
  }
554
630
 
555
631
 
632
+ template<unsigned SliceLength, typename Type>
633
+ inline CUDA_CALLABLE void adj_add_inplace(
634
+ const quat_t<Type>& q, slice_t slice, const vec_t<SliceLength, Type> &a,
635
+ quat_t<Type>& adj_q, slice_t& adj_slice, vec_t<SliceLength, Type>& adj_a
636
+ )
637
+ {
638
+ assert(slice.start >= 0 && slice.start <= 4);
639
+ assert(slice.stop >= -1 && slice.stop <= 4);
640
+ assert(slice.step != 0 && slice.step < 0 ? slice.start >= slice.stop : slice.start <= slice.stop);
641
+ assert(slice_get_length(slice) == SliceLength);
642
+
643
+ bool is_reversed = slice.step < 0;
644
+
645
+ int ii = 0;
646
+ for (
647
+ int i = slice.start;
648
+ is_reversed ? (i > slice.stop) : (i < slice.stop);
649
+ i += slice.step
650
+ )
651
+ {
652
+ adj_a[ii] += adj_q[i];
653
+ ++ii;
654
+ }
655
+
656
+ assert(ii == SliceLength);
657
+ }
658
+
659
+
556
660
  template<typename Type>
557
661
  inline CUDA_CALLABLE void sub_inplace(quat_t<Type>& q, int idx, Type value)
558
662
  {
559
663
  #ifndef NDEBUG
560
- if (idx < 0 || idx > 3)
664
+ if (idx < -4 || idx >= 4)
561
665
  {
562
666
  printf("quat index %d out of bounds at %s %d\n", idx, __FILE__, __LINE__);
563
667
  assert(0);
564
668
  }
565
669
  #endif
566
670
 
671
+ if (idx < 0)
672
+ {
673
+ idx += 4;
674
+ }
675
+
567
676
  q[idx] -= value;
568
677
  }
569
678
 
570
679
 
680
+ template<unsigned SliceLength, typename Type>
681
+ inline CUDA_CALLABLE void sub_inplace(quat_t<Type>& q, slice_t slice, const vec_t<SliceLength, Type> &a)
682
+ {
683
+ assert(slice.start >= 0 && slice.start <= 4);
684
+ assert(slice.stop >= -1 && slice.stop <= 4);
685
+ assert(slice.step != 0 && slice.step < 0 ? slice.start >= slice.stop : slice.start <= slice.stop);
686
+ assert(slice_get_length(slice) == SliceLength);
687
+
688
+ bool is_reversed = slice.step < 0;
689
+
690
+ int ii = 0;
691
+ for (
692
+ int i = slice.start;
693
+ is_reversed ? (i > slice.stop) : (i < slice.stop);
694
+ i += slice.step
695
+ )
696
+ {
697
+ q[i] -= a[ii];
698
+ ++ii;
699
+ }
700
+
701
+ assert(ii == SliceLength);
702
+ }
703
+
704
+
571
705
  template<typename Type>
572
706
  inline CUDA_CALLABLE void adj_sub_inplace(quat_t<Type>& q, int idx, Type value,
573
707
  quat_t<Type>& adj_q, int adj_idx, Type& adj_value)
574
708
  {
575
709
  #ifndef NDEBUG
576
- if (idx < 0 || idx > 3)
710
+ if (idx < -4 || idx >= 4)
577
711
  {
578
712
  printf("quat index %d out of bounds at %s %d\n", idx, __FILE__, __LINE__);
579
713
  assert(0);
580
714
  }
581
715
  #endif
582
716
 
717
+ if (idx < 0)
718
+ {
719
+ idx += 4;
720
+ }
721
+
583
722
  adj_value -= adj_q[idx];
584
723
  }
585
724
 
586
725
 
726
+ template<unsigned SliceLength, typename Type>
727
+ inline CUDA_CALLABLE void adj_sub_inplace(
728
+ const quat_t<Type>& q, slice_t slice, const vec_t<SliceLength, Type> &a,
729
+ quat_t<Type>& adj_q, slice_t& adj_slice, vec_t<SliceLength, Type>& adj_a
730
+ )
731
+ {
732
+ assert(slice.start >= 0 && slice.start <= 4);
733
+ assert(slice.stop >= -1 && slice.stop <= 4);
734
+ assert(slice.step != 0 && slice.step < 0 ? slice.start >= slice.stop : slice.start <= slice.stop);
735
+ assert(slice_get_length(slice) == SliceLength);
736
+
737
+ bool is_reversed = slice.step < 0;
738
+
739
+ int ii = 0;
740
+ for (
741
+ int i = slice.start;
742
+ is_reversed ? (i > slice.stop) : (i < slice.stop);
743
+ i += slice.step
744
+ )
745
+ {
746
+ adj_a[ii] -= adj_q[i];
747
+ ++ii;
748
+ }
749
+
750
+ assert(ii == SliceLength);
751
+ }
752
+
753
+
587
754
  template<typename Type>
588
755
  inline CUDA_CALLABLE void assign_inplace(quat_t<Type>& q, int idx, Type value)
589
756
  {
590
757
  #ifndef NDEBUG
591
- if (idx < 0 || idx > 3)
758
+ if (idx < -4 || idx >= 4)
592
759
  {
593
760
  printf("quat index %d out of bounds at %s %d\n", idx, __FILE__, __LINE__);
594
761
  assert(0);
595
762
  }
596
763
  #endif
597
764
 
765
+ if (idx < 0)
766
+ {
767
+ idx += 4;
768
+ }
769
+
598
770
  q[idx] = value;
599
771
  }
600
772
 
773
+
774
+ template<unsigned SliceLength, typename Type>
775
+ inline CUDA_CALLABLE void assign_inplace(quat_t<Type>& q, slice_t slice, const vec_t<SliceLength, Type> &a)
776
+ {
777
+ assert(slice.start >= 0 && slice.start <= 4);
778
+ assert(slice.stop >= -1 && slice.stop <= 4);
779
+ assert(slice.step != 0 && slice.step < 0 ? slice.start >= slice.stop : slice.start <= slice.stop);
780
+ assert(slice_get_length(slice) == SliceLength);
781
+
782
+ bool is_reversed = slice.step < 0;
783
+
784
+ int ii = 0;
785
+ for (
786
+ int i = slice.start;
787
+ is_reversed ? (i > slice.stop) : (i < slice.stop);
788
+ i += slice.step
789
+ )
790
+ {
791
+ q[i] = a[ii];
792
+ ++ii;
793
+ }
794
+
795
+ assert(ii == SliceLength);
796
+ }
797
+
798
+
601
799
  template<typename Type>
602
800
  inline CUDA_CALLABLE void adj_assign_inplace(quat_t<Type>& q, int idx, Type value, quat_t<Type>& adj_q, int& adj_idx, Type& adj_value)
603
801
  {
604
802
  #ifndef NDEBUG
605
- if (idx < 0 || idx > 3)
803
+ if (idx < -4 || idx >= 4)
606
804
  {
607
805
  printf("quat index %d out of bounds at %s %d\n", idx, __FILE__, __LINE__);
608
806
  assert(0);
609
807
  }
610
808
  #endif
611
809
 
810
+ if (idx < 0)
811
+ {
812
+ idx += 4;
813
+ }
814
+
612
815
  adj_value += adj_q[idx];
613
816
  }
614
817
 
615
818
 
819
+ template<unsigned SliceLength, typename Type>
820
+ inline CUDA_CALLABLE void adj_assign_inplace(
821
+ const quat_t<Type>& q, slice_t slice, const vec_t<SliceLength, Type> &a,
822
+ quat_t<Type>& adj_q, slice_t& adj_slice, vec_t<SliceLength, Type>& adj_a
823
+ )
824
+ {
825
+ assert(slice.start >= 0 && slice.start <= 4);
826
+ assert(slice.stop >= -1 && slice.stop <= 4);
827
+ assert(slice.step != 0 && slice.step < 0 ? slice.start >= slice.stop : slice.start <= slice.stop);
828
+ assert(slice_get_length(slice) == SliceLength);
829
+
830
+ bool is_reversed = slice.step < 0;
831
+
832
+ int ii = 0;
833
+ for (
834
+ int i = slice.start;
835
+ is_reversed ? (i > slice.stop) : (i < slice.stop);
836
+ i += slice.step
837
+ )
838
+ {
839
+ adj_a[ii] += adj_q[i];
840
+ ++ii;
841
+ }
842
+
843
+ assert(ii == SliceLength);
844
+ }
845
+
846
+
616
847
  template<typename Type>
617
848
  inline CUDA_CALLABLE quat_t<Type> assign_copy(quat_t<Type>& q, int idx, Type value)
618
849
  {
619
850
  #ifndef NDEBUG
620
- if (idx < 0 || idx > 3)
851
+ if (idx < -4 || idx >= 4)
621
852
  {
622
853
  printf("quat index %d out of bounds at %s %d\n", idx, __FILE__, __LINE__);
623
854
  assert(0);
624
855
  }
625
856
  #endif
626
857
 
858
+ if (idx < 0)
859
+ {
860
+ idx += 4;
861
+ }
862
+
627
863
  quat_t<Type> ret(q);
628
864
  ret[idx] = value;
629
865
  return ret;
630
866
  }
631
867
 
868
+ template<unsigned SliceLength, typename Type>
869
+ inline CUDA_CALLABLE quat_t<Type> assign_copy(quat_t<Type>& q, slice_t slice, const vec_t<SliceLength, Type> &a)
870
+ {
871
+ quat_t<Type> ret(q);
872
+ assign_inplace<SliceLength>(ret, slice, a);
873
+ return ret;
874
+ }
875
+
632
876
  template<typename Type>
633
877
  inline CUDA_CALLABLE void adj_assign_copy(quat_t<Type>& q, int idx, Type value, quat_t<Type>& adj_q, int& adj_idx, Type& adj_value, const quat_t<Type>& adj_ret)
634
878
  {
635
879
  #ifndef NDEBUG
636
- if (idx < 0 || idx > 3)
880
+ if (idx < -4 || idx >= 4)
637
881
  {
638
882
  printf("quat index %d out of bounds at %s %d\n", idx, __FILE__, __LINE__);
639
883
  assert(0);
640
884
  }
641
885
  #endif
642
886
 
887
+ if (idx < 0)
888
+ {
889
+ idx += 4;
890
+ }
891
+
643
892
  adj_value += adj_ret[idx];
644
893
  for(unsigned i=0; i < 4; ++i)
645
894
  {
@@ -648,6 +897,41 @@ inline CUDA_CALLABLE void adj_assign_copy(quat_t<Type>& q, int idx, Type value,
648
897
  }
649
898
  }
650
899
 
900
+ template<unsigned SliceLength, typename Type>
901
+ inline CUDA_CALLABLE void adj_assign_copy(
902
+ quat_t<Type>& q, slice_t slice, const vec_t<SliceLength, Type> &a,
903
+ quat_t<Type>& adj_q, slice_t& adj_slice, vec_t<SliceLength, Type>& adj_a,
904
+ const quat_t<Type>& adj_ret
905
+ )
906
+ {
907
+ assert(slice.start >= 0 && slice.start <= 4);
908
+ assert(slice.stop >= -1 && slice.stop <= 4);
909
+ assert(slice.step != 0 && slice.step < 0 ? slice.start >= slice.stop : slice.start <= slice.stop);
910
+ assert(slice_get_length(slice) == SliceLength);
911
+
912
+ bool is_reversed = slice.step < 0;
913
+
914
+ int ii = 0;
915
+ for (int i = 0; i < 4; ++i)
916
+ {
917
+ bool in_slice = is_reversed
918
+ ? (i <= slice.start && i > slice.stop && (slice.start - i) % (-slice.step) == 0)
919
+ : (i >= slice.start && i < slice.stop && (i - slice.start) % slice.step == 0);
920
+
921
+ if (!in_slice)
922
+ {
923
+ adj_q[i] += adj_ret[i];
924
+ }
925
+ else
926
+ {
927
+ adj_a[ii] += adj_ret[i];
928
+ ++ii;
929
+ }
930
+ }
931
+
932
+ assert(ii == SliceLength);
933
+ }
934
+
651
935
 
652
936
  template<typename Type>
653
937
  CUDA_CALLABLE inline quat_t<Type> lerp(const quat_t<Type>& a, const quat_t<Type>& b, Type t)
@@ -666,14 +950,19 @@ CUDA_CALLABLE inline void adj_lerp(const quat_t<Type>& a, const quat_t<Type>& b,
666
950
  template<typename Type>
667
951
  inline CUDA_CALLABLE void adj_extract(const quat_t<Type>& a, int idx, quat_t<Type>& adj_a, int & adj_idx, Type & adj_ret)
668
952
  {
669
- #if FP_CHECK
670
- if (idx < 0 || idx > 3)
953
+ #ifndef NDEBUG
954
+ if (idx < -4 || idx >= 4)
671
955
  {
672
956
  printf("quat_t index %d out of bounds at %s %d", idx, __FILE__, __LINE__);
673
957
  assert(0);
674
958
  }
675
959
  #endif
676
960
 
961
+ if (idx < 0)
962
+ {
963
+ idx += 4;
964
+ }
965
+
677
966
  // See wp::extract(const quat_t<Type>& a, int idx) note
678
967
  if (idx == 0) {adj_a.x += adj_ret;}
679
968
  else if (idx == 1) {adj_a.y += adj_ret;}
@@ -681,6 +970,34 @@ inline CUDA_CALLABLE void adj_extract(const quat_t<Type>& a, int idx, quat_t<Typ
681
970
  else {adj_a.w += adj_ret;}
682
971
  }
683
972
 
973
+ template<unsigned SliceLength, typename Type>
974
+ inline CUDA_CALLABLE void adj_extract(
975
+ const quat_t<Type>& a, slice_t slice,
976
+ quat_t<Type>& adj_a, slice_t& adj_slice,
977
+ const vec_t<SliceLength, Type>& adj_ret
978
+ )
979
+ {
980
+ assert(slice.start >= 0 && slice.start <= 4);
981
+ assert(slice.stop >= -1 && slice.stop <= 4);
982
+ assert(slice.step != 0 && slice.step < 0 ? slice.start >= slice.stop : slice.start <= slice.stop);
983
+ assert(slice_get_length(slice) == SliceLength);
984
+
985
+ bool is_reversed = slice.step < 0;
986
+
987
+ int ii = 0;
988
+ for (
989
+ int i = slice.start;
990
+ is_reversed ? (i > slice.stop) : (i < slice.stop);
991
+ i += slice.step
992
+ )
993
+ {
994
+ adj_a[i] += adj_ret[ii];
995
+ ++ii;
996
+ }
997
+
998
+ assert(ii == SliceLength);
999
+ }
1000
+
684
1001
 
685
1002
  // backward methods
686
1003
  template<typename Type>
warp/native/range.h CHANGED
@@ -115,7 +115,13 @@ CUDA_CALLABLE inline range_t iter_reverse(const range_t& r)
115
115
  // generates a reverse range, equivalent to reversed(range())
116
116
  range_t rev;
117
117
 
118
- if (r.step > 0)
118
+ if (r.step == 0)
119
+ {
120
+ // degenerate case where step == 0, return empty range
121
+ rev.start = r.start;
122
+ rev.end = r.start;
123
+ }
124
+ else if (r.step > 0)
119
125
  {
120
126
  rev.start = r.start + int((r.end - r.start - 1) / r.step) * r.step;
121
127
  }
warp/native/reduce.cpp CHANGED
@@ -119,7 +119,7 @@ template <typename T> void array_sum_host(const T *ptr_a, T *ptr_out, int count,
119
119
  accumulate_func(ptr_a + i * stride, ptr_out, type_length);
120
120
  }
121
121
 
122
- void array_inner_float_host(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b,
122
+ void wp_array_inner_float_host(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b,
123
123
  int type_length)
124
124
  {
125
125
  const float *ptr_a = (const float *)(a);
@@ -129,7 +129,7 @@ void array_inner_float_host(uint64_t a, uint64_t b, uint64_t out, int count, int
129
129
  array_inner_host(ptr_a, ptr_b, ptr_out, count, byte_stride_a, byte_stride_b, type_length);
130
130
  }
131
131
 
132
- void array_inner_double_host(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b,
132
+ void wp_array_inner_double_host(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b,
133
133
  int type_length)
134
134
  {
135
135
  const double *ptr_a = (const double *)(a);
@@ -139,14 +139,14 @@ void array_inner_double_host(uint64_t a, uint64_t b, uint64_t out, int count, in
139
139
  array_inner_host(ptr_a, ptr_b, ptr_out, count, byte_stride_a, byte_stride_b, type_length);
140
140
  }
141
141
 
142
- void array_sum_float_host(uint64_t a, uint64_t out, int count, int byte_stride_a, int type_length)
142
+ void wp_array_sum_float_host(uint64_t a, uint64_t out, int count, int byte_stride_a, int type_length)
143
143
  {
144
144
  const float *ptr_a = (const float *)(a);
145
145
  float *ptr_out = (float *)(out);
146
146
  array_sum_host(ptr_a, ptr_out, count, byte_stride_a, type_length);
147
147
  }
148
148
 
149
- void array_sum_double_host(uint64_t a, uint64_t out, int count, int byte_stride_a, int type_length)
149
+ void wp_array_sum_double_host(uint64_t a, uint64_t out, int count, int byte_stride_a, int type_length)
150
150
  {
151
151
  const double *ptr_a = (const double *)(a);
152
152
  double *ptr_out = (double *)(out);
@@ -154,21 +154,21 @@ void array_sum_double_host(uint64_t a, uint64_t out, int count, int byte_stride_
154
154
  }
155
155
 
156
156
  #if !WP_ENABLE_CUDA
157
- void array_inner_float_device(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b,
158
- int type_length)
157
+ void wp_array_inner_float_device(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b,
158
+ int type_length)
159
159
  {
160
160
  }
161
161
 
162
- void array_inner_double_device(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b,
163
- int type_length)
162
+ void wp_array_inner_double_device(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b,
163
+ int type_length)
164
164
  {
165
165
  }
166
166
 
167
- void array_sum_float_device(uint64_t a, uint64_t out, int count, int byte_stride_a, int type_length)
167
+ void wp_array_sum_float_device(uint64_t a, uint64_t out, int count, int byte_stride_a, int type_length)
168
168
  {
169
169
  }
170
170
 
171
- void array_sum_double_device(uint64_t a, uint64_t out, int count, int byte_stride_a, int type_length)
171
+ void wp_array_sum_double_device(uint64_t a, uint64_t out, int count, int byte_stride_a, int type_length)
172
172
  {
173
173
  }
174
174
  #endif
warp/native/reduce.cu CHANGED
@@ -22,7 +22,6 @@
22
22
 
23
23
  #define THRUST_IGNORE_CUB_VERSION_CHECK
24
24
  #include <cub/device/device_reduce.cuh>
25
- #include <cub/iterator/counting_input_iterator.cuh>
26
25
 
27
26
  namespace
28
27
  {
@@ -119,14 +118,14 @@ template <typename T> void array_sum_device(const T *ptr_a, T *ptr_out, int coun
119
118
  assert((byte_stride % sizeof(T)) == 0);
120
119
  const int stride = byte_stride / sizeof(T);
121
120
 
122
- ContextGuard guard(cuda_context_get_current());
123
- cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
121
+ ContextGuard guard(wp_cuda_context_get_current());
122
+ cudaStream_t stream = static_cast<cudaStream_t>(wp_cuda_stream_get_current());
124
123
 
125
124
  cub_strided_iterator<const T> ptr_strided{ptr_a, stride};
126
125
 
127
126
  size_t buff_size = 0;
128
127
  check_cuda(cub::DeviceReduce::Sum(nullptr, buff_size, ptr_strided, ptr_out, count, stream));
129
- void* temp_buffer = alloc_device(WP_CURRENT_CONTEXT, buff_size);
128
+ void* temp_buffer = wp_alloc_device(WP_CURRENT_CONTEXT, buff_size);
130
129
 
131
130
  for (int k = 0; k < type_length; ++k)
132
131
  {
@@ -134,7 +133,7 @@ template <typename T> void array_sum_device(const T *ptr_a, T *ptr_out, int coun
134
133
  check_cuda(cub::DeviceReduce::Sum(temp_buffer, buff_size, ptr_strided, ptr_out + k, count, stream));
135
134
  }
136
135
 
137
- free_device(WP_CURRENT_CONTEXT, temp_buffer);
136
+ wp_free_device(WP_CURRENT_CONTEXT, temp_buffer);
138
137
  }
139
138
 
140
139
  template <typename T>
@@ -280,18 +279,18 @@ void array_inner_device(const ElemT *ptr_a, const ElemT *ptr_b, ScalarT *ptr_out
280
279
  const int stride_a = byte_stride_a / sizeof(ElemT);
281
280
  const int stride_b = byte_stride_b / sizeof(ElemT);
282
281
 
283
- ContextGuard guard(cuda_context_get_current());
284
- cudaStream_t stream = static_cast<cudaStream_t>(cuda_stream_get_current());
282
+ ContextGuard guard(wp_cuda_context_get_current());
283
+ cudaStream_t stream = static_cast<cudaStream_t>(wp_cuda_stream_get_current());
285
284
 
286
285
  cub_inner_product_iterator<ElemT, ScalarT> inner_iterator{ptr_a, ptr_b, stride_a, stride_b, type_length};
287
286
 
288
287
  size_t buff_size = 0;
289
288
  check_cuda(cub::DeviceReduce::Sum(nullptr, buff_size, inner_iterator, ptr_out, count, stream));
290
- void* temp_buffer = alloc_device(WP_CURRENT_CONTEXT, buff_size);
289
+ void* temp_buffer = wp_alloc_device(WP_CURRENT_CONTEXT, buff_size);
291
290
 
292
291
  check_cuda(cub::DeviceReduce::Sum(temp_buffer, buff_size, inner_iterator, ptr_out, count, stream));
293
292
 
294
- free_device(WP_CURRENT_CONTEXT, temp_buffer);
293
+ wp_free_device(WP_CURRENT_CONTEXT, temp_buffer);
295
294
  }
296
295
 
297
296
  template <typename T>
@@ -327,10 +326,10 @@ void array_inner_device_dispatch(const T *ptr_a, const T *ptr_b, T *ptr_out, int
327
326
 
328
327
  } // anonymous namespace
329
328
 
330
- void array_inner_float_device(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b,
329
+ void wp_array_inner_float_device(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b,
331
330
  int type_len)
332
331
  {
333
- void *context = cuda_context_get_current();
332
+ void *context = wp_cuda_context_get_current();
334
333
 
335
334
  const float *ptr_a = (const float *)(a);
336
335
  const float *ptr_b = (const float *)(b);
@@ -339,7 +338,7 @@ void array_inner_float_device(uint64_t a, uint64_t b, uint64_t out, int count, i
339
338
  array_inner_device_dispatch(ptr_a, ptr_b, ptr_out, count, byte_stride_a, byte_stride_b, type_len);
340
339
  }
341
340
 
342
- void array_inner_double_device(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b,
341
+ void wp_array_inner_double_device(uint64_t a, uint64_t b, uint64_t out, int count, int byte_stride_a, int byte_stride_b,
343
342
  int type_len)
344
343
  {
345
344
  const double *ptr_a = (const double *)(a);
@@ -349,14 +348,14 @@ void array_inner_double_device(uint64_t a, uint64_t b, uint64_t out, int count,
349
348
  array_inner_device_dispatch(ptr_a, ptr_b, ptr_out, count, byte_stride_a, byte_stride_b, type_len);
350
349
  }
351
350
 
352
- void array_sum_float_device(uint64_t a, uint64_t out, int count, int byte_stride, int type_length)
351
+ void wp_array_sum_float_device(uint64_t a, uint64_t out, int count, int byte_stride, int type_length)
353
352
  {
354
353
  const float *ptr_a = (const float *)(a);
355
354
  float *ptr_out = (float *)(out);
356
355
  array_sum_device_dispatch(ptr_a, ptr_out, count, byte_stride, type_length);
357
356
  }
358
357
 
359
- void array_sum_double_device(uint64_t a, uint64_t out, int count, int byte_stride, int type_length)
358
+ void wp_array_sum_double_device(uint64_t a, uint64_t out, int count, int byte_stride, int type_length)
360
359
  {
361
360
  const double *ptr_a = (const double *)(a);
362
361
  double *ptr_out = (double *)(out);