ninetoothed 0.9.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ninetoothed/jit.py +5 -2
- ninetoothed/tensor.py +22 -15
- {ninetoothed-0.9.0.dist-info → ninetoothed-0.11.0.dist-info}/METADATA +1 -1
- ninetoothed-0.11.0.dist-info/RECORD +11 -0
- ninetoothed-0.9.0.dist-info/RECORD +0 -11
- {ninetoothed-0.9.0.dist-info → ninetoothed-0.11.0.dist-info}/WHEEL +0 -0
- {ninetoothed-0.9.0.dist-info → ninetoothed-0.11.0.dist-info}/licenses/LICENSE +0 -0
ninetoothed/jit.py
CHANGED
@@ -479,7 +479,7 @@ class CodeGenerator(ast.NodeTransformer):
|
|
479
479
|
def _complete_indices(self, tensor, indices):
|
480
480
|
indices = list(self._generate_pid_indices(tensor) + tuple(indices))
|
481
481
|
|
482
|
-
for size in tensor.
|
482
|
+
for size in tensor.innermost().shape:
|
483
483
|
if Symbol.is_name(size):
|
484
484
|
name = size.node.id
|
485
485
|
if not naming.is_meta(name):
|
@@ -514,7 +514,10 @@ class CodeGenerator(ast.NodeTransformer):
|
|
514
514
|
|
515
515
|
@staticmethod
|
516
516
|
def _generate_slices(tensor, dim):
|
517
|
-
return tuple(
|
517
|
+
return tuple(
|
518
|
+
slice(None) if target_dim == dim else None
|
519
|
+
for target_dim in tensor.innermost().target_dims
|
520
|
+
)
|
518
521
|
|
519
522
|
@staticmethod
|
520
523
|
def _generate_offsets(tensor, indices):
|
ninetoothed/tensor.py
CHANGED
@@ -112,14 +112,10 @@ class Tensor:
|
|
112
112
|
strides=inner_strides,
|
113
113
|
source=self.source,
|
114
114
|
source_dims=self.source_dims,
|
115
|
-
target=self.target,
|
116
|
-
target_dims=self.target_dims,
|
117
115
|
),
|
118
116
|
strides=outer_strides,
|
119
117
|
source=self.source,
|
120
118
|
source_dims=self.source_dims,
|
121
|
-
target=self.target,
|
122
|
-
target_dims=self.target_dims,
|
123
119
|
)
|
124
120
|
|
125
121
|
def expand(self, shape):
|
@@ -136,23 +132,28 @@ class Tensor:
|
|
136
132
|
],
|
137
133
|
source=self.source,
|
138
134
|
source_dims=self.source_dims,
|
139
|
-
target=self.target,
|
140
135
|
target_dims=self.target_dims,
|
141
136
|
)
|
142
137
|
|
143
138
|
def squeeze(self, dim):
|
139
|
+
if not isinstance(dim, tuple):
|
140
|
+
dim = (dim,)
|
141
|
+
|
144
142
|
# TODO: Add error handling.
|
145
143
|
return type(self)(
|
146
|
-
shape=[size for i, size in enumerate(self.shape) if
|
144
|
+
shape=[size for i, size in enumerate(self.shape) if i not in dim],
|
147
145
|
dtype=self.dtype,
|
148
|
-
strides=[stride for i, stride in enumerate(self.strides) if
|
146
|
+
strides=[stride for i, stride in enumerate(self.strides) if i not in dim],
|
149
147
|
source=self.source,
|
150
148
|
source_dims=[
|
151
|
-
source_dim
|
149
|
+
source_dim
|
150
|
+
for i, source_dim in enumerate(self.source_dims)
|
151
|
+
if i not in dim
|
152
152
|
],
|
153
|
-
target=self.target,
|
154
153
|
target_dims=[
|
155
|
-
target_dim
|
154
|
+
target_dim
|
155
|
+
for i, target_dim in enumerate(self.target_dims)
|
156
|
+
if i not in dim
|
156
157
|
],
|
157
158
|
)
|
158
159
|
|
@@ -173,7 +174,6 @@ class Tensor:
|
|
173
174
|
strides=new_strides,
|
174
175
|
source=self.source,
|
175
176
|
source_dims=new_source_dims,
|
176
|
-
target=self.target,
|
177
177
|
target_dims=self.target_dims,
|
178
178
|
)
|
179
179
|
|
@@ -204,14 +204,21 @@ class Tensor:
|
|
204
204
|
leading_source_dims + (flattening_source_dims,) + trailing_source_dims
|
205
205
|
)
|
206
206
|
|
207
|
+
leading_target_dims = self.target_dims[:start_dim]
|
208
|
+
flattening_target_dims = self.target_dims[start_dim:end_dim]
|
209
|
+
trailing_target_dims = self.target_dims[end_dim:]
|
210
|
+
|
211
|
+
new_target_dims = (
|
212
|
+
leading_target_dims + (flattening_target_dims[-1],) + trailing_target_dims
|
213
|
+
)
|
214
|
+
|
207
215
|
return type(self)(
|
208
216
|
shape=new_shape,
|
209
217
|
dtype=self.dtype,
|
210
218
|
strides=new_strides,
|
211
219
|
source=self.source,
|
212
220
|
source_dims=new_source_dims,
|
213
|
-
|
214
|
-
target_dims=self.target_dims,
|
221
|
+
target_dims=new_target_dims,
|
215
222
|
)
|
216
223
|
|
217
224
|
def ravel(self):
|
@@ -250,11 +257,11 @@ class Tensor:
|
|
250
257
|
| (self.source.names() if self.source is not self else set())
|
251
258
|
)
|
252
259
|
|
253
|
-
def
|
260
|
+
def innermost(self):
|
254
261
|
if not isinstance(self.dtype, type(self)):
|
255
262
|
return self
|
256
263
|
|
257
|
-
return self.dtype.
|
264
|
+
return self.dtype.innermost()
|
258
265
|
|
259
266
|
def pointer_string(self):
|
260
267
|
return f"{self.name}_pointer"
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: ninetoothed
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.11.0
|
4
4
|
Summary: A domain-specific language based on Triton but providing higher-level abstraction.
|
5
5
|
Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
|
6
6
|
Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
|
@@ -0,0 +1,11 @@
|
|
1
|
+
ninetoothed/__init__.py,sha256=dX34sk5GA3OgWf1Jc4gJMW3UwcGcJsuG3hs3rkiqq6g,161
|
2
|
+
ninetoothed/jit.py,sha256=0LeDBpSYFgPx4hatP_ZsvElsj0d9d552OKRc__L1Jvc,23460
|
3
|
+
ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
|
4
|
+
ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
|
5
|
+
ninetoothed/symbol.py,sha256=rZ5nXtn-U1Nw0BBRJ-kfrwmX_zCbAi76un-Z2QFaoZc,4773
|
6
|
+
ninetoothed/tensor.py,sha256=ehgbHrxL1gc9iOI9rGUp5k-nI_daLmCfOdLI9hE-GLw,9756
|
7
|
+
ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
|
8
|
+
ninetoothed-0.11.0.dist-info/METADATA,sha256=NlC4oj1R7gNoCdA1AxYc_7UWY3uZPv3LeOq9Jo-9Q8w,7055
|
9
|
+
ninetoothed-0.11.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
10
|
+
ninetoothed-0.11.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
11
|
+
ninetoothed-0.11.0.dist-info/RECORD,,
|
@@ -1,11 +0,0 @@
|
|
1
|
-
ninetoothed/__init__.py,sha256=dX34sk5GA3OgWf1Jc4gJMW3UwcGcJsuG3hs3rkiqq6g,161
|
2
|
-
ninetoothed/jit.py,sha256=B3q32ksKTRr7I4jLcoDXjwEx7A_Awz9DXGEmIkrtoBc,23393
|
3
|
-
ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
|
4
|
-
ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
|
5
|
-
ninetoothed/symbol.py,sha256=rZ5nXtn-U1Nw0BBRJ-kfrwmX_zCbAi76un-Z2QFaoZc,4773
|
6
|
-
ninetoothed/tensor.py,sha256=OU6lVjzKU614mk3EN1AAgTDradbvJkyl22AXwdhxcfs,9577
|
7
|
-
ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
|
8
|
-
ninetoothed-0.9.0.dist-info/METADATA,sha256=LKtSbgc_mWKJ4L_8BYiRKWjoV-JKjr4hefhrkdPmrHs,7054
|
9
|
-
ninetoothed-0.9.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
10
|
-
ninetoothed-0.9.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
11
|
-
ninetoothed-0.9.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|