gstaichi 2.1.1rc3__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/CHANGELOG.md +4 -0
- gstaichi/__init__.py +40 -0
- gstaichi/_funcs.py +706 -0
- gstaichi/_kernels.py +420 -0
- gstaichi/_lib/__init__.py +3 -0
- gstaichi/_lib/core/__init__.py +0 -0
- gstaichi/_lib/core/gstaichi_python.cpython-310-darwin.so +0 -0
- gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
- gstaichi/_lib/core/py.typed +0 -0
- gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
- gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
- gstaichi/_lib/utils.py +243 -0
- gstaichi/_logging.py +131 -0
- gstaichi/_snode/__init__.py +5 -0
- gstaichi/_snode/fields_builder.py +187 -0
- gstaichi/_snode/snode_tree.py +34 -0
- gstaichi/_test_tools/__init__.py +18 -0
- gstaichi/_test_tools/dataclass_test_tools.py +36 -0
- gstaichi/_test_tools/load_kernel_string.py +30 -0
- gstaichi/_test_tools/textwrap2.py +6 -0
- gstaichi/_version.py +1 -0
- gstaichi/_version_check.py +100 -0
- gstaichi/ad/__init__.py +3 -0
- gstaichi/ad/_ad.py +530 -0
- gstaichi/algorithms/__init__.py +3 -0
- gstaichi/algorithms/_algorithms.py +117 -0
- gstaichi/assets/.git +1 -0
- gstaichi/assets/Go-Regular.ttf +0 -0
- gstaichi/assets/static/imgs/ti_gallery.png +0 -0
- gstaichi/examples/lcg_python.py +26 -0
- gstaichi/examples/lcg_taichi.py +34 -0
- gstaichi/examples/minimal.py +28 -0
- gstaichi/experimental.py +16 -0
- gstaichi/lang/__init__.py +50 -0
- gstaichi/lang/_dataclass_util.py +31 -0
- gstaichi/lang/_fast_caching/__init__.py +3 -0
- gstaichi/lang/_fast_caching/args_hasher.py +110 -0
- gstaichi/lang/_fast_caching/config_hasher.py +30 -0
- gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
- gstaichi/lang/_fast_caching/function_hasher.py +57 -0
- gstaichi/lang/_fast_caching/hash_utils.py +11 -0
- gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
- gstaichi/lang/_fast_caching/src_hasher.py +75 -0
- gstaichi/lang/_kernel_impl_dataclass.py +212 -0
- gstaichi/lang/_ndarray.py +352 -0
- gstaichi/lang/_ndrange.py +152 -0
- gstaichi/lang/_template_mapper.py +195 -0
- gstaichi/lang/_texture.py +172 -0
- gstaichi/lang/_wrap_inspect.py +215 -0
- gstaichi/lang/any_array.py +99 -0
- gstaichi/lang/ast/__init__.py +5 -0
- gstaichi/lang/ast/ast_transformer.py +1323 -0
- gstaichi/lang/ast/ast_transformer_utils.py +346 -0
- gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
- gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
- gstaichi/lang/ast/checkers.py +106 -0
- gstaichi/lang/ast/symbol_resolver.py +57 -0
- gstaichi/lang/ast/transform.py +9 -0
- gstaichi/lang/common_ops.py +310 -0
- gstaichi/lang/exception.py +80 -0
- gstaichi/lang/expr.py +180 -0
- gstaichi/lang/field.py +428 -0
- gstaichi/lang/impl.py +1243 -0
- gstaichi/lang/kernel_arguments.py +155 -0
- gstaichi/lang/kernel_impl.py +1341 -0
- gstaichi/lang/matrix.py +1835 -0
- gstaichi/lang/matrix_ops.py +341 -0
- gstaichi/lang/matrix_ops_utils.py +190 -0
- gstaichi/lang/mesh.py +687 -0
- gstaichi/lang/misc.py +782 -0
- gstaichi/lang/ops.py +1494 -0
- gstaichi/lang/runtime_ops.py +13 -0
- gstaichi/lang/shell.py +35 -0
- gstaichi/lang/simt/__init__.py +5 -0
- gstaichi/lang/simt/block.py +94 -0
- gstaichi/lang/simt/grid.py +7 -0
- gstaichi/lang/simt/subgroup.py +191 -0
- gstaichi/lang/simt/warp.py +96 -0
- gstaichi/lang/snode.py +489 -0
- gstaichi/lang/source_builder.py +150 -0
- gstaichi/lang/struct.py +810 -0
- gstaichi/lang/util.py +312 -0
- gstaichi/linalg/__init__.py +8 -0
- gstaichi/linalg/matrixfree_cg.py +310 -0
- gstaichi/linalg/sparse_cg.py +59 -0
- gstaichi/linalg/sparse_matrix.py +303 -0
- gstaichi/linalg/sparse_solver.py +123 -0
- gstaichi/math/__init__.py +11 -0
- gstaichi/math/_complex.py +205 -0
- gstaichi/math/mathimpl.py +886 -0
- gstaichi/profiler/__init__.py +6 -0
- gstaichi/profiler/kernel_metrics.py +260 -0
- gstaichi/profiler/kernel_profiler.py +586 -0
- gstaichi/profiler/memory_profiler.py +15 -0
- gstaichi/profiler/scoped_profiler.py +36 -0
- gstaichi/sparse/__init__.py +3 -0
- gstaichi/sparse/_sparse_grid.py +77 -0
- gstaichi/tools/__init__.py +12 -0
- gstaichi/tools/diagnose.py +117 -0
- gstaichi/tools/np2ply.py +364 -0
- gstaichi/tools/vtk.py +38 -0
- gstaichi/types/__init__.py +19 -0
- gstaichi/types/annotations.py +52 -0
- gstaichi/types/compound_types.py +71 -0
- gstaichi/types/enums.py +49 -0
- gstaichi/types/ndarray_type.py +169 -0
- gstaichi/types/primitive_types.py +206 -0
- gstaichi/types/quant.py +88 -0
- gstaichi/types/texture_type.py +85 -0
- gstaichi/types/utils.py +11 -0
- gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3.h +6389 -0
- gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3native.h +594 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/instrument.hpp +268 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.h +907 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.hpp +375 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/linker.hpp +97 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/optimizer.hpp +970 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.h +2568 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.hpp +2579 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
- gstaichi-2.1.1rc3.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
- gstaichi-2.1.1rc3.dist-info/METADATA +106 -0
- gstaichi-2.1.1rc3.dist-info/RECORD +179 -0
- gstaichi-2.1.1rc3.dist-info/WHEEL +5 -0
- gstaichi-2.1.1rc3.dist-info/licenses/LICENSE +201 -0
- gstaichi-2.1.1rc3.dist-info/top_level.txt +1 -0
gstaichi/lang/shell.py
ADDED
@@ -0,0 +1,35 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
import functools
|
4
|
+
import os
|
5
|
+
import sys
|
6
|
+
|
7
|
+
from gstaichi._lib import core as _ti_core
|
8
|
+
from gstaichi._logging import info
|
9
|
+
|
10
|
+
pybuf_enabled = False
|
11
|
+
_env_enable_pybuf = os.environ.get("TI_ENABLE_PYBUF", "1")
|
12
|
+
if not _env_enable_pybuf or int(_env_enable_pybuf):
|
13
|
+
# When using in Jupyter / IDLE, the sys.stdout will be their wrapped ones.
|
14
|
+
# While sys.__stdout__ should always be the raw console stdout.
|
15
|
+
pybuf_enabled = sys.stdout is not sys.__stdout__
|
16
|
+
|
17
|
+
_ti_core.toggle_python_print_buffer(pybuf_enabled)
|
18
|
+
|
19
|
+
|
20
|
+
def _shell_pop_print(old_call):
|
21
|
+
if not pybuf_enabled:
|
22
|
+
# zero-overhead!
|
23
|
+
return old_call
|
24
|
+
|
25
|
+
info("Graphical python shell detected, using wrapped sys.stdout")
|
26
|
+
|
27
|
+
@functools.wraps(old_call)
|
28
|
+
def new_call(*args, **kwargs):
|
29
|
+
ret = old_call(*args, **kwargs)
|
30
|
+
# print's in kernel won't take effect until ti.sync(), discussion:
|
31
|
+
# https://github.com/taichi-dev/gstaichi/pull/1303#discussion_r444897102
|
32
|
+
print(_ti_core.pop_python_print_buffer(), end="")
|
33
|
+
return ret
|
34
|
+
|
35
|
+
return new_call
|
@@ -0,0 +1,94 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
from gstaichi._lib import core as _ti_core
|
4
|
+
from gstaichi.lang import impl
|
5
|
+
from gstaichi.lang.expr import make_expr_group
|
6
|
+
from gstaichi.lang.util import gstaichi_scope
|
7
|
+
|
8
|
+
|
9
|
+
def arch_uses_spv(arch):
|
10
|
+
return arch == _ti_core.vulkan or arch == _ti_core.metal
|
11
|
+
|
12
|
+
|
13
|
+
def sync():
|
14
|
+
arch = impl.get_runtime().prog.config().arch
|
15
|
+
if arch == _ti_core.cuda or arch == _ti_core.amdgpu:
|
16
|
+
return impl.call_internal("block_barrier", with_runtime_context=False)
|
17
|
+
if arch_uses_spv(arch):
|
18
|
+
return impl.call_internal("workgroupBarrier", with_runtime_context=False)
|
19
|
+
raise ValueError(f"ti.block.shared_array is not supported for arch {arch}")
|
20
|
+
|
21
|
+
|
22
|
+
def sync_all_nonzero(predicate):
|
23
|
+
arch = impl.get_runtime().prog.config().arch
|
24
|
+
if arch == _ti_core.cuda:
|
25
|
+
return impl.call_internal("block_barrier_and_i32", predicate, with_runtime_context=False)
|
26
|
+
raise ValueError(f"ti.block.sync_all_nonzero is not supported for arch {arch}")
|
27
|
+
|
28
|
+
|
29
|
+
def sync_any_nonzero(predicate):
|
30
|
+
arch = impl.get_runtime().prog.config().arch
|
31
|
+
if arch == _ti_core.cuda:
|
32
|
+
return impl.call_internal("block_barrier_or_i32", predicate, with_runtime_context=False)
|
33
|
+
raise ValueError(f"ti.block.sync_any_nonzero is not supported for arch {arch}")
|
34
|
+
|
35
|
+
|
36
|
+
def sync_count_nonzero(predicate):
|
37
|
+
arch = impl.get_runtime().prog.config().arch
|
38
|
+
if arch == _ti_core.cuda:
|
39
|
+
return impl.call_internal("block_barrier_count_i32", predicate, with_runtime_context=False)
|
40
|
+
raise ValueError(f"ti.block.sync_count_nonzero is not supported for arch {arch}")
|
41
|
+
|
42
|
+
|
43
|
+
def mem_sync():
|
44
|
+
arch = impl.get_runtime().prog.config().arch
|
45
|
+
if arch == _ti_core.cuda:
|
46
|
+
return impl.call_internal("block_barrier", with_runtime_context=False)
|
47
|
+
if arch_uses_spv(arch):
|
48
|
+
return impl.call_internal("workgroupMemoryBarrier", with_runtime_context=False)
|
49
|
+
raise ValueError(f"ti.block.mem_sync is not supported for arch {arch}")
|
50
|
+
|
51
|
+
|
52
|
+
def thread_idx():
|
53
|
+
arch = impl.get_runtime().prog.config().arch
|
54
|
+
if arch_uses_spv(arch):
|
55
|
+
return impl.call_internal("localInvocationId", with_runtime_context=False)
|
56
|
+
raise ValueError(f"ti.block.thread_idx is not supported for arch {arch}")
|
57
|
+
|
58
|
+
|
59
|
+
def global_thread_idx():
|
60
|
+
arch = impl.get_runtime().prog.config().arch
|
61
|
+
if arch == _ti_core.cuda or _ti_core.amdgpu:
|
62
|
+
return impl.get_runtime().compiling_callable.ast_builder().insert_thread_idx_expr()
|
63
|
+
if arch_uses_spv(arch):
|
64
|
+
return impl.call_internal("globalInvocationId", with_runtime_context=False)
|
65
|
+
raise ValueError(f"ti.block.global_thread_idx is not supported for arch {arch}")
|
66
|
+
|
67
|
+
|
68
|
+
class SharedArray:
|
69
|
+
_is_gstaichi_class = True
|
70
|
+
|
71
|
+
def __init__(self, shape, dtype):
|
72
|
+
if isinstance(shape, int):
|
73
|
+
self.shape = (shape,)
|
74
|
+
elif (isinstance(shape, tuple) or isinstance(shape, list)) and all(isinstance(s, int) for s in shape):
|
75
|
+
self.shape = shape
|
76
|
+
else:
|
77
|
+
raise ValueError(
|
78
|
+
f"ti.simt.block.shared_array shape must be an integer or a tuple of integers, but got {shape}"
|
79
|
+
)
|
80
|
+
if isinstance(dtype, impl.MatrixType):
|
81
|
+
dtype = dtype.tensor_type
|
82
|
+
self.dtype = dtype
|
83
|
+
self.shared_array_proxy = impl.expr_init_shared_array(self.shape, dtype)
|
84
|
+
|
85
|
+
@gstaichi_scope
|
86
|
+
def subscript(self, *indices):
|
87
|
+
ast_builder = impl.get_runtime().compiling_callable.ast_builder()
|
88
|
+
return impl.Expr(
|
89
|
+
ast_builder.expr_subscript(
|
90
|
+
self.shared_array_proxy,
|
91
|
+
make_expr_group(*indices),
|
92
|
+
_ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
|
93
|
+
)
|
94
|
+
)
|
@@ -0,0 +1,191 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
from gstaichi.lang import impl
|
4
|
+
|
5
|
+
|
6
|
+
def barrier():
|
7
|
+
return impl.call_internal("subgroupBarrier", with_runtime_context=False)
|
8
|
+
|
9
|
+
|
10
|
+
def memory_barrier():
|
11
|
+
return impl.call_internal("subgroupMemoryBarrier", with_runtime_context=False)
|
12
|
+
|
13
|
+
|
14
|
+
def elect():
|
15
|
+
return impl.call_internal("subgroupElect", with_runtime_context=False)
|
16
|
+
|
17
|
+
|
18
|
+
def all_true(cond):
|
19
|
+
# TODO
|
20
|
+
pass
|
21
|
+
|
22
|
+
|
23
|
+
def any_true(cond):
|
24
|
+
# TODO
|
25
|
+
pass
|
26
|
+
|
27
|
+
|
28
|
+
def all_equal(value):
|
29
|
+
# TODO
|
30
|
+
pass
|
31
|
+
|
32
|
+
|
33
|
+
def broadcast_first(value):
|
34
|
+
# TODO
|
35
|
+
pass
|
36
|
+
|
37
|
+
|
38
|
+
def broadcast(value, index):
|
39
|
+
return impl.call_internal("subgroupBroadcast", value, index, with_runtime_context=False)
|
40
|
+
|
41
|
+
|
42
|
+
def group_size():
|
43
|
+
return impl.call_internal("subgroupSize", with_runtime_context=False)
|
44
|
+
|
45
|
+
|
46
|
+
def invocation_id():
|
47
|
+
return impl.call_internal("subgroupInvocationId", with_runtime_context=False)
|
48
|
+
|
49
|
+
|
50
|
+
def reduce_add(value):
|
51
|
+
return impl.call_internal("subgroupAdd", value, with_runtime_context=False)
|
52
|
+
|
53
|
+
|
54
|
+
def reduce_mul(value):
|
55
|
+
return impl.call_internal("subgroupMul", value, with_runtime_context=False)
|
56
|
+
|
57
|
+
|
58
|
+
def reduce_min(value):
|
59
|
+
return impl.call_internal("subgroupMin", value, with_runtime_context=False)
|
60
|
+
|
61
|
+
|
62
|
+
def reduce_max(value):
|
63
|
+
return impl.call_internal("subgroupMax", value, with_runtime_context=False)
|
64
|
+
|
65
|
+
|
66
|
+
def reduce_and(value):
|
67
|
+
return impl.call_internal("subgroupAnd", value, with_runtime_context=False)
|
68
|
+
|
69
|
+
|
70
|
+
def reduce_or(value):
|
71
|
+
return impl.call_internal("subgroupOr", value, with_runtime_context=False)
|
72
|
+
|
73
|
+
|
74
|
+
def reduce_xor(value):
|
75
|
+
return impl.call_internal("subgroupXor", value, with_runtime_context=False)
|
76
|
+
|
77
|
+
|
78
|
+
def inclusive_add(value):
|
79
|
+
return impl.call_internal("subgroupInclusiveAdd", value, with_runtime_context=False)
|
80
|
+
|
81
|
+
|
82
|
+
def inclusive_mul(value):
|
83
|
+
return impl.call_internal("subgroupInclusiveMul", value, with_runtime_context=False)
|
84
|
+
|
85
|
+
|
86
|
+
def inclusive_min(value):
|
87
|
+
return impl.call_internal("subgroupInclusiveMin", value, with_runtime_context=False)
|
88
|
+
|
89
|
+
|
90
|
+
def inclusive_max(value):
|
91
|
+
return impl.call_internal("subgroupInclusiveMax", value, with_runtime_context=False)
|
92
|
+
|
93
|
+
|
94
|
+
def inclusive_and(value):
|
95
|
+
return impl.call_internal("subgroupInclusiveAnd", value, with_runtime_context=False)
|
96
|
+
|
97
|
+
|
98
|
+
def inclusive_or(value):
|
99
|
+
return impl.call_internal("subgroupInclusiveOr", value, with_runtime_context=False)
|
100
|
+
|
101
|
+
|
102
|
+
def inclusive_xor(value):
|
103
|
+
return impl.call_internal("subgroupInclusiveXor", value, with_runtime_context=False)
|
104
|
+
|
105
|
+
|
106
|
+
def exclusive_add(value):
|
107
|
+
# TODO
|
108
|
+
pass
|
109
|
+
|
110
|
+
|
111
|
+
def exclusive_mul(value):
|
112
|
+
# TODO
|
113
|
+
pass
|
114
|
+
|
115
|
+
|
116
|
+
def exclusive_min(value):
|
117
|
+
# TODO
|
118
|
+
pass
|
119
|
+
|
120
|
+
|
121
|
+
def exclusive_max(value):
|
122
|
+
# TODO
|
123
|
+
pass
|
124
|
+
|
125
|
+
|
126
|
+
def exclusive_and(value):
|
127
|
+
# TODO
|
128
|
+
pass
|
129
|
+
|
130
|
+
|
131
|
+
def exclusive_or(value):
|
132
|
+
# TODO
|
133
|
+
pass
|
134
|
+
|
135
|
+
|
136
|
+
def exclusive_xor(value):
|
137
|
+
# TODO
|
138
|
+
pass
|
139
|
+
|
140
|
+
|
141
|
+
def shuffle(value, index):
|
142
|
+
return impl.call_internal("subgroupShuffle", value, index, with_runtime_context=False)
|
143
|
+
|
144
|
+
|
145
|
+
def shuffle_xor(value, mask):
|
146
|
+
# TODO
|
147
|
+
pass
|
148
|
+
|
149
|
+
|
150
|
+
def shuffle_up(value, offset):
|
151
|
+
return impl.call_internal("subgroupShuffleUp", value, offset, with_runtime_context=False)
|
152
|
+
|
153
|
+
|
154
|
+
def shuffle_down(value, offset):
|
155
|
+
return impl.call_internal("subgroupShuffleDown", value, offset, with_runtime_context=False)
|
156
|
+
|
157
|
+
|
158
|
+
__all__ = [
|
159
|
+
"barrier",
|
160
|
+
"memory_barrier",
|
161
|
+
"elect",
|
162
|
+
"all_true",
|
163
|
+
"any_true",
|
164
|
+
"all_equal",
|
165
|
+
"broadcast_first",
|
166
|
+
"reduce_add",
|
167
|
+
"reduce_mul",
|
168
|
+
"reduce_min",
|
169
|
+
"reduce_max",
|
170
|
+
"reduce_and",
|
171
|
+
"reduce_or",
|
172
|
+
"reduce_xor",
|
173
|
+
"inclusive_add",
|
174
|
+
"inclusive_mul",
|
175
|
+
"inclusive_min",
|
176
|
+
"inclusive_max",
|
177
|
+
"inclusive_and",
|
178
|
+
"inclusive_or",
|
179
|
+
"inclusive_xor",
|
180
|
+
"exclusive_add",
|
181
|
+
"exclusive_mul",
|
182
|
+
"exclusive_min",
|
183
|
+
"exclusive_max",
|
184
|
+
"exclusive_and",
|
185
|
+
"exclusive_or",
|
186
|
+
"exclusive_xor",
|
187
|
+
"shuffle",
|
188
|
+
"shuffle_xor",
|
189
|
+
"shuffle_up",
|
190
|
+
"shuffle_down",
|
191
|
+
]
|
@@ -0,0 +1,96 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
from gstaichi.lang import impl
|
4
|
+
|
5
|
+
|
6
|
+
def all_nonzero(mask, predicate):
|
7
|
+
return impl.call_internal("cuda_all_sync_i32", mask, predicate, with_runtime_context=False)
|
8
|
+
|
9
|
+
|
10
|
+
def any_nonzero(mask, predicate):
|
11
|
+
return impl.call_internal("cuda_any_sync_i32", mask, predicate, with_runtime_context=False)
|
12
|
+
|
13
|
+
|
14
|
+
def unique(mask, predicate):
|
15
|
+
return impl.call_internal("cuda_uni_sync_i32", mask, predicate, with_runtime_context=False)
|
16
|
+
|
17
|
+
|
18
|
+
def ballot(predicate):
|
19
|
+
return impl.call_internal("cuda_ballot_i32", predicate, with_runtime_context=False)
|
20
|
+
|
21
|
+
|
22
|
+
def shfl_sync_i32(mask, val, offset):
|
23
|
+
# lane offset is 31 for warp size 32
|
24
|
+
return impl.call_internal("cuda_shfl_sync_i32", mask, val, offset, 31, with_runtime_context=False)
|
25
|
+
|
26
|
+
|
27
|
+
def shfl_sync_f32(mask, val, offset):
|
28
|
+
# lane offset is 31 for warp size 32
|
29
|
+
return impl.call_internal("cuda_shfl_sync_f32", mask, val, offset, 31, with_runtime_context=False)
|
30
|
+
|
31
|
+
|
32
|
+
def shfl_up_i32(mask, val, offset):
|
33
|
+
# lane offset is 0 for warp size 32
|
34
|
+
return impl.call_internal("cuda_shfl_up_sync_i32", mask, val, offset, 0, with_runtime_context=False)
|
35
|
+
|
36
|
+
|
37
|
+
def shfl_up_f32(mask, val, offset):
|
38
|
+
# lane offset is 0 for warp size 32
|
39
|
+
return impl.call_internal("cuda_shfl_up_sync_f32", mask, val, offset, 0, with_runtime_context=False)
|
40
|
+
|
41
|
+
|
42
|
+
def shfl_down_i32(mask, val, offset):
|
43
|
+
# lane offset is 31 for warp size 32
|
44
|
+
return impl.call_internal("cuda_shfl_down_sync_i32", mask, val, offset, 31, with_runtime_context=False)
|
45
|
+
|
46
|
+
|
47
|
+
def shfl_down_f32(mask, val, offset):
|
48
|
+
# lane offset is 31 for warp size 32
|
49
|
+
return impl.call_internal("cuda_shfl_down_sync_f32", mask, val, offset, 31, with_runtime_context=False)
|
50
|
+
|
51
|
+
|
52
|
+
def shfl_xor_i32(mask, val, offset):
|
53
|
+
return impl.call_internal("cuda_shfl_xor_sync_i32", mask, val, offset, 31, with_runtime_context=False)
|
54
|
+
|
55
|
+
|
56
|
+
def match_any(mask, value):
|
57
|
+
# These intrinsics are only available on compute_70 or higher
|
58
|
+
# https://docs.nvidia.com/cuda/pdf/NVVM_IR_Specification.pdf
|
59
|
+
if impl.get_cuda_compute_capability() < 70:
|
60
|
+
raise AssertionError("match_any intrinsic only available on compute_70 or higher")
|
61
|
+
return impl.call_internal("cuda_match_any_sync_i32", mask, value, with_runtime_context=False)
|
62
|
+
|
63
|
+
|
64
|
+
def match_all(mask, val):
|
65
|
+
# These intrinsics are only available on compute_70 or higher
|
66
|
+
# https://docs.nvidia.com/cuda/pdf/NVVM_IR_Specification.pdf
|
67
|
+
if impl.get_cuda_compute_capability() < 70:
|
68
|
+
raise AssertionError("match_all intrinsic only available on compute_70 or higher")
|
69
|
+
return impl.call_internal("cuda_match_all_sync_i32", mask, val, with_runtime_context=False)
|
70
|
+
|
71
|
+
|
72
|
+
def active_mask():
|
73
|
+
return impl.call_internal("cuda_active_mask", with_runtime_context=False)
|
74
|
+
|
75
|
+
|
76
|
+
def sync(mask):
|
77
|
+
return impl.call_internal("warp_barrier", mask, with_runtime_context=False)
|
78
|
+
|
79
|
+
|
80
|
+
__all__ = [
|
81
|
+
"all_nonzero",
|
82
|
+
"any_nonzero",
|
83
|
+
"unique",
|
84
|
+
"ballot",
|
85
|
+
"shfl_sync_i32",
|
86
|
+
"shfl_sync_f32",
|
87
|
+
"shfl_up_i32",
|
88
|
+
"shfl_up_f32",
|
89
|
+
"shfl_down_i32",
|
90
|
+
"shfl_down_f32",
|
91
|
+
"shfl_xor_i32",
|
92
|
+
"match_any",
|
93
|
+
"match_all",
|
94
|
+
"active_mask",
|
95
|
+
"sync",
|
96
|
+
]
|