gstaichi 2.1.1rc3__cp310-cp310-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. gstaichi/CHANGELOG.md +4 -0
  2. gstaichi/__init__.py +40 -0
  3. gstaichi/_funcs.py +706 -0
  4. gstaichi/_kernels.py +420 -0
  5. gstaichi/_lib/__init__.py +3 -0
  6. gstaichi/_lib/core/__init__.py +0 -0
  7. gstaichi/_lib/core/gstaichi_python.cpython-310-darwin.so +0 -0
  8. gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
  9. gstaichi/_lib/core/py.typed +0 -0
  10. gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
  11. gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
  12. gstaichi/_lib/utils.py +243 -0
  13. gstaichi/_logging.py +131 -0
  14. gstaichi/_snode/__init__.py +5 -0
  15. gstaichi/_snode/fields_builder.py +187 -0
  16. gstaichi/_snode/snode_tree.py +34 -0
  17. gstaichi/_test_tools/__init__.py +18 -0
  18. gstaichi/_test_tools/dataclass_test_tools.py +36 -0
  19. gstaichi/_test_tools/load_kernel_string.py +30 -0
  20. gstaichi/_test_tools/textwrap2.py +6 -0
  21. gstaichi/_version.py +1 -0
  22. gstaichi/_version_check.py +100 -0
  23. gstaichi/ad/__init__.py +3 -0
  24. gstaichi/ad/_ad.py +530 -0
  25. gstaichi/algorithms/__init__.py +3 -0
  26. gstaichi/algorithms/_algorithms.py +117 -0
  27. gstaichi/assets/.git +1 -0
  28. gstaichi/assets/Go-Regular.ttf +0 -0
  29. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  30. gstaichi/examples/lcg_python.py +26 -0
  31. gstaichi/examples/lcg_taichi.py +34 -0
  32. gstaichi/examples/minimal.py +28 -0
  33. gstaichi/experimental.py +16 -0
  34. gstaichi/lang/__init__.py +50 -0
  35. gstaichi/lang/_dataclass_util.py +31 -0
  36. gstaichi/lang/_fast_caching/__init__.py +3 -0
  37. gstaichi/lang/_fast_caching/args_hasher.py +110 -0
  38. gstaichi/lang/_fast_caching/config_hasher.py +30 -0
  39. gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
  40. gstaichi/lang/_fast_caching/function_hasher.py +57 -0
  41. gstaichi/lang/_fast_caching/hash_utils.py +11 -0
  42. gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
  43. gstaichi/lang/_fast_caching/src_hasher.py +75 -0
  44. gstaichi/lang/_kernel_impl_dataclass.py +212 -0
  45. gstaichi/lang/_ndarray.py +352 -0
  46. gstaichi/lang/_ndrange.py +152 -0
  47. gstaichi/lang/_template_mapper.py +195 -0
  48. gstaichi/lang/_texture.py +172 -0
  49. gstaichi/lang/_wrap_inspect.py +215 -0
  50. gstaichi/lang/any_array.py +99 -0
  51. gstaichi/lang/ast/__init__.py +5 -0
  52. gstaichi/lang/ast/ast_transformer.py +1323 -0
  53. gstaichi/lang/ast/ast_transformer_utils.py +346 -0
  54. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  55. gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
  56. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
  57. gstaichi/lang/ast/checkers.py +106 -0
  58. gstaichi/lang/ast/symbol_resolver.py +57 -0
  59. gstaichi/lang/ast/transform.py +9 -0
  60. gstaichi/lang/common_ops.py +310 -0
  61. gstaichi/lang/exception.py +80 -0
  62. gstaichi/lang/expr.py +180 -0
  63. gstaichi/lang/field.py +428 -0
  64. gstaichi/lang/impl.py +1243 -0
  65. gstaichi/lang/kernel_arguments.py +155 -0
  66. gstaichi/lang/kernel_impl.py +1341 -0
  67. gstaichi/lang/matrix.py +1835 -0
  68. gstaichi/lang/matrix_ops.py +341 -0
  69. gstaichi/lang/matrix_ops_utils.py +190 -0
  70. gstaichi/lang/mesh.py +687 -0
  71. gstaichi/lang/misc.py +782 -0
  72. gstaichi/lang/ops.py +1494 -0
  73. gstaichi/lang/runtime_ops.py +13 -0
  74. gstaichi/lang/shell.py +35 -0
  75. gstaichi/lang/simt/__init__.py +5 -0
  76. gstaichi/lang/simt/block.py +94 -0
  77. gstaichi/lang/simt/grid.py +7 -0
  78. gstaichi/lang/simt/subgroup.py +191 -0
  79. gstaichi/lang/simt/warp.py +96 -0
  80. gstaichi/lang/snode.py +489 -0
  81. gstaichi/lang/source_builder.py +150 -0
  82. gstaichi/lang/struct.py +810 -0
  83. gstaichi/lang/util.py +312 -0
  84. gstaichi/linalg/__init__.py +8 -0
  85. gstaichi/linalg/matrixfree_cg.py +310 -0
  86. gstaichi/linalg/sparse_cg.py +59 -0
  87. gstaichi/linalg/sparse_matrix.py +303 -0
  88. gstaichi/linalg/sparse_solver.py +123 -0
  89. gstaichi/math/__init__.py +11 -0
  90. gstaichi/math/_complex.py +205 -0
  91. gstaichi/math/mathimpl.py +886 -0
  92. gstaichi/profiler/__init__.py +6 -0
  93. gstaichi/profiler/kernel_metrics.py +260 -0
  94. gstaichi/profiler/kernel_profiler.py +586 -0
  95. gstaichi/profiler/memory_profiler.py +15 -0
  96. gstaichi/profiler/scoped_profiler.py +36 -0
  97. gstaichi/sparse/__init__.py +3 -0
  98. gstaichi/sparse/_sparse_grid.py +77 -0
  99. gstaichi/tools/__init__.py +12 -0
  100. gstaichi/tools/diagnose.py +117 -0
  101. gstaichi/tools/np2ply.py +364 -0
  102. gstaichi/tools/vtk.py +38 -0
  103. gstaichi/types/__init__.py +19 -0
  104. gstaichi/types/annotations.py +52 -0
  105. gstaichi/types/compound_types.py +71 -0
  106. gstaichi/types/enums.py +49 -0
  107. gstaichi/types/ndarray_type.py +169 -0
  108. gstaichi/types/primitive_types.py +206 -0
  109. gstaichi/types/quant.py +88 -0
  110. gstaichi/types/texture_type.py +85 -0
  111. gstaichi/types/utils.py +11 -0
  112. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3.h +6389 -0
  113. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3native.h +594 -0
  114. gstaichi-2.1.1rc3.data/data/include/spirv-tools/instrument.hpp +268 -0
  115. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.h +907 -0
  116. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.hpp +375 -0
  117. gstaichi-2.1.1rc3.data/data/include/spirv-tools/linker.hpp +97 -0
  118. gstaichi-2.1.1rc3.data/data/include/spirv-tools/optimizer.hpp +970 -0
  119. gstaichi-2.1.1rc3.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
  120. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.h +2568 -0
  121. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.hpp +2579 -0
  122. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
  123. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
  124. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
  125. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
  126. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
  127. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
  128. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
  129. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
  130. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
  131. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
  132. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
  133. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
  134. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
  135. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
  136. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
  137. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
  138. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
  139. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
  140. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
  141. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
  142. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
  143. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
  144. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
  145. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
  146. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
  147. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
  148. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
  149. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
  150. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
  151. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
  152. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  153. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
  154. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  155. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  156. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  157. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  158. gstaichi-2.1.1rc3.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  159. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
  160. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
  161. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
  162. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
  163. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
  164. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
  165. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
  166. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
  167. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
  168. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
  169. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
  170. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
  171. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
  172. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
  173. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
  174. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
  175. gstaichi-2.1.1rc3.dist-info/METADATA +106 -0
  176. gstaichi-2.1.1rc3.dist-info/RECORD +179 -0
  177. gstaichi-2.1.1rc3.dist-info/WHEEL +5 -0
  178. gstaichi-2.1.1rc3.dist-info/licenses/LICENSE +201 -0
  179. gstaichi-2.1.1rc3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,71 @@
1
+ # type: ignore
2
+
3
+ import gstaichi
4
+ from gstaichi._lib.utils import ti_python_core as _ti_python_core
5
+
6
+ _type_factory = _ti_python_core.get_type_factory_instance()
7
+
8
+
9
+ class CompoundType:
10
+ def from_kernel_struct_ret(self, launch_ctx, index: tuple):
11
+ raise NotImplementedError()
12
+
13
+
14
+ # TODO: maybe move MatrixType, StructType here to avoid the circular import?
15
+ def matrix(n=None, m=None, dtype=None):
16
+ """Creates a matrix type with given shape and data type.
17
+
18
+ Args:
19
+ n (int): number of rows of the matrix.
20
+ m (int): number of columns of the matrix.
21
+ dtype (:mod:`~gstaichi.types.primitive_types`): matrix data type.
22
+
23
+ Returns:
24
+ A matrix type.
25
+
26
+ Example::
27
+
28
+ >>> mat2x2 = ti.types.matrix(2, 2, ti.f32) # 2x2 matrix type
29
+ >>> M = mat2x2([[1., 2.], [3., 4.]]) # an instance of this type
30
+ """
31
+ return gstaichi.lang.matrix.MatrixType(n, m, 2, dtype)
32
+
33
+
34
+ def vector(n=None, dtype=None):
35
+ """Creates a vector type with given shape and data type.
36
+
37
+ Args:
38
+ n (int): dimension of the vector.
39
+ dtype (:mod:`~gstaichi.types.primitive_types`): vector data type.
40
+
41
+ Returns:
42
+ A vector type.
43
+
44
+ Example::
45
+
46
+ >>> vec3 = ti.types.vector(3, ti.f32) # 3d vector type
47
+ >>> v = vec3([1., 2., 3.]) # an instance of this type
48
+ """
49
+ return gstaichi.lang.matrix.VectorType(n, dtype)
50
+
51
+
52
+ def struct(**kwargs):
53
+ """Creates a struct type with given members.
54
+
55
+ Args:
56
+ kwargs (dict): a dictionary contains the names and types of the
57
+ struct members.
58
+
59
+ Returns:
60
+ A struct type.
61
+
62
+ Example::
63
+
64
+ >>> vec3 = ti.types.vector(3, ti.f32)
65
+ >>> sphere = ti.types.struct(center=vec3, radius=float)
66
+ >>> s = sphere(center=vec3([0., 0., 0.]), radius=1.0)
67
+ """
68
+ return gstaichi.lang.struct.StructType(**kwargs)
69
+
70
+
71
+ __all__ = ["matrix", "vector", "struct"]
@@ -0,0 +1,49 @@
1
+ # type: ignore
2
+
3
+ from gstaichi._lib import core as _ti_core
4
+
5
+ Layout = _ti_core.Layout
6
+ AutodiffMode = _ti_core.AutodiffMode
7
+ SNodeGradType = _ti_core.SNodeGradType
8
+ Format = _ti_core.Format
9
+ BoundaryMode = _ti_core.BoundaryMode
10
+
11
+
12
+ def to_boundary_enum(boundary):
13
+ if boundary == "clamp":
14
+ return BoundaryMode.CLAMP
15
+ if boundary == "unsafe":
16
+ return BoundaryMode.UNSAFE
17
+ raise ValueError(f"Invalid boundary argument: {boundary}")
18
+
19
+
20
+ class DeviceCapability:
21
+ spirv_version_1_3 = "spirv_version=66304"
22
+ spirv_version_1_4 = "spirv_version=66560"
23
+ spirv_version_1_5 = "spirv_version=66816"
24
+ spirv_has_int8 = "spirv_has_int8"
25
+ spirv_has_int16 = "spirv_has_int16"
26
+ spirv_has_int64 = "spirv_has_int64"
27
+ spirv_has_float16 = "spirv_has_float16"
28
+ spirv_has_float64 = "spirv_has_float64"
29
+ spirv_has_atomic_int64 = "spirv_has_atomic_int64"
30
+ spirv_has_atomic_float16 = "spirv_has_atomic_float16"
31
+ spirv_has_atomic_float16_add = "spirv_has_atomic_float16_add"
32
+ spirv_has_atomic_float16_minmax = "spirv_has_atomic_float16_minmax"
33
+ spirv_has_atomic_float = "spirv_has_atomic_float"
34
+ spirv_has_atomic_float_add = "spirv_has_atomic_float_add"
35
+ spirv_has_atomic_float_minmax = "spirv_has_atomic_float_minmax"
36
+ spirv_has_atomic_float64 = "spirv_has_atomic_float64"
37
+ spirv_has_atomic_float64_add = "spirv_has_atomic_float64_add"
38
+ spirv_has_atomic_float64_minmax = "spirv_has_atomic_float64_minmax"
39
+ spirv_has_variable_ptr = "spirv_has_variable_ptr"
40
+ spirv_has_physical_storage_buffer = "spirv_has_physical_storage_buffer"
41
+ spirv_has_subgroup_basic = "spirv_has_subgroup_basic"
42
+ spirv_has_subgroup_vote = "spirv_has_subgroup_vote"
43
+ spirv_has_subgroup_arithmetic = "spirv_has_subgroup_arithmetic"
44
+ spirv_has_subgroup_ballot = "spirv_has_subgroup_ballot"
45
+ spirv_has_non_semantic_info = "spirv_has_non_semantic_info"
46
+ spirv_has_no_integer_wrap_decoration = "spirv_has_no_integer_wrap_decoration"
47
+
48
+
49
+ __all__ = ["Layout", "AutodiffMode", "SNodeGradType", "Format", "DeviceCapability"]
@@ -0,0 +1,169 @@
1
+ from typing import Any
2
+
3
+ from gstaichi.types.compound_types import CompoundType, matrix, vector
4
+ from gstaichi.types.enums import Layout, to_boundary_enum
5
+
6
+
7
+ class NdarrayTypeMetadata:
8
+ def __init__(self, element_type, shape=None, needs_grad=False):
9
+ self.element_type = element_type
10
+ self.shape = shape
11
+ self.layout = Layout.AOS
12
+ self.needs_grad = needs_grad
13
+
14
+
15
+ # TODO(Haidong): This is a helper function that creates a MatrixType
16
+ # with respect to element_dim and element_shape.
17
+ # Remove this function when the two args are totally deprecated.
18
+ def _make_matrix_dtype_from_element_shape(element_dim, element_shape, primitive_dtype):
19
+ if isinstance(primitive_dtype, CompoundType):
20
+ raise TypeError(f'Cannot specifiy matrix dtype "{primitive_dtype}" and element shape or dim at the same time.')
21
+
22
+ # Scalars
23
+ if element_dim == 0 or (element_shape is not None and len(element_shape) == 0):
24
+ return primitive_dtype
25
+
26
+ # Cook element dim and shape into matrix type.
27
+ mat_dtype = None
28
+ if element_dim is not None:
29
+ # TODO: expand use case with arbitary tensor dims!
30
+ if element_dim < 0 or element_dim > 2:
31
+ raise ValueError("Only scalars, vectors, and matrices are allowed as elements of ti.types.ndarray()")
32
+ # Check dim consistency. The matrix dtype will be cooked later.
33
+ if element_shape is not None and len(element_shape) != element_dim:
34
+ raise ValueError(
35
+ f"Both element_shape and element_dim are specified, but shape doesn't match specified dim: "
36
+ f"{len(element_shape)}!={element_dim}"
37
+ )
38
+ mat_dtype = vector(None, primitive_dtype) if element_dim == 1 else matrix(None, None, primitive_dtype)
39
+ elif element_shape is not None:
40
+ if len(element_shape) > 2:
41
+ raise ValueError("Only scalars, vectors, and matrices are allowed as elements of ti.types.ndarray()")
42
+ mat_dtype = (
43
+ vector(element_shape[0], primitive_dtype)
44
+ if len(element_shape) == 1
45
+ else matrix(element_shape[0], element_shape[1], primitive_dtype)
46
+ )
47
+ return mat_dtype
48
+
49
+
50
+ class NdarrayType:
51
+ """Type annotation for arbitrary arrays, including external arrays (numpy ndarrays and torch tensors) and GsTaichi ndarrays.
52
+
53
+ For external arrays, we treat it as a GsTaichi data container with Scalar, Vector or Matrix elements.
54
+ For GsTaichi vector/matrix ndarrays, we will automatically identify element dimension and their corresponding axis by the
55
+ dimension of datatype, say scalars, matrices or vectors.
56
+ For example, given type annotation `ti.types.ndarray(dtype=ti.math.vec3)`, a numpy array `np.zeros(10, 10, 3)` will be
57
+ recognized as a 10x10 matrix composed of vec3 elements.
58
+
59
+ Args:
60
+ dtype (Union[PrimitiveType, VectorType, MatrixType, NoneType], optional): None if not speicified.
61
+ ndim (Union[Int, NoneType]): None if not specified, number of field dimensions. This argument is ignored for externa
62
+ arrays for now.
63
+ element_dim (Union[Int, NoneType], optional):
64
+ None if not specified (will be treated as 0 for external arrays),
65
+ 0 if scalar elements,
66
+ 1 if vector elements, and
67
+ 2 if matrix elements.
68
+ element_shape (Union[Tuple[Int], NoneType]):
69
+ None if not specified, shapes of each element.
70
+ For example, element_shape must be 1d for vector and 2d tuple for matrix.
71
+ This argument is ignored for external arrays for now.
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ dtype=None,
77
+ ndim=None,
78
+ element_dim=None,
79
+ element_shape=None,
80
+ field_dim=None,
81
+ needs_grad=None,
82
+ boundary="unsafe",
83
+ ):
84
+ if field_dim is not None:
85
+ raise ValueError("The field_dim argument for ndarray type is already deprecated. Please use ndim instead.")
86
+ if element_dim is not None or element_shape is not None:
87
+ self.dtype = _make_matrix_dtype_from_element_shape(element_dim, element_shape, dtype)
88
+ else:
89
+ self.dtype = dtype
90
+
91
+ self.ndim = ndim
92
+ self.layout = Layout.AOS
93
+ self.needs_grad = needs_grad
94
+ self.boundary = to_boundary_enum(boundary)
95
+
96
+ @classmethod
97
+ def __class_getitem__(cls, args, **kwargs):
98
+ return cls(*args, **kwargs)
99
+
100
+ def check_matched(self, ndarray_type: NdarrayTypeMetadata, arg_name: str):
101
+ # FIXME(Haidong) Cannot use Vector/MatrixType due to circular import
102
+ # Use the CompuoundType instead to determine the specific typs.
103
+ # TODO Replace CompoundType with MatrixType and VectorType
104
+
105
+ # Check dtype match
106
+ if isinstance(self.dtype, CompoundType):
107
+ if not self.dtype.check_matched(ndarray_type.element_type): # type: ignore
108
+ raise ValueError(
109
+ f"Invalid value for argument {arg_name} - required element type: {self.dtype.to_string()}, " # type: ignore
110
+ f"but {ndarray_type.element_type.to_string()} is provided"
111
+ )
112
+ else:
113
+ if self.dtype is not None:
114
+ # Check dtype match for scalar.
115
+ from gstaichi.lang import util # pylint: disable=C0415
116
+
117
+ if not util.cook_dtype(self.dtype) == ndarray_type.element_type:
118
+ raise TypeError(
119
+ f"Expect element type {self.dtype} for argument {arg_name}, but get {ndarray_type.element_type}"
120
+ )
121
+
122
+ # Check ndim match
123
+ if self.ndim is not None and ndarray_type.shape is not None and self.ndim != len(ndarray_type.shape):
124
+ raise ValueError(
125
+ f"Invalid value for argument {arg_name} - required ndim={self.ndim}, but {len(ndarray_type.shape)}d "
126
+ f"ndarray with shape {ndarray_type.shape} is provided"
127
+ )
128
+
129
+ # Check needs_grad
130
+ if self.needs_grad is not None and self.needs_grad > ndarray_type.needs_grad:
131
+ # It's okay to pass a needs_grad=True ndarray at runtime to a need_grad=False arg but not vice versa.
132
+ raise ValueError(
133
+ f"Invalid value for argument {arg_name} - required needs_grad={self.needs_grad}, but "
134
+ f"{ndarray_type.needs_grad} is provided"
135
+ )
136
+
137
+ def __repr__(self):
138
+ return f"NdarrayType(dtype={self.dtype}, ndim={self.ndim}, layout={self.layout}, needs_grad={self.needs_grad})"
139
+
140
+ def __str__(self):
141
+ return self.__repr__()
142
+
143
+ def __getitem__(self, i: Any) -> Any:
144
+ # needed for pyright
145
+ raise NotImplemented
146
+
147
+ def __setitem__(self, i: Any, v: Any) -> None:
148
+ # needed for pyright
149
+ raise NotImplemented
150
+
151
+
152
+ ndarray = NdarrayType
153
+ NDArray = NdarrayType
154
+ """Alias for :class:`~gstaichi.types.ndarray_type.NdarrayType`.
155
+
156
+ Example::
157
+
158
+ >>> @ti.kernel
159
+ >>> def to_numpy(x: ti.types.ndarray(), y: ti.types.ndarray()):
160
+ >>> for i in range(n):
161
+ >>> x[i] = y[i]
162
+ >>>
163
+ >>> y = ti.ndarray(ti.f64, shape=n)
164
+ >>> ... # calculate y
165
+ >>> x = numpy.zeros(n)
166
+ >>> to_numpy(x, y) # `x` will be filled with `y`'s data.
167
+ """
168
+
169
+ __all__ = ["ndarray", "NDArray"]
@@ -0,0 +1,206 @@
1
+ from typing import Union
2
+
3
+ from gstaichi._lib import core as ti_python_core
4
+
5
+ # ========================================
6
+ # real types
7
+
8
+ # ----------------------------------------
9
+
10
+ float16 = ti_python_core.DataType_f16
11
+ """16-bit precision floating point data type.
12
+ """
13
+
14
+ # ----------------------------------------
15
+
16
+ f16 = float16
17
+ """Alias for :const:`~gstaichi.types.primitive_types.float16`
18
+ """
19
+
20
+ # ----------------------------------------
21
+
22
+ float32 = ti_python_core.DataType_f32
23
+ """32-bit single precision floating point data type.
24
+ """
25
+
26
+ # ----------------------------------------
27
+
28
+ f32 = float32
29
+ """Alias for :const:`~gstaichi.types.primitive_types.float32`
30
+ """
31
+
32
+ # ----------------------------------------
33
+
34
+ float64 = ti_python_core.DataType_f64
35
+ """64-bit double precision floating point data type.
36
+ """
37
+
38
+ # ----------------------------------------
39
+
40
+ f64 = float64
41
+ """Alias for :const:`~gstaichi.types.primitive_types.float64`
42
+ """
43
+ # ----------------------------------------
44
+
45
+ # ========================================
46
+ # Integer types
47
+
48
+ # ----------------------------------------
49
+
50
+ int8 = ti_python_core.DataType_i8
51
+ """8-bit signed integer data type.
52
+ """
53
+
54
+ # ----------------------------------------
55
+
56
+ i8 = int8
57
+ """Alias for :const:`~gstaichi.types.primitive_types.int8`
58
+ """
59
+
60
+ # ----------------------------------------
61
+
62
+ int16 = ti_python_core.DataType_i16
63
+ """16-bit signed integer data type.
64
+ """
65
+
66
+ # ----------------------------------------
67
+
68
+ i16 = int16
69
+ """Alias for :const:`~gstaichi.types.primitive_types.int16`
70
+ """
71
+
72
+ # ----------------------------------------
73
+
74
+ int32 = ti_python_core.DataType_i32
75
+ """32-bit signed integer data type.
76
+ """
77
+
78
+ # ----------------------------------------
79
+
80
+ i32 = int32
81
+ """Alias for :const:`~gstaichi.types.primitive_types.int32`
82
+ """
83
+
84
+ # ----------------------------------------
85
+
86
+ int64 = ti_python_core.DataType_i64
87
+ """64-bit signed integer data type.
88
+ """
89
+
90
+ # ----------------------------------------
91
+
92
+ i64 = int64
93
+ """Alias for :const:`~gstaichi.types.primitive_types.int64`
94
+ """
95
+
96
+ # ----------------------------------------
97
+
98
+ uint8 = ti_python_core.DataType_u8
99
+ """8-bit unsigned integer data type.
100
+ """
101
+
102
+ # ----------------------------------------
103
+
104
+ uint1 = ti_python_core.DataType_u1
105
+ """1-bit unsigned integer data type. Same as booleans.
106
+ """
107
+
108
+ # ----------------------------------------
109
+
110
+ u1 = uint1
111
+ """Alias for :const:`~gstaichi.types.primitive_types.uint1`
112
+ """
113
+
114
+ # ----------------------------------------
115
+
116
+ u8 = uint8
117
+ """Alias for :const:`~gstaichi.types.primitive_types.uint8`
118
+ """
119
+
120
+ # ----------------------------------------
121
+
122
+ uint16 = ti_python_core.DataType_u16
123
+ """16-bit unsigned integer data type.
124
+ """
125
+
126
+ # ----------------------------------------
127
+
128
+ u16 = uint16
129
+ """Alias for :const:`~gstaichi.types.primitive_types.uint16`
130
+ """
131
+
132
+ # ----------------------------------------
133
+
134
+ uint32 = ti_python_core.DataType_u32
135
+ """32-bit unsigned integer data type.
136
+ """
137
+
138
+ # ----------------------------------------
139
+
140
+ u32 = uint32
141
+ """Alias for :const:`~gstaichi.types.primitive_types.uint32`
142
+ """
143
+
144
+ # ----------------------------------------
145
+
146
+ uint64 = ti_python_core.DataType_u64
147
+ """64-bit unsigned integer data type.
148
+ """
149
+
150
+ # ----------------------------------------
151
+
152
+ u64 = uint64
153
+ """Alias for :const:`~gstaichi.types.primitive_types.uint64`
154
+ """
155
+
156
+ # ----------------------------------------
157
+
158
+
159
+ class RefType:
160
+ def __init__(self, tp):
161
+ self.tp = tp
162
+
163
+
164
+ def ref(tp):
165
+ return RefType(tp)
166
+
167
+
168
+ real_types = [f16, f32, f64, float]
169
+ real_type_ids = [id(t) for t in real_types]
170
+
171
+ integer_types = [i8, i16, i32, i64, u1, u8, u16, u32, u64, int, bool]
172
+ integer_type_ids = [id(t) for t in integer_types]
173
+
174
+ all_types = real_types + integer_types
175
+ type_ids = [id(t) for t in all_types]
176
+
177
+ _python_primitive_types = Union[int, float, bool, str, None]
178
+
179
+ __all__ = [
180
+ "float32",
181
+ "f32",
182
+ "float64",
183
+ "f64",
184
+ "float16",
185
+ "f16",
186
+ "int8",
187
+ "i8",
188
+ "int16",
189
+ "i16",
190
+ "int32",
191
+ "i32",
192
+ "int64",
193
+ "i64",
194
+ "uint1",
195
+ "u1",
196
+ "uint8",
197
+ "u8",
198
+ "uint16",
199
+ "u16",
200
+ "uint32",
201
+ "u32",
202
+ "uint64",
203
+ "u64",
204
+ "ref",
205
+ "_python_primitive_types",
206
+ ]
@@ -0,0 +1,88 @@
1
+ # type: ignore
2
+
3
+ """
4
+ This module defines generators of quantized types.
5
+ For more details, read https://yuanming.gstaichi.graphics/publication/2021-quangstaichi/quangstaichi.pdf.
6
+ """
7
+
8
+ from gstaichi._lib.utils import ti_python_core as _ti_python_core
9
+ from gstaichi.types.primitive_types import i32
10
+
11
+ _type_factory = _ti_python_core.get_type_factory_instance()
12
+
13
+
14
+ def int(bits, signed=True, compute=None): # pylint: disable=W0622
15
+ """Generates a quantized type for integers.
16
+
17
+ Args:
18
+ bits (int): Number of bits.
19
+ signed (bool): Signed or unsigned.
20
+ compute (DataType): Type for computation.
21
+
22
+ Returns:
23
+ DataType: The specified type.
24
+ """
25
+ if compute is None:
26
+ from gstaichi.lang import impl # pylint: disable=C0415
27
+
28
+ compute = impl.get_runtime().default_ip if signed else impl.get_runtime().default_up
29
+ if isinstance(compute, _ti_python_core.DataTypeCxx):
30
+ compute = compute.get_ptr()
31
+ return _type_factory.get_quant_int_type(bits, signed, compute)
32
+
33
+
34
+ def fixed(bits, signed=True, max_value=1.0, compute=None, scale=None):
35
+ """Generates a quantized type for fixed-point real numbers.
36
+
37
+ Args:
38
+ bits (int): Number of bits.
39
+ signed (bool): Signed or unsigned.
40
+ max_value (float): Maximum value of the number.
41
+ compute (DataType): Type for computation.
42
+ scale (float): Scaling factor. The argument is prioritized over range.
43
+
44
+ Returns:
45
+ DataType: The specified type.
46
+ """
47
+ if compute is None:
48
+ from gstaichi.lang import impl # pylint: disable=C0415
49
+
50
+ compute = impl.get_runtime().default_fp
51
+ if isinstance(compute, _ti_python_core.DataTypeCxx):
52
+ compute = compute.get_ptr()
53
+ # TODO: handle cases with bits > 32
54
+ underlying_type = int(bits=bits, signed=signed, compute=i32)
55
+ if scale is None:
56
+ if signed:
57
+ scale = max_value / 2 ** (bits - 1)
58
+ else:
59
+ scale = max_value / 2**bits
60
+ return _type_factory.get_quant_fixed_type(underlying_type, compute, scale)
61
+
62
+
63
+ def float(exp, frac, signed=True, compute=None): # pylint: disable=W0622
64
+ """Generates a quantized type for floating-point real numbers.
65
+
66
+ Args:
67
+ exp (int): Number of exponent bits.
68
+ frac (int): Number of fraction bits.
69
+ signed (bool): Signed or unsigned.
70
+ compute (DataType): Type for computation.
71
+
72
+ Returns:
73
+ DataType: The specified type.
74
+ """
75
+ if compute is None:
76
+ from gstaichi.lang import impl # pylint: disable=C0415
77
+
78
+ compute = impl.get_runtime().default_fp
79
+ if isinstance(compute, _ti_python_core.DataTypeCxx):
80
+ compute = compute.get_ptr()
81
+ # Exponent is always unsigned
82
+ exp_type = int(bits=exp, signed=False, compute=i32)
83
+ # TODO: handle cases with frac > 32
84
+ frac_type = int(bits=frac, signed=signed, compute=i32)
85
+ return _type_factory.get_quant_float_type(frac_type, exp_type, compute)
86
+
87
+
88
+ __all__ = ["int", "fixed", "float"]
@@ -0,0 +1,85 @@
1
+ # type: ignore
2
+
3
+ from gstaichi.lang.exception import GsTaichiCompilationError
4
+ from gstaichi.types.enums import Format
5
+ from gstaichi.types.primitive_types import f16, f32, i8, i16, i32, u8, u16, u32
6
+
7
+ FORMAT2TY_CH = {
8
+ Format.r8: (u8, 1),
9
+ Format.r8u: (u8, 1),
10
+ Format.r8i: (i8, 1),
11
+ Format.rg8: (u8, 2),
12
+ Format.rg8u: (u8, 2),
13
+ Format.rg8i: (i8, 2),
14
+ Format.rgba8: (u8, 4),
15
+ Format.rgba8u: (u8, 4),
16
+ Format.rgba8i: (i8, 4),
17
+ Format.r16: (u16, 1),
18
+ Format.r16u: (u16, 1),
19
+ Format.r16i: (i16, 1),
20
+ Format.r16f: (f16, 1),
21
+ Format.rg16: (u16, 2),
22
+ Format.rg16u: (u16, 2),
23
+ Format.rg16i: (i16, 2),
24
+ Format.rg16f: (f16, 2),
25
+ Format.rgb16: (u16, 3),
26
+ Format.rgb16u: (u16, 3),
27
+ Format.rgb16i: (i16, 3),
28
+ Format.rgb16f: (f16, 3),
29
+ Format.rgba16: (u16, 4),
30
+ Format.rgba16u: (u16, 4),
31
+ Format.rgba16i: (i16, 4),
32
+ Format.rgba16f: (f16, 4),
33
+ Format.r32u: (u32, 1),
34
+ Format.r32i: (i32, 1),
35
+ Format.r32f: (f32, 1),
36
+ Format.rg32u: (u32, 2),
37
+ Format.rg32i: (i32, 2),
38
+ Format.rg32f: (f32, 2),
39
+ Format.rgb32u: (u32, 3),
40
+ Format.rgb32i: (i32, 3),
41
+ Format.rgb32f: (f32, 3),
42
+ Format.rgba32u: (u32, 4),
43
+ Format.rgba32i: (i32, 4),
44
+ Format.rgba32f: (f32, 4),
45
+ }
46
+
47
+ # Reverse lookup by (channel_format, num_channels)
48
+ TY_CH2FORMAT = {v: k for k, v in FORMAT2TY_CH.items()}
49
+
50
+
51
+ class TextureType:
52
+ """Type annotation for Textures.
53
+
54
+ Args:
55
+ num_dimensions (int): Number of dimensions. For examples for a 2D texture this should be `2`.
56
+ """
57
+
58
+ def __init__(self, num_dimensions):
59
+ self.num_dimensions = num_dimensions
60
+
61
+
62
+ class RWTextureType:
63
+ """Type annotation for RW Textures (image load store).
64
+
65
+ Args:
66
+ num_dimensions (int): Number of dimensions. For examples for a 2D texture this should be `2`.
67
+ lod (float): Specifies the explicit level-of-detail.
68
+ fmt (ti.Format): Color format of texture
69
+ """
70
+
71
+ def __init__(self, num_dimensions, lod=0, fmt=None):
72
+ self.num_dimensions = num_dimensions
73
+ if fmt is None:
74
+ raise GsTaichiCompilationError("fmt is required for rw_texture type")
75
+ else:
76
+ self.fmt = fmt
77
+ self.lod = lod
78
+
79
+
80
+ texture = TextureType
81
+ rw_texture = RWTextureType
82
+ """Alias for :class:`~gstaichi.types.ndarray_type.TextureType`.
83
+ """
84
+
85
+ __all__ = ["texture", "rw_texture"]
@@ -0,0 +1,11 @@
1
+ from gstaichi._lib import core as ti_python_core
2
+
3
+ is_signed = ti_python_core.is_signed
4
+
5
+ is_integral = ti_python_core.is_integral
6
+
7
+ is_real = ti_python_core.is_real
8
+
9
+ is_tensor = ti_python_core.is_tensor
10
+
11
+ __all__ = ["is_signed", "is_integral", "is_real", "is_tensor"]