ninetoothed 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ninetoothed/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
- from ninetoothed.jit import jit
1
+ from ninetoothed.jit import jit, make
2
2
  from ninetoothed.symbol import Symbol
3
3
  from ninetoothed.tensor import Tensor
4
4
 
5
- __all__ = ["Symbol", "Tensor", "jit"]
5
+ __all__ = ["Symbol", "Tensor", "jit", "make"]
ninetoothed/jit.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import ast
2
2
  import collections
3
+ import copy
3
4
  import functools
4
5
  import importlib.util
5
6
  import inspect
@@ -18,6 +19,15 @@ from ninetoothed.tensor import Tensor
18
19
  from ninetoothed.torchifier import Torchifier
19
20
 
20
21
 
22
+ def make(arrangement, application, tensors):
23
+ params = inspect.signature(application).parameters
24
+ types = arrangement(*tensors)
25
+ annotations = {param: type for param, type in zip(params, types)}
26
+ application.__annotations__ = annotations
27
+
28
+ return jit(application)
29
+
30
+
21
31
  def jit(_func=None, *, _prettify=False):
22
32
  def wrapper(func):
23
33
  return JIT(func, _prettify=_prettify)()
@@ -29,23 +39,12 @@ def jit(_func=None, *, _prettify=False):
29
39
 
30
40
 
31
41
  class JIT:
32
- handles = collections.defaultdict(dict)
33
-
34
42
  def __init__(self, func, _prettify=False):
35
43
  self.func = func
36
44
 
37
45
  self._prettify = _prettify
38
46
 
39
47
  def __call__(self):
40
- source_file = inspect.getsourcefile(self.func)
41
- source_line = inspect.getsourcelines(self.func)[1]
42
-
43
- if (
44
- source_file in type(self).handles
45
- and source_line in type(self).handles[source_file]
46
- ):
47
- return type(self).handles[source_file][source_line]
48
-
49
48
  tree = self._get_tree()
50
49
 
51
50
  CodeGenerator(inspect.get_annotations(self.func)).visit(tree)
@@ -83,8 +82,6 @@ class JIT:
83
82
  source,
84
83
  )
85
84
 
86
- type(self).handles[source_file][source_line] = handle
87
-
88
85
  return handle
89
86
 
90
87
  def _get_tree(self):
@@ -141,30 +138,12 @@ class CodeGenerator(ast.NodeTransformer):
141
138
  def visit_FunctionDef(self, node):
142
139
  self._func_def = node
143
140
 
144
- self.generic_visit(node)
145
-
146
- for arg in self._args:
147
- if not isinstance(arg, Tensor) or arg.ndim == 0:
148
- continue
149
-
150
- offsets = arg.offsets()
141
+ self._invariants = {}
151
142
 
152
- initializations = {
153
- type(self)._name_for_offsets(arg, dim): offs
154
- for dim, offs in enumerate(offsets)
155
- } | {
156
- type(self)._name_for_pointers(arg): arg.original.pointer_string()
157
- + sum(
158
- type(self)._name_for_offsets(arg, dim)[
159
- type(self)._generate_slices(arg, dim)
160
- ]
161
- * stride
162
- for dim, stride in enumerate(arg.original.strides)
163
- )
164
- }
143
+ self.generic_visit(node)
165
144
 
166
- for target, value in reversed(initializations.items()):
167
- node.body.insert(0, ast.Assign(targets=[target.node], value=value.node))
145
+ for target, value in reversed(self._invariants.items()):
146
+ node.body.insert(0, ast.Assign(targets=[target.node], value=value.node))
168
147
 
169
148
  return node
170
149
 
@@ -192,24 +171,20 @@ class CodeGenerator(ast.NodeTransformer):
192
171
  ]
193
172
 
194
173
  autotune = self._generate_autotune(non_meta_names, meta_names)
195
- self._func_def.decorator_list.insert(0, autotune)
174
+ self._func_def.decorator_list = [autotune, Symbol("triton.jit").node]
196
175
 
197
176
  self._launch = self._generate_launch(non_meta_names, meta_names)
198
177
 
199
178
  return node
200
179
 
201
180
  def visit_Subscript(self, node):
202
- if (
203
- isinstance(node.value, ast.Name)
204
- and node.value.id in self._context
205
- and isinstance(node.ctx, ast.Load)
206
- ):
181
+ if self._in_context(node.value) and isinstance(node.ctx, ast.Load):
207
182
  value = self._context[node.value.id]
208
183
 
209
184
  if isinstance(value, Tensor):
210
- return type(self)._generate_load(
185
+ return self._generate_load(
211
186
  value,
212
- intermediate_indices=node.slice.elts
187
+ indices=node.slice.elts
213
188
  if isinstance(node.slice, ast.Tuple)
214
189
  else (node.slice,),
215
190
  )
@@ -219,7 +194,7 @@ class CodeGenerator(ast.NodeTransformer):
219
194
  return node
220
195
 
221
196
  def visit_Attribute(self, node):
222
- if isinstance(node.value, ast.Name) and node.value.id in self._context:
197
+ if self._in_context(node.value):
223
198
  value = self._context[node.value.id]
224
199
 
225
200
  if isinstance(value, Tensor):
@@ -234,8 +209,8 @@ class CodeGenerator(ast.NodeTransformer):
234
209
  def visit_Name(self, node):
235
210
  self.generic_visit(node)
236
211
 
237
- if node.id in self._context and isinstance(node.ctx, ast.Load):
238
- return type(self)._generate_load(self._context[node.id])
212
+ if self._in_context(node) and isinstance(node.ctx, ast.Load):
213
+ return self._generate_load(self._context[node.id])
239
214
 
240
215
  return node
241
216
 
@@ -243,16 +218,15 @@ class CodeGenerator(ast.NodeTransformer):
243
218
  if len(node.targets) == 1:
244
219
  target = node.targets[0]
245
220
 
246
- if isinstance(target, ast.Name) and target.id in self._context:
221
+ if self._in_context(target):
247
222
  self.generic_visit(node)
248
223
 
249
224
  return ast.Expr(
250
- type(self)._generate_store(self._context[target.id], node.value)
225
+ self._generate_store(self._context[target.id], node.value)
251
226
  )
252
227
  elif (
253
228
  isinstance(target, ast.Subscript)
254
- and isinstance(target.value, ast.Name)
255
- and target.value.id in self._context
229
+ and self._in_context(target.value)
256
230
  and isinstance(target.ctx, ast.Store)
257
231
  ):
258
232
  value = self._context[target.value.id]
@@ -261,10 +235,10 @@ class CodeGenerator(ast.NodeTransformer):
261
235
  self.generic_visit(node)
262
236
 
263
237
  return ast.Expr(
264
- type(self)._generate_store(
238
+ self._generate_store(
265
239
  value,
266
240
  node.value,
267
- intermediate_indices=target.slice.elts
241
+ indices=target.slice.elts
268
242
  if isinstance(target.slice, ast.Tuple)
269
243
  else (target.slice,),
270
244
  )
@@ -274,6 +248,11 @@ class CodeGenerator(ast.NodeTransformer):
274
248
 
275
249
  return node
276
250
 
251
+ _NAME_FOR_PID = Symbol("ninetoothed_pid")
252
+
253
+ def _in_context(self, node):
254
+ return isinstance(node, ast.Name) and node.id in self._context
255
+
277
256
  def _generate_autotune(self, params, meta):
278
257
  device = triton.runtime.driver.active.get_current_device()
279
258
  properties = triton.runtime.driver.active.utils.get_device_properties(device)
@@ -354,10 +333,11 @@ class CodeGenerator(ast.NodeTransformer):
354
333
  name=f"launch_{self._func_def.name}",
355
334
  args=ast.arguments(
356
335
  posonlyargs=[],
357
- args=[ast.arg(arg=arg.original.name) for arg in self._args]
336
+ args=[ast.arg(arg=arg.source.name) for arg in self._args]
358
337
  + [
359
338
  ast.arg(arg=param)
360
339
  for param in non_next_power_of_2_constexpr_params_without_prefixes
340
+ if not Tensor.size_pattern().fullmatch(param)
361
341
  ],
362
342
  kwonlyargs=[],
363
343
  defaults=[],
@@ -427,51 +407,105 @@ class CodeGenerator(ast.NodeTransformer):
427
407
 
428
408
  return ast.parse(f"lambda meta: ({num_elements},)", mode="eval").body
429
409
 
430
- @staticmethod
431
- def _generate_load(tensor, intermediate_indices=()):
410
+ def _generate_load(self, tensor, indices=()):
432
411
  if tensor.ndim == 0:
433
- return Symbol(tensor.original.name).node
412
+ return Symbol(tensor.source.name).node
434
413
 
435
- pointers, mask = CodeGenerator._generate_pointers_and_mask(
436
- tensor, intermediate_indices
437
- )
438
- other = CodeGenerator._generate_other(tensor)
414
+ pointers, mask = self._generate_pointers_and_mask(tensor, indices)
415
+ other = type(self)._generate_other(tensor)
439
416
 
440
417
  return call("load", pointers, mask=mask, other=other).node
441
418
 
442
- @staticmethod
443
- def _generate_store(tensor, value, intermediate_indices=()):
444
- pointers, mask = CodeGenerator._generate_pointers_and_mask(
445
- tensor, intermediate_indices
446
- )
419
+ def _generate_store(self, tensor, value, indices=()):
420
+ pointers, mask = self._generate_pointers_and_mask(tensor, indices)
447
421
 
448
422
  return call("store", pointers, value, mask=mask).node
449
423
 
450
- @staticmethod
451
- def _generate_pointers_and_mask(tensor, intermediate_indices):
452
- intermediate_offsets = CodeGenerator._generate_intermediate_offsets(
453
- tensor, intermediate_indices
454
- )
455
- offsets = [
456
- CodeGenerator._name_for_offsets(tensor, dim) + intermediate_offsets[dim]
457
- for dim in range(tensor.original.ndim)
458
- ]
459
- pointers = CodeGenerator._name_for_pointers(tensor) + sum(
460
- map(lambda x, y: x * y, intermediate_offsets, tensor.original.strides)
424
+ def _generate_pointers_and_mask(self, tensor, indices):
425
+ invariant_target_dims = type(self)._find_invariant_target_dims(tensor)
426
+
427
+ indices = self._complete_indices(tensor, indices)
428
+ offsets = type(self)._generate_offsets(tensor, indices)
429
+
430
+ for source_dim in range(tensor.source.ndim):
431
+ for target_dim in range(tensor.target.ndim):
432
+ if target_dim not in invariant_target_dims:
433
+ continue
434
+
435
+ name = type(self)._name_for_offsets(tensor, source_dim, target_dim)
436
+ self._invariants[name] = offsets[source_dim][target_dim]
437
+ offsets[source_dim][target_dim] = name
438
+
439
+ name_for_pointers = type(self)._name_for_pointers(tensor)
440
+ self._invariants[name_for_pointers] = Symbol(tensor.source.pointer_string())
441
+
442
+ for source_dim in range(tensor.source.ndim):
443
+ for target_dim in range(tensor.target.ndim):
444
+ if target_dim not in invariant_target_dims:
445
+ continue
446
+
447
+ self._invariants[name_for_pointers] += (
448
+ offsets[source_dim][target_dim][
449
+ type(self)._generate_slices(tensor, target_dim)
450
+ ]
451
+ * tensor.source.strides[source_dim]
452
+ )
453
+
454
+ pointers = name_for_pointers + sum(
455
+ offsets[source_dim][target_dim][
456
+ type(self)._generate_slices(tensor, target_dim)
457
+ ]
458
+ * tensor.source.strides[source_dim]
459
+ for source_dim in range(tensor.source.ndim)
460
+ for target_dim in range(tensor.target.ndim)
461
+ if target_dim not in invariant_target_dims
462
+ and offsets[source_dim][target_dim] != 0
461
463
  )
462
464
  mask = functools.reduce(
463
465
  lambda x, y: x & y,
464
466
  (
465
- offs[CodeGenerator._generate_slices(tensor, dim)] < size
466
- for dim, (offs, size) in enumerate(zip(offsets, tensor.original.shape))
467
+ offsets[source_dim][target_dim][
468
+ type(self)._generate_slices(tensor, target_dim)
469
+ ]
470
+ < tensor.source.shape[source_dim]
471
+ for source_dim in range(tensor.source.ndim)
472
+ for target_dim in range(tensor.target.ndim)
473
+ if offsets[source_dim][target_dim] != 0
467
474
  ),
468
475
  )
469
476
 
470
477
  return pointers, mask
471
478
 
479
+ def _complete_indices(self, tensor, indices):
480
+ indices = list(self._generate_pid_indices(tensor) + tuple(indices))
481
+
482
+ for size in tensor.inmost().shape:
483
+ if Symbol.is_name(size):
484
+ name = size.node.id
485
+ if not naming.is_meta(name):
486
+ size = naming.make_next_power_of_2(name)
487
+
488
+ indices.append(call("arange", 0, size))
489
+
490
+ return tuple(indices)
491
+
492
+ def _generate_pid_indices(self, tensor):
493
+ self._invariants[type(self)._NAME_FOR_PID] = call("program_id", 0)
494
+
495
+ indices = list(
496
+ type(self)._unravel_index(type(self)._NAME_FOR_PID, tensor.shape)
497
+ )
498
+
499
+ for dim, index in enumerate(indices):
500
+ name = type(self)._name_for_index(tensor, dim)
501
+ self._invariants[name] = index
502
+ indices[dim] = name
503
+
504
+ return tuple(indices)
505
+
472
506
  @staticmethod
473
507
  def _generate_other(tensor):
474
- other = tensor.original.other
508
+ other = tensor.source.other
475
509
 
476
510
  if isinstance(other, float) and not math.isfinite(other):
477
511
  return f"float('{other}')"
@@ -483,23 +517,86 @@ class CodeGenerator(ast.NodeTransformer):
483
517
  return tuple(slice(None) if i == dim else None for i in range(tensor.ndim))
484
518
 
485
519
  @staticmethod
486
- def _generate_intermediate_offsets(tensor, intermediate_indices):
487
- return tuple(
488
- offs
489
- for offs in tensor.offsets(
490
- [0 for _ in range(tensor.ndim)]
491
- + list(intermediate_indices)
492
- + [0 for _ in range(tensor.inmost().ndim)]
493
- )
520
+ def _generate_offsets(tensor, indices):
521
+ offsets = collections.defaultdict(
522
+ lambda: collections.defaultdict(lambda: Symbol(0))
494
523
  )
495
524
 
525
+ curr = tensor
526
+ start = 0
527
+
528
+ while isinstance(curr, type(tensor)):
529
+ stop = start + curr.ndim
530
+ curr_indices = indices[start:stop]
531
+
532
+ for index, stride, source_dim, target_dim in zip(
533
+ curr_indices, curr.strides, curr.source_dims, curr.target_dims
534
+ ):
535
+ offsets[source_dim][target_dim] += index * stride
536
+
537
+ start = stop
538
+ curr = curr.dtype
539
+
540
+ for source_dim in tuple(offsets):
541
+ for target_dim in tuple(offsets[source_dim]):
542
+ if not isinstance(source_dim, tuple):
543
+ continue
544
+
545
+ unraveled = CodeGenerator._unravel_index(
546
+ offsets[source_dim][target_dim],
547
+ tuple(tensor.source.shape[dim] for dim in source_dim),
548
+ )
549
+
550
+ for offs, dim in zip(unraveled, source_dim):
551
+ offsets[dim][target_dim] = offs
552
+
553
+ for source_dim in range(tensor.source.ndim):
554
+ for target_dim in range(tensor.target.ndim):
555
+ offsets[source_dim][target_dim] = copy.deepcopy(
556
+ offsets[source_dim][target_dim]
557
+ )
558
+ offsets[source_dim][target_dim].find_and_replace(
559
+ Symbol(tensor.source.strides[source_dim]), Symbol(1)
560
+ )
561
+
562
+ return offsets
563
+
564
+ @staticmethod
565
+ def _find_invariant_target_dims(tensor):
566
+ invariant_target_dims = set()
567
+
568
+ curr = tensor.dtype
569
+
570
+ while isinstance(curr.dtype, Tensor):
571
+ for target_dim in range(curr.target.ndim):
572
+ if target_dim not in curr.target_dims:
573
+ invariant_target_dims.add(target_dim)
574
+
575
+ curr = curr.dtype
576
+
577
+ return invariant_target_dims
578
+
496
579
  @staticmethod
497
580
  def _name_for_pointers(tensor):
498
- return Symbol(f"{tensor.original.name}_pointers")
581
+ return Symbol(f"{tensor.source.name}_pointers")
582
+
583
+ @staticmethod
584
+ def _name_for_offsets(tensor, source_dim, target_dim):
585
+ return Symbol(f"{tensor.source.name}_offsets_{source_dim}_{target_dim}")
499
586
 
500
587
  @staticmethod
501
- def _name_for_offsets(tensor, dim):
502
- return Symbol(f"{tensor.original.name}_offsets_{dim}")
588
+ def _name_for_index(tensor, dim):
589
+ return Symbol(f"{tensor.source.name}_index_{dim}")
590
+
591
+ @staticmethod
592
+ def _unravel_index(index, shape):
593
+ indices = []
594
+
595
+ for stride in Tensor(shape=shape).strides:
596
+ indices.append(index // stride)
597
+ index %= stride
598
+
599
+ return tuple(indices)
503
600
 
504
601
 
505
602
  class Tritonizer(ast.NodeTransformer):
ninetoothed/naming.py CHANGED
@@ -1,6 +1,10 @@
1
1
  import re
2
2
 
3
3
 
4
+ def auto_generate(name):
5
+ return f"ninetoothed_{name}"
6
+
7
+
4
8
  def make_constexpr(name):
5
9
  return _add_prefix(name, _CONSTEXPR)
6
10
 
ninetoothed/tensor.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import itertools
2
+ import math
2
3
  import re
3
4
 
4
5
  import ninetoothed.naming as naming
@@ -16,18 +17,25 @@ class Tensor:
16
17
  dtype=None,
17
18
  strides=None,
18
19
  other=None,
20
+ constexpr_shape=None,
19
21
  name=None,
20
- original=None,
22
+ source=None,
23
+ source_dims=None,
24
+ target=None,
25
+ target_dims=None,
21
26
  ):
22
27
  self.dtype = dtype
23
28
 
24
29
  if name is not None:
25
30
  self.name = name
26
31
  else:
27
- self.name = f"_ninetoothed_tensor_{type(self).num_instances}"
32
+ self.name = naming.auto_generate(f"tensor_{type(self).num_instances}")
28
33
 
29
34
  if ndim is not None:
30
- self.shape = (Symbol(self.size_string(i)) for i in range(ndim))
35
+ self.shape = (
36
+ Symbol(self.size_string(i), constexpr=constexpr_shape)
37
+ for i in range(ndim)
38
+ )
31
39
  self.strides = (Symbol(self.stride_string(i)) for i in range(ndim))
32
40
  else:
33
41
  self.shape = shape
@@ -39,10 +47,25 @@ class Tensor:
39
47
 
40
48
  self.other = other
41
49
 
42
- if original is not None:
43
- self.original = original
50
+ if source is not None:
51
+ self.source = source
52
+ else:
53
+ self.source = self
54
+
55
+ if source_dims is not None:
56
+ self.source_dims = source_dims
57
+ else:
58
+ self.source_dims = (dim for dim in range(self.source.ndim))
59
+
60
+ if target is not None:
61
+ self.target = target
62
+ else:
63
+ self.target = self
64
+
65
+ if target_dims is not None:
66
+ self.target_dims = target_dims
44
67
  else:
45
- self.original = self
68
+ self.target_dims = (dim for dim in range(self.target.ndim))
46
69
 
47
70
  type(self).num_instances += 1
48
71
 
@@ -87,10 +110,16 @@ class Tensor:
87
110
  shape=inner_shape,
88
111
  dtype=self.dtype,
89
112
  strides=inner_strides,
90
- original=self.original,
113
+ source=self.source,
114
+ source_dims=self.source_dims,
115
+ target=self.target,
116
+ target_dims=self.target_dims,
91
117
  ),
92
118
  strides=outer_strides,
93
- original=self.original,
119
+ source=self.source,
120
+ source_dims=self.source_dims,
121
+ target=self.target,
122
+ target_dims=self.target_dims,
94
123
  )
95
124
 
96
125
  def expand(self, shape):
@@ -105,7 +134,10 @@ class Tensor:
105
134
  stride if new_size == -1 else 0
106
135
  for new_size, stride in zip(shape, self.strides)
107
136
  ],
108
- original=self.original,
137
+ source=self.source,
138
+ source_dims=self.source_dims,
139
+ target=self.target,
140
+ target_dims=self.target_dims,
109
141
  )
110
142
 
111
143
  def squeeze(self, dim):
@@ -114,81 +146,109 @@ class Tensor:
114
146
  shape=[size for i, size in enumerate(self.shape) if dim != i],
115
147
  dtype=self.dtype,
116
148
  strides=[stride for i, stride in enumerate(self.strides) if dim != i],
117
- original=self.original,
149
+ source=self.source,
150
+ source_dims=[
151
+ source_dim for i, source_dim in enumerate(self.source_dims) if dim != i
152
+ ],
153
+ target=self.target,
154
+ target_dims=[
155
+ target_dim for i, target_dim in enumerate(self.target_dims) if dim != i
156
+ ],
118
157
  )
119
158
 
120
- def names(self):
121
- if self.ndim == 0:
122
- return {self.original.name}
123
-
124
- return (
125
- {self.original.pointer_string()}
126
- | {
127
- name
128
- for value in itertools.chain(self.shape, self.strides)
129
- if isinstance(value, Symbol)
130
- for name in value.names()
131
- }
132
- | (self.dtype.names() if isinstance(self.dtype, type(self)) else set())
133
- | (self.original.names() if self.original is not self else set())
134
- )
159
+ def permute(self, dims):
160
+ # TODO: Add error handling.
161
+ new_shape = [None for _ in range(self.ndim)]
162
+ new_strides = [None for _ in range(self.ndim)]
163
+ new_source_dims = [None for _ in range(self.ndim)]
135
164
 
136
- def offsets(self, indices=None):
137
- if indices is None:
138
- indices = self.indices()
165
+ for original_dim, permuted_dim in enumerate(dims):
166
+ new_shape[original_dim] = self.shape[permuted_dim]
167
+ new_strides[original_dim] = self.strides[permuted_dim]
168
+ new_source_dims[original_dim] = self.source_dims[permuted_dim]
139
169
 
140
- offsets = [[] for _ in range(self.original.ndim)]
170
+ return type(self)(
171
+ shape=new_shape,
172
+ dtype=self.dtype,
173
+ strides=new_strides,
174
+ source=self.source,
175
+ source_dims=new_source_dims,
176
+ target=self.target,
177
+ target_dims=self.target_dims,
178
+ )
141
179
 
142
- curr = self
143
- start = 0
180
+ def flatten(self, start_dim=None, end_dim=None):
181
+ # TODO: Add error handling.
182
+ if start_dim is None:
183
+ start_dim = 0
184
+ if end_dim is None:
185
+ end_dim = self.ndim
144
186
 
145
- while isinstance(curr, type(self)):
146
- stop = start + curr.ndim
147
- curr_indices = indices[start:stop]
187
+ leading_sizes = self.shape[:start_dim]
188
+ flattening_sizes = self.shape[start_dim:end_dim]
189
+ trailing_sizes = self.shape[end_dim:]
148
190
 
149
- for index, stride in zip(curr_indices, curr.strides):
150
- for dim in self._dims_of(stride):
151
- offsets[dim].append(index * stride)
191
+ new_shape = leading_sizes + (math.prod(flattening_sizes),) + trailing_sizes
152
192
 
153
- start = stop
154
- curr = curr.dtype
193
+ leading_strides = self.strides[:start_dim]
194
+ flattening_strides = self.strides[start_dim:end_dim]
195
+ trailing_strides = self.strides[end_dim:]
155
196
 
156
- for dim in range(self.original.ndim):
157
- offsets[dim] = sum(offsets[dim])
158
- offsets[dim].find_and_replace(Symbol(self.original.strides[dim]), Symbol(1))
197
+ new_strides = leading_strides + (flattening_strides[-1],) + trailing_strides
159
198
 
160
- return offsets
199
+ leading_source_dims = self.source_dims[:start_dim]
200
+ flattening_source_dims = self.source_dims[start_dim:end_dim]
201
+ trailing_source_dims = self.source_dims[end_dim:]
161
202
 
162
- def indices(self, index=None):
163
- if index is None:
164
- index = call("program_id", 0)
203
+ new_source_dims = (
204
+ leading_source_dims + (flattening_source_dims,) + trailing_source_dims
205
+ )
165
206
 
166
- indices = []
207
+ return type(self)(
208
+ shape=new_shape,
209
+ dtype=self.dtype,
210
+ strides=new_strides,
211
+ source=self.source,
212
+ source_dims=new_source_dims,
213
+ target=self.target,
214
+ target_dims=self.target_dims,
215
+ )
167
216
 
168
- for stride in type(self)(shape=self.shape, original=self.original).strides:
169
- indices.append(index // stride)
170
- index %= stride
217
+ def ravel(self):
218
+ # TODO: Add error handling.
219
+ new_shape = []
220
+ new_strides = []
171
221
 
172
- curr = self.dtype
222
+ curr = self
173
223
 
174
- while isinstance(curr.dtype, type(self)):
175
- for _ in range(curr.ndim):
176
- indices.append(0)
224
+ while isinstance(curr, type(self)):
225
+ new_shape.extend(curr.shape)
226
+ new_strides.extend(curr.strides)
177
227
 
178
228
  curr = curr.dtype
179
229
 
180
- if isinstance(curr, type(self)):
181
- for dim in range(curr.ndim):
182
- size = curr.shape[dim]
183
-
184
- if Symbol.is_name(size):
185
- name = size.node.id
186
- if not naming.is_meta(name):
187
- size = naming.make_next_power_of_2(name)
230
+ return type(self)(
231
+ shape=new_shape,
232
+ strides=new_strides,
233
+ other=self.source.other,
234
+ name=self.source.name,
235
+ )
188
236
 
189
- indices.append(call("arange", 0, size))
237
+ def names(self):
238
+ if self.ndim == 0:
239
+ return {self.source.name}
190
240
 
191
- return tuple(indices)
241
+ return (
242
+ {self.source.pointer_string()}
243
+ | {
244
+ name
245
+ for value in itertools.chain(self.shape, self.strides)
246
+ if isinstance(value, Symbol)
247
+ for name in value.names()
248
+ }
249
+ | (self.dtype.names() if isinstance(self.dtype, type(self)) else set())
250
+ | (self.source.names() if self.source is not self else set())
251
+ )
192
252
 
193
253
  def inmost(self):
194
254
  if not isinstance(self.dtype, type(self)):
@@ -237,6 +297,22 @@ class Tensor:
237
297
  def ndim(self):
238
298
  return len(self.shape)
239
299
 
300
+ @property
301
+ def source_dims(self):
302
+ return self._source_dims
303
+
304
+ @source_dims.setter
305
+ def source_dims(self, value):
306
+ self._source_dims = tuple(value)
307
+
308
+ @property
309
+ def target_dims(self):
310
+ return self._target_dims
311
+
312
+ @target_dims.setter
313
+ def target_dims(self, value):
314
+ self._target_dims = tuple(value)
315
+
240
316
  @staticmethod
241
317
  def pointer_pattern():
242
318
  return re.compile(rf"({_identifier_pattern_raw_string()})_(pointer)")
@@ -249,16 +325,6 @@ class Tensor:
249
325
  def stride_pattern():
250
326
  return re.compile(rf"({_identifier_pattern_raw_string()})_(stride)_(.+)")
251
327
 
252
- def _dims_of(self, stride):
253
- dims = set()
254
- names = stride.names() if isinstance(stride, Symbol) else {stride}
255
-
256
- for dim, original_stride in enumerate(self.original.strides):
257
- if str(original_stride) in names:
258
- dims.add(dim)
259
-
260
- return dims
261
-
262
328
  @staticmethod
263
329
  def _calculate_default_strides(shape):
264
330
  strides = [1]
ninetoothed/torchifier.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import ast
2
2
 
3
+ import ninetoothed.naming as naming
3
4
  from ninetoothed.tensor import Tensor
4
5
 
5
6
 
@@ -9,6 +10,9 @@ class Torchifier(ast.NodeTransformer):
9
10
 
10
11
  source = node.id
11
12
 
13
+ if naming.is_constexpr(source):
14
+ return node
15
+
12
16
  def repl(match):
13
17
  return f"{match.group(1)}"
14
18
 
@@ -1,10 +1,11 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: ninetoothed
3
- Version: 0.7.0
3
+ Version: 0.9.0
4
4
  Summary: A domain-specific language based on Triton but providing higher-level abstraction.
5
5
  Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
6
6
  Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
7
7
  Author-email: Jiacheng Huang <huangjiacheng0709@outlook.com>
8
+ License-File: LICENSE
8
9
  Classifier: License :: OSI Approved :: Apache Software License
9
10
  Classifier: Operating System :: OS Independent
10
11
  Classifier: Programming Language :: Python :: 3
@@ -0,0 +1,11 @@
1
+ ninetoothed/__init__.py,sha256=dX34sk5GA3OgWf1Jc4gJMW3UwcGcJsuG3hs3rkiqq6g,161
2
+ ninetoothed/jit.py,sha256=B3q32ksKTRr7I4jLcoDXjwEx7A_Awz9DXGEmIkrtoBc,23393
3
+ ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
4
+ ninetoothed/naming.py,sha256=Fl0x4eDRStTpkXjJg6179ErEnY7bR5Qi0AT6RX9C3fU,951
5
+ ninetoothed/symbol.py,sha256=rZ5nXtn-U1Nw0BBRJ-kfrwmX_zCbAi76un-Z2QFaoZc,4773
6
+ ninetoothed/tensor.py,sha256=OU6lVjzKU614mk3EN1AAgTDradbvJkyl22AXwdhxcfs,9577
7
+ ninetoothed/torchifier.py,sha256=aDijK5UOwK2oLXDHgDo8M959rJclEI0lcfaPr7GQTXY,1012
8
+ ninetoothed-0.9.0.dist-info/METADATA,sha256=LKtSbgc_mWKJ4L_8BYiRKWjoV-JKjr4hefhrkdPmrHs,7054
9
+ ninetoothed-0.9.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
10
+ ninetoothed-0.9.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
11
+ ninetoothed-0.9.0.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.26.3
2
+ Generator: hatchling 1.27.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -1,11 +0,0 @@
1
- ninetoothed/__init__.py,sha256=T5UJXlC-wbo8JKPbLUNT65Kccp12xP52WFV5FsugETI,147
2
- ninetoothed/jit.py,sha256=y0y3gfcBNeRLhozIzLuDHLLzBAGrL8tLk8Rcvhw_uec,20068
3
- ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
4
- ninetoothed/naming.py,sha256=3FBnC-S3dAZRcBcob9SrcVpVEYE5IXRacwkCiA3vIGU,891
5
- ninetoothed/symbol.py,sha256=rZ5nXtn-U1Nw0BBRJ-kfrwmX_zCbAi76un-Z2QFaoZc,4773
6
- ninetoothed/tensor.py,sha256=KL6Iw2nwaRnZsNTxOeghDeXUQUgphnVh_1fsmAObOtI,7391
7
- ninetoothed/torchifier.py,sha256=8M2PDwyFIfVypX6Z-Vt_bGbsCPqxqKnftL0rXeh9bOM,911
8
- ninetoothed-0.7.0.dist-info/METADATA,sha256=gRLAzPSdxWYff8SqPuwJ3ji3Z_kPq_vtEupnSe7dIQ0,7032
9
- ninetoothed-0.7.0.dist-info/WHEEL,sha256=C2FUgwZgiLbznR-k0b_5k3Ai_1aASOXDss3lzCUsUug,87
10
- ninetoothed-0.7.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
11
- ninetoothed-0.7.0.dist-info/RECORD,,