gstaichi 0.1.25.dev0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. gstaichi/CHANGELOG.md +9 -0
  2. gstaichi/__init__.py +40 -0
  3. gstaichi/__main__.py +5 -0
  4. gstaichi/_funcs.py +706 -0
  5. gstaichi/_kernels.py +420 -0
  6. gstaichi/_lib/__init__.py +3 -0
  7. gstaichi/_lib/core/__init__.py +0 -0
  8. gstaichi/_lib/core/gstaichi_python.cp311-win_amd64.pyd +0 -0
  9. gstaichi/_lib/core/gstaichi_python.pyi +2937 -0
  10. gstaichi/_lib/core/py.typed +0 -0
  11. gstaichi/_lib/runtime/runtime_cuda.bc +0 -0
  12. gstaichi/_lib/runtime/runtime_x64.bc +0 -0
  13. gstaichi/_lib/runtime/slim_libdevice.10.bc +0 -0
  14. gstaichi/_lib/utils.py +249 -0
  15. gstaichi/_logging.py +131 -0
  16. gstaichi/_main.py +545 -0
  17. gstaichi/_snode/__init__.py +5 -0
  18. gstaichi/_snode/fields_builder.py +187 -0
  19. gstaichi/_snode/snode_tree.py +34 -0
  20. gstaichi/_test_tools/__init__.py +0 -0
  21. gstaichi/_test_tools/load_kernel_string.py +30 -0
  22. gstaichi/_version.py +1 -0
  23. gstaichi/_version_check.py +103 -0
  24. gstaichi/ad/__init__.py +3 -0
  25. gstaichi/ad/_ad.py +530 -0
  26. gstaichi/algorithms/__init__.py +3 -0
  27. gstaichi/algorithms/_algorithms.py +117 -0
  28. gstaichi/assets/.git +1 -0
  29. gstaichi/assets/Go-Regular.ttf +0 -0
  30. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  31. gstaichi/examples/minimal.py +28 -0
  32. gstaichi/experimental.py +16 -0
  33. gstaichi/lang/__init__.py +50 -0
  34. gstaichi/lang/_ndarray.py +352 -0
  35. gstaichi/lang/_ndrange.py +152 -0
  36. gstaichi/lang/_template_mapper.py +199 -0
  37. gstaichi/lang/_texture.py +172 -0
  38. gstaichi/lang/_wrap_inspect.py +189 -0
  39. gstaichi/lang/any_array.py +99 -0
  40. gstaichi/lang/argpack.py +411 -0
  41. gstaichi/lang/ast/__init__.py +5 -0
  42. gstaichi/lang/ast/ast_transformer.py +1318 -0
  43. gstaichi/lang/ast/ast_transformer_utils.py +341 -0
  44. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  45. gstaichi/lang/ast/ast_transformers/call_transformer.py +267 -0
  46. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +320 -0
  47. gstaichi/lang/ast/checkers.py +106 -0
  48. gstaichi/lang/ast/symbol_resolver.py +57 -0
  49. gstaichi/lang/ast/transform.py +9 -0
  50. gstaichi/lang/common_ops.py +310 -0
  51. gstaichi/lang/exception.py +80 -0
  52. gstaichi/lang/expr.py +180 -0
  53. gstaichi/lang/field.py +466 -0
  54. gstaichi/lang/impl.py +1241 -0
  55. gstaichi/lang/kernel_arguments.py +157 -0
  56. gstaichi/lang/kernel_impl.py +1382 -0
  57. gstaichi/lang/matrix.py +1881 -0
  58. gstaichi/lang/matrix_ops.py +341 -0
  59. gstaichi/lang/matrix_ops_utils.py +190 -0
  60. gstaichi/lang/mesh.py +687 -0
  61. gstaichi/lang/misc.py +778 -0
  62. gstaichi/lang/ops.py +1494 -0
  63. gstaichi/lang/runtime_ops.py +13 -0
  64. gstaichi/lang/shell.py +35 -0
  65. gstaichi/lang/simt/__init__.py +5 -0
  66. gstaichi/lang/simt/block.py +94 -0
  67. gstaichi/lang/simt/grid.py +7 -0
  68. gstaichi/lang/simt/subgroup.py +191 -0
  69. gstaichi/lang/simt/warp.py +96 -0
  70. gstaichi/lang/snode.py +489 -0
  71. gstaichi/lang/source_builder.py +150 -0
  72. gstaichi/lang/struct.py +855 -0
  73. gstaichi/lang/util.py +381 -0
  74. gstaichi/linalg/__init__.py +8 -0
  75. gstaichi/linalg/matrixfree_cg.py +310 -0
  76. gstaichi/linalg/sparse_cg.py +59 -0
  77. gstaichi/linalg/sparse_matrix.py +303 -0
  78. gstaichi/linalg/sparse_solver.py +123 -0
  79. gstaichi/math/__init__.py +11 -0
  80. gstaichi/math/_complex.py +205 -0
  81. gstaichi/math/mathimpl.py +886 -0
  82. gstaichi/profiler/__init__.py +6 -0
  83. gstaichi/profiler/kernel_metrics.py +260 -0
  84. gstaichi/profiler/kernel_profiler.py +586 -0
  85. gstaichi/profiler/memory_profiler.py +15 -0
  86. gstaichi/profiler/scoped_profiler.py +36 -0
  87. gstaichi/sparse/__init__.py +3 -0
  88. gstaichi/sparse/_sparse_grid.py +77 -0
  89. gstaichi/tools/__init__.py +12 -0
  90. gstaichi/tools/diagnose.py +117 -0
  91. gstaichi/tools/np2ply.py +364 -0
  92. gstaichi/tools/vtk.py +38 -0
  93. gstaichi/types/__init__.py +19 -0
  94. gstaichi/types/annotations.py +47 -0
  95. gstaichi/types/compound_types.py +90 -0
  96. gstaichi/types/enums.py +49 -0
  97. gstaichi/types/ndarray_type.py +147 -0
  98. gstaichi/types/primitive_types.py +206 -0
  99. gstaichi/types/quant.py +88 -0
  100. gstaichi/types/texture_type.py +85 -0
  101. gstaichi/types/utils.py +13 -0
  102. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +5 -0
  103. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +29 -0
  104. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +113 -0
  105. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +5 -0
  106. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +19 -0
  107. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +122 -0
  108. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +5 -0
  109. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +19 -0
  110. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +122 -0
  111. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +5 -0
  112. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +19 -0
  113. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +122 -0
  114. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +5 -0
  115. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +19 -0
  116. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +122 -0
  117. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +5 -0
  118. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  119. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +122 -0
  120. gstaichi-0.1.25.dev0.data/data/bin/SPIRV-Tools-shared.dll +0 -0
  121. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/instrument.hpp +268 -0
  122. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/libspirv.h +907 -0
  123. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/libspirv.hpp +375 -0
  124. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/linker.hpp +97 -0
  125. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/optimizer.hpp +970 -0
  126. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-diff.lib +0 -0
  127. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-link.lib +0 -0
  128. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-lint.lib +0 -0
  129. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-opt.lib +0 -0
  130. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-reduce.lib +0 -0
  131. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-shared.lib +0 -0
  132. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools.lib +0 -0
  133. gstaichi-0.1.25.dev0.dist-info/METADATA +105 -0
  134. gstaichi-0.1.25.dev0.dist-info/RECORD +138 -0
  135. gstaichi-0.1.25.dev0.dist-info/WHEEL +5 -0
  136. gstaichi-0.1.25.dev0.dist-info/entry_points.txt +2 -0
  137. gstaichi-0.1.25.dev0.dist-info/licenses/LICENSE +201 -0
  138. gstaichi-0.1.25.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,199 @@
1
+ import dataclasses
2
+ import weakref
3
+ from typing import Any, Union
4
+
5
+ import gstaichi.lang
6
+ import gstaichi.lang._ndarray
7
+ import gstaichi.lang._texture
8
+ import gstaichi.lang.expr
9
+ import gstaichi.lang.snode
10
+ from gstaichi._lib import core as _ti_core
11
+ from gstaichi.lang.any_array import AnyArray
12
+ from gstaichi.lang.argpack import ArgPack, ArgPackType
13
+ from gstaichi.lang.exception import (
14
+ GsTaichiRuntimeTypeError,
15
+ )
16
+ from gstaichi.lang.kernel_arguments import KernelArgument
17
+ from gstaichi.lang.matrix import MatrixType
18
+ from gstaichi.lang.util import to_gstaichi_type
19
+ from gstaichi.types import (
20
+ ndarray_type,
21
+ sparse_matrix_builder,
22
+ template,
23
+ texture_type,
24
+ )
25
+
26
+ AnnotationType = Union[
27
+ template,
28
+ ArgPackType,
29
+ "texture_type.TextureType",
30
+ "texture_type.RWTextureType",
31
+ ndarray_type.NdarrayType,
32
+ sparse_matrix_builder,
33
+ Any,
34
+ ]
35
+
36
+
37
+ class GsTaichiCallableTemplateMapper:
38
+ """
39
+ This should probably be renamed to sometihng like FeatureMapper, or
40
+ FeatureExtractor, since:
41
+ - it's not specific to templates
42
+ - it extracts what are later called 'features', for example for ndarray this includes:
43
+ - element type
44
+ - number dimensions
45
+ - needs grad (or not)
46
+ - these are returned as a heterogeneous tuple, whose contents depends on the type
47
+ """
48
+
49
+ def __init__(self, arguments: list[KernelArgument], template_slot_locations: list[int]) -> None:
50
+ self.arguments: list[KernelArgument] = arguments
51
+ self.num_args: int = len(arguments)
52
+ self.template_slot_locations: list[int] = template_slot_locations
53
+ self.mapping: dict[tuple[Any, ...], int] = {}
54
+
55
+ @staticmethod
56
+ def extract_arg(arg, annotation: AnnotationType, arg_name: str) -> Any:
57
+ if annotation == template or isinstance(annotation, template):
58
+ if isinstance(arg, gstaichi.lang.snode.SNode):
59
+ return arg.ptr
60
+ if isinstance(arg, gstaichi.lang.expr.Expr):
61
+ return arg.ptr.get_underlying_ptr_address()
62
+ if isinstance(arg, _ti_core.ExprCxx):
63
+ return arg.get_underlying_ptr_address()
64
+ if isinstance(arg, tuple):
65
+ return tuple(GsTaichiCallableTemplateMapper.extract_arg(item, annotation, arg_name) for item in arg)
66
+ if isinstance(arg, gstaichi.lang._ndarray.Ndarray):
67
+ raise GsTaichiRuntimeTypeError(
68
+ "Ndarray shouldn't be passed in via `ti.template()`, please annotate your kernel using `ti.types.ndarray(...)` instead"
69
+ )
70
+
71
+ if isinstance(arg, (list, tuple, dict, set)) or hasattr(arg, "_data_oriented"):
72
+ # [Composite arguments] Return weak reference to the object
73
+ # GsTaichi kernel will cache the extracted arguments, thus we can't simply return the original argument.
74
+ # Instead, a weak reference to the original value is returned to avoid memory leak.
75
+
76
+ # TODO(zhanlue): replacing "tuple(args)" with "hash of argument values"
77
+ # This can resolve the following issues:
78
+ # 1. Invalid weak-ref will leave a dead(dangling) entry in both caches: "self.mapping" and "self.compiled_functions"
79
+ # 2. Different argument instances with same type and same value, will get templatized into seperate kernels.
80
+ return weakref.ref(arg)
81
+
82
+ # [Primitive arguments] Return the value
83
+ return arg
84
+ if isinstance(annotation, ArgPackType):
85
+ if not isinstance(arg, ArgPack):
86
+ raise GsTaichiRuntimeTypeError(f"Argument {arg_name} must be a argument pack, got {type(arg)}")
87
+ return tuple(
88
+ GsTaichiCallableTemplateMapper.extract_arg(arg[name], dtype, arg_name)
89
+ for index, (name, dtype) in enumerate(annotation.members.items())
90
+ )
91
+ if dataclasses.is_dataclass(annotation):
92
+ _res_l = []
93
+ for field in dataclasses.fields(annotation):
94
+ field_value = getattr(arg, field.name)
95
+ arg_name = f"__ti_{arg_name}_{field.name}"
96
+ field_extracted = GsTaichiCallableTemplateMapper.extract_arg(field_value, field.type, arg_name)
97
+ _res_l.append(field_extracted)
98
+ return tuple(_res_l)
99
+ if isinstance(annotation, texture_type.TextureType):
100
+ if not isinstance(arg, gstaichi.lang._texture.Texture):
101
+ raise GsTaichiRuntimeTypeError(f"Argument {arg_name} must be a texture, got {type(arg)}")
102
+ if arg.num_dims != annotation.num_dimensions:
103
+ raise GsTaichiRuntimeTypeError(
104
+ f"TextureType dimension mismatch for argument {arg_name}: expected {annotation.num_dimensions}, got {arg.num_dims}"
105
+ )
106
+ return (arg.num_dims,)
107
+ if isinstance(annotation, texture_type.RWTextureType):
108
+ if not isinstance(arg, gstaichi.lang._texture.Texture):
109
+ raise GsTaichiRuntimeTypeError(f"Argument {arg_name} must be a texture, got {type(arg)}")
110
+ if arg.num_dims != annotation.num_dimensions:
111
+ raise GsTaichiRuntimeTypeError(
112
+ f"RWTextureType dimension mismatch for argument {arg_name}: expected {annotation.num_dimensions}, got {arg.num_dims}"
113
+ )
114
+ if arg.fmt != annotation.fmt:
115
+ raise GsTaichiRuntimeTypeError(
116
+ f"RWTextureType format mismatch for argument {arg_name}: expected {annotation.fmt}, got {arg.fmt}"
117
+ )
118
+ # (penguinliong) '0' is the assumed LOD level. We currently don't
119
+ # support mip-mapping.
120
+ return arg.num_dims, arg.fmt, 0
121
+ if isinstance(annotation, ndarray_type.NdarrayType):
122
+ if isinstance(arg, gstaichi.lang._ndarray.Ndarray):
123
+ annotation.check_matched(arg.get_type(), arg_name)
124
+ needs_grad = (arg.grad is not None) if annotation.needs_grad is None else annotation.needs_grad
125
+ assert arg.shape is not None
126
+ return arg.element_type, len(arg.shape), needs_grad, annotation.boundary
127
+ if isinstance(arg, AnyArray):
128
+ ty = arg.get_type()
129
+ annotation.check_matched(arg.get_type(), arg_name)
130
+ return ty.element_type, len(arg.shape), ty.needs_grad, annotation.boundary
131
+ # external arrays
132
+ shape = getattr(arg, "shape", None)
133
+ if shape is None:
134
+ raise GsTaichiRuntimeTypeError(f"Invalid type for argument {arg_name}, got {arg}")
135
+ shape = tuple(shape)
136
+ element_shape: tuple[int, ...] = ()
137
+ dtype = to_gstaichi_type(arg.dtype)
138
+ if isinstance(annotation.dtype, MatrixType):
139
+ if annotation.ndim is not None:
140
+ if len(shape) != annotation.dtype.ndim + annotation.ndim:
141
+ raise ValueError(
142
+ f"Invalid value for argument {arg_name} - required array has ndim={annotation.ndim} element_dim={annotation.dtype.ndim}, "
143
+ f"array with {len(shape)} dimensions is provided"
144
+ )
145
+ else:
146
+ if len(shape) < annotation.dtype.ndim:
147
+ raise ValueError(
148
+ f"Invalid value for argument {arg_name} - required element_dim={annotation.dtype.ndim}, "
149
+ f"array with {len(shape)} dimensions is provided"
150
+ )
151
+ element_shape = shape[-annotation.dtype.ndim :]
152
+ anno_element_shape = annotation.dtype.get_shape()
153
+ if None not in anno_element_shape and element_shape != anno_element_shape:
154
+ raise ValueError(
155
+ f"Invalid value for argument {arg_name} - required element_shape={anno_element_shape}, "
156
+ f"array with element shape of {element_shape} is provided"
157
+ )
158
+ elif annotation.dtype is not None:
159
+ # User specified scalar dtype
160
+ if annotation.dtype != dtype:
161
+ raise ValueError(
162
+ f"Invalid value for argument {arg_name} - required array has dtype={annotation.dtype.to_string()}, "
163
+ f"array with dtype={dtype.to_string()} is provided"
164
+ )
165
+
166
+ if annotation.ndim is not None and len(shape) != annotation.ndim:
167
+ raise ValueError(
168
+ f"Invalid value for argument {arg_name} - required array has ndim={annotation.ndim}, "
169
+ f"array with {len(shape)} dimensions is provided"
170
+ )
171
+ needs_grad = (
172
+ getattr(arg, "requires_grad", False) if annotation.needs_grad is None else annotation.needs_grad
173
+ )
174
+ element_type = (
175
+ _ti_core.get_type_factory_instance().get_tensor_type(element_shape, dtype)
176
+ if len(element_shape) != 0
177
+ else arg.dtype
178
+ )
179
+ return element_type, len(shape) - len(element_shape), needs_grad, annotation.boundary
180
+ if isinstance(annotation, sparse_matrix_builder):
181
+ return arg.dtype
182
+ # Use '#' as a placeholder because other kinds of arguments are not involved in template instantiation
183
+ return "#"
184
+
185
+ def extract(self, args: tuple[Any, ...]) -> tuple[Any, ...]:
186
+ extracted: list[Any] = []
187
+ for arg, kernel_arg in zip(args, self.arguments):
188
+ extracted.append(self.extract_arg(arg, kernel_arg.annotation, kernel_arg.name))
189
+ return tuple(extracted)
190
+
191
+ def lookup(self, args: tuple[Any, ...]) -> tuple[int, tuple[Any, ...]]:
192
+ if len(args) != self.num_args:
193
+ raise TypeError(f"{self.num_args} argument(s) needed but {len(args)} provided.")
194
+
195
+ key = self.extract(args)
196
+ if key not in self.mapping:
197
+ count = len(self.mapping)
198
+ self.mapping[key] = count
199
+ return self.mapping[key], key
@@ -0,0 +1,172 @@
1
+ # type: ignore
2
+
3
+ import numpy as np
4
+
5
+ from gstaichi._lib import core as _ti_core
6
+ from gstaichi.lang import impl
7
+ from gstaichi.lang.expr import Expr, make_expr_group
8
+ from gstaichi.lang.matrix import Matrix
9
+ from gstaichi.lang.util import gstaichi_scope
10
+ from gstaichi.types import vector
11
+ from gstaichi.types.primitive_types import f32
12
+
13
+
14
+ def _get_entries(mat):
15
+ if isinstance(mat, Matrix):
16
+ return mat.entries
17
+ return [mat]
18
+
19
+
20
+ class TextureSampler:
21
+ def __init__(self, ptr_expr, num_dims) -> None:
22
+ self.ptr_expr = ptr_expr
23
+ self.num_dims = num_dims
24
+
25
+ @gstaichi_scope
26
+ def sample_lod(self, uv, lod):
27
+ ast_builder = impl.get_runtime().compiling_callable.ast_builder()
28
+ dbg_info = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
29
+ args_group = make_expr_group(*_get_entries(uv), lod)
30
+ v = ast_builder.make_texture_op_expr(_ti_core.TextureOpType.kSampleLod, self.ptr_expr, args_group, dbg_info)
31
+ r = impl.call_internal("composite_extract_0", v, with_runtime_context=False)
32
+ g = impl.call_internal("composite_extract_1", v, with_runtime_context=False)
33
+ b = impl.call_internal("composite_extract_2", v, with_runtime_context=False)
34
+ a = impl.call_internal("composite_extract_3", v, with_runtime_context=False)
35
+ return vector(4, f32)([r, g, b, a])
36
+
37
+ @gstaichi_scope
38
+ def fetch(self, index, lod):
39
+ ast_builder = impl.get_runtime().compiling_callable.ast_builder()
40
+ dbg_info = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
41
+ args_group = make_expr_group(*_get_entries(index), lod)
42
+ v = ast_builder.make_texture_op_expr(_ti_core.TextureOpType.kFetchTexel, self.ptr_expr, args_group, dbg_info)
43
+ r = impl.call_internal("composite_extract_0", v, with_runtime_context=False)
44
+ g = impl.call_internal("composite_extract_1", v, with_runtime_context=False)
45
+ b = impl.call_internal("composite_extract_2", v, with_runtime_context=False)
46
+ a = impl.call_internal("composite_extract_3", v, with_runtime_context=False)
47
+ return vector(4, f32)([r, g, b, a])
48
+
49
+
50
+ class RWTextureAccessor:
51
+ def __init__(self, ptr_expr, num_dims) -> None:
52
+ # gstaichi_python.TexturePtrExpression.
53
+ self.ptr_expr = ptr_expr
54
+ self.num_dims = num_dims
55
+
56
+ @gstaichi_scope
57
+ def load(self, index):
58
+ ast_builder = impl.get_runtime().compiling_callable.ast_builder()
59
+ dbg_info = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
60
+ args_group = make_expr_group(*_get_entries(index))
61
+ v = ast_builder.make_texture_op_expr(_ti_core.TextureOpType.kLoad, self.ptr_expr, args_group, dbg_info)
62
+ r = impl.call_internal("composite_extract_0", v, with_runtime_context=False)
63
+ g = impl.call_internal("composite_extract_1", v, with_runtime_context=False)
64
+ b = impl.call_internal("composite_extract_2", v, with_runtime_context=False)
65
+ a = impl.call_internal("composite_extract_3", v, with_runtime_context=False)
66
+ return vector(4, f32)([r, g, b, a])
67
+
68
+ @gstaichi_scope
69
+ def store(self, index, value):
70
+ ast_builder = impl.get_runtime().compiling_callable.ast_builder()
71
+ dbg_info = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
72
+ args_group = make_expr_group(*_get_entries(index), *_get_entries(value))
73
+ impl.expr_init(
74
+ ast_builder.make_texture_op_expr(_ti_core.TextureOpType.kStore, self.ptr_expr, args_group, dbg_info)
75
+ )
76
+
77
+ @property
78
+ @gstaichi_scope
79
+ def shape(self):
80
+ """A list containing sizes for each dimension. Note that element shape will be excluded.
81
+
82
+ Returns:
83
+ List[Int]: The result list.
84
+ """
85
+ dim = _ti_core.get_external_tensor_dim(self.ptr_expr)
86
+ dbg_info = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
87
+ ret = [Expr(_ti_core.get_external_tensor_shape_along_axis(self.ptr_expr, i, dbg_info)) for i in range(dim)]
88
+ return ret
89
+
90
+ @gstaichi_scope
91
+ def _loop_range(self):
92
+ """Gets the corresponding gstaichi_python.Expr to serve as loop range.
93
+
94
+ Returns:
95
+ gstaichi_python.Expr: See above.
96
+ """
97
+ return self.ptr_expr
98
+
99
+
100
+ class Texture:
101
+ """GsTaichi Texture class.
102
+
103
+ Args:
104
+ fmt (ti.Format): Color format of the texture.
105
+ shape (Tuple[int]): Shape of the Texture.
106
+ """
107
+
108
+ def __init__(self, fmt, arr_shape):
109
+ self.tex = impl.get_runtime().prog.create_texture(fmt, arr_shape)
110
+ self.fmt = fmt
111
+ self.num_dims = len(arr_shape)
112
+ self.shape = arr_shape
113
+
114
+ def from_ndarray(self, ndarray):
115
+ """Loads an ndarray to texture.
116
+
117
+ Args:
118
+ ndarray (ti.Ndarray): Source ndarray to load from.
119
+ """
120
+ self.tex.from_ndarray(ndarray.arr)
121
+
122
+ def from_field(self, field):
123
+ """Loads a field to texture.
124
+
125
+ Args:
126
+ field (ti.Field): Source field to load from.
127
+ """
128
+ self.tex.from_snode(field.snode.ptr)
129
+
130
+ def _device_allocation_ptr(self):
131
+ return self.tex.device_allocation_ptr()
132
+
133
+ def from_image(self, image):
134
+ """Loads a PIL image to texture. This method is only allowed a 2D texture with `ti.Format.rgba8`.
135
+
136
+ Args:
137
+ image (PIL.Image.Image): Source PIL image to load from.
138
+
139
+ """
140
+ from PIL import Image # pylint: disable=import-outside-toplevel
141
+
142
+ assert isinstance(image, Image.Image)
143
+ if image.mode != "RGB":
144
+ image = image.convert("RGB")
145
+ assert image.size == tuple(self.shape)
146
+
147
+ assert self.num_dims == 2
148
+ # Don't use transpose method since its enums are too new
149
+ image = image.rotate(90, expand=True)
150
+ arr = np.asarray(image)
151
+ from gstaichi._kernels import ( # pylint: disable=import-outside-toplevel
152
+ load_texture_from_numpy,
153
+ )
154
+
155
+ load_texture_from_numpy(self, arr)
156
+
157
+ def to_image(self):
158
+ """Saves a texture to a PIL image in RGB mode. This method is only allowed a 2D texture with `ti.Format.rgba8`.
159
+
160
+ Returns:
161
+ img (PIL.Image.Image): a PIL image in RGB mode, with the same size as source texture.
162
+ """
163
+ assert self.num_dims == 2
164
+ from PIL import Image # pylint: disable=import-outside-toplevel
165
+
166
+ res = np.zeros(self.shape + (3,), np.uint8)
167
+ from gstaichi._kernels import ( # pylint: disable=import-outside-toplevel
168
+ save_texture_to_numpy,
169
+ )
170
+
171
+ save_texture_to_numpy(self, res)
172
+ return Image.fromarray(res).rotate(270, expand=True)
@@ -0,0 +1,189 @@
1
+ # type: ignore
2
+
3
+ # GsTaichi's custom inspect module.
4
+ # This module is used by GsTaichi's ast transformer to parse the source code.
5
+ # Currently this module is aimed for working in the following modes:
6
+ # 1. Usual Python/IPython mode, e.g. python script.py
7
+ # In this case we mainly rely on the built-in `inspect` module, except
8
+ # we need some hacks when we are in IPython mode and there is a cell magic.
9
+ # 2. Blender's scripting mode, e.g. Users write GsTaichi code in the scripting
10
+ # window in Blender and press the run button. In this case we need to
11
+ # retrieve the source using Blender's `bpy.data.texts` and write it to a temp
12
+ # file so that the inspect module can parse.
13
+ # 3. The interactive shell mode, e.g. Users directly type their code in the
14
+ # interactive shell. In this case we use `dill` to get the source.
15
+ #
16
+ # NB: Running GsTaichi in other modes are likely not supported.
17
+
18
+ import atexit
19
+ import inspect
20
+ import os
21
+ import tempfile
22
+
23
+ import dill
24
+
25
+ _builtin_getfile = inspect.getfile
26
+ _builtin_findsource = inspect.findsource
27
+
28
+
29
+ def _find_source_with_custom_getfile_func(func, obj):
30
+ """Use a custom function `func` to replace inspect's `getfile`, return the
31
+ source found by the new routine and restore the original `getfile` back.
32
+ """
33
+ inspect.getfile = func # replace with our custom func
34
+ source = inspect.findsource(obj)
35
+ inspect.getfile = _builtin_getfile # restore
36
+ return source
37
+
38
+
39
+ def _blender_get_text_name(filename: str):
40
+ """Extract filename from path in the Blender mode."""
41
+ # In Blender's scripting mode, unsaved files are named
42
+ # like `/Text`, `/Text.001`, `/test.py`, etc.
43
+ # We simply remove this path seperator.
44
+ if filename.startswith(os.path.sep) and filename.count(os.path.sep) == 1:
45
+ return filename[1:] # "/Text.001" --> "Text.001"
46
+
47
+ # Saved text files are named like `some-path/xxx.blend/Text` or
48
+ # `some-path/xxx.blend/test.py`
49
+ # We drop the path and extract the filename with extension.
50
+ index = filename.rfind(".blend" + os.path.sep)
51
+ if index != -1:
52
+ return filename[index + 7 :] # "xxx.blend/test.py" --> "test.py"
53
+
54
+ return None
55
+
56
+
57
+ def _blender_findsource(obj):
58
+ try:
59
+ import bpy # pylint: disable=import-outside-toplevel
60
+ except:
61
+ raise ImportError("Not in Blender environment!")
62
+
63
+ # Inspect's built-in `getfile` returns the filename like
64
+ # `/Text`, `/Text.001`, `some-path/xxx.blend/test.py`
65
+ # This filename may not be a full valid path.
66
+ filename = _builtin_getfile(obj)
67
+ # Extract the text name without path
68
+ text_name = _blender_get_text_name(filename)
69
+ if text_name is None:
70
+ raise IOError("Object `{obj.__name__}` is not defined in a .blend file!")
71
+ # Get the lines of code via text_name
72
+ lines = bpy.data.texts[text_name].as_string()
73
+ # Now we have found the lines of code.
74
+ # We first check if they are already cached, to avoid file io in each query.
75
+ try:
76
+ filename = _blender_findsource._saved_inspect_cache[lines] # pylint: disable=no-member
77
+ except KeyError:
78
+ # Save the code to a valid path.
79
+ fd, filename = tempfile.mkstemp(prefix="_Blender_", suffix=f"_{text_name}.py")
80
+ os.close(fd)
81
+
82
+ with open(filename, "w") as f:
83
+ f.write(lines)
84
+
85
+ _blender_findsource._saved_inspect_cache[lines] = filename # pylint: disable=no-member
86
+ atexit.register(os.unlink, filename) # Remove file when program exits
87
+
88
+ # Our custom getfile function
89
+ def wrapped_getfile(ob):
90
+ if id(ob) == id(obj):
91
+ return filename
92
+
93
+ return _builtin_getfile(ob)
94
+
95
+ return _find_source_with_custom_getfile_func(wrapped_getfile, obj)
96
+
97
+
98
+ _blender_findsource._saved_inspect_cache = {}
99
+
100
+
101
+ def _Python_IPython_findsource(obj):
102
+ try:
103
+ # In Python and IPython the builtin inspect would suffice in most cases
104
+ return _builtin_findsource(obj)
105
+ except IOError:
106
+ # Except that the cell has a magic command like %%time or %%timeit
107
+ # In this case the filename returned by the built-in's getfile is wrong,
108
+ # it becomes something like `<timed exec>` or `<magic-timeit>`.
109
+ filename = _builtin_getfile(obj)
110
+ if filename in {"<timed exec>", "<magic-timeit>"}:
111
+ try:
112
+ ip = get_ipython()
113
+ if ip is not None:
114
+ # So we are in IPython's cell magic
115
+ session_id = ip.history_manager.get_last_session_id()
116
+ fd, filename = tempfile.mkstemp(prefix="_IPython_", suffix=f"_{session_id}.py")
117
+ os.close(fd)
118
+ # The latest lines of code can be retrived from here
119
+ lines = ip.history_manager._i00
120
+
121
+ # `lines` is a string that also contains the cell magic
122
+ # command, we need to remove the magic command
123
+ # (and spaces/sep around it) to obtain a valid Python code
124
+ # snippet before saving it to a file
125
+ index = lines.find("%time")
126
+ lines_stripped = lines[index:]
127
+ lines_stripped = lines_stripped.split(maxsplit=1)[1]
128
+
129
+ with open(filename, "w") as f:
130
+ f.write(lines_stripped)
131
+
132
+ atexit.register(os.unlink, filename) # Remove the file after the program exits
133
+ func = lambda obj: filename
134
+ return _find_source_with_custom_getfile_func(func, obj)
135
+
136
+ except ImportError:
137
+ pass
138
+ raise IOError(
139
+ f"Cannot find source code for Object: {obj}, it's likely \
140
+ you are not running GsTaichi from command line or IPython."
141
+ )
142
+
143
+
144
+ def _REPL_findsource(obj):
145
+ """Findsource in the interactive shell mode."""
146
+ return dill.source.findsource(obj)
147
+
148
+
149
+ def _custom_findsource(obj):
150
+ try:
151
+ return _Python_IPython_findsource(obj)
152
+ except IOError:
153
+ try:
154
+ return _REPL_findsource(obj)
155
+ except:
156
+ try:
157
+ return _blender_findsource(obj)
158
+ except:
159
+ raise IOError(
160
+ f"Cannot find source code for Object: {obj}, this \
161
+ is possibly because of you are running GsTaichi in an environment that GsTaichi's own \
162
+ inspect module cannot find the source. Please report an issue to help us fix: \
163
+ https://github.com/taichi-dev/gstaichi/issues"
164
+ )
165
+
166
+
167
+ class _InspectContextManager:
168
+ def __enter__(self):
169
+ inspect.findsource = _custom_findsource
170
+ return self
171
+
172
+ def __exit__(self, *_):
173
+ inspect.findsource = _builtin_findsource
174
+
175
+
176
+ def getsourcelines(obj):
177
+ with _InspectContextManager():
178
+ return inspect.getsourcelines(obj)
179
+
180
+
181
+ def getsourcefile(obj):
182
+ with _InspectContextManager():
183
+ ret = inspect.getsourcefile(obj)
184
+ if ret is None:
185
+ ret = inspect.getfile(obj)
186
+ return ret
187
+
188
+
189
+ __all__ = ["getsourcelines", "getsourcefile"]
@@ -0,0 +1,99 @@
1
+ # type: ignore
2
+
3
+ from gstaichi._lib import core as _ti_core
4
+ from gstaichi.lang import impl
5
+ from gstaichi.lang.expr import Expr, make_expr_group
6
+ from gstaichi.lang.util import gstaichi_scope
7
+ from gstaichi.types.enums import Layout
8
+ from gstaichi.types.ndarray_type import NdarrayTypeMetadata
9
+
10
+
11
+ class AnyArray:
12
+ """Class for arbitrary arrays in Python AST.
13
+
14
+ Args:
15
+ ptr (gstaichi_python.Expr): A gstaichi_python.Expr wrapping a gstaichi_python.ExternalTensorExpression.
16
+ element_shape (Tuple[Int]): () if scalar elements (default), (n) if vector elements, and (n, m) if matrix elements.
17
+ layout (Layout): Memory layout.
18
+ """
19
+
20
+ def __init__(self, ptr):
21
+ assert ptr.is_external_tensor_expr()
22
+ self.ptr = ptr
23
+ self.ptr.type_check(impl.get_runtime().prog.config())
24
+
25
+ def element_shape(self):
26
+ return _ti_core.get_external_tensor_element_shape(self.ptr)
27
+
28
+ def layout(self):
29
+ # 0: scalar; 1: vector (SOA); 2: matrix (SOA); -1: vector
30
+ # (AOS); -2: matrix (AOS)
31
+ element_dim = _ti_core.get_external_tensor_element_dim(self.ptr)
32
+ if element_dim == 1 or element_dim == 2:
33
+ return Layout.SOA
34
+ return Layout.AOS
35
+
36
+ def get_type(self):
37
+ return NdarrayTypeMetadata(
38
+ _ti_core.get_external_tensor_element_type(self.ptr), None, _ti_core.get_external_tensor_needs_grad(self.ptr)
39
+ ) # AnyArray can take any shape
40
+
41
+ @property
42
+ @gstaichi_scope
43
+ def grad(self):
44
+ """Returns the gradient of this array."""
45
+ return AnyArray(_ti_core.make_external_tensor_grad_expr(self.ptr))
46
+
47
+ @property
48
+ @gstaichi_scope
49
+ def shape(self):
50
+ """A list containing sizes for each dimension. Note that element shape will be excluded.
51
+
52
+ Returns:
53
+ List[Int]: The result list.
54
+ """
55
+ dim = _ti_core.get_external_tensor_dim(self.ptr)
56
+ dbg_info = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
57
+ return [Expr(_ti_core.get_external_tensor_shape_along_axis(self.ptr, i, dbg_info)) for i in range(dim)]
58
+
59
+ @gstaichi_scope
60
+ def _loop_range(self):
61
+ """Gets the corresponding gstaichi_python.Expr to serve as loop range.
62
+
63
+ Returns:
64
+ gstaichi_python.Expr: See above.
65
+ """
66
+ return self.ptr
67
+
68
+
69
+ class AnyArrayAccess:
70
+ """Class for first-level access to AnyArray with Vector/Matrix elements in Python AST.
71
+
72
+ Args:
73
+ arr (AnyArray): See above.
74
+ indices_first (Tuple[Int]): Indices of first-level access.
75
+ """
76
+
77
+ def __init__(self, arr, indices_first):
78
+ self.arr = arr
79
+ self.indices_first = indices_first
80
+
81
+ @gstaichi_scope
82
+ def subscript(self, i, j):
83
+ ast_builder = impl.get_runtime().compiling_callable.ast_builder()
84
+
85
+ indices_second = (i,) if len(self.arr.element_shape()) == 1 else (i, j)
86
+ if self.arr.layout() == Layout.SOA:
87
+ indices = indices_second + self.indices_first
88
+ else:
89
+ indices = self.indices_first + indices_second
90
+ return Expr(
91
+ ast_builder.expr_subscript(
92
+ self.arr.ptr,
93
+ make_expr_group(*indices),
94
+ _ti_core.DebugInfo(impl.get_runtime().get_current_src_info()),
95
+ )
96
+ )
97
+
98
+
99
+ __all__ = []