ninetoothed 0.7.0__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ninetoothed/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
- from ninetoothed.jit import jit
1
+ from ninetoothed.jit import jit, make
2
2
  from ninetoothed.symbol import Symbol
3
3
  from ninetoothed.tensor import Tensor
4
4
 
5
- __all__ = ["Symbol", "Tensor", "jit"]
5
+ __all__ = ["Symbol", "Tensor", "jit", "make"]
ninetoothed/jit.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import ast
2
2
  import collections
3
+ import copy
3
4
  import functools
4
5
  import importlib.util
5
6
  import inspect
@@ -18,6 +19,15 @@ from ninetoothed.tensor import Tensor
18
19
  from ninetoothed.torchifier import Torchifier
19
20
 
20
21
 
22
+ def make(arrangement, application, tensors):
23
+ params = inspect.signature(application).parameters
24
+ types = arrangement(*tensors)
25
+ annotations = {param: type for param, type in zip(params, types)}
26
+ application.__annotations__ = annotations
27
+
28
+ return jit(application)
29
+
30
+
21
31
  def jit(_func=None, *, _prettify=False):
22
32
  def wrapper(func):
23
33
  return JIT(func, _prettify=_prettify)()
@@ -141,30 +151,12 @@ class CodeGenerator(ast.NodeTransformer):
141
151
  def visit_FunctionDef(self, node):
142
152
  self._func_def = node
143
153
 
144
- self.generic_visit(node)
145
-
146
- for arg in self._args:
147
- if not isinstance(arg, Tensor) or arg.ndim == 0:
148
- continue
154
+ self._invariants = {}
149
155
 
150
- offsets = arg.offsets()
151
-
152
- initializations = {
153
- type(self)._name_for_offsets(arg, dim): offs
154
- for dim, offs in enumerate(offsets)
155
- } | {
156
- type(self)._name_for_pointers(arg): arg.original.pointer_string()
157
- + sum(
158
- type(self)._name_for_offsets(arg, dim)[
159
- type(self)._generate_slices(arg, dim)
160
- ]
161
- * stride
162
- for dim, stride in enumerate(arg.original.strides)
163
- )
164
- }
156
+ self.generic_visit(node)
165
157
 
166
- for target, value in reversed(initializations.items()):
167
- node.body.insert(0, ast.Assign(targets=[target.node], value=value.node))
158
+ for target, value in reversed(self._invariants.items()):
159
+ node.body.insert(0, ast.Assign(targets=[target.node], value=value.node))
168
160
 
169
161
  return node
170
162
 
@@ -192,24 +184,20 @@ class CodeGenerator(ast.NodeTransformer):
192
184
  ]
193
185
 
194
186
  autotune = self._generate_autotune(non_meta_names, meta_names)
195
- self._func_def.decorator_list.insert(0, autotune)
187
+ self._func_def.decorator_list = [autotune, Symbol("triton.jit").node]
196
188
 
197
189
  self._launch = self._generate_launch(non_meta_names, meta_names)
198
190
 
199
191
  return node
200
192
 
201
193
  def visit_Subscript(self, node):
202
- if (
203
- isinstance(node.value, ast.Name)
204
- and node.value.id in self._context
205
- and isinstance(node.ctx, ast.Load)
206
- ):
194
+ if self._in_context(node.value) and isinstance(node.ctx, ast.Load):
207
195
  value = self._context[node.value.id]
208
196
 
209
197
  if isinstance(value, Tensor):
210
- return type(self)._generate_load(
198
+ return self._generate_load(
211
199
  value,
212
- intermediate_indices=node.slice.elts
200
+ indices=node.slice.elts
213
201
  if isinstance(node.slice, ast.Tuple)
214
202
  else (node.slice,),
215
203
  )
@@ -219,7 +207,7 @@ class CodeGenerator(ast.NodeTransformer):
219
207
  return node
220
208
 
221
209
  def visit_Attribute(self, node):
222
- if isinstance(node.value, ast.Name) and node.value.id in self._context:
210
+ if self._in_context(node.value):
223
211
  value = self._context[node.value.id]
224
212
 
225
213
  if isinstance(value, Tensor):
@@ -234,8 +222,8 @@ class CodeGenerator(ast.NodeTransformer):
234
222
  def visit_Name(self, node):
235
223
  self.generic_visit(node)
236
224
 
237
- if node.id in self._context and isinstance(node.ctx, ast.Load):
238
- return type(self)._generate_load(self._context[node.id])
225
+ if self._in_context(node) and isinstance(node.ctx, ast.Load):
226
+ return self._generate_load(self._context[node.id])
239
227
 
240
228
  return node
241
229
 
@@ -243,16 +231,15 @@ class CodeGenerator(ast.NodeTransformer):
243
231
  if len(node.targets) == 1:
244
232
  target = node.targets[0]
245
233
 
246
- if isinstance(target, ast.Name) and target.id in self._context:
234
+ if self._in_context(target):
247
235
  self.generic_visit(node)
248
236
 
249
237
  return ast.Expr(
250
- type(self)._generate_store(self._context[target.id], node.value)
238
+ self._generate_store(self._context[target.id], node.value)
251
239
  )
252
240
  elif (
253
241
  isinstance(target, ast.Subscript)
254
- and isinstance(target.value, ast.Name)
255
- and target.value.id in self._context
242
+ and self._in_context(target.value)
256
243
  and isinstance(target.ctx, ast.Store)
257
244
  ):
258
245
  value = self._context[target.value.id]
@@ -261,10 +248,10 @@ class CodeGenerator(ast.NodeTransformer):
261
248
  self.generic_visit(node)
262
249
 
263
250
  return ast.Expr(
264
- type(self)._generate_store(
251
+ self._generate_store(
265
252
  value,
266
253
  node.value,
267
- intermediate_indices=target.slice.elts
254
+ indices=target.slice.elts
268
255
  if isinstance(target.slice, ast.Tuple)
269
256
  else (target.slice,),
270
257
  )
@@ -274,6 +261,11 @@ class CodeGenerator(ast.NodeTransformer):
274
261
 
275
262
  return node
276
263
 
264
+ _NAME_FOR_PID = Symbol("ninetoothed_pid")
265
+
266
+ def _in_context(self, node):
267
+ return isinstance(node, ast.Name) and node.id in self._context
268
+
277
269
  def _generate_autotune(self, params, meta):
278
270
  device = triton.runtime.driver.active.get_current_device()
279
271
  properties = triton.runtime.driver.active.utils.get_device_properties(device)
@@ -354,7 +346,7 @@ class CodeGenerator(ast.NodeTransformer):
354
346
  name=f"launch_{self._func_def.name}",
355
347
  args=ast.arguments(
356
348
  posonlyargs=[],
357
- args=[ast.arg(arg=arg.original.name) for arg in self._args]
349
+ args=[ast.arg(arg=arg.source.name) for arg in self._args]
358
350
  + [
359
351
  ast.arg(arg=param)
360
352
  for param in non_next_power_of_2_constexpr_params_without_prefixes
@@ -427,51 +419,105 @@ class CodeGenerator(ast.NodeTransformer):
427
419
 
428
420
  return ast.parse(f"lambda meta: ({num_elements},)", mode="eval").body
429
421
 
430
- @staticmethod
431
- def _generate_load(tensor, intermediate_indices=()):
422
+ def _generate_load(self, tensor, indices=()):
432
423
  if tensor.ndim == 0:
433
- return Symbol(tensor.original.name).node
424
+ return Symbol(tensor.source.name).node
434
425
 
435
- pointers, mask = CodeGenerator._generate_pointers_and_mask(
436
- tensor, intermediate_indices
437
- )
438
- other = CodeGenerator._generate_other(tensor)
426
+ pointers, mask = self._generate_pointers_and_mask(tensor, indices)
427
+ other = type(self)._generate_other(tensor)
439
428
 
440
429
  return call("load", pointers, mask=mask, other=other).node
441
430
 
442
- @staticmethod
443
- def _generate_store(tensor, value, intermediate_indices=()):
444
- pointers, mask = CodeGenerator._generate_pointers_and_mask(
445
- tensor, intermediate_indices
446
- )
431
+ def _generate_store(self, tensor, value, indices=()):
432
+ pointers, mask = self._generate_pointers_and_mask(tensor, indices)
447
433
 
448
434
  return call("store", pointers, value, mask=mask).node
449
435
 
450
- @staticmethod
451
- def _generate_pointers_and_mask(tensor, intermediate_indices):
452
- intermediate_offsets = CodeGenerator._generate_intermediate_offsets(
453
- tensor, intermediate_indices
454
- )
455
- offsets = [
456
- CodeGenerator._name_for_offsets(tensor, dim) + intermediate_offsets[dim]
457
- for dim in range(tensor.original.ndim)
458
- ]
459
- pointers = CodeGenerator._name_for_pointers(tensor) + sum(
460
- map(lambda x, y: x * y, intermediate_offsets, tensor.original.strides)
436
+ def _generate_pointers_and_mask(self, tensor, indices):
437
+ invariant_target_dims = type(self)._find_invariant_target_dims(tensor)
438
+
439
+ indices = self._complete_indices(tensor, indices)
440
+ offsets = type(self)._generate_offsets(tensor, indices)
441
+
442
+ for source_dim in range(tensor.source.ndim):
443
+ for target_dim in range(tensor.target.ndim):
444
+ if target_dim not in invariant_target_dims:
445
+ continue
446
+
447
+ name = type(self)._name_for_offsets(tensor, source_dim, target_dim)
448
+ self._invariants[name] = offsets[source_dim][target_dim]
449
+ offsets[source_dim][target_dim] = name
450
+
451
+ name_for_pointers = type(self)._name_for_pointers(tensor)
452
+ self._invariants[name_for_pointers] = Symbol(tensor.source.pointer_string())
453
+
454
+ for source_dim in range(tensor.source.ndim):
455
+ for target_dim in range(tensor.target.ndim):
456
+ if target_dim not in invariant_target_dims:
457
+ continue
458
+
459
+ self._invariants[name_for_pointers] += (
460
+ offsets[source_dim][target_dim][
461
+ type(self)._generate_slices(tensor, target_dim)
462
+ ]
463
+ * tensor.source.strides[source_dim]
464
+ )
465
+
466
+ pointers = name_for_pointers + sum(
467
+ offsets[source_dim][target_dim][
468
+ type(self)._generate_slices(tensor, target_dim)
469
+ ]
470
+ * tensor.source.strides[source_dim]
471
+ for source_dim in range(tensor.source.ndim)
472
+ for target_dim in range(tensor.target.ndim)
473
+ if target_dim not in invariant_target_dims
474
+ and offsets[source_dim][target_dim] != 0
461
475
  )
462
476
  mask = functools.reduce(
463
477
  lambda x, y: x & y,
464
478
  (
465
- offs[CodeGenerator._generate_slices(tensor, dim)] < size
466
- for dim, (offs, size) in enumerate(zip(offsets, tensor.original.shape))
479
+ offsets[source_dim][target_dim][
480
+ type(self)._generate_slices(tensor, target_dim)
481
+ ]
482
+ < tensor.source.shape[source_dim]
483
+ for source_dim in range(tensor.source.ndim)
484
+ for target_dim in range(tensor.target.ndim)
485
+ if offsets[source_dim][target_dim] != 0
467
486
  ),
468
487
  )
469
488
 
470
489
  return pointers, mask
471
490
 
491
+ def _complete_indices(self, tensor, indices):
492
+ indices = list(self._generate_pid_indices(tensor) + indices)
493
+
494
+ for size in tensor.inmost().shape:
495
+ if Symbol.is_name(size):
496
+ name = size.node.id
497
+ if not naming.is_meta(name):
498
+ size = naming.make_next_power_of_2(name)
499
+
500
+ indices.append(call("arange", 0, size))
501
+
502
+ return tuple(indices)
503
+
504
+ def _generate_pid_indices(self, tensor):
505
+ self._invariants[type(self)._NAME_FOR_PID] = call("program_id", 0)
506
+
507
+ indices = list(
508
+ type(self)._unravel_index(type(self)._NAME_FOR_PID, tensor.shape)
509
+ )
510
+
511
+ for dim, index in enumerate(indices):
512
+ name = type(self)._name_for_index(tensor, dim)
513
+ self._invariants[name] = index
514
+ indices[dim] = name
515
+
516
+ return tuple(indices)
517
+
472
518
  @staticmethod
473
519
  def _generate_other(tensor):
474
- other = tensor.original.other
520
+ other = tensor.source.other
475
521
 
476
522
  if isinstance(other, float) and not math.isfinite(other):
477
523
  return f"float('{other}')"
@@ -483,23 +529,86 @@ class CodeGenerator(ast.NodeTransformer):
483
529
  return tuple(slice(None) if i == dim else None for i in range(tensor.ndim))
484
530
 
485
531
  @staticmethod
486
- def _generate_intermediate_offsets(tensor, intermediate_indices):
487
- return tuple(
488
- offs
489
- for offs in tensor.offsets(
490
- [0 for _ in range(tensor.ndim)]
491
- + list(intermediate_indices)
492
- + [0 for _ in range(tensor.inmost().ndim)]
493
- )
532
+ def _generate_offsets(tensor, indices):
533
+ offsets = collections.defaultdict(
534
+ lambda: collections.defaultdict(lambda: Symbol(0))
494
535
  )
495
536
 
537
+ curr = tensor
538
+ start = 0
539
+
540
+ while isinstance(curr, type(tensor)):
541
+ stop = start + curr.ndim
542
+ curr_indices = indices[start:stop]
543
+
544
+ for index, stride, source_dim, target_dim in zip(
545
+ curr_indices, curr.strides, curr.source_dims, curr.target_dims
546
+ ):
547
+ offsets[source_dim][target_dim] += index * stride
548
+
549
+ start = stop
550
+ curr = curr.dtype
551
+
552
+ for source_dim in tuple(offsets):
553
+ for target_dim in tuple(offsets[source_dim]):
554
+ if not isinstance(source_dim, tuple):
555
+ continue
556
+
557
+ unraveled = CodeGenerator._unravel_index(
558
+ offsets[source_dim][target_dim],
559
+ tuple(tensor.source.shape[dim] for dim in source_dim),
560
+ )
561
+
562
+ for offs, dim in zip(unraveled, source_dim):
563
+ offsets[dim][target_dim] = offs
564
+
565
+ for source_dim in range(tensor.source.ndim):
566
+ for target_dim in range(tensor.target.ndim):
567
+ offsets[source_dim][target_dim] = copy.deepcopy(
568
+ offsets[source_dim][target_dim]
569
+ )
570
+ offsets[source_dim][target_dim].find_and_replace(
571
+ Symbol(tensor.source.strides[source_dim]), Symbol(1)
572
+ )
573
+
574
+ return offsets
575
+
576
+ @staticmethod
577
+ def _find_invariant_target_dims(tensor):
578
+ invariant_target_dims = set()
579
+
580
+ curr = tensor.dtype
581
+
582
+ while isinstance(curr.dtype, Tensor):
583
+ for target_dim in range(curr.target.ndim):
584
+ if target_dim not in curr.target_dims:
585
+ invariant_target_dims.add(target_dim)
586
+
587
+ curr = curr.dtype
588
+
589
+ return invariant_target_dims
590
+
496
591
  @staticmethod
497
592
  def _name_for_pointers(tensor):
498
- return Symbol(f"{tensor.original.name}_pointers")
593
+ return Symbol(f"{tensor.source.name}_pointers")
594
+
595
+ @staticmethod
596
+ def _name_for_offsets(tensor, source_dim, target_dim):
597
+ return Symbol(f"{tensor.source.name}_offsets_{source_dim}_{target_dim}")
499
598
 
500
599
  @staticmethod
501
- def _name_for_offsets(tensor, dim):
502
- return Symbol(f"{tensor.original.name}_offsets_{dim}")
600
+ def _name_for_index(tensor, dim):
601
+ return Symbol(f"{tensor.source.name}_index_{dim}")
602
+
603
+ @staticmethod
604
+ def _unravel_index(index, shape):
605
+ indices = []
606
+
607
+ for stride in Tensor(shape=shape).strides:
608
+ indices.append(index // stride)
609
+ index %= stride
610
+
611
+ return tuple(indices)
503
612
 
504
613
 
505
614
  class Tritonizer(ast.NodeTransformer):
ninetoothed/tensor.py CHANGED
@@ -1,7 +1,7 @@
1
1
  import itertools
2
+ import math
2
3
  import re
3
4
 
4
- import ninetoothed.naming as naming
5
5
  from ninetoothed.language import call
6
6
  from ninetoothed.symbol import Symbol
7
7
 
@@ -17,7 +17,10 @@ class Tensor:
17
17
  strides=None,
18
18
  other=None,
19
19
  name=None,
20
- original=None,
20
+ source=None,
21
+ source_dims=None,
22
+ target=None,
23
+ target_dims=None,
21
24
  ):
22
25
  self.dtype = dtype
23
26
 
@@ -39,10 +42,25 @@ class Tensor:
39
42
 
40
43
  self.other = other
41
44
 
42
- if original is not None:
43
- self.original = original
45
+ if source is not None:
46
+ self.source = source
44
47
  else:
45
- self.original = self
48
+ self.source = self
49
+
50
+ if source_dims is not None:
51
+ self.source_dims = source_dims
52
+ else:
53
+ self.source_dims = (dim for dim in range(self.source.ndim))
54
+
55
+ if target is not None:
56
+ self.target = target
57
+ else:
58
+ self.target = self
59
+
60
+ if target_dims is not None:
61
+ self.target_dims = target_dims
62
+ else:
63
+ self.target_dims = (dim for dim in range(self.target.ndim))
46
64
 
47
65
  type(self).num_instances += 1
48
66
 
@@ -87,10 +105,16 @@ class Tensor:
87
105
  shape=inner_shape,
88
106
  dtype=self.dtype,
89
107
  strides=inner_strides,
90
- original=self.original,
108
+ source=self.source,
109
+ source_dims=self.source_dims,
110
+ target=self.target,
111
+ target_dims=self.target_dims,
91
112
  ),
92
113
  strides=outer_strides,
93
- original=self.original,
114
+ source=self.source,
115
+ source_dims=self.source_dims,
116
+ target=self.target,
117
+ target_dims=self.target_dims,
94
118
  )
95
119
 
96
120
  def expand(self, shape):
@@ -105,7 +129,10 @@ class Tensor:
105
129
  stride if new_size == -1 else 0
106
130
  for new_size, stride in zip(shape, self.strides)
107
131
  ],
108
- original=self.original,
132
+ source=self.source,
133
+ source_dims=self.source_dims,
134
+ target=self.target,
135
+ target_dims=self.target_dims,
109
136
  )
110
137
 
111
138
  def squeeze(self, dim):
@@ -114,81 +141,109 @@ class Tensor:
114
141
  shape=[size for i, size in enumerate(self.shape) if dim != i],
115
142
  dtype=self.dtype,
116
143
  strides=[stride for i, stride in enumerate(self.strides) if dim != i],
117
- original=self.original,
144
+ source=self.source,
145
+ source_dims=[
146
+ source_dim for i, source_dim in enumerate(self.source_dims) if dim != i
147
+ ],
148
+ target=self.target,
149
+ target_dims=[
150
+ target_dim for i, target_dim in enumerate(self.target_dims) if dim != i
151
+ ],
118
152
  )
119
153
 
120
- def names(self):
121
- if self.ndim == 0:
122
- return {self.original.name}
123
-
124
- return (
125
- {self.original.pointer_string()}
126
- | {
127
- name
128
- for value in itertools.chain(self.shape, self.strides)
129
- if isinstance(value, Symbol)
130
- for name in value.names()
131
- }
132
- | (self.dtype.names() if isinstance(self.dtype, type(self)) else set())
133
- | (self.original.names() if self.original is not self else set())
134
- )
154
+ def permute(self, dims):
155
+ # TODO: Add error handling.
156
+ new_shape = [None for _ in range(self.ndim)]
157
+ new_strides = [None for _ in range(self.ndim)]
158
+ new_source_dims = [None for _ in range(self.ndim)]
135
159
 
136
- def offsets(self, indices=None):
137
- if indices is None:
138
- indices = self.indices()
160
+ for original_dim, permuted_dim in enumerate(dims):
161
+ new_shape[original_dim] = self.shape[permuted_dim]
162
+ new_strides[original_dim] = self.strides[permuted_dim]
163
+ new_source_dims[original_dim] = self.source_dims[permuted_dim]
139
164
 
140
- offsets = [[] for _ in range(self.original.ndim)]
165
+ return type(self)(
166
+ shape=new_shape,
167
+ dtype=self.dtype,
168
+ strides=new_strides,
169
+ source=self.source,
170
+ source_dims=new_source_dims,
171
+ target=self.target,
172
+ target_dims=self.target_dims,
173
+ )
141
174
 
142
- curr = self
143
- start = 0
175
+ def flatten(self, start_dim=None, end_dim=None):
176
+ # TODO: Add error handling.
177
+ if start_dim is None:
178
+ start_dim = 0
179
+ if end_dim is None:
180
+ end_dim = self.ndim
144
181
 
145
- while isinstance(curr, type(self)):
146
- stop = start + curr.ndim
147
- curr_indices = indices[start:stop]
182
+ leading_sizes = self.shape[:start_dim]
183
+ flattening_sizes = self.shape[start_dim:end_dim]
184
+ trailing_sizes = self.shape[end_dim:]
148
185
 
149
- for index, stride in zip(curr_indices, curr.strides):
150
- for dim in self._dims_of(stride):
151
- offsets[dim].append(index * stride)
186
+ new_shape = leading_sizes + (math.prod(flattening_sizes),) + trailing_sizes
152
187
 
153
- start = stop
154
- curr = curr.dtype
188
+ leading_strides = self.strides[:start_dim]
189
+ flattening_strides = self.strides[start_dim:end_dim]
190
+ trailing_strides = self.strides[end_dim:]
155
191
 
156
- for dim in range(self.original.ndim):
157
- offsets[dim] = sum(offsets[dim])
158
- offsets[dim].find_and_replace(Symbol(self.original.strides[dim]), Symbol(1))
192
+ new_strides = leading_strides + (flattening_strides[-1],) + trailing_strides
159
193
 
160
- return offsets
194
+ leading_source_dims = self.source_dims[:start_dim]
195
+ flattening_source_dims = self.source_dims[start_dim:end_dim]
196
+ trailing_source_dims = self.source_dims[end_dim:]
161
197
 
162
- def indices(self, index=None):
163
- if index is None:
164
- index = call("program_id", 0)
198
+ new_source_dims = (
199
+ leading_source_dims + (flattening_source_dims,) + trailing_source_dims
200
+ )
165
201
 
166
- indices = []
202
+ return type(self)(
203
+ shape=new_shape,
204
+ dtype=self.dtype,
205
+ strides=new_strides,
206
+ source=self.source,
207
+ source_dims=new_source_dims,
208
+ target=self.target,
209
+ target_dims=self.target_dims,
210
+ )
167
211
 
168
- for stride in type(self)(shape=self.shape, original=self.original).strides:
169
- indices.append(index // stride)
170
- index %= stride
212
+ def ravel(self):
213
+ # TODO: Add error handling.
214
+ new_shape = []
215
+ new_strides = []
171
216
 
172
- curr = self.dtype
217
+ curr = self
173
218
 
174
- while isinstance(curr.dtype, type(self)):
175
- for _ in range(curr.ndim):
176
- indices.append(0)
219
+ while isinstance(curr, type(self)):
220
+ new_shape.extend(curr.shape)
221
+ new_strides.extend(curr.strides)
177
222
 
178
223
  curr = curr.dtype
179
224
 
180
- if isinstance(curr, type(self)):
181
- for dim in range(curr.ndim):
182
- size = curr.shape[dim]
183
-
184
- if Symbol.is_name(size):
185
- name = size.node.id
186
- if not naming.is_meta(name):
187
- size = naming.make_next_power_of_2(name)
225
+ return type(self)(
226
+ shape=new_shape,
227
+ strides=new_strides,
228
+ other=self.source.other,
229
+ name=self.source.name,
230
+ )
188
231
 
189
- indices.append(call("arange", 0, size))
232
+ def names(self):
233
+ if self.ndim == 0:
234
+ return {self.source.name}
190
235
 
191
- return tuple(indices)
236
+ return (
237
+ {self.source.pointer_string()}
238
+ | {
239
+ name
240
+ for value in itertools.chain(self.shape, self.strides)
241
+ if isinstance(value, Symbol)
242
+ for name in value.names()
243
+ }
244
+ | (self.dtype.names() if isinstance(self.dtype, type(self)) else set())
245
+ | (self.source.names() if self.source is not self else set())
246
+ )
192
247
 
193
248
  def inmost(self):
194
249
  if not isinstance(self.dtype, type(self)):
@@ -237,6 +292,22 @@ class Tensor:
237
292
  def ndim(self):
238
293
  return len(self.shape)
239
294
 
295
+ @property
296
+ def source_dims(self):
297
+ return self._source_dims
298
+
299
+ @source_dims.setter
300
+ def source_dims(self, value):
301
+ self._source_dims = tuple(value)
302
+
303
+ @property
304
+ def target_dims(self):
305
+ return self._target_dims
306
+
307
+ @target_dims.setter
308
+ def target_dims(self, value):
309
+ self._target_dims = tuple(value)
310
+
240
311
  @staticmethod
241
312
  def pointer_pattern():
242
313
  return re.compile(rf"({_identifier_pattern_raw_string()})_(pointer)")
@@ -249,16 +320,6 @@ class Tensor:
249
320
  def stride_pattern():
250
321
  return re.compile(rf"({_identifier_pattern_raw_string()})_(stride)_(.+)")
251
322
 
252
- def _dims_of(self, stride):
253
- dims = set()
254
- names = stride.names() if isinstance(stride, Symbol) else {stride}
255
-
256
- for dim, original_stride in enumerate(self.original.strides):
257
- if str(original_stride) in names:
258
- dims.add(dim)
259
-
260
- return dims
261
-
262
323
  @staticmethod
263
324
  def _calculate_default_strides(shape):
264
325
  strides = [1]
@@ -1,10 +1,11 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: ninetoothed
3
- Version: 0.7.0
3
+ Version: 0.8.0
4
4
  Summary: A domain-specific language based on Triton but providing higher-level abstraction.
5
5
  Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
6
6
  Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
7
7
  Author-email: Jiacheng Huang <huangjiacheng0709@outlook.com>
8
+ License-File: LICENSE
8
9
  Classifier: License :: OSI Approved :: Apache Software License
9
10
  Classifier: Operating System :: OS Independent
10
11
  Classifier: Programming Language :: Python :: 3
@@ -0,0 +1,11 @@
1
+ ninetoothed/__init__.py,sha256=dX34sk5GA3OgWf1Jc4gJMW3UwcGcJsuG3hs3rkiqq6g,161
2
+ ninetoothed/jit.py,sha256=z70hQEsogfQu0cLxq5m3cOsWsVANcMRJaVv5di9vk1c,23741
3
+ ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
4
+ ninetoothed/naming.py,sha256=3FBnC-S3dAZRcBcob9SrcVpVEYE5IXRacwkCiA3vIGU,891
5
+ ninetoothed/symbol.py,sha256=rZ5nXtn-U1Nw0BBRJ-kfrwmX_zCbAi76un-Z2QFaoZc,4773
6
+ ninetoothed/tensor.py,sha256=_jM0tVgqIwZd3MJJsGVTaLCsSxpPO8JfF4qkMShhQvQ,9429
7
+ ninetoothed/torchifier.py,sha256=8M2PDwyFIfVypX6Z-Vt_bGbsCPqxqKnftL0rXeh9bOM,911
8
+ ninetoothed-0.8.0.dist-info/METADATA,sha256=gPWYhTBH5EdeOyGnArZIEw82aFmoQchD6pxtLi6LGMA,7054
9
+ ninetoothed-0.8.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
10
+ ninetoothed-0.8.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
11
+ ninetoothed-0.8.0.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.26.3
2
+ Generator: hatchling 1.27.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -1,11 +0,0 @@
1
- ninetoothed/__init__.py,sha256=T5UJXlC-wbo8JKPbLUNT65Kccp12xP52WFV5FsugETI,147
2
- ninetoothed/jit.py,sha256=y0y3gfcBNeRLhozIzLuDHLLzBAGrL8tLk8Rcvhw_uec,20068
3
- ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
4
- ninetoothed/naming.py,sha256=3FBnC-S3dAZRcBcob9SrcVpVEYE5IXRacwkCiA3vIGU,891
5
- ninetoothed/symbol.py,sha256=rZ5nXtn-U1Nw0BBRJ-kfrwmX_zCbAi76un-Z2QFaoZc,4773
6
- ninetoothed/tensor.py,sha256=KL6Iw2nwaRnZsNTxOeghDeXUQUgphnVh_1fsmAObOtI,7391
7
- ninetoothed/torchifier.py,sha256=8M2PDwyFIfVypX6Z-Vt_bGbsCPqxqKnftL0rXeh9bOM,911
8
- ninetoothed-0.7.0.dist-info/METADATA,sha256=gRLAzPSdxWYff8SqPuwJ3ji3Z_kPq_vtEupnSe7dIQ0,7032
9
- ninetoothed-0.7.0.dist-info/WHEEL,sha256=C2FUgwZgiLbznR-k0b_5k3Ai_1aASOXDss3lzCUsUug,87
10
- ninetoothed-0.7.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
11
- ninetoothed-0.7.0.dist-info/RECORD,,