gstaichi 2.1.1rc3__cp310-cp310-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. gstaichi/CHANGELOG.md +4 -0
  2. gstaichi/__init__.py +40 -0
  3. gstaichi/_funcs.py +706 -0
  4. gstaichi/_kernels.py +420 -0
  5. gstaichi/_lib/__init__.py +3 -0
  6. gstaichi/_lib/core/__init__.py +0 -0
  7. gstaichi/_lib/core/gstaichi_python.cpython-310-darwin.so +0 -0
  8. gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
  9. gstaichi/_lib/core/py.typed +0 -0
  10. gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
  11. gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
  12. gstaichi/_lib/utils.py +243 -0
  13. gstaichi/_logging.py +131 -0
  14. gstaichi/_snode/__init__.py +5 -0
  15. gstaichi/_snode/fields_builder.py +187 -0
  16. gstaichi/_snode/snode_tree.py +34 -0
  17. gstaichi/_test_tools/__init__.py +18 -0
  18. gstaichi/_test_tools/dataclass_test_tools.py +36 -0
  19. gstaichi/_test_tools/load_kernel_string.py +30 -0
  20. gstaichi/_test_tools/textwrap2.py +6 -0
  21. gstaichi/_version.py +1 -0
  22. gstaichi/_version_check.py +100 -0
  23. gstaichi/ad/__init__.py +3 -0
  24. gstaichi/ad/_ad.py +530 -0
  25. gstaichi/algorithms/__init__.py +3 -0
  26. gstaichi/algorithms/_algorithms.py +117 -0
  27. gstaichi/assets/.git +1 -0
  28. gstaichi/assets/Go-Regular.ttf +0 -0
  29. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  30. gstaichi/examples/lcg_python.py +26 -0
  31. gstaichi/examples/lcg_taichi.py +34 -0
  32. gstaichi/examples/minimal.py +28 -0
  33. gstaichi/experimental.py +16 -0
  34. gstaichi/lang/__init__.py +50 -0
  35. gstaichi/lang/_dataclass_util.py +31 -0
  36. gstaichi/lang/_fast_caching/__init__.py +3 -0
  37. gstaichi/lang/_fast_caching/args_hasher.py +110 -0
  38. gstaichi/lang/_fast_caching/config_hasher.py +30 -0
  39. gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
  40. gstaichi/lang/_fast_caching/function_hasher.py +57 -0
  41. gstaichi/lang/_fast_caching/hash_utils.py +11 -0
  42. gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
  43. gstaichi/lang/_fast_caching/src_hasher.py +75 -0
  44. gstaichi/lang/_kernel_impl_dataclass.py +212 -0
  45. gstaichi/lang/_ndarray.py +352 -0
  46. gstaichi/lang/_ndrange.py +152 -0
  47. gstaichi/lang/_template_mapper.py +195 -0
  48. gstaichi/lang/_texture.py +172 -0
  49. gstaichi/lang/_wrap_inspect.py +215 -0
  50. gstaichi/lang/any_array.py +99 -0
  51. gstaichi/lang/ast/__init__.py +5 -0
  52. gstaichi/lang/ast/ast_transformer.py +1323 -0
  53. gstaichi/lang/ast/ast_transformer_utils.py +346 -0
  54. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  55. gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
  56. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
  57. gstaichi/lang/ast/checkers.py +106 -0
  58. gstaichi/lang/ast/symbol_resolver.py +57 -0
  59. gstaichi/lang/ast/transform.py +9 -0
  60. gstaichi/lang/common_ops.py +310 -0
  61. gstaichi/lang/exception.py +80 -0
  62. gstaichi/lang/expr.py +180 -0
  63. gstaichi/lang/field.py +428 -0
  64. gstaichi/lang/impl.py +1243 -0
  65. gstaichi/lang/kernel_arguments.py +155 -0
  66. gstaichi/lang/kernel_impl.py +1341 -0
  67. gstaichi/lang/matrix.py +1835 -0
  68. gstaichi/lang/matrix_ops.py +341 -0
  69. gstaichi/lang/matrix_ops_utils.py +190 -0
  70. gstaichi/lang/mesh.py +687 -0
  71. gstaichi/lang/misc.py +782 -0
  72. gstaichi/lang/ops.py +1494 -0
  73. gstaichi/lang/runtime_ops.py +13 -0
  74. gstaichi/lang/shell.py +35 -0
  75. gstaichi/lang/simt/__init__.py +5 -0
  76. gstaichi/lang/simt/block.py +94 -0
  77. gstaichi/lang/simt/grid.py +7 -0
  78. gstaichi/lang/simt/subgroup.py +191 -0
  79. gstaichi/lang/simt/warp.py +96 -0
  80. gstaichi/lang/snode.py +489 -0
  81. gstaichi/lang/source_builder.py +150 -0
  82. gstaichi/lang/struct.py +810 -0
  83. gstaichi/lang/util.py +312 -0
  84. gstaichi/linalg/__init__.py +8 -0
  85. gstaichi/linalg/matrixfree_cg.py +310 -0
  86. gstaichi/linalg/sparse_cg.py +59 -0
  87. gstaichi/linalg/sparse_matrix.py +303 -0
  88. gstaichi/linalg/sparse_solver.py +123 -0
  89. gstaichi/math/__init__.py +11 -0
  90. gstaichi/math/_complex.py +205 -0
  91. gstaichi/math/mathimpl.py +886 -0
  92. gstaichi/profiler/__init__.py +6 -0
  93. gstaichi/profiler/kernel_metrics.py +260 -0
  94. gstaichi/profiler/kernel_profiler.py +586 -0
  95. gstaichi/profiler/memory_profiler.py +15 -0
  96. gstaichi/profiler/scoped_profiler.py +36 -0
  97. gstaichi/sparse/__init__.py +3 -0
  98. gstaichi/sparse/_sparse_grid.py +77 -0
  99. gstaichi/tools/__init__.py +12 -0
  100. gstaichi/tools/diagnose.py +117 -0
  101. gstaichi/tools/np2ply.py +364 -0
  102. gstaichi/tools/vtk.py +38 -0
  103. gstaichi/types/__init__.py +19 -0
  104. gstaichi/types/annotations.py +52 -0
  105. gstaichi/types/compound_types.py +71 -0
  106. gstaichi/types/enums.py +49 -0
  107. gstaichi/types/ndarray_type.py +169 -0
  108. gstaichi/types/primitive_types.py +206 -0
  109. gstaichi/types/quant.py +88 -0
  110. gstaichi/types/texture_type.py +85 -0
  111. gstaichi/types/utils.py +11 -0
  112. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3.h +6389 -0
  113. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3native.h +594 -0
  114. gstaichi-2.1.1rc3.data/data/include/spirv-tools/instrument.hpp +268 -0
  115. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.h +907 -0
  116. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.hpp +375 -0
  117. gstaichi-2.1.1rc3.data/data/include/spirv-tools/linker.hpp +97 -0
  118. gstaichi-2.1.1rc3.data/data/include/spirv-tools/optimizer.hpp +970 -0
  119. gstaichi-2.1.1rc3.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
  120. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.h +2568 -0
  121. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.hpp +2579 -0
  122. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
  123. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
  124. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
  125. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
  126. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
  127. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
  128. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
  129. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
  130. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
  131. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
  132. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
  133. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
  134. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
  135. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
  136. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
  137. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
  138. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
  139. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
  140. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
  141. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
  142. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
  143. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
  144. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
  145. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
  146. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
  147. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
  148. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
  149. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
  150. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
  151. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
  152. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  153. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
  154. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  155. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  156. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  157. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  158. gstaichi-2.1.1rc3.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  159. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
  160. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
  161. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
  162. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
  163. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
  164. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
  165. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
  166. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
  167. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
  168. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
  169. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
  170. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
  171. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
  172. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
  173. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
  174. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
  175. gstaichi-2.1.1rc3.dist-info/METADATA +106 -0
  176. gstaichi-2.1.1rc3.dist-info/RECORD +179 -0
  177. gstaichi-2.1.1rc3.dist-info/WHEEL +5 -0
  178. gstaichi-2.1.1rc3.dist-info/licenses/LICENSE +201 -0
  179. gstaichi-2.1.1rc3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,13 @@
1
+ # type: ignore
2
+
3
+ from gstaichi.lang import impl
4
+
5
+
6
+ def sync():
7
+ """Blocks the calling thread until all the previously
8
+ launched GsTaichi kernels have completed.
9
+ """
10
+ impl.get_runtime().sync()
11
+
12
+
13
+ __all__ = ["sync"]
gstaichi/lang/shell.py ADDED
@@ -0,0 +1,35 @@
1
+ # type: ignore
2
+
3
+ import functools
4
+ import os
5
+ import sys
6
+
7
+ from gstaichi._lib import core as _ti_core
8
+ from gstaichi._logging import info
9
+
10
+ pybuf_enabled = False
11
+ _env_enable_pybuf = os.environ.get("TI_ENABLE_PYBUF", "1")
12
+ if not _env_enable_pybuf or int(_env_enable_pybuf):
13
+ # When using in Jupyter / IDLE, the sys.stdout will be their wrapped ones.
14
+ # While sys.__stdout__ should always be the raw console stdout.
15
+ pybuf_enabled = sys.stdout is not sys.__stdout__
16
+
17
+ _ti_core.toggle_python_print_buffer(pybuf_enabled)
18
+
19
+
20
+ def _shell_pop_print(old_call):
21
+ if not pybuf_enabled:
22
+ # zero-overhead!
23
+ return old_call
24
+
25
+ info("Graphical python shell detected, using wrapped sys.stdout")
26
+
27
+ @functools.wraps(old_call)
28
+ def new_call(*args, **kwargs):
29
+ ret = old_call(*args, **kwargs)
30
+ # print's in kernel won't take effect until ti.sync(), discussion:
31
+ # https://github.com/taichi-dev/gstaichi/pull/1303#discussion_r444897102
32
+ print(_ti_core.pop_python_print_buffer(), end="")
33
+ return ret
34
+
35
+ return new_call
@@ -0,0 +1,5 @@
1
+ # type: ignore
2
+
3
+ from gstaichi.lang.simt import block, grid, subgroup, warp
4
+
5
+ __all__ = ["warp", "subgroup", "block", "grid"]
@@ -0,0 +1,94 @@
1
+ # type: ignore
2
+
3
+ from gstaichi._lib import core as _ti_core
4
+ from gstaichi.lang import impl
5
+ from gstaichi.lang.expr import make_expr_group
6
+ from gstaichi.lang.util import gstaichi_scope
7
+
8
+
9
+ def arch_uses_spv(arch):
10
+ return arch == _ti_core.vulkan or arch == _ti_core.metal
11
+
12
+
13
+ def sync():
14
+ arch = impl.get_runtime().prog.config().arch
15
+ if arch == _ti_core.cuda or arch == _ti_core.amdgpu:
16
+ return impl.call_internal("block_barrier", with_runtime_context=False)
17
+ if arch_uses_spv(arch):
18
+ return impl.call_internal("workgroupBarrier", with_runtime_context=False)
19
+ raise ValueError(f"ti.block.shared_array is not supported for arch {arch}")
20
+
21
+
22
+ def sync_all_nonzero(predicate):
23
+ arch = impl.get_runtime().prog.config().arch
24
+ if arch == _ti_core.cuda:
25
+ return impl.call_internal("block_barrier_and_i32", predicate, with_runtime_context=False)
26
+ raise ValueError(f"ti.block.sync_all_nonzero is not supported for arch {arch}")
27
+
28
+
29
+ def sync_any_nonzero(predicate):
30
+ arch = impl.get_runtime().prog.config().arch
31
+ if arch == _ti_core.cuda:
32
+ return impl.call_internal("block_barrier_or_i32", predicate, with_runtime_context=False)
33
+ raise ValueError(f"ti.block.sync_any_nonzero is not supported for arch {arch}")
34
+
35
+
36
+ def sync_count_nonzero(predicate):
37
+ arch = impl.get_runtime().prog.config().arch
38
+ if arch == _ti_core.cuda:
39
+ return impl.call_internal("block_barrier_count_i32", predicate, with_runtime_context=False)
40
+ raise ValueError(f"ti.block.sync_count_nonzero is not supported for arch {arch}")
41
+
42
+
43
+ def mem_sync():
44
+ arch = impl.get_runtime().prog.config().arch
45
+ if arch == _ti_core.cuda:
46
+ return impl.call_internal("block_barrier", with_runtime_context=False)
47
+ if arch_uses_spv(arch):
48
+ return impl.call_internal("workgroupMemoryBarrier", with_runtime_context=False)
49
+ raise ValueError(f"ti.block.mem_sync is not supported for arch {arch}")
50
+
51
+
52
+ def thread_idx():
53
+ arch = impl.get_runtime().prog.config().arch
54
+ if arch_uses_spv(arch):
55
+ return impl.call_internal("localInvocationId", with_runtime_context=False)
56
+ raise ValueError(f"ti.block.thread_idx is not supported for arch {arch}")
57
+
58
+
59
+ def global_thread_idx():
60
+ arch = impl.get_runtime().prog.config().arch
61
+ if arch == _ti_core.cuda or _ti_core.amdgpu:
62
+ return impl.get_runtime().compiling_callable.ast_builder().insert_thread_idx_expr()
63
+ if arch_uses_spv(arch):
64
+ return impl.call_internal("globalInvocationId", with_runtime_context=False)
65
+ raise ValueError(f"ti.block.global_thread_idx is not supported for arch {arch}")
66
+
67
+
68
+ class SharedArray:
69
+ _is_gstaichi_class = True
70
+
71
+ def __init__(self, shape, dtype):
72
+ if isinstance(shape, int):
73
+ self.shape = (shape,)
74
+ elif (isinstance(shape, tuple) or isinstance(shape, list)) and all(isinstance(s, int) for s in shape):
75
+ self.shape = shape
76
+ else:
77
+ raise ValueError(
78
+ f"ti.simt.block.shared_array shape must be an integer or a tuple of integers, but got {shape}"
79
+ )
80
+ if isinstance(dtype, impl.MatrixType):
81
+ dtype = dtype.tensor_type
82
+ self.dtype = dtype
83
+ self.shared_array_proxy = impl.expr_init_shared_array(self.shape, dtype)
84
+
85
+ @gstaichi_scope
86
+ def subscript(self, *indices):
87
+ ast_builder = impl.get_runtime().compiling_callable.ast_builder()
88
+ return impl.Expr(
89
+ ast_builder.expr_subscript(
90
+ self.shared_array_proxy,
91
+ make_expr_group(*indices),
92
+ _ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
93
+ )
94
+ )
@@ -0,0 +1,7 @@
1
+ # type: ignore
2
+
3
+ from gstaichi.lang import impl
4
+
5
+
6
+ def memfence():
7
+ return impl.call_internal("grid_memfence", with_runtime_context=False)
@@ -0,0 +1,191 @@
1
+ # type: ignore
2
+
3
+ from gstaichi.lang import impl
4
+
5
+
6
+ def barrier():
7
+ return impl.call_internal("subgroupBarrier", with_runtime_context=False)
8
+
9
+
10
+ def memory_barrier():
11
+ return impl.call_internal("subgroupMemoryBarrier", with_runtime_context=False)
12
+
13
+
14
+ def elect():
15
+ return impl.call_internal("subgroupElect", with_runtime_context=False)
16
+
17
+
18
+ def all_true(cond):
19
+ # TODO
20
+ pass
21
+
22
+
23
+ def any_true(cond):
24
+ # TODO
25
+ pass
26
+
27
+
28
+ def all_equal(value):
29
+ # TODO
30
+ pass
31
+
32
+
33
+ def broadcast_first(value):
34
+ # TODO
35
+ pass
36
+
37
+
38
+ def broadcast(value, index):
39
+ return impl.call_internal("subgroupBroadcast", value, index, with_runtime_context=False)
40
+
41
+
42
+ def group_size():
43
+ return impl.call_internal("subgroupSize", with_runtime_context=False)
44
+
45
+
46
+ def invocation_id():
47
+ return impl.call_internal("subgroupInvocationId", with_runtime_context=False)
48
+
49
+
50
+ def reduce_add(value):
51
+ return impl.call_internal("subgroupAdd", value, with_runtime_context=False)
52
+
53
+
54
+ def reduce_mul(value):
55
+ return impl.call_internal("subgroupMul", value, with_runtime_context=False)
56
+
57
+
58
+ def reduce_min(value):
59
+ return impl.call_internal("subgroupMin", value, with_runtime_context=False)
60
+
61
+
62
+ def reduce_max(value):
63
+ return impl.call_internal("subgroupMax", value, with_runtime_context=False)
64
+
65
+
66
+ def reduce_and(value):
67
+ return impl.call_internal("subgroupAnd", value, with_runtime_context=False)
68
+
69
+
70
+ def reduce_or(value):
71
+ return impl.call_internal("subgroupOr", value, with_runtime_context=False)
72
+
73
+
74
+ def reduce_xor(value):
75
+ return impl.call_internal("subgroupXor", value, with_runtime_context=False)
76
+
77
+
78
+ def inclusive_add(value):
79
+ return impl.call_internal("subgroupInclusiveAdd", value, with_runtime_context=False)
80
+
81
+
82
+ def inclusive_mul(value):
83
+ return impl.call_internal("subgroupInclusiveMul", value, with_runtime_context=False)
84
+
85
+
86
+ def inclusive_min(value):
87
+ return impl.call_internal("subgroupInclusiveMin", value, with_runtime_context=False)
88
+
89
+
90
+ def inclusive_max(value):
91
+ return impl.call_internal("subgroupInclusiveMax", value, with_runtime_context=False)
92
+
93
+
94
+ def inclusive_and(value):
95
+ return impl.call_internal("subgroupInclusiveAnd", value, with_runtime_context=False)
96
+
97
+
98
+ def inclusive_or(value):
99
+ return impl.call_internal("subgroupInclusiveOr", value, with_runtime_context=False)
100
+
101
+
102
+ def inclusive_xor(value):
103
+ return impl.call_internal("subgroupInclusiveXor", value, with_runtime_context=False)
104
+
105
+
106
+ def exclusive_add(value):
107
+ # TODO
108
+ pass
109
+
110
+
111
+ def exclusive_mul(value):
112
+ # TODO
113
+ pass
114
+
115
+
116
+ def exclusive_min(value):
117
+ # TODO
118
+ pass
119
+
120
+
121
+ def exclusive_max(value):
122
+ # TODO
123
+ pass
124
+
125
+
126
+ def exclusive_and(value):
127
+ # TODO
128
+ pass
129
+
130
+
131
+ def exclusive_or(value):
132
+ # TODO
133
+ pass
134
+
135
+
136
+ def exclusive_xor(value):
137
+ # TODO
138
+ pass
139
+
140
+
141
+ def shuffle(value, index):
142
+ return impl.call_internal("subgroupShuffle", value, index, with_runtime_context=False)
143
+
144
+
145
+ def shuffle_xor(value, mask):
146
+ # TODO
147
+ pass
148
+
149
+
150
+ def shuffle_up(value, offset):
151
+ return impl.call_internal("subgroupShuffleUp", value, offset, with_runtime_context=False)
152
+
153
+
154
+ def shuffle_down(value, offset):
155
+ return impl.call_internal("subgroupShuffleDown", value, offset, with_runtime_context=False)
156
+
157
+
158
+ __all__ = [
159
+ "barrier",
160
+ "memory_barrier",
161
+ "elect",
162
+ "all_true",
163
+ "any_true",
164
+ "all_equal",
165
+ "broadcast_first",
166
+ "reduce_add",
167
+ "reduce_mul",
168
+ "reduce_min",
169
+ "reduce_max",
170
+ "reduce_and",
171
+ "reduce_or",
172
+ "reduce_xor",
173
+ "inclusive_add",
174
+ "inclusive_mul",
175
+ "inclusive_min",
176
+ "inclusive_max",
177
+ "inclusive_and",
178
+ "inclusive_or",
179
+ "inclusive_xor",
180
+ "exclusive_add",
181
+ "exclusive_mul",
182
+ "exclusive_min",
183
+ "exclusive_max",
184
+ "exclusive_and",
185
+ "exclusive_or",
186
+ "exclusive_xor",
187
+ "shuffle",
188
+ "shuffle_xor",
189
+ "shuffle_up",
190
+ "shuffle_down",
191
+ ]
@@ -0,0 +1,96 @@
1
+ # type: ignore
2
+
3
+ from gstaichi.lang import impl
4
+
5
+
6
+ def all_nonzero(mask, predicate):
7
+ return impl.call_internal("cuda_all_sync_i32", mask, predicate, with_runtime_context=False)
8
+
9
+
10
+ def any_nonzero(mask, predicate):
11
+ return impl.call_internal("cuda_any_sync_i32", mask, predicate, with_runtime_context=False)
12
+
13
+
14
+ def unique(mask, predicate):
15
+ return impl.call_internal("cuda_uni_sync_i32", mask, predicate, with_runtime_context=False)
16
+
17
+
18
+ def ballot(predicate):
19
+ return impl.call_internal("cuda_ballot_i32", predicate, with_runtime_context=False)
20
+
21
+
22
+ def shfl_sync_i32(mask, val, offset):
23
+ # lane offset is 31 for warp size 32
24
+ return impl.call_internal("cuda_shfl_sync_i32", mask, val, offset, 31, with_runtime_context=False)
25
+
26
+
27
+ def shfl_sync_f32(mask, val, offset):
28
+ # lane offset is 31 for warp size 32
29
+ return impl.call_internal("cuda_shfl_sync_f32", mask, val, offset, 31, with_runtime_context=False)
30
+
31
+
32
+ def shfl_up_i32(mask, val, offset):
33
+ # lane offset is 0 for warp size 32
34
+ return impl.call_internal("cuda_shfl_up_sync_i32", mask, val, offset, 0, with_runtime_context=False)
35
+
36
+
37
+ def shfl_up_f32(mask, val, offset):
38
+ # lane offset is 0 for warp size 32
39
+ return impl.call_internal("cuda_shfl_up_sync_f32", mask, val, offset, 0, with_runtime_context=False)
40
+
41
+
42
+ def shfl_down_i32(mask, val, offset):
43
+ # lane offset is 31 for warp size 32
44
+ return impl.call_internal("cuda_shfl_down_sync_i32", mask, val, offset, 31, with_runtime_context=False)
45
+
46
+
47
+ def shfl_down_f32(mask, val, offset):
48
+ # lane offset is 31 for warp size 32
49
+ return impl.call_internal("cuda_shfl_down_sync_f32", mask, val, offset, 31, with_runtime_context=False)
50
+
51
+
52
+ def shfl_xor_i32(mask, val, offset):
53
+ return impl.call_internal("cuda_shfl_xor_sync_i32", mask, val, offset, 31, with_runtime_context=False)
54
+
55
+
56
+ def match_any(mask, value):
57
+ # These intrinsics are only available on compute_70 or higher
58
+ # https://docs.nvidia.com/cuda/pdf/NVVM_IR_Specification.pdf
59
+ if impl.get_cuda_compute_capability() < 70:
60
+ raise AssertionError("match_any intrinsic only available on compute_70 or higher")
61
+ return impl.call_internal("cuda_match_any_sync_i32", mask, value, with_runtime_context=False)
62
+
63
+
64
+ def match_all(mask, val):
65
+ # These intrinsics are only available on compute_70 or higher
66
+ # https://docs.nvidia.com/cuda/pdf/NVVM_IR_Specification.pdf
67
+ if impl.get_cuda_compute_capability() < 70:
68
+ raise AssertionError("match_all intrinsic only available on compute_70 or higher")
69
+ return impl.call_internal("cuda_match_all_sync_i32", mask, val, with_runtime_context=False)
70
+
71
+
72
+ def active_mask():
73
+ return impl.call_internal("cuda_active_mask", with_runtime_context=False)
74
+
75
+
76
+ def sync(mask):
77
+ return impl.call_internal("warp_barrier", mask, with_runtime_context=False)
78
+
79
+
80
+ __all__ = [
81
+ "all_nonzero",
82
+ "any_nonzero",
83
+ "unique",
84
+ "ballot",
85
+ "shfl_sync_i32",
86
+ "shfl_sync_f32",
87
+ "shfl_up_i32",
88
+ "shfl_up_f32",
89
+ "shfl_down_i32",
90
+ "shfl_down_f32",
91
+ "shfl_xor_i32",
92
+ "match_any",
93
+ "match_all",
94
+ "active_mask",
95
+ "sync",
96
+ ]