gstaichi 2.1.1rc3__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/CHANGELOG.md +4 -0
- gstaichi/__init__.py +40 -0
- gstaichi/_funcs.py +706 -0
- gstaichi/_kernels.py +420 -0
- gstaichi/_lib/__init__.py +3 -0
- gstaichi/_lib/core/__init__.py +0 -0
- gstaichi/_lib/core/gstaichi_python.cpython-310-darwin.so +0 -0
- gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
- gstaichi/_lib/core/py.typed +0 -0
- gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
- gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
- gstaichi/_lib/utils.py +243 -0
- gstaichi/_logging.py +131 -0
- gstaichi/_snode/__init__.py +5 -0
- gstaichi/_snode/fields_builder.py +187 -0
- gstaichi/_snode/snode_tree.py +34 -0
- gstaichi/_test_tools/__init__.py +18 -0
- gstaichi/_test_tools/dataclass_test_tools.py +36 -0
- gstaichi/_test_tools/load_kernel_string.py +30 -0
- gstaichi/_test_tools/textwrap2.py +6 -0
- gstaichi/_version.py +1 -0
- gstaichi/_version_check.py +100 -0
- gstaichi/ad/__init__.py +3 -0
- gstaichi/ad/_ad.py +530 -0
- gstaichi/algorithms/__init__.py +3 -0
- gstaichi/algorithms/_algorithms.py +117 -0
- gstaichi/assets/.git +1 -0
- gstaichi/assets/Go-Regular.ttf +0 -0
- gstaichi/assets/static/imgs/ti_gallery.png +0 -0
- gstaichi/examples/lcg_python.py +26 -0
- gstaichi/examples/lcg_taichi.py +34 -0
- gstaichi/examples/minimal.py +28 -0
- gstaichi/experimental.py +16 -0
- gstaichi/lang/__init__.py +50 -0
- gstaichi/lang/_dataclass_util.py +31 -0
- gstaichi/lang/_fast_caching/__init__.py +3 -0
- gstaichi/lang/_fast_caching/args_hasher.py +110 -0
- gstaichi/lang/_fast_caching/config_hasher.py +30 -0
- gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
- gstaichi/lang/_fast_caching/function_hasher.py +57 -0
- gstaichi/lang/_fast_caching/hash_utils.py +11 -0
- gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
- gstaichi/lang/_fast_caching/src_hasher.py +75 -0
- gstaichi/lang/_kernel_impl_dataclass.py +212 -0
- gstaichi/lang/_ndarray.py +352 -0
- gstaichi/lang/_ndrange.py +152 -0
- gstaichi/lang/_template_mapper.py +195 -0
- gstaichi/lang/_texture.py +172 -0
- gstaichi/lang/_wrap_inspect.py +215 -0
- gstaichi/lang/any_array.py +99 -0
- gstaichi/lang/ast/__init__.py +5 -0
- gstaichi/lang/ast/ast_transformer.py +1323 -0
- gstaichi/lang/ast/ast_transformer_utils.py +346 -0
- gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
- gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
- gstaichi/lang/ast/checkers.py +106 -0
- gstaichi/lang/ast/symbol_resolver.py +57 -0
- gstaichi/lang/ast/transform.py +9 -0
- gstaichi/lang/common_ops.py +310 -0
- gstaichi/lang/exception.py +80 -0
- gstaichi/lang/expr.py +180 -0
- gstaichi/lang/field.py +428 -0
- gstaichi/lang/impl.py +1243 -0
- gstaichi/lang/kernel_arguments.py +155 -0
- gstaichi/lang/kernel_impl.py +1341 -0
- gstaichi/lang/matrix.py +1835 -0
- gstaichi/lang/matrix_ops.py +341 -0
- gstaichi/lang/matrix_ops_utils.py +190 -0
- gstaichi/lang/mesh.py +687 -0
- gstaichi/lang/misc.py +782 -0
- gstaichi/lang/ops.py +1494 -0
- gstaichi/lang/runtime_ops.py +13 -0
- gstaichi/lang/shell.py +35 -0
- gstaichi/lang/simt/__init__.py +5 -0
- gstaichi/lang/simt/block.py +94 -0
- gstaichi/lang/simt/grid.py +7 -0
- gstaichi/lang/simt/subgroup.py +191 -0
- gstaichi/lang/simt/warp.py +96 -0
- gstaichi/lang/snode.py +489 -0
- gstaichi/lang/source_builder.py +150 -0
- gstaichi/lang/struct.py +810 -0
- gstaichi/lang/util.py +312 -0
- gstaichi/linalg/__init__.py +8 -0
- gstaichi/linalg/matrixfree_cg.py +310 -0
- gstaichi/linalg/sparse_cg.py +59 -0
- gstaichi/linalg/sparse_matrix.py +303 -0
- gstaichi/linalg/sparse_solver.py +123 -0
- gstaichi/math/__init__.py +11 -0
- gstaichi/math/_complex.py +205 -0
- gstaichi/math/mathimpl.py +886 -0
- gstaichi/profiler/__init__.py +6 -0
- gstaichi/profiler/kernel_metrics.py +260 -0
- gstaichi/profiler/kernel_profiler.py +586 -0
- gstaichi/profiler/memory_profiler.py +15 -0
- gstaichi/profiler/scoped_profiler.py +36 -0
- gstaichi/sparse/__init__.py +3 -0
- gstaichi/sparse/_sparse_grid.py +77 -0
- gstaichi/tools/__init__.py +12 -0
- gstaichi/tools/diagnose.py +117 -0
- gstaichi/tools/np2ply.py +364 -0
- gstaichi/tools/vtk.py +38 -0
- gstaichi/types/__init__.py +19 -0
- gstaichi/types/annotations.py +52 -0
- gstaichi/types/compound_types.py +71 -0
- gstaichi/types/enums.py +49 -0
- gstaichi/types/ndarray_type.py +169 -0
- gstaichi/types/primitive_types.py +206 -0
- gstaichi/types/quant.py +88 -0
- gstaichi/types/texture_type.py +85 -0
- gstaichi/types/utils.py +11 -0
- gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3.h +6389 -0
- gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3native.h +594 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/instrument.hpp +268 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.h +907 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.hpp +375 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/linker.hpp +97 -0
- gstaichi-2.1.1rc3.data/data/include/spirv-tools/optimizer.hpp +970 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.h +2568 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.hpp +2579 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
- gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
- gstaichi-2.1.1rc3.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
- gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
- gstaichi-2.1.1rc3.dist-info/METADATA +106 -0
- gstaichi-2.1.1rc3.dist-info/RECORD +179 -0
- gstaichi-2.1.1rc3.dist-info/WHEEL +5 -0
- gstaichi-2.1.1rc3.dist-info/licenses/LICENSE +201 -0
- gstaichi-2.1.1rc3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,71 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
import gstaichi
|
4
|
+
from gstaichi._lib.utils import ti_python_core as _ti_python_core
|
5
|
+
|
6
|
+
_type_factory = _ti_python_core.get_type_factory_instance()
|
7
|
+
|
8
|
+
|
9
|
+
class CompoundType:
|
10
|
+
def from_kernel_struct_ret(self, launch_ctx, index: tuple):
|
11
|
+
raise NotImplementedError()
|
12
|
+
|
13
|
+
|
14
|
+
# TODO: maybe move MatrixType, StructType here to avoid the circular import?
|
15
|
+
def matrix(n=None, m=None, dtype=None):
|
16
|
+
"""Creates a matrix type with given shape and data type.
|
17
|
+
|
18
|
+
Args:
|
19
|
+
n (int): number of rows of the matrix.
|
20
|
+
m (int): number of columns of the matrix.
|
21
|
+
dtype (:mod:`~gstaichi.types.primitive_types`): matrix data type.
|
22
|
+
|
23
|
+
Returns:
|
24
|
+
A matrix type.
|
25
|
+
|
26
|
+
Example::
|
27
|
+
|
28
|
+
>>> mat2x2 = ti.types.matrix(2, 2, ti.f32) # 2x2 matrix type
|
29
|
+
>>> M = mat2x2([[1., 2.], [3., 4.]]) # an instance of this type
|
30
|
+
"""
|
31
|
+
return gstaichi.lang.matrix.MatrixType(n, m, 2, dtype)
|
32
|
+
|
33
|
+
|
34
|
+
def vector(n=None, dtype=None):
|
35
|
+
"""Creates a vector type with given shape and data type.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
n (int): dimension of the vector.
|
39
|
+
dtype (:mod:`~gstaichi.types.primitive_types`): vector data type.
|
40
|
+
|
41
|
+
Returns:
|
42
|
+
A vector type.
|
43
|
+
|
44
|
+
Example::
|
45
|
+
|
46
|
+
>>> vec3 = ti.types.vector(3, ti.f32) # 3d vector type
|
47
|
+
>>> v = vec3([1., 2., 3.]) # an instance of this type
|
48
|
+
"""
|
49
|
+
return gstaichi.lang.matrix.VectorType(n, dtype)
|
50
|
+
|
51
|
+
|
52
|
+
def struct(**kwargs):
|
53
|
+
"""Creates a struct type with given members.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
kwargs (dict): a dictionary contains the names and types of the
|
57
|
+
struct members.
|
58
|
+
|
59
|
+
Returns:
|
60
|
+
A struct type.
|
61
|
+
|
62
|
+
Example::
|
63
|
+
|
64
|
+
>>> vec3 = ti.types.vector(3, ti.f32)
|
65
|
+
>>> sphere = ti.types.struct(center=vec3, radius=float)
|
66
|
+
>>> s = sphere(center=vec3([0., 0., 0.]), radius=1.0)
|
67
|
+
"""
|
68
|
+
return gstaichi.lang.struct.StructType(**kwargs)
|
69
|
+
|
70
|
+
|
71
|
+
__all__ = ["matrix", "vector", "struct"]
|
gstaichi/types/enums.py
ADDED
@@ -0,0 +1,49 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
from gstaichi._lib import core as _ti_core
|
4
|
+
|
5
|
+
Layout = _ti_core.Layout
|
6
|
+
AutodiffMode = _ti_core.AutodiffMode
|
7
|
+
SNodeGradType = _ti_core.SNodeGradType
|
8
|
+
Format = _ti_core.Format
|
9
|
+
BoundaryMode = _ti_core.BoundaryMode
|
10
|
+
|
11
|
+
|
12
|
+
def to_boundary_enum(boundary):
|
13
|
+
if boundary == "clamp":
|
14
|
+
return BoundaryMode.CLAMP
|
15
|
+
if boundary == "unsafe":
|
16
|
+
return BoundaryMode.UNSAFE
|
17
|
+
raise ValueError(f"Invalid boundary argument: {boundary}")
|
18
|
+
|
19
|
+
|
20
|
+
class DeviceCapability:
|
21
|
+
spirv_version_1_3 = "spirv_version=66304"
|
22
|
+
spirv_version_1_4 = "spirv_version=66560"
|
23
|
+
spirv_version_1_5 = "spirv_version=66816"
|
24
|
+
spirv_has_int8 = "spirv_has_int8"
|
25
|
+
spirv_has_int16 = "spirv_has_int16"
|
26
|
+
spirv_has_int64 = "spirv_has_int64"
|
27
|
+
spirv_has_float16 = "spirv_has_float16"
|
28
|
+
spirv_has_float64 = "spirv_has_float64"
|
29
|
+
spirv_has_atomic_int64 = "spirv_has_atomic_int64"
|
30
|
+
spirv_has_atomic_float16 = "spirv_has_atomic_float16"
|
31
|
+
spirv_has_atomic_float16_add = "spirv_has_atomic_float16_add"
|
32
|
+
spirv_has_atomic_float16_minmax = "spirv_has_atomic_float16_minmax"
|
33
|
+
spirv_has_atomic_float = "spirv_has_atomic_float"
|
34
|
+
spirv_has_atomic_float_add = "spirv_has_atomic_float_add"
|
35
|
+
spirv_has_atomic_float_minmax = "spirv_has_atomic_float_minmax"
|
36
|
+
spirv_has_atomic_float64 = "spirv_has_atomic_float64"
|
37
|
+
spirv_has_atomic_float64_add = "spirv_has_atomic_float64_add"
|
38
|
+
spirv_has_atomic_float64_minmax = "spirv_has_atomic_float64_minmax"
|
39
|
+
spirv_has_variable_ptr = "spirv_has_variable_ptr"
|
40
|
+
spirv_has_physical_storage_buffer = "spirv_has_physical_storage_buffer"
|
41
|
+
spirv_has_subgroup_basic = "spirv_has_subgroup_basic"
|
42
|
+
spirv_has_subgroup_vote = "spirv_has_subgroup_vote"
|
43
|
+
spirv_has_subgroup_arithmetic = "spirv_has_subgroup_arithmetic"
|
44
|
+
spirv_has_subgroup_ballot = "spirv_has_subgroup_ballot"
|
45
|
+
spirv_has_non_semantic_info = "spirv_has_non_semantic_info"
|
46
|
+
spirv_has_no_integer_wrap_decoration = "spirv_has_no_integer_wrap_decoration"
|
47
|
+
|
48
|
+
|
49
|
+
__all__ = ["Layout", "AutodiffMode", "SNodeGradType", "Format", "DeviceCapability"]
|
@@ -0,0 +1,169 @@
|
|
1
|
+
from typing import Any
|
2
|
+
|
3
|
+
from gstaichi.types.compound_types import CompoundType, matrix, vector
|
4
|
+
from gstaichi.types.enums import Layout, to_boundary_enum
|
5
|
+
|
6
|
+
|
7
|
+
class NdarrayTypeMetadata:
|
8
|
+
def __init__(self, element_type, shape=None, needs_grad=False):
|
9
|
+
self.element_type = element_type
|
10
|
+
self.shape = shape
|
11
|
+
self.layout = Layout.AOS
|
12
|
+
self.needs_grad = needs_grad
|
13
|
+
|
14
|
+
|
15
|
+
# TODO(Haidong): This is a helper function that creates a MatrixType
|
16
|
+
# with respect to element_dim and element_shape.
|
17
|
+
# Remove this function when the two args are totally deprecated.
|
18
|
+
def _make_matrix_dtype_from_element_shape(element_dim, element_shape, primitive_dtype):
|
19
|
+
if isinstance(primitive_dtype, CompoundType):
|
20
|
+
raise TypeError(f'Cannot specifiy matrix dtype "{primitive_dtype}" and element shape or dim at the same time.')
|
21
|
+
|
22
|
+
# Scalars
|
23
|
+
if element_dim == 0 or (element_shape is not None and len(element_shape) == 0):
|
24
|
+
return primitive_dtype
|
25
|
+
|
26
|
+
# Cook element dim and shape into matrix type.
|
27
|
+
mat_dtype = None
|
28
|
+
if element_dim is not None:
|
29
|
+
# TODO: expand use case with arbitary tensor dims!
|
30
|
+
if element_dim < 0 or element_dim > 2:
|
31
|
+
raise ValueError("Only scalars, vectors, and matrices are allowed as elements of ti.types.ndarray()")
|
32
|
+
# Check dim consistency. The matrix dtype will be cooked later.
|
33
|
+
if element_shape is not None and len(element_shape) != element_dim:
|
34
|
+
raise ValueError(
|
35
|
+
f"Both element_shape and element_dim are specified, but shape doesn't match specified dim: "
|
36
|
+
f"{len(element_shape)}!={element_dim}"
|
37
|
+
)
|
38
|
+
mat_dtype = vector(None, primitive_dtype) if element_dim == 1 else matrix(None, None, primitive_dtype)
|
39
|
+
elif element_shape is not None:
|
40
|
+
if len(element_shape) > 2:
|
41
|
+
raise ValueError("Only scalars, vectors, and matrices are allowed as elements of ti.types.ndarray()")
|
42
|
+
mat_dtype = (
|
43
|
+
vector(element_shape[0], primitive_dtype)
|
44
|
+
if len(element_shape) == 1
|
45
|
+
else matrix(element_shape[0], element_shape[1], primitive_dtype)
|
46
|
+
)
|
47
|
+
return mat_dtype
|
48
|
+
|
49
|
+
|
50
|
+
class NdarrayType:
|
51
|
+
"""Type annotation for arbitrary arrays, including external arrays (numpy ndarrays and torch tensors) and GsTaichi ndarrays.
|
52
|
+
|
53
|
+
For external arrays, we treat it as a GsTaichi data container with Scalar, Vector or Matrix elements.
|
54
|
+
For GsTaichi vector/matrix ndarrays, we will automatically identify element dimension and their corresponding axis by the
|
55
|
+
dimension of datatype, say scalars, matrices or vectors.
|
56
|
+
For example, given type annotation `ti.types.ndarray(dtype=ti.math.vec3)`, a numpy array `np.zeros(10, 10, 3)` will be
|
57
|
+
recognized as a 10x10 matrix composed of vec3 elements.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
dtype (Union[PrimitiveType, VectorType, MatrixType, NoneType], optional): None if not speicified.
|
61
|
+
ndim (Union[Int, NoneType]): None if not specified, number of field dimensions. This argument is ignored for externa
|
62
|
+
arrays for now.
|
63
|
+
element_dim (Union[Int, NoneType], optional):
|
64
|
+
None if not specified (will be treated as 0 for external arrays),
|
65
|
+
0 if scalar elements,
|
66
|
+
1 if vector elements, and
|
67
|
+
2 if matrix elements.
|
68
|
+
element_shape (Union[Tuple[Int], NoneType]):
|
69
|
+
None if not specified, shapes of each element.
|
70
|
+
For example, element_shape must be 1d for vector and 2d tuple for matrix.
|
71
|
+
This argument is ignored for external arrays for now.
|
72
|
+
"""
|
73
|
+
|
74
|
+
def __init__(
|
75
|
+
self,
|
76
|
+
dtype=None,
|
77
|
+
ndim=None,
|
78
|
+
element_dim=None,
|
79
|
+
element_shape=None,
|
80
|
+
field_dim=None,
|
81
|
+
needs_grad=None,
|
82
|
+
boundary="unsafe",
|
83
|
+
):
|
84
|
+
if field_dim is not None:
|
85
|
+
raise ValueError("The field_dim argument for ndarray type is already deprecated. Please use ndim instead.")
|
86
|
+
if element_dim is not None or element_shape is not None:
|
87
|
+
self.dtype = _make_matrix_dtype_from_element_shape(element_dim, element_shape, dtype)
|
88
|
+
else:
|
89
|
+
self.dtype = dtype
|
90
|
+
|
91
|
+
self.ndim = ndim
|
92
|
+
self.layout = Layout.AOS
|
93
|
+
self.needs_grad = needs_grad
|
94
|
+
self.boundary = to_boundary_enum(boundary)
|
95
|
+
|
96
|
+
@classmethod
|
97
|
+
def __class_getitem__(cls, args, **kwargs):
|
98
|
+
return cls(*args, **kwargs)
|
99
|
+
|
100
|
+
def check_matched(self, ndarray_type: NdarrayTypeMetadata, arg_name: str):
|
101
|
+
# FIXME(Haidong) Cannot use Vector/MatrixType due to circular import
|
102
|
+
# Use the CompuoundType instead to determine the specific typs.
|
103
|
+
# TODO Replace CompoundType with MatrixType and VectorType
|
104
|
+
|
105
|
+
# Check dtype match
|
106
|
+
if isinstance(self.dtype, CompoundType):
|
107
|
+
if not self.dtype.check_matched(ndarray_type.element_type): # type: ignore
|
108
|
+
raise ValueError(
|
109
|
+
f"Invalid value for argument {arg_name} - required element type: {self.dtype.to_string()}, " # type: ignore
|
110
|
+
f"but {ndarray_type.element_type.to_string()} is provided"
|
111
|
+
)
|
112
|
+
else:
|
113
|
+
if self.dtype is not None:
|
114
|
+
# Check dtype match for scalar.
|
115
|
+
from gstaichi.lang import util # pylint: disable=C0415
|
116
|
+
|
117
|
+
if not util.cook_dtype(self.dtype) == ndarray_type.element_type:
|
118
|
+
raise TypeError(
|
119
|
+
f"Expect element type {self.dtype} for argument {arg_name}, but get {ndarray_type.element_type}"
|
120
|
+
)
|
121
|
+
|
122
|
+
# Check ndim match
|
123
|
+
if self.ndim is not None and ndarray_type.shape is not None and self.ndim != len(ndarray_type.shape):
|
124
|
+
raise ValueError(
|
125
|
+
f"Invalid value for argument {arg_name} - required ndim={self.ndim}, but {len(ndarray_type.shape)}d "
|
126
|
+
f"ndarray with shape {ndarray_type.shape} is provided"
|
127
|
+
)
|
128
|
+
|
129
|
+
# Check needs_grad
|
130
|
+
if self.needs_grad is not None and self.needs_grad > ndarray_type.needs_grad:
|
131
|
+
# It's okay to pass a needs_grad=True ndarray at runtime to a need_grad=False arg but not vice versa.
|
132
|
+
raise ValueError(
|
133
|
+
f"Invalid value for argument {arg_name} - required needs_grad={self.needs_grad}, but "
|
134
|
+
f"{ndarray_type.needs_grad} is provided"
|
135
|
+
)
|
136
|
+
|
137
|
+
def __repr__(self):
|
138
|
+
return f"NdarrayType(dtype={self.dtype}, ndim={self.ndim}, layout={self.layout}, needs_grad={self.needs_grad})"
|
139
|
+
|
140
|
+
def __str__(self):
|
141
|
+
return self.__repr__()
|
142
|
+
|
143
|
+
def __getitem__(self, i: Any) -> Any:
|
144
|
+
# needed for pyright
|
145
|
+
raise NotImplemented
|
146
|
+
|
147
|
+
def __setitem__(self, i: Any, v: Any) -> None:
|
148
|
+
# needed for pyright
|
149
|
+
raise NotImplemented
|
150
|
+
|
151
|
+
|
152
|
+
ndarray = NdarrayType
|
153
|
+
NDArray = NdarrayType
|
154
|
+
"""Alias for :class:`~gstaichi.types.ndarray_type.NdarrayType`.
|
155
|
+
|
156
|
+
Example::
|
157
|
+
|
158
|
+
>>> @ti.kernel
|
159
|
+
>>> def to_numpy(x: ti.types.ndarray(), y: ti.types.ndarray()):
|
160
|
+
>>> for i in range(n):
|
161
|
+
>>> x[i] = y[i]
|
162
|
+
>>>
|
163
|
+
>>> y = ti.ndarray(ti.f64, shape=n)
|
164
|
+
>>> ... # calculate y
|
165
|
+
>>> x = numpy.zeros(n)
|
166
|
+
>>> to_numpy(x, y) # `x` will be filled with `y`'s data.
|
167
|
+
"""
|
168
|
+
|
169
|
+
__all__ = ["ndarray", "NDArray"]
|
@@ -0,0 +1,206 @@
|
|
1
|
+
from typing import Union
|
2
|
+
|
3
|
+
from gstaichi._lib import core as ti_python_core
|
4
|
+
|
5
|
+
# ========================================
|
6
|
+
# real types
|
7
|
+
|
8
|
+
# ----------------------------------------
|
9
|
+
|
10
|
+
float16 = ti_python_core.DataType_f16
|
11
|
+
"""16-bit precision floating point data type.
|
12
|
+
"""
|
13
|
+
|
14
|
+
# ----------------------------------------
|
15
|
+
|
16
|
+
f16 = float16
|
17
|
+
"""Alias for :const:`~gstaichi.types.primitive_types.float16`
|
18
|
+
"""
|
19
|
+
|
20
|
+
# ----------------------------------------
|
21
|
+
|
22
|
+
float32 = ti_python_core.DataType_f32
|
23
|
+
"""32-bit single precision floating point data type.
|
24
|
+
"""
|
25
|
+
|
26
|
+
# ----------------------------------------
|
27
|
+
|
28
|
+
f32 = float32
|
29
|
+
"""Alias for :const:`~gstaichi.types.primitive_types.float32`
|
30
|
+
"""
|
31
|
+
|
32
|
+
# ----------------------------------------
|
33
|
+
|
34
|
+
float64 = ti_python_core.DataType_f64
|
35
|
+
"""64-bit double precision floating point data type.
|
36
|
+
"""
|
37
|
+
|
38
|
+
# ----------------------------------------
|
39
|
+
|
40
|
+
f64 = float64
|
41
|
+
"""Alias for :const:`~gstaichi.types.primitive_types.float64`
|
42
|
+
"""
|
43
|
+
# ----------------------------------------
|
44
|
+
|
45
|
+
# ========================================
|
46
|
+
# Integer types
|
47
|
+
|
48
|
+
# ----------------------------------------
|
49
|
+
|
50
|
+
int8 = ti_python_core.DataType_i8
|
51
|
+
"""8-bit signed integer data type.
|
52
|
+
"""
|
53
|
+
|
54
|
+
# ----------------------------------------
|
55
|
+
|
56
|
+
i8 = int8
|
57
|
+
"""Alias for :const:`~gstaichi.types.primitive_types.int8`
|
58
|
+
"""
|
59
|
+
|
60
|
+
# ----------------------------------------
|
61
|
+
|
62
|
+
int16 = ti_python_core.DataType_i16
|
63
|
+
"""16-bit signed integer data type.
|
64
|
+
"""
|
65
|
+
|
66
|
+
# ----------------------------------------
|
67
|
+
|
68
|
+
i16 = int16
|
69
|
+
"""Alias for :const:`~gstaichi.types.primitive_types.int16`
|
70
|
+
"""
|
71
|
+
|
72
|
+
# ----------------------------------------
|
73
|
+
|
74
|
+
int32 = ti_python_core.DataType_i32
|
75
|
+
"""32-bit signed integer data type.
|
76
|
+
"""
|
77
|
+
|
78
|
+
# ----------------------------------------
|
79
|
+
|
80
|
+
i32 = int32
|
81
|
+
"""Alias for :const:`~gstaichi.types.primitive_types.int32`
|
82
|
+
"""
|
83
|
+
|
84
|
+
# ----------------------------------------
|
85
|
+
|
86
|
+
int64 = ti_python_core.DataType_i64
|
87
|
+
"""64-bit signed integer data type.
|
88
|
+
"""
|
89
|
+
|
90
|
+
# ----------------------------------------
|
91
|
+
|
92
|
+
i64 = int64
|
93
|
+
"""Alias for :const:`~gstaichi.types.primitive_types.int64`
|
94
|
+
"""
|
95
|
+
|
96
|
+
# ----------------------------------------
|
97
|
+
|
98
|
+
uint8 = ti_python_core.DataType_u8
|
99
|
+
"""8-bit unsigned integer data type.
|
100
|
+
"""
|
101
|
+
|
102
|
+
# ----------------------------------------
|
103
|
+
|
104
|
+
uint1 = ti_python_core.DataType_u1
|
105
|
+
"""1-bit unsigned integer data type. Same as booleans.
|
106
|
+
"""
|
107
|
+
|
108
|
+
# ----------------------------------------
|
109
|
+
|
110
|
+
u1 = uint1
|
111
|
+
"""Alias for :const:`~gstaichi.types.primitive_types.uint1`
|
112
|
+
"""
|
113
|
+
|
114
|
+
# ----------------------------------------
|
115
|
+
|
116
|
+
u8 = uint8
|
117
|
+
"""Alias for :const:`~gstaichi.types.primitive_types.uint8`
|
118
|
+
"""
|
119
|
+
|
120
|
+
# ----------------------------------------
|
121
|
+
|
122
|
+
uint16 = ti_python_core.DataType_u16
|
123
|
+
"""16-bit unsigned integer data type.
|
124
|
+
"""
|
125
|
+
|
126
|
+
# ----------------------------------------
|
127
|
+
|
128
|
+
u16 = uint16
|
129
|
+
"""Alias for :const:`~gstaichi.types.primitive_types.uint16`
|
130
|
+
"""
|
131
|
+
|
132
|
+
# ----------------------------------------
|
133
|
+
|
134
|
+
uint32 = ti_python_core.DataType_u32
|
135
|
+
"""32-bit unsigned integer data type.
|
136
|
+
"""
|
137
|
+
|
138
|
+
# ----------------------------------------
|
139
|
+
|
140
|
+
u32 = uint32
|
141
|
+
"""Alias for :const:`~gstaichi.types.primitive_types.uint32`
|
142
|
+
"""
|
143
|
+
|
144
|
+
# ----------------------------------------
|
145
|
+
|
146
|
+
uint64 = ti_python_core.DataType_u64
|
147
|
+
"""64-bit unsigned integer data type.
|
148
|
+
"""
|
149
|
+
|
150
|
+
# ----------------------------------------
|
151
|
+
|
152
|
+
u64 = uint64
|
153
|
+
"""Alias for :const:`~gstaichi.types.primitive_types.uint64`
|
154
|
+
"""
|
155
|
+
|
156
|
+
# ----------------------------------------
|
157
|
+
|
158
|
+
|
159
|
+
class RefType:
|
160
|
+
def __init__(self, tp):
|
161
|
+
self.tp = tp
|
162
|
+
|
163
|
+
|
164
|
+
def ref(tp):
|
165
|
+
return RefType(tp)
|
166
|
+
|
167
|
+
|
168
|
+
real_types = [f16, f32, f64, float]
|
169
|
+
real_type_ids = [id(t) for t in real_types]
|
170
|
+
|
171
|
+
integer_types = [i8, i16, i32, i64, u1, u8, u16, u32, u64, int, bool]
|
172
|
+
integer_type_ids = [id(t) for t in integer_types]
|
173
|
+
|
174
|
+
all_types = real_types + integer_types
|
175
|
+
type_ids = [id(t) for t in all_types]
|
176
|
+
|
177
|
+
_python_primitive_types = Union[int, float, bool, str, None]
|
178
|
+
|
179
|
+
__all__ = [
|
180
|
+
"float32",
|
181
|
+
"f32",
|
182
|
+
"float64",
|
183
|
+
"f64",
|
184
|
+
"float16",
|
185
|
+
"f16",
|
186
|
+
"int8",
|
187
|
+
"i8",
|
188
|
+
"int16",
|
189
|
+
"i16",
|
190
|
+
"int32",
|
191
|
+
"i32",
|
192
|
+
"int64",
|
193
|
+
"i64",
|
194
|
+
"uint1",
|
195
|
+
"u1",
|
196
|
+
"uint8",
|
197
|
+
"u8",
|
198
|
+
"uint16",
|
199
|
+
"u16",
|
200
|
+
"uint32",
|
201
|
+
"u32",
|
202
|
+
"uint64",
|
203
|
+
"u64",
|
204
|
+
"ref",
|
205
|
+
"_python_primitive_types",
|
206
|
+
]
|
gstaichi/types/quant.py
ADDED
@@ -0,0 +1,88 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
"""
|
4
|
+
This module defines generators of quantized types.
|
5
|
+
For more details, read https://yuanming.gstaichi.graphics/publication/2021-quangstaichi/quangstaichi.pdf.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from gstaichi._lib.utils import ti_python_core as _ti_python_core
|
9
|
+
from gstaichi.types.primitive_types import i32
|
10
|
+
|
11
|
+
_type_factory = _ti_python_core.get_type_factory_instance()
|
12
|
+
|
13
|
+
|
14
|
+
def int(bits, signed=True, compute=None): # pylint: disable=W0622
|
15
|
+
"""Generates a quantized type for integers.
|
16
|
+
|
17
|
+
Args:
|
18
|
+
bits (int): Number of bits.
|
19
|
+
signed (bool): Signed or unsigned.
|
20
|
+
compute (DataType): Type for computation.
|
21
|
+
|
22
|
+
Returns:
|
23
|
+
DataType: The specified type.
|
24
|
+
"""
|
25
|
+
if compute is None:
|
26
|
+
from gstaichi.lang import impl # pylint: disable=C0415
|
27
|
+
|
28
|
+
compute = impl.get_runtime().default_ip if signed else impl.get_runtime().default_up
|
29
|
+
if isinstance(compute, _ti_python_core.DataTypeCxx):
|
30
|
+
compute = compute.get_ptr()
|
31
|
+
return _type_factory.get_quant_int_type(bits, signed, compute)
|
32
|
+
|
33
|
+
|
34
|
+
def fixed(bits, signed=True, max_value=1.0, compute=None, scale=None):
|
35
|
+
"""Generates a quantized type for fixed-point real numbers.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
bits (int): Number of bits.
|
39
|
+
signed (bool): Signed or unsigned.
|
40
|
+
max_value (float): Maximum value of the number.
|
41
|
+
compute (DataType): Type for computation.
|
42
|
+
scale (float): Scaling factor. The argument is prioritized over range.
|
43
|
+
|
44
|
+
Returns:
|
45
|
+
DataType: The specified type.
|
46
|
+
"""
|
47
|
+
if compute is None:
|
48
|
+
from gstaichi.lang import impl # pylint: disable=C0415
|
49
|
+
|
50
|
+
compute = impl.get_runtime().default_fp
|
51
|
+
if isinstance(compute, _ti_python_core.DataTypeCxx):
|
52
|
+
compute = compute.get_ptr()
|
53
|
+
# TODO: handle cases with bits > 32
|
54
|
+
underlying_type = int(bits=bits, signed=signed, compute=i32)
|
55
|
+
if scale is None:
|
56
|
+
if signed:
|
57
|
+
scale = max_value / 2 ** (bits - 1)
|
58
|
+
else:
|
59
|
+
scale = max_value / 2**bits
|
60
|
+
return _type_factory.get_quant_fixed_type(underlying_type, compute, scale)
|
61
|
+
|
62
|
+
|
63
|
+
def float(exp, frac, signed=True, compute=None): # pylint: disable=W0622
|
64
|
+
"""Generates a quantized type for floating-point real numbers.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
exp (int): Number of exponent bits.
|
68
|
+
frac (int): Number of fraction bits.
|
69
|
+
signed (bool): Signed or unsigned.
|
70
|
+
compute (DataType): Type for computation.
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
DataType: The specified type.
|
74
|
+
"""
|
75
|
+
if compute is None:
|
76
|
+
from gstaichi.lang import impl # pylint: disable=C0415
|
77
|
+
|
78
|
+
compute = impl.get_runtime().default_fp
|
79
|
+
if isinstance(compute, _ti_python_core.DataTypeCxx):
|
80
|
+
compute = compute.get_ptr()
|
81
|
+
# Exponent is always unsigned
|
82
|
+
exp_type = int(bits=exp, signed=False, compute=i32)
|
83
|
+
# TODO: handle cases with frac > 32
|
84
|
+
frac_type = int(bits=frac, signed=signed, compute=i32)
|
85
|
+
return _type_factory.get_quant_float_type(frac_type, exp_type, compute)
|
86
|
+
|
87
|
+
|
88
|
+
__all__ = ["int", "fixed", "float"]
|
@@ -0,0 +1,85 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
from gstaichi.lang.exception import GsTaichiCompilationError
|
4
|
+
from gstaichi.types.enums import Format
|
5
|
+
from gstaichi.types.primitive_types import f16, f32, i8, i16, i32, u8, u16, u32
|
6
|
+
|
7
|
+
FORMAT2TY_CH = {
|
8
|
+
Format.r8: (u8, 1),
|
9
|
+
Format.r8u: (u8, 1),
|
10
|
+
Format.r8i: (i8, 1),
|
11
|
+
Format.rg8: (u8, 2),
|
12
|
+
Format.rg8u: (u8, 2),
|
13
|
+
Format.rg8i: (i8, 2),
|
14
|
+
Format.rgba8: (u8, 4),
|
15
|
+
Format.rgba8u: (u8, 4),
|
16
|
+
Format.rgba8i: (i8, 4),
|
17
|
+
Format.r16: (u16, 1),
|
18
|
+
Format.r16u: (u16, 1),
|
19
|
+
Format.r16i: (i16, 1),
|
20
|
+
Format.r16f: (f16, 1),
|
21
|
+
Format.rg16: (u16, 2),
|
22
|
+
Format.rg16u: (u16, 2),
|
23
|
+
Format.rg16i: (i16, 2),
|
24
|
+
Format.rg16f: (f16, 2),
|
25
|
+
Format.rgb16: (u16, 3),
|
26
|
+
Format.rgb16u: (u16, 3),
|
27
|
+
Format.rgb16i: (i16, 3),
|
28
|
+
Format.rgb16f: (f16, 3),
|
29
|
+
Format.rgba16: (u16, 4),
|
30
|
+
Format.rgba16u: (u16, 4),
|
31
|
+
Format.rgba16i: (i16, 4),
|
32
|
+
Format.rgba16f: (f16, 4),
|
33
|
+
Format.r32u: (u32, 1),
|
34
|
+
Format.r32i: (i32, 1),
|
35
|
+
Format.r32f: (f32, 1),
|
36
|
+
Format.rg32u: (u32, 2),
|
37
|
+
Format.rg32i: (i32, 2),
|
38
|
+
Format.rg32f: (f32, 2),
|
39
|
+
Format.rgb32u: (u32, 3),
|
40
|
+
Format.rgb32i: (i32, 3),
|
41
|
+
Format.rgb32f: (f32, 3),
|
42
|
+
Format.rgba32u: (u32, 4),
|
43
|
+
Format.rgba32i: (i32, 4),
|
44
|
+
Format.rgba32f: (f32, 4),
|
45
|
+
}
|
46
|
+
|
47
|
+
# Reverse lookup by (channel_format, num_channels)
|
48
|
+
TY_CH2FORMAT = {v: k for k, v in FORMAT2TY_CH.items()}
|
49
|
+
|
50
|
+
|
51
|
+
class TextureType:
|
52
|
+
"""Type annotation for Textures.
|
53
|
+
|
54
|
+
Args:
|
55
|
+
num_dimensions (int): Number of dimensions. For examples for a 2D texture this should be `2`.
|
56
|
+
"""
|
57
|
+
|
58
|
+
def __init__(self, num_dimensions):
|
59
|
+
self.num_dimensions = num_dimensions
|
60
|
+
|
61
|
+
|
62
|
+
class RWTextureType:
|
63
|
+
"""Type annotation for RW Textures (image load store).
|
64
|
+
|
65
|
+
Args:
|
66
|
+
num_dimensions (int): Number of dimensions. For examples for a 2D texture this should be `2`.
|
67
|
+
lod (float): Specifies the explicit level-of-detail.
|
68
|
+
fmt (ti.Format): Color format of texture
|
69
|
+
"""
|
70
|
+
|
71
|
+
def __init__(self, num_dimensions, lod=0, fmt=None):
|
72
|
+
self.num_dimensions = num_dimensions
|
73
|
+
if fmt is None:
|
74
|
+
raise GsTaichiCompilationError("fmt is required for rw_texture type")
|
75
|
+
else:
|
76
|
+
self.fmt = fmt
|
77
|
+
self.lod = lod
|
78
|
+
|
79
|
+
|
80
|
+
texture = TextureType
|
81
|
+
rw_texture = RWTextureType
|
82
|
+
"""Alias for :class:`~gstaichi.types.ndarray_type.TextureType`.
|
83
|
+
"""
|
84
|
+
|
85
|
+
__all__ = ["texture", "rw_texture"]
|
gstaichi/types/utils.py
ADDED
@@ -0,0 +1,11 @@
|
|
1
|
+
from gstaichi._lib import core as ti_python_core
|
2
|
+
|
3
|
+
is_signed = ti_python_core.is_signed
|
4
|
+
|
5
|
+
is_integral = ti_python_core.is_integral
|
6
|
+
|
7
|
+
is_real = ti_python_core.is_real
|
8
|
+
|
9
|
+
is_tensor = ti_python_core.is_tensor
|
10
|
+
|
11
|
+
__all__ = ["is_signed", "is_integral", "is_real", "is_tensor"]
|