gstaichi 1.0.1__cp313-cp313-win_amd64.whl → 2.1.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. gstaichi/CHANGELOG.md +1 -3
  2. gstaichi/_lib/core/gstaichi_python.cp313-win_amd64.pyd +0 -0
  3. gstaichi/_lib/core/gstaichi_python.pyi +13 -41
  4. gstaichi/_lib/runtime/runtime_cuda.bc +0 -0
  5. gstaichi/_lib/runtime/runtime_x64.bc +0 -0
  6. gstaichi/_lib/utils.py +1 -7
  7. gstaichi/_test_tools/__init__.py +18 -0
  8. gstaichi/_test_tools/dataclass_test_tools.py +36 -0
  9. gstaichi/_test_tools/textwrap2.py +6 -0
  10. gstaichi/_version.py +1 -1
  11. gstaichi/examples/lcg_python.py +26 -0
  12. gstaichi/examples/lcg_taichi.py +34 -0
  13. gstaichi/examples/minimal.py +1 -1
  14. gstaichi/lang/__init__.py +1 -1
  15. gstaichi/lang/_dataclass_util.py +31 -0
  16. gstaichi/lang/_fast_caching/__init__.py +3 -0
  17. gstaichi/lang/_fast_caching/args_hasher.py +110 -0
  18. gstaichi/lang/_fast_caching/config_hasher.py +30 -0
  19. gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
  20. gstaichi/lang/_fast_caching/function_hasher.py +57 -0
  21. gstaichi/lang/_fast_caching/hash_utils.py +11 -0
  22. gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
  23. gstaichi/lang/_fast_caching/src_hasher.py +75 -0
  24. gstaichi/lang/_kernel_impl_dataclass.py +212 -0
  25. gstaichi/lang/_template_mapper.py +16 -20
  26. gstaichi/lang/_wrap_inspect.py +27 -1
  27. gstaichi/lang/ast/ast_transformer.py +7 -2
  28. gstaichi/lang/ast/ast_transformer_utils.py +18 -13
  29. gstaichi/lang/ast/ast_transformers/call_transformer.py +73 -16
  30. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +102 -118
  31. gstaichi/lang/field.py +0 -38
  32. gstaichi/lang/impl.py +25 -24
  33. gstaichi/lang/kernel_arguments.py +28 -30
  34. gstaichi/lang/kernel_impl.py +154 -200
  35. gstaichi/lang/matrix.py +0 -46
  36. gstaichi/lang/struct.py +0 -45
  37. gstaichi/lang/util.py +11 -80
  38. gstaichi/types/annotations.py +10 -5
  39. gstaichi/types/compound_types.py +1 -20
  40. gstaichi/types/ndarray_type.py +33 -11
  41. gstaichi/types/utils.py +0 -2
  42. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/bin/SPIRV-Tools-shared.dll +0 -0
  43. gstaichi-2.1.0.data/data/include/GLFW/glfw3.h +6389 -0
  44. gstaichi-2.1.0.data/data/include/GLFW/glfw3native.h +594 -0
  45. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/SPIRV-Tools-diff.lib +0 -0
  46. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/SPIRV-Tools-link.lib +0 -0
  47. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/SPIRV-Tools-lint.lib +0 -0
  48. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/SPIRV-Tools-opt.lib +0 -0
  49. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/SPIRV-Tools-reduce.lib +0 -0
  50. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/SPIRV-Tools-shared.lib +0 -0
  51. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/SPIRV-Tools.lib +0 -0
  52. gstaichi-2.1.0.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  53. gstaichi-2.1.0.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  54. gstaichi-2.1.0.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  55. gstaichi-2.1.0.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  56. gstaichi-2.1.0.data/data/lib/glfw3.lib +0 -0
  57. {gstaichi-1.0.1.dist-info → gstaichi-2.1.0.dist-info}/METADATA +4 -3
  58. {gstaichi-1.0.1.dist-info → gstaichi-2.1.0.dist-info}/RECORD +84 -64
  59. gstaichi/lang/argpack.py +0 -411
  60. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +0 -0
  61. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +0 -0
  62. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +0 -0
  63. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +0 -0
  64. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +0 -0
  65. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +0 -0
  66. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +0 -0
  67. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +0 -0
  68. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +0 -0
  69. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +0 -0
  70. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +0 -0
  71. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +0 -0
  72. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +0 -0
  73. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +0 -0
  74. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +0 -0
  75. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +0 -0
  76. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +0 -0
  77. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +0 -0
  78. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv-tools/instrument.hpp +0 -0
  79. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv-tools/libspirv.h +0 -0
  80. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv-tools/libspirv.hpp +0 -0
  81. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv-tools/linker.hpp +0 -0
  82. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv-tools/optimizer.hpp +0 -0
  83. {gstaichi-1.0.1.dist-info → gstaichi-2.1.0.dist-info}/WHEEL +0 -0
  84. {gstaichi-1.0.1.dist-info → gstaichi-2.1.0.dist-info}/licenses/LICENSE +0 -0
  85. {gstaichi-1.0.1.dist-info → gstaichi-2.1.0.dist-info}/top_level.txt +0 -0
gstaichi/lang/impl.py CHANGED
@@ -8,6 +8,7 @@ from gstaichi._lib import core as _ti_core
8
8
  from gstaichi._lib.core.gstaichi_python import (
9
9
  DataTypeCxx,
10
10
  Function,
11
+ KernelCxx,
11
12
  Program,
12
13
  )
13
14
  from gstaichi._snode.fields_builder import FieldsBuilder
@@ -70,17 +71,14 @@ from gstaichi.types.primitive_types import (
70
71
 
71
72
  @gstaichi_scope
72
73
  def expr_init_shared_array(shape, element_type):
73
- compiling_callable = get_runtime().compiling_callable
74
- assert compiling_callable is not None
75
- return compiling_callable.ast_builder().expr_alloca_shared_array(
76
- shape, element_type, _ti_core.DebugInfo(get_runtime().get_current_src_info())
77
- )
74
+ ast_builder = get_runtime().compiling_callable.ast_builder()
75
+ debug_info = _ti_core.DebugInfo(get_runtime().get_current_src_info())
76
+ return ast_builder.expr_alloca_shared_array(shape, element_type, debug_info)
78
77
 
79
78
 
80
79
  @gstaichi_scope
81
80
  def expr_init(rhs):
82
81
  compiling_callable = get_runtime().compiling_callable
83
- assert compiling_callable is not None
84
82
  if rhs is None:
85
83
  return Expr(
86
84
  compiling_callable.ast_builder().expr_alloca(_ti_core.DebugInfo(get_runtime().get_current_src_info()))
@@ -167,7 +165,7 @@ def _calc_slice(index, default_stop):
167
165
  "GsTaichi does not support variables in slice now, please use constant instead of it."
168
166
  )
169
167
 
170
- check_validity(start), check_validity(stop), check_validity(step)
168
+ _ = check_validity(start), check_validity(stop), check_validity(step)
171
169
  return [_ for _ in range(start, stop, step)]
172
170
 
173
171
 
@@ -194,9 +192,7 @@ def validate_subscript_index(value, index):
194
192
  @gstaichi_scope
195
193
  def subscript(ast_builder, value, *_indices, skip_reordered=False):
196
194
  dbg_info = _ti_core.DebugInfo(get_runtime().get_current_src_info())
197
- compiling_callable = get_runtime().compiling_callable
198
- assert compiling_callable is not None
199
- ast_builder = compiling_callable.ast_builder()
195
+ ast_builder = get_runtime().compiling_callable.ast_builder()
200
196
  # Directly evaluate in Python for non-GsTaichi types
201
197
  if not isinstance(
202
198
  value,
@@ -337,8 +333,8 @@ class PyGsTaichi:
337
333
  self._prog: Program | None = None
338
334
  self.src_info_stack = []
339
335
  self.inside_kernel: bool = False
340
- self.compiling_callable: Kernel | Function | None = None # pointer to instance of lang::Kernel/Function
341
- self._current_kernel: Kernel | None = None
336
+ self._compiling_callable: KernelCxx | Kernel | Function | None = None
337
+ self._current_kernel: "Kernel | None" = None
342
338
  self.global_vars = []
343
339
  self.grad_vars = []
344
340
  self.dual_vars = []
@@ -350,10 +346,18 @@ class PyGsTaichi:
350
346
  self.target_tape = None
351
347
  self.fwd_mode_manager = None
352
348
  self.grad_replaced = False
353
- self.kernels = kernels or []
349
+ self.kernels: list[Kernel] = kernels or []
354
350
  self._signal_handler_registry = None
355
351
  self.unfinalized_fields_builder = {}
356
352
 
353
+ @property
354
+ def compiling_callable(self) -> KernelCxx | Kernel | Function:
355
+ if self._compiling_callable is None:
356
+ raise GsTaichiRuntimeError(
357
+ "_compiling_callable attribute not initialized. Maybe you forgot to call `ti.init()` first?"
358
+ )
359
+ return self._compiling_callable
360
+
357
361
  @property
358
362
  def prog(self) -> Program:
359
363
  if self._prog is None:
@@ -364,7 +368,7 @@ class PyGsTaichi:
364
368
  def current_kernel(self) -> Kernel:
365
369
  if self._current_kernel is None:
366
370
  raise GsTaichiRuntimeError(
367
- "_pr_current_kernelog attribute not initialized. Maybe you forgot to call `ti.init()` first?"
371
+ "_current_kernel attribute not initialized. Maybe you forgot to call `ti.init()` first?"
368
372
  )
369
373
  return self._current_kernel
370
374
 
@@ -373,7 +377,7 @@ class PyGsTaichi:
373
377
 
374
378
  def clear_compiled_functions(self):
375
379
  for k in self.kernels:
376
- k.compiled_kernels.clear()
380
+ k.materialized_kernels.clear()
377
381
 
378
382
  def finalize_fields_builder(self, builder):
379
383
  self.unfinalized_fields_builder.pop(builder)
@@ -390,7 +394,7 @@ class PyGsTaichi:
390
394
  def get_num_compiled_functions(self):
391
395
  count = 0
392
396
  for k in self.kernels:
393
- count += len(k.compiled_kernels)
397
+ count += len(k.materialized_kernels)
394
398
  return count
395
399
 
396
400
  def src_info_guard(self, info):
@@ -962,11 +966,9 @@ def ti_print(*_vars, sep=" ", end="\n"):
962
966
 
963
967
  _vars = add_separators(_vars)
964
968
  contents, formats = ti_format_list_to_content_entries(_vars)
965
- compiling_callable = get_runtime().compiling_callable
966
- assert compiling_callable is not None
967
- compiling_callable.ast_builder().create_print(
968
- contents, formats, _ti_core.DebugInfo(get_runtime().get_current_src_info())
969
- )
969
+ ast_builder = get_runtime().compiling_callable.ast_builder()
970
+ debug_info = _ti_core.DebugInfo(get_runtime().get_current_src_info())
971
+ ast_builder.create_print(contents, formats, debug_info)
970
972
 
971
973
 
972
974
  @gstaichi_scope
@@ -996,9 +998,8 @@ def ti_format(*args):
996
998
  def ti_assert(cond, msg, extra_args, dbg_info):
997
999
  # Mostly a wrapper to help us convert from Expr (defined in Python) to
998
1000
  # _ti_core.Expr (defined in C++)
999
- compiling_callable = get_runtime().compiling_callable
1000
- assert compiling_callable is not None
1001
- compiling_callable.ast_builder().create_assert_stmt(Expr(cond).ptr, msg, extra_args, dbg_info)
1001
+ ast_builder = get_runtime().compiling_callable.ast_builder()
1002
+ ast_builder.create_assert_stmt(Expr(cond).ptr, msg, extra_args, dbg_info)
1002
1003
 
1003
1004
 
1004
1005
  @gstaichi_scope
@@ -4,6 +4,10 @@ import inspect
4
4
 
5
5
  import gstaichi.lang
6
6
  from gstaichi._lib import core as _ti_core
7
+ from gstaichi._lib.core.gstaichi_python import (
8
+ BoundaryMode,
9
+ DataTypeCxx,
10
+ )
7
11
  from gstaichi.lang import impl, ops
8
12
  from gstaichi.lang._texture import RWTextureAccessor, TextureSampler
9
13
  from gstaichi.lang.any_array import AnyArray
@@ -15,11 +19,18 @@ from gstaichi.types.compound_types import CompoundType
15
19
  from gstaichi.types.primitive_types import RefType, u64
16
20
 
17
21
 
18
- class KernelArgument:
19
- def __init__(self, _annotation, _name, _default=inspect.Parameter.empty):
20
- self.annotation = _annotation
21
- self.name = _name
22
- self.default = _default
22
+ class ArgMetadata:
23
+ """
24
+ Metadata about an argument to a function
25
+ """
26
+
27
+ def __init__(self, annotation, name, default=inspect.Parameter.empty):
28
+ self.annotation = annotation
29
+ self.name = name
30
+ self.default = default
31
+
32
+ def __repr__(self) -> str:
33
+ return f"{self.__class__.__name__}(annotation={self.annotation}, name={self.name}, default={self.default})"
23
34
 
24
35
 
25
36
  class SparseMatrixEntry:
@@ -48,7 +59,7 @@ class SparseMatrixProxy:
48
59
  return SparseMatrixEntry(self.ptr, i, j, self.dtype)
49
60
 
50
61
 
51
- def decl_scalar_arg(dtype, name, arg_depth):
62
+ def decl_scalar_arg(dtype, name):
52
63
  is_ref = False
53
64
  if isinstance(dtype, RefType):
54
65
  is_ref = True
@@ -60,9 +71,7 @@ def decl_scalar_arg(dtype, name, arg_depth):
60
71
  arg_id = impl.get_runtime().compiling_callable.insert_scalar_param(dtype, name)
61
72
 
62
73
  argload_di = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
63
- return Expr(
64
- _ti_core.make_arg_load_expr(arg_id, dtype, is_ref, create_load=True, arg_depth=arg_depth, dbg_info=argload_di)
65
- )
74
+ return Expr(_ti_core.make_arg_load_expr(arg_id, dtype, is_ref, create_load=True, dbg_info=argload_di))
66
75
 
67
76
 
68
77
  def get_type_for_kernel_args(dtype, name):
@@ -86,35 +95,22 @@ def get_type_for_kernel_args(dtype, name):
86
95
  return dtype
87
96
 
88
97
 
89
- def decl_matrix_arg(matrixtype, name, arg_depth):
98
+ def decl_matrix_arg(matrixtype, name):
90
99
  arg_type = get_type_for_kernel_args(matrixtype, name)
91
100
  arg_id = impl.get_runtime().compiling_callable.insert_scalar_param(arg_type, name)
92
101
  argload_di = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
93
- arg_load = Expr(
94
- _ti_core.make_arg_load_expr(arg_id, arg_type, create_load=False, arg_depth=arg_depth, dbg_info=argload_di)
95
- )
102
+ arg_load = Expr(_ti_core.make_arg_load_expr(arg_id, arg_type, create_load=False, dbg_info=argload_di))
96
103
  return matrixtype.from_gstaichi_object(arg_load)
97
104
 
98
105
 
99
- def decl_struct_arg(structtype, name, arg_depth):
106
+ def decl_struct_arg(structtype, name):
100
107
  arg_type = get_type_for_kernel_args(structtype, name)
101
108
  arg_id = impl.get_runtime().compiling_callable.insert_scalar_param(arg_type, name)
102
109
  argload_di = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
103
- arg_load = Expr(
104
- _ti_core.make_arg_load_expr(arg_id, arg_type, create_load=False, arg_depth=arg_depth, dbg_info=argload_di)
105
- )
110
+ arg_load = Expr(_ti_core.make_arg_load_expr(arg_id, arg_type, create_load=False, dbg_info=argload_di))
106
111
  return structtype.from_gstaichi_object(arg_load)
107
112
 
108
113
 
109
- def push_argpack_arg(name):
110
- impl.get_runtime().compiling_callable.insert_argpack_param_and_push(name)
111
-
112
-
113
- def decl_argpack_arg(argpacktype, member_dict):
114
- impl.get_runtime().compiling_callable.pop_argpack_stack()
115
- return argpacktype.from_gstaichi_object(member_dict)
116
-
117
-
118
114
  def decl_sparse_matrix(dtype, name):
119
115
  value_type = cook_dtype(dtype)
120
116
  ptr_type = cook_dtype(u64)
@@ -126,16 +122,18 @@ def decl_sparse_matrix(dtype, name):
126
122
  )
127
123
 
128
124
 
129
- def decl_ndarray_arg(element_type, ndim, name, needs_grad, boundary):
125
+ def decl_ndarray_arg(
126
+ element_type: DataTypeCxx, ndim: int, name: str, needs_grad: bool, boundary: BoundaryMode
127
+ ) -> AnyArray:
130
128
  arg_id = impl.get_runtime().compiling_callable.insert_ndarray_param(element_type, ndim, name, needs_grad)
131
- return AnyArray(_ti_core.make_external_tensor_expr(element_type, ndim, arg_id, needs_grad, 0, boundary))
129
+ return AnyArray(_ti_core.make_external_tensor_expr(element_type, ndim, arg_id, needs_grad, boundary))
132
130
 
133
131
 
134
132
  def decl_texture_arg(num_dimensions, name):
135
133
  # FIXME: texture_arg doesn't have element_shape so better separate them
136
134
  arg_id = impl.get_runtime().compiling_callable.insert_texture_param(num_dimensions, name)
137
135
  dbg_info = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
138
- return TextureSampler(_ti_core.make_texture_ptr_expr(arg_id, num_dimensions, 0, dbg_info), num_dimensions)
136
+ return TextureSampler(_ti_core.make_texture_ptr_expr(arg_id, num_dimensions, dbg_info), num_dimensions)
139
137
 
140
138
 
141
139
  def decl_rw_texture_arg(num_dimensions, buffer_format, lod, name):
@@ -143,7 +141,7 @@ def decl_rw_texture_arg(num_dimensions, buffer_format, lod, name):
143
141
  arg_id = impl.get_runtime().compiling_callable.insert_rw_texture_param(num_dimensions, buffer_format, name)
144
142
  dbg_info = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
145
143
  return RWTextureAccessor(
146
- _ti_core.make_rw_texture_ptr_expr(arg_id, num_dimensions, 0, buffer_format, lod, dbg_info), num_dimensions
144
+ _ti_core.make_rw_texture_ptr_expr(arg_id, num_dimensions, buffer_format, lod, dbg_info), num_dimensions
147
145
  )
148
146
 
149
147