ninetoothed 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ninetoothed/jit.py CHANGED
@@ -479,7 +479,7 @@ class CodeGenerator(ast.NodeTransformer):
479
479
  def _complete_indices(self, tensor, indices):
480
480
  indices = list(self._generate_pid_indices(tensor) + tuple(indices))
481
481
 
482
- for size in tensor.inmost().shape:
482
+ for size in tensor.innermost().shape:
483
483
  if Symbol.is_name(size):
484
484
  name = size.node.id
485
485
  if not naming.is_meta(name):
@@ -514,7 +514,10 @@ class CodeGenerator(ast.NodeTransformer):
514
514
 
515
515
  @staticmethod
516
516
  def _generate_slices(tensor, dim):
517
- return tuple(slice(None) if i == dim else None for i in range(tensor.ndim))
517
+ return tuple(
518
+ slice(None) if target_dim == dim else None
519
+ for target_dim in tensor.innermost().target_dims
520
+ )
518
521
 
519
522
  @staticmethod
520
523
  def _generate_offsets(tensor, indices):
ninetoothed/tensor.py CHANGED
@@ -112,14 +112,10 @@ class Tensor:
112
112
  strides=inner_strides,
113
113
  source=self.source,
114
114
  source_dims=self.source_dims,
115
- target=self.target,
116
- target_dims=self.target_dims,
117
115
  ),
118
116
  strides=outer_strides,
119
117
  source=self.source,
120
118
  source_dims=self.source_dims,
121
- target=self.target,
122
- target_dims=self.target_dims,
123
119
  )
124
120
 
125
121
  def expand(self, shape):
@@ -136,23 +132,28 @@ class Tensor:
136
132
  ],
137
133
  source=self.source,
138
134
  source_dims=self.source_dims,
139
- target=self.target,
140
135
  target_dims=self.target_dims,
141
136
  )
142
137
 
143
138
  def squeeze(self, dim):
139
+ if not isinstance(dim, tuple):
140
+ dim = (dim,)
141
+
144
142
  # TODO: Add error handling.
145
143
  return type(self)(
146
- shape=[size for i, size in enumerate(self.shape) if dim != i],
144
+ shape=[size for i, size in enumerate(self.shape) if i not in dim],
147
145
  dtype=self.dtype,
148
- strides=[stride for i, stride in enumerate(self.strides) if dim != i],
146
+ strides=[stride for i, stride in enumerate(self.strides) if i not in dim],
149
147
  source=self.source,
150
148
  source_dims=[
151
- source_dim for i, source_dim in enumerate(self.source_dims) if dim != i
149
+ source_dim
150
+ for i, source_dim in enumerate(self.source_dims)
151
+ if i not in dim
152
152
  ],
153
- target=self.target,
154
153
  target_dims=[
155
- target_dim for i, target_dim in enumerate(self.target_dims) if dim != i
154
+ target_dim
155
+ for i, target_dim in enumerate(self.target_dims)
156
+ if i not in dim
156
157
  ],
157
158
  )
158
159
 
@@ -173,7 +174,6 @@ class Tensor:
173
174
  strides=new_strides,
174
175
  source=self.source,
175
176
  source_dims=new_source_dims,
176
- target=self.target,
177
177
  target_dims=self.target_dims,
178
178
  )
179
179
 
@@ -204,14 +204,21 @@ class Tensor:
204
204
  leading_source_dims + (flattening_source_dims,) + trailing_source_dims
205
205
  )
206
206
 
207
+ leading_target_dims = self.target_dims[:start_dim]
208
+ flattening_target_dims = self.target_dims[start_dim:end_dim]
209
+ trailing_target_dims = self.target_dims[end_dim:]
210
+
211
+ new_target_dims = (
212
+ leading_target_dims + (flattening_target_dims[-1],) + trailing_target_dims
213
+ )
214
+
207
215
  return type(self)(
208
216
  shape=new_shape,
209
217
  dtype=self.dtype,
210
218
  strides=new_strides,
211
219
  source=self.source,
212
220
  source_dims=new_source_dims,
213
- target=self.target,
214
- target_dims=self.target_dims,
221
+ target_dims=new_target_dims,
215
222
  )
216
223
 
217
224
  def ravel(self):
@@ -250,11 +257,11 @@ class Tensor:
250
257
  | (self.source.names() if self.source is not self else set())
251
258
  )
252
259
 
253
- def inmost(self):
260
+ def innermost(self):
254
261
  if not isinstance(self.dtype, type(self)):
255
262
  return self
256
263
 
257
- return self.dtype.inmost()
264
+ return self.dtype.innermost()
258
265
 
259
266
  def pointer_string(self):
260
267
  return f"{self.name}_pointer"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ninetoothed
3
- Version: 0.9.0
3
+ Version: 0.11.0
4
4
  Summary: A domain-specific language based on Triton but providing higher-level abstraction.
5
5
  Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
6
6
  Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
@@ -0,0 +1,11 @@
1
+ ninetoothed/__init__.py,sha256=dX34sk5GA3OgWf1Jc4gJMW3UwcGcJsuG3hs3rkiqq6g,161
2
+ ninetoothed/jit.py,sha256=0LeDBpSYFgPx4hatP_ZsvElsj0d9d552OKRc__L1Jvc,23460
3
+ ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
4
+ ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
5
+ ninetoothed/symbol.py,sha256=rZ5nXtn-U1Nw0BBRJ-kfrwmX_zCbAi76un-Z2QFaoZc,4773
6
+ ninetoothed/tensor.py,sha256=ehgbHrxL1gc9iOI9rGUp5k-nI_daLmCfOdLI9hE-GLw,9756
7
+ ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
8
+ ninetoothed-0.11.0.dist-info/METADATA,sha256=NlC4oj1R7gNoCdA1AxYc_7UWY3uZPv3LeOq9Jo-9Q8w,7055
9
+ ninetoothed-0.11.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
10
+ ninetoothed-0.11.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
11
+ ninetoothed-0.11.0.dist-info/RECORD,,
@@ -1,11 +0,0 @@
1
- ninetoothed/__init__.py,sha256=dX34sk5GA3OgWf1Jc4gJMW3UwcGcJsuG3hs3rkiqq6g,161
2
- ninetoothed/jit.py,sha256=B3q32ksKTRr7I4jLcoDXjwEx7A_Awz9DXGEmIkrtoBc,23393
3
- ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
4
- ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
5
- ninetoothed/symbol.py,sha256=rZ5nXtn-U1Nw0BBRJ-kfrwmX_zCbAi76un-Z2QFaoZc,4773
6
- ninetoothed/tensor.py,sha256=OU6lVjzKU614mk3EN1AAgTDradbvJkyl22AXwdhxcfs,9577
7
- ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
8
- ninetoothed-0.9.0.dist-info/METADATA,sha256=LKtSbgc_mWKJ4L_8BYiRKWjoV-JKjr4hefhrkdPmrHs,7054
9
- ninetoothed-0.9.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
10
- ninetoothed-0.9.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
11
- ninetoothed-0.9.0.dist-info/RECORD,,