ninetoothed 0.9.0__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ninetoothed/jit.py CHANGED
@@ -479,7 +479,7 @@ class CodeGenerator(ast.NodeTransformer):
479
479
  def _complete_indices(self, tensor, indices):
480
480
  indices = list(self._generate_pid_indices(tensor) + tuple(indices))
481
481
 
482
- for size in tensor.inmost().shape:
482
+ for size in tensor.innermost().shape:
483
483
  if Symbol.is_name(size):
484
484
  name = size.node.id
485
485
  if not naming.is_meta(name):
@@ -514,7 +514,10 @@ class CodeGenerator(ast.NodeTransformer):
514
514
 
515
515
  @staticmethod
516
516
  def _generate_slices(tensor, dim):
517
- return tuple(slice(None) if i == dim else None for i in range(tensor.ndim))
517
+ return tuple(
518
+ slice(None) if target_dim == dim else None
519
+ for target_dim in tensor.innermost().target_dims
520
+ )
518
521
 
519
522
  @staticmethod
520
523
  def _generate_offsets(tensor, indices):
ninetoothed/tensor.py CHANGED
@@ -112,14 +112,10 @@ class Tensor:
112
112
  strides=inner_strides,
113
113
  source=self.source,
114
114
  source_dims=self.source_dims,
115
- target=self.target,
116
- target_dims=self.target_dims,
117
115
  ),
118
116
  strides=outer_strides,
119
117
  source=self.source,
120
118
  source_dims=self.source_dims,
121
- target=self.target,
122
- target_dims=self.target_dims,
123
119
  )
124
120
 
125
121
  def expand(self, shape):
@@ -136,23 +132,22 @@ class Tensor:
136
132
  ],
137
133
  source=self.source,
138
134
  source_dims=self.source_dims,
139
- target=self.target,
140
- target_dims=self.target_dims,
141
135
  )
142
136
 
143
137
  def squeeze(self, dim):
138
+ if not isinstance(dim, tuple):
139
+ dim = (dim,)
140
+
144
141
  # TODO: Add error handling.
145
142
  return type(self)(
146
- shape=[size for i, size in enumerate(self.shape) if dim != i],
143
+ shape=[size for i, size in enumerate(self.shape) if i not in dim],
147
144
  dtype=self.dtype,
148
- strides=[stride for i, stride in enumerate(self.strides) if dim != i],
145
+ strides=[stride for i, stride in enumerate(self.strides) if i not in dim],
149
146
  source=self.source,
150
147
  source_dims=[
151
- source_dim for i, source_dim in enumerate(self.source_dims) if dim != i
152
- ],
153
- target=self.target,
154
- target_dims=[
155
- target_dim for i, target_dim in enumerate(self.target_dims) if dim != i
148
+ source_dim
149
+ for i, source_dim in enumerate(self.source_dims)
150
+ if i not in dim
156
151
  ],
157
152
  )
158
153
 
@@ -173,8 +168,6 @@ class Tensor:
173
168
  strides=new_strides,
174
169
  source=self.source,
175
170
  source_dims=new_source_dims,
176
- target=self.target,
177
- target_dims=self.target_dims,
178
171
  )
179
172
 
180
173
  def flatten(self, start_dim=None, end_dim=None):
@@ -210,8 +203,6 @@ class Tensor:
210
203
  strides=new_strides,
211
204
  source=self.source,
212
205
  source_dims=new_source_dims,
213
- target=self.target,
214
- target_dims=self.target_dims,
215
206
  )
216
207
 
217
208
  def ravel(self):
@@ -250,11 +241,11 @@ class Tensor:
250
241
  | (self.source.names() if self.source is not self else set())
251
242
  )
252
243
 
253
- def inmost(self):
244
+ def innermost(self):
254
245
  if not isinstance(self.dtype, type(self)):
255
246
  return self
256
247
 
257
- return self.dtype.inmost()
248
+ return self.dtype.innermost()
258
249
 
259
250
  def pointer_string(self):
260
251
  return f"{self.name}_pointer"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ninetoothed
3
- Version: 0.9.0
3
+ Version: 0.10.0
4
4
  Summary: A domain-specific language based on Triton but providing higher-level abstraction.
5
5
  Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
6
6
  Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
@@ -0,0 +1,11 @@
1
+ ninetoothed/__init__.py,sha256=dX34sk5GA3OgWf1Jc4gJMW3UwcGcJsuG3hs3rkiqq6g,161
2
+ ninetoothed/jit.py,sha256=0LeDBpSYFgPx4hatP_ZsvElsj0d9d552OKRc__L1Jvc,23460
3
+ ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
4
+ ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
5
+ ninetoothed/symbol.py,sha256=rZ5nXtn-U1Nw0BBRJ-kfrwmX_zCbAi76un-Z2QFaoZc,4773
6
+ ninetoothed/tensor.py,sha256=LGS9wYmPKckZjIEXsMHdclTmVhzkfrc3avOPiCQY1tU,9153
7
+ ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
8
+ ninetoothed-0.10.0.dist-info/METADATA,sha256=nQWkQ--AceNN3DoD-kP9_aykZYVw8LOfqf2iO63v1Ek,7055
9
+ ninetoothed-0.10.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
10
+ ninetoothed-0.10.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
11
+ ninetoothed-0.10.0.dist-info/RECORD,,
@@ -1,11 +0,0 @@
1
- ninetoothed/__init__.py,sha256=dX34sk5GA3OgWf1Jc4gJMW3UwcGcJsuG3hs3rkiqq6g,161
2
- ninetoothed/jit.py,sha256=B3q32ksKTRr7I4jLcoDXjwEx7A_Awz9DXGEmIkrtoBc,23393
3
- ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
4
- ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
5
- ninetoothed/symbol.py,sha256=rZ5nXtn-U1Nw0BBRJ-kfrwmX_zCbAi76un-Z2QFaoZc,4773
6
- ninetoothed/tensor.py,sha256=OU6lVjzKU614mk3EN1AAgTDradbvJkyl22AXwdhxcfs,9577
7
- ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
8
- ninetoothed-0.9.0.dist-info/METADATA,sha256=LKtSbgc_mWKJ4L_8BYiRKWjoV-JKjr4hefhrkdPmrHs,7054
9
- ninetoothed-0.9.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
10
- ninetoothed-0.9.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
11
- ninetoothed-0.9.0.dist-info/RECORD,,