gstaichi 1.0.1__cp311-cp311-win_amd64.whl → 2.1.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. gstaichi/CHANGELOG.md +1 -3
  2. gstaichi/_lib/core/gstaichi_python.cp311-win_amd64.pyd +0 -0
  3. gstaichi/_lib/core/gstaichi_python.pyi +13 -41
  4. gstaichi/_lib/runtime/runtime_cuda.bc +0 -0
  5. gstaichi/_lib/runtime/runtime_x64.bc +0 -0
  6. gstaichi/_lib/utils.py +1 -7
  7. gstaichi/_test_tools/__init__.py +18 -0
  8. gstaichi/_test_tools/dataclass_test_tools.py +36 -0
  9. gstaichi/_test_tools/textwrap2.py +6 -0
  10. gstaichi/_version.py +1 -1
  11. gstaichi/examples/lcg_python.py +26 -0
  12. gstaichi/examples/lcg_taichi.py +34 -0
  13. gstaichi/examples/minimal.py +1 -1
  14. gstaichi/lang/__init__.py +1 -1
  15. gstaichi/lang/_dataclass_util.py +31 -0
  16. gstaichi/lang/_fast_caching/__init__.py +3 -0
  17. gstaichi/lang/_fast_caching/args_hasher.py +110 -0
  18. gstaichi/lang/_fast_caching/config_hasher.py +30 -0
  19. gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
  20. gstaichi/lang/_fast_caching/function_hasher.py +57 -0
  21. gstaichi/lang/_fast_caching/hash_utils.py +11 -0
  22. gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
  23. gstaichi/lang/_fast_caching/src_hasher.py +75 -0
  24. gstaichi/lang/_kernel_impl_dataclass.py +212 -0
  25. gstaichi/lang/_template_mapper.py +16 -20
  26. gstaichi/lang/_wrap_inspect.py +27 -1
  27. gstaichi/lang/ast/ast_transformer.py +7 -2
  28. gstaichi/lang/ast/ast_transformer_utils.py +18 -13
  29. gstaichi/lang/ast/ast_transformers/call_transformer.py +73 -16
  30. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +102 -118
  31. gstaichi/lang/field.py +0 -38
  32. gstaichi/lang/impl.py +25 -24
  33. gstaichi/lang/kernel_arguments.py +28 -30
  34. gstaichi/lang/kernel_impl.py +154 -200
  35. gstaichi/lang/matrix.py +0 -46
  36. gstaichi/lang/struct.py +0 -45
  37. gstaichi/lang/util.py +11 -80
  38. gstaichi/types/annotations.py +10 -5
  39. gstaichi/types/compound_types.py +1 -20
  40. gstaichi/types/ndarray_type.py +33 -11
  41. gstaichi/types/utils.py +0 -2
  42. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/bin/SPIRV-Tools-shared.dll +0 -0
  43. gstaichi-2.1.0.data/data/include/GLFW/glfw3.h +6389 -0
  44. gstaichi-2.1.0.data/data/include/GLFW/glfw3native.h +594 -0
  45. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/SPIRV-Tools-diff.lib +0 -0
  46. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/SPIRV-Tools-link.lib +0 -0
  47. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/SPIRV-Tools-lint.lib +0 -0
  48. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/SPIRV-Tools-opt.lib +0 -0
  49. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/SPIRV-Tools-reduce.lib +0 -0
  50. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/SPIRV-Tools-shared.lib +0 -0
  51. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/lib/SPIRV-Tools.lib +0 -0
  52. gstaichi-2.1.0.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  53. gstaichi-2.1.0.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  54. gstaichi-2.1.0.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  55. gstaichi-2.1.0.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  56. gstaichi-2.1.0.data/data/lib/glfw3.lib +0 -0
  57. {gstaichi-1.0.1.dist-info → gstaichi-2.1.0.dist-info}/METADATA +4 -3
  58. {gstaichi-1.0.1.dist-info → gstaichi-2.1.0.dist-info}/RECORD +84 -64
  59. gstaichi/lang/argpack.py +0 -411
  60. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +0 -0
  61. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +0 -0
  62. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +0 -0
  63. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +0 -0
  64. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +0 -0
  65. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +0 -0
  66. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +0 -0
  67. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +0 -0
  68. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +0 -0
  69. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +0 -0
  70. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +0 -0
  71. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +0 -0
  72. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +0 -0
  73. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +0 -0
  74. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +0 -0
  75. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +0 -0
  76. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +0 -0
  77. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +0 -0
  78. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv-tools/instrument.hpp +0 -0
  79. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv-tools/libspirv.h +0 -0
  80. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv-tools/libspirv.hpp +0 -0
  81. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv-tools/linker.hpp +0 -0
  82. {gstaichi-1.0.1.data → gstaichi-2.1.0.data}/data/include/spirv-tools/optimizer.hpp +0 -0
  83. {gstaichi-1.0.1.dist-info → gstaichi-2.1.0.dist-info}/WHEEL +0 -0
  84. {gstaichi-1.0.1.dist-info → gstaichi-2.1.0.dist-info}/licenses/LICENSE +0 -0
  85. {gstaichi-1.0.1.dist-info → gstaichi-2.1.0.dist-info}/top_level.txt +0 -0
gstaichi/lang/matrix.py CHANGED
@@ -27,7 +27,6 @@ from gstaichi.lang.util import (
27
27
  in_python_scope,
28
28
  python_scope,
29
29
  to_numpy_type,
30
- to_paddle_type,
31
30
  to_pytorch_type,
32
31
  warning,
33
32
  )
@@ -1301,33 +1300,6 @@ class MatrixField(Field):
1301
1300
  runtime_ops.sync()
1302
1301
  return arr
1303
1302
 
1304
- def to_paddle(self, place=None, keep_dims=False):
1305
- """Converts the field instance to a Paddle tensor.
1306
-
1307
- Args:
1308
- place (paddle.CPUPlace()/CUDAPlace(n), optional): The desired place of returned tensor.
1309
- keep_dims (bool, optional): Whether to keep the dimension after conversion.
1310
- See :meth:`~gstaichi.lang.field.MatrixField.to_numpy` for more detailed explanation.
1311
-
1312
- Returns:
1313
- paddle.Tensor: The result paddle tensor.
1314
- """
1315
- import paddle # pylint: disable=C0415
1316
-
1317
- as_vector = self.m == 1 and not keep_dims and self.ndim == 1
1318
- shape_ext = (self.n,) if as_vector else (self.n, self.m)
1319
- # pylint: disable=E1101
1320
- # paddle.empty() doesn't support argument `place``
1321
- arr = paddle.to_tensor(
1322
- paddle.empty(self.shape + shape_ext, to_paddle_type(self.dtype)),
1323
- place=place,
1324
- )
1325
- from gstaichi._kernels import matrix_to_ext_arr # pylint: disable=C0415
1326
-
1327
- matrix_to_ext_arr(self, arr, as_vector)
1328
- runtime_ops.sync()
1329
- return arr
1330
-
1331
1303
  @python_scope
1332
1304
  def _from_external_arr(self, arr):
1333
1305
  if len(arr.shape) == len(self.shape) + 1:
@@ -1500,24 +1472,6 @@ class MatrixType(CompoundType):
1500
1472
  for j in range(self.m):
1501
1473
  set_arg_func(ret_index + (i * self.m + j,), mat[i, j])
1502
1474
 
1503
- def set_argpack_struct_args(self, mat, argpack, ret_index=()):
1504
- if self.dtype in primitive_types.integer_types:
1505
- if is_signed(cook_dtype(self.dtype)):
1506
- set_arg_func = argpack.set_arg_int
1507
- else:
1508
- set_arg_func = argpack.set_arg_uint
1509
- elif self.dtype in primitive_types.real_types:
1510
- set_arg_func = argpack.set_arg_float
1511
- else:
1512
- raise GsTaichiRuntimeTypeError(f"Invalid return type on index={ret_index}")
1513
- if self.ndim == 1:
1514
- for i in range(self.n):
1515
- set_arg_func(ret_index + (i,), mat[i])
1516
- else:
1517
- for i in range(self.n):
1518
- for j in range(self.m):
1519
- set_arg_func(ret_index + (i * self.m + j,), mat[i, j])
1520
-
1521
1475
  def _instantiate_in_python_scope(self, entries):
1522
1476
  entries = [[entries[k * self.m + i] for i in range(self.m)] for k in range(self.n)]
1523
1477
  return Matrix(
gstaichi/lang/struct.py CHANGED
@@ -528,17 +528,6 @@ class StructField(Field):
528
528
  for k, v in self._items:
529
529
  v.from_torch(array_dict[k])
530
530
 
531
- @python_scope
532
- def from_paddle(self, array_dict):
533
- """Copies the data from a set of `paddle.Tensor` into this field.
534
-
535
- The argument `array_dict` must be a dictionay-like object, it
536
- contains all the keys in this field and the copying process
537
- between corresponding items can be performed.
538
- """
539
- for k, v in self._items:
540
- v.from_paddle(array_dict[k])
541
-
542
531
  @python_scope
543
532
  def to_numpy(self):
544
533
  """Converts the Struct field instance to a dictionary of NumPy arrays.
@@ -566,22 +555,6 @@ class StructField(Field):
566
555
  """
567
556
  return {k: v.to_torch(device=device) for k, v in self._items}
568
557
 
569
- @python_scope
570
- def to_paddle(self, place=None):
571
- """Converts the Struct field instance to a dictionary of Paddle tensors.
572
-
573
- The dictionary may be nested when converting nested structs.
574
-
575
- Args:
576
- place (paddle.CPUPlace()/CUDAPlace(n), optional): The
577
- desired place of returned tensor.
578
-
579
- Returns:
580
- Dict[str, Union[paddle.Tensor, Dict]]: The result
581
- Paddle tensor.
582
- """
583
- return {k: v.to_paddle(place=place) for k, v in self._items}
584
-
585
558
  @python_scope
586
559
  def __setitem__(self, indices, element):
587
560
  self._initialize_host_accessors()
@@ -742,24 +715,6 @@ class StructType(CompoundType):
742
715
  else:
743
716
  raise GsTaichiRuntimeTypeError(f"Invalid argument type on index={ret_index + (index, )}")
744
717
 
745
- def set_argpack_struct_args(self, struct, argpack, ret_index=()):
746
- # TODO: move this to class Struct after we add dtype to Struct
747
- items = self.members.items()
748
- for index, pair in enumerate(items):
749
- name, dtype = pair
750
- if isinstance(dtype, CompoundType):
751
- dtype.set_kernel_struct_args(struct[name], argpack, ret_index + (index,))
752
- else:
753
- if dtype in primitive_types.integer_types:
754
- if is_signed(cook_dtype(dtype)):
755
- argpack.set_arg_int(ret_index + (index,), struct[name])
756
- else:
757
- argpack.set_arg_uint(ret_index + (index,), struct[name])
758
- elif dtype in primitive_types.real_types:
759
- argpack.set_arg_float(ret_index + (index,), struct[name])
760
- else:
761
- raise GsTaichiRuntimeTypeError(f"Invalid argument type on index={ret_index + (index, )}")
762
-
763
718
  def cast(self, struct):
764
719
  # sanity check members
765
720
  if self.members.keys() != struct._Struct__entries.keys():
gstaichi/lang/util.py CHANGED
@@ -1,9 +1,8 @@
1
- # type: ignore
2
-
3
1
  import functools
4
2
  import os
5
3
  import traceback
6
4
  import warnings
5
+ from typing import Any
7
6
 
8
7
  import numpy as np
9
8
  from colorama import Fore, Style
@@ -11,6 +10,7 @@ from colorama import Fore, Style
11
10
  from gstaichi._lib import core as _ti_core
12
11
  from gstaichi._logging import is_logging_effective
13
12
  from gstaichi.lang import impl
13
+ from gstaichi.types import Template
14
14
  from gstaichi.types.primitive_types import (
15
15
  f16,
16
16
  f32,
@@ -46,24 +46,6 @@ def has_pytorch():
46
46
  return _has_pytorch
47
47
 
48
48
 
49
- def has_paddle():
50
- """Whether has paddle in the current Python environment.
51
-
52
- Returns:
53
- bool: True if has paddle else False.
54
- """
55
- _has_paddle = False
56
- _env_paddle = os.environ.get("TI_ENABLE_PADDLE", "1")
57
- if not _env_paddle or int(_env_paddle):
58
- try:
59
- import paddle # pylint: disable=C0415
60
-
61
- _has_paddle = True
62
- except:
63
- pass
64
- return _has_paddle
65
-
66
-
67
49
  def get_clangpp():
68
50
  from distutils.spawn import find_executable # pylint: disable=C0415
69
51
 
@@ -183,43 +165,8 @@ def to_pytorch_type(dt):
183
165
  raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type.")
184
166
 
185
167
 
186
- def to_paddle_type(dt):
187
- """Convert gstaichi data type to its counterpart in paddle.
188
-
189
- Args:
190
- dt (DataType): The desired data type to convert.
191
-
192
- Returns:
193
- DataType: The counterpart data type in paddle.
194
-
195
- """
196
- import paddle # pylint: disable=C0415
197
-
198
- if dt == f32:
199
- return paddle.float32
200
- if dt == f64:
201
- return paddle.float64
202
- if dt == i32:
203
- return paddle.int32
204
- if dt == i64:
205
- return paddle.int64
206
- if dt == i8:
207
- return paddle.int8
208
- if dt == i16:
209
- return paddle.int16
210
- if dt == u1:
211
- return paddle.bool
212
- if dt == u8:
213
- return paddle.uint8
214
- if dt == f16:
215
- return paddle.float16
216
- if dt in (u16, u32, u64):
217
- raise RuntimeError(f"Paddle doesn't support {dt.to_string()} data type.")
218
- assert False
219
-
220
-
221
168
  def to_gstaichi_type(dt):
222
- """Convert numpy or torch or paddle data type to its counterpart in gstaichi.
169
+ """Convert numpy or torch data type to its counterpart in gstaichi.
223
170
 
224
171
  Args:
225
172
  dt (DataType): The desired data type to convert.
@@ -289,30 +236,6 @@ def to_gstaichi_type(dt):
289
236
 
290
237
  raise RuntimeError(f"PyTorch doesn't support {dt.to_string()} data type before version 2.3.0.")
291
238
 
292
- if has_paddle():
293
- import paddle # pylint: disable=C0415
294
-
295
- if dt == paddle.float32:
296
- return f32
297
- if dt == paddle.float64:
298
- return f64
299
- if dt == paddle.int32:
300
- return i32
301
- if dt == paddle.int64:
302
- return i64
303
- if dt == paddle.int8:
304
- return i8
305
- if dt == paddle.int16:
306
- return i16
307
- if dt == paddle.bool:
308
- return u1
309
- if dt == paddle.uint8:
310
- return u8
311
- if dt == paddle.float16:
312
- return f16
313
- if dt in (u16, u32, u64):
314
- raise RuntimeError(f"Paddle doesn't support {dt.to_string()} data type.")
315
-
316
239
  raise AssertionError(f"Unknown type {dt}")
317
240
 
318
241
 
@@ -378,4 +301,12 @@ def get_traceback(stacklevel=1):
378
301
  return "".join(traceback.format_list(s))
379
302
 
380
303
 
304
+ def is_data_oriented(obj: Any) -> bool:
305
+ return getattr(obj, "_data_oriented", False)
306
+
307
+
308
+ def is_ti_template(annotation: Any) -> bool:
309
+ return annotation == Template or isinstance(annotation, Template)
310
+
311
+
381
312
  __all__ = []
@@ -1,7 +1,9 @@
1
- # type: ignore
1
+ from typing import Any, Generic, TypeVar
2
2
 
3
+ T = TypeVar("T")
3
4
 
4
- class Template:
5
+
6
+ class Template(Generic[T]):
5
7
  """Type annotation for template kernel parameter.
6
8
  Useful for passing parameters to kernels by reference.
7
9
 
@@ -30,9 +32,12 @@ class Template:
30
32
  >>> test_template(a) # will print 2
31
33
  """
32
34
 
33
- def __init__(self, tensor=None, dim=None):
34
- self.tensor = tensor
35
- self.dim = dim
35
+ def __init__(self, element_type: type[T] = object, ndim: int | None = None):
36
+ self.element_type = element_type
37
+ self.ndim = ndim
38
+
39
+ def __getitem__(self, i: Any) -> T:
40
+ raise NotImplemented
36
41
 
37
42
 
38
43
  template = Template
@@ -68,23 +68,4 @@ def struct(**kwargs):
68
68
  return gstaichi.lang.struct.StructType(**kwargs)
69
69
 
70
70
 
71
- def argpack(**kwargs):
72
- """Creates an argument pack type with given members.
73
-
74
- Args:
75
- kwargs (dict): a dictionary contains the names and types of the
76
- argument pack members.
77
-
78
- Returns:
79
- A argument pack type.
80
-
81
- Example::
82
-
83
- >>> vec3 = ti.types.vector(3, ti.f32)
84
- >>> sphere = ti.types.argpack(center=vec3, radius=float)
85
- >>> s = sphere(center=vec3([0., 0., 0.]), radius=1.0)
86
- """
87
- return gstaichi.lang.argpack.ArgPackType(**kwargs)
88
-
89
-
90
- __all__ = ["matrix", "vector", "struct", "argpack"]
71
+ __all__ = ["matrix", "vector", "struct"]
@@ -1,4 +1,4 @@
1
- # type: ignore
1
+ from typing import Any
2
2
 
3
3
  from gstaichi.types.compound_types import CompoundType, matrix, vector
4
4
  from gstaichi.types.enums import Layout, to_boundary_enum
@@ -32,7 +32,8 @@ def _make_matrix_dtype_from_element_shape(element_dim, element_shape, primitive_
32
32
  # Check dim consistency. The matrix dtype will be cooked later.
33
33
  if element_shape is not None and len(element_shape) != element_dim:
34
34
  raise ValueError(
35
- f"Both element_shape and element_dim are specified, but shape doesn't match specified dim: {len(element_shape)}!={element_dim}"
35
+ f"Both element_shape and element_dim are specified, but shape doesn't match specified dim: "
36
+ f"{len(element_shape)}!={element_dim}"
36
37
  )
37
38
  mat_dtype = vector(None, primitive_dtype) if element_dim == 1 else matrix(None, None, primitive_dtype)
38
39
  elif element_shape is not None:
@@ -50,14 +51,24 @@ class NdarrayType:
50
51
  """Type annotation for arbitrary arrays, including external arrays (numpy ndarrays and torch tensors) and GsTaichi ndarrays.
51
52
 
52
53
  For external arrays, we treat it as a GsTaichi data container with Scalar, Vector or Matrix elements.
53
- For GsTaichi vector/matrix ndarrays, we will automatically identify element dimension and their corresponding axis by the dimension of datatype, say scalars, matrices or vectors.
54
- For example, given type annotation `ti.types.ndarray(dtype=ti.math.vec3)`, a numpy array `np.zeros(10, 10, 3)` will be recognized as a 10x10 matrix composed of vec3 elements.
54
+ For GsTaichi vector/matrix ndarrays, we will automatically identify element dimension and their corresponding axis by the
55
+ dimension of datatype, say scalars, matrices or vectors.
56
+ For example, given type annotation `ti.types.ndarray(dtype=ti.math.vec3)`, a numpy array `np.zeros(10, 10, 3)` will be
57
+ recognized as a 10x10 matrix composed of vec3 elements.
55
58
 
56
59
  Args:
57
60
  dtype (Union[PrimitiveType, VectorType, MatrixType, NoneType], optional): None if not speicified.
58
- ndim (Union[Int, NoneType]): None if not specified, number of field dimensions. This argument is ignored for external arrays for now.
59
- element_dim (Union[Int, NoneType], optional): None if not specified (will be treated as 0 for external arrays), 0 if scalar elements, 1 if vector elements, and 2 if matrix elements.
60
- element_shape (Union[Tuple[Int], NoneType]): None if not specified, shapes of each element. For example, element_shape must be 1d for vector and 2d tuple for matrix. This argument is ignored for external arrays for now.
61
+ ndim (Union[Int, NoneType]): None if not specified, number of field dimensions. This argument is ignored for externa
62
+ arrays for now.
63
+ element_dim (Union[Int, NoneType], optional):
64
+ None if not specified (will be treated as 0 for external arrays),
65
+ 0 if scalar elements,
66
+ 1 if vector elements, and
67
+ 2 if matrix elements.
68
+ element_shape (Union[Tuple[Int], NoneType]):
69
+ None if not specified, shapes of each element.
70
+ For example, element_shape must be 1d for vector and 2d tuple for matrix.
71
+ This argument is ignored for external arrays for now.
61
72
  """
62
73
 
63
74
  def __init__(
@@ -93,9 +104,10 @@ class NdarrayType:
93
104
 
94
105
  # Check dtype match
95
106
  if isinstance(self.dtype, CompoundType):
96
- if not self.dtype.check_matched(ndarray_type.element_type):
107
+ if not self.dtype.check_matched(ndarray_type.element_type): # type: ignore
97
108
  raise ValueError(
98
- f"Invalid value for argument {arg_name} - required element type: {self.dtype.to_string()}, but {ndarray_type.element_type.to_string()} is provided"
109
+ f"Invalid value for argument {arg_name} - required element type: {self.dtype.to_string()}, " # type: ignore
110
+ f"but {ndarray_type.element_type.to_string()} is provided"
99
111
  )
100
112
  else:
101
113
  if self.dtype is not None:
@@ -110,14 +122,16 @@ class NdarrayType:
110
122
  # Check ndim match
111
123
  if self.ndim is not None and ndarray_type.shape is not None and self.ndim != len(ndarray_type.shape):
112
124
  raise ValueError(
113
- f"Invalid value for argument {arg_name} - required ndim={self.ndim}, but {len(ndarray_type.shape)}d ndarray with shape {ndarray_type.shape} is provided"
125
+ f"Invalid value for argument {arg_name} - required ndim={self.ndim}, but {len(ndarray_type.shape)}d "
126
+ f"ndarray with shape {ndarray_type.shape} is provided"
114
127
  )
115
128
 
116
129
  # Check needs_grad
117
130
  if self.needs_grad is not None and self.needs_grad > ndarray_type.needs_grad:
118
131
  # It's okay to pass a needs_grad=True ndarray at runtime to a need_grad=False arg but not vice versa.
119
132
  raise ValueError(
120
- f"Invalid value for argument {arg_name} - required needs_grad={self.needs_grad}, but {ndarray_type.needs_grad} is provided"
133
+ f"Invalid value for argument {arg_name} - required needs_grad={self.needs_grad}, but "
134
+ f"{ndarray_type.needs_grad} is provided"
121
135
  )
122
136
 
123
137
  def __repr__(self):
@@ -126,6 +140,14 @@ class NdarrayType:
126
140
  def __str__(self):
127
141
  return self.__repr__()
128
142
 
143
+ def __getitem__(self, i: Any) -> Any:
144
+ # needed for pyright
145
+ raise NotImplemented
146
+
147
+ def __setitem__(self, i: Any, v: Any) -> None:
148
+ # needed for pyright
149
+ raise NotImplemented
150
+
129
151
 
130
152
  ndarray = NdarrayType
131
153
  NDArray = NdarrayType
gstaichi/types/utils.py CHANGED
@@ -1,5 +1,3 @@
1
- # type: ignore
2
-
3
1
  from gstaichi._lib import core as ti_python_core
4
2
 
5
3
  is_signed = ti_python_core.is_signed