tinygrad 0.7.0__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygrad/__init__.py +6 -0
 - tinygrad/codegen/kernel.py +572 -83
 - tinygrad/codegen/linearizer.py +415 -395
 - tinygrad/codegen/uops.py +415 -0
 - tinygrad/device.py +183 -0
 - tinygrad/dtype.py +113 -0
 - tinygrad/engine/__init__.py +0 -0
 - tinygrad/engine/graph.py +100 -0
 - tinygrad/engine/jit.py +195 -0
 - tinygrad/engine/realize.py +191 -0
 - tinygrad/engine/schedule.py +362 -0
 - tinygrad/engine/search.py +196 -0
 - tinygrad/{mlops.py → function.py} +76 -55
 - tinygrad/helpers.py +196 -89
 - tinygrad/lazy.py +210 -371
 - tinygrad/multi.py +169 -0
 - tinygrad/nn/__init__.py +202 -22
 - tinygrad/nn/datasets.py +7 -0
 - tinygrad/nn/optim.py +112 -32
 - tinygrad/nn/state.py +136 -39
 - tinygrad/ops.py +119 -202
 - tinygrad/renderer/__init__.py +61 -0
 - tinygrad/renderer/assembly.py +276 -0
 - tinygrad/renderer/cstyle.py +353 -166
 - tinygrad/renderer/llvmir.py +150 -138
 - tinygrad/runtime/autogen/amd_gpu.py +1900 -0
 - tinygrad/runtime/autogen/comgr.py +865 -0
 - tinygrad/runtime/autogen/cuda.py +5923 -0
 - tinygrad/runtime/autogen/hip.py +5909 -0
 - tinygrad/runtime/autogen/hsa.py +5761 -0
 - tinygrad/runtime/autogen/kfd.py +812 -0
 - tinygrad/runtime/autogen/nv_gpu.py +33328 -0
 - tinygrad/runtime/autogen/opencl.py +1795 -0
 - tinygrad/runtime/driver/hip_comgr.py +47 -0
 - tinygrad/runtime/driver/hsa.py +143 -0
 - tinygrad/runtime/graph/clang.py +38 -0
 - tinygrad/runtime/graph/cuda.py +81 -0
 - tinygrad/runtime/graph/hcq.py +143 -0
 - tinygrad/runtime/graph/hsa.py +171 -0
 - tinygrad/runtime/graph/metal.py +75 -0
 - tinygrad/runtime/ops_amd.py +564 -0
 - tinygrad/runtime/ops_clang.py +24 -77
 - tinygrad/runtime/ops_cuda.py +175 -89
 - tinygrad/runtime/ops_disk.py +56 -33
 - tinygrad/runtime/ops_gpu.py +92 -95
 - tinygrad/runtime/ops_hsa.py +278 -0
 - tinygrad/runtime/ops_llvm.py +39 -60
 - tinygrad/runtime/ops_metal.py +92 -74
 - tinygrad/runtime/ops_npy.py +9 -0
 - tinygrad/runtime/ops_nv.py +630 -0
 - tinygrad/runtime/ops_python.py +204 -0
 - tinygrad/shape/shapetracker.py +86 -254
 - tinygrad/shape/symbolic.py +166 -141
 - tinygrad/shape/view.py +296 -0
 - tinygrad/tensor.py +2619 -448
 - {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/LICENSE +1 -1
 - tinygrad-0.9.0.dist-info/METADATA +227 -0
 - tinygrad-0.9.0.dist-info/RECORD +60 -0
 - {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/WHEEL +1 -1
 - tinygrad/codegen/assembly.py +0 -190
 - tinygrad/codegen/optimizer.py +0 -379
 - tinygrad/codegen/search.py +0 -72
 - tinygrad/graph.py +0 -83
 - tinygrad/jit.py +0 -57
 - tinygrad/nn/image.py +0 -100
 - tinygrad/renderer/assembly_arm64.py +0 -169
 - tinygrad/renderer/assembly_ptx.py +0 -98
 - tinygrad/renderer/wgsl.py +0 -53
 - tinygrad/runtime/lib.py +0 -113
 - tinygrad/runtime/ops_cpu.py +0 -51
 - tinygrad/runtime/ops_hip.py +0 -82
 - tinygrad/runtime/ops_shm.py +0 -29
 - tinygrad/runtime/ops_torch.py +0 -30
 - tinygrad/runtime/ops_webgpu.py +0 -45
 - tinygrad-0.7.0.dist-info/METADATA +0 -212
 - tinygrad-0.7.0.dist-info/RECORD +0 -40
 - {tinygrad-0.7.0.dist-info → tinygrad-0.9.0.dist-info}/top_level.txt +0 -0
 
    
        tinygrad/shape/symbolic.py
    CHANGED
    
    | 
         @@ -1,80 +1,62 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            from __future__ import annotations
         
     | 
| 
       2 
     | 
    
         
            -
            from abc import abstractmethod
         
     | 
| 
       3 
2 
     | 
    
         
             
            import functools
         
     | 
| 
       4 
3 
     | 
    
         
             
            from math import gcd
         
     | 
| 
       5 
4 
     | 
    
         
             
            from tinygrad.helpers import partition
         
     | 
| 
       6 
     | 
    
         
            -
            from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any
         
     | 
| 
      
 5 
     | 
    
         
            +
            from typing import List, Dict, Callable, Tuple, Type, Union, Optional, Any, Set, Mapping
         
     | 
| 
       7 
6 
     | 
    
         | 
| 
       8 
7 
     | 
    
         
             
            # NOTE: Python has different behavior for negative mod and floor div than c
         
     | 
| 
       9 
8 
     | 
    
         
             
            # symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
         
     | 
| 
       10 
9 
     | 
    
         | 
| 
       11 
     | 
    
         
            -
            def is_sym_int(x: Any) -> bool: return isinstance(x, (int, Node))
         
     | 
| 
       12 
     | 
    
         
            -
             
     | 
| 
       13 
10 
     | 
    
         
             
            class Node:
         
     | 
| 
       14 
11 
     | 
    
         
             
              b: Union[Node, int]
         
     | 
| 
       15 
12 
     | 
    
         
             
              min: int
         
     | 
| 
       16 
     | 
    
         
            -
              max:  
     | 
| 
       17 
     | 
    
         
            -
              def render(self, ops=None, ctx=None 
     | 
| 
      
 13 
     | 
    
         
            +
              max: sint
         
     | 
| 
      
 14 
     | 
    
         
            +
              def render(self, ops=None, ctx=None) -> Any:
         
     | 
| 
       18 
15 
     | 
    
         
             
                if ops is None: ops = render_python
         
     | 
| 
       19 
16 
     | 
    
         
             
                assert self.__class__ in (Variable, NumNode) or self.min != self.max
         
     | 
| 
       20 
     | 
    
         
            -
                 
     | 
| 
       21 
     | 
    
         
            -
             
     | 
| 
       22 
     | 
    
         
            -
             
     | 
| 
       23 
     | 
    
         
            -
              def  
     | 
| 
      
 17 
     | 
    
         
            +
                return ops[type(self)](self, ops, ctx)
         
     | 
| 
      
 18 
     | 
    
         
            +
              def vars(self) -> Set[Variable]: return set()
         
     | 
| 
      
 19 
     | 
    
         
            +
              # substitute Variables with the values in var_vals
         
     | 
| 
      
 20 
     | 
    
         
            +
              def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: raise RuntimeError(self.__class__.__name__)
         
     | 
| 
      
 21 
     | 
    
         
            +
              def unbind(self) -> Tuple[Node, Optional[int]]: return self.substitute({v: v.unbind()[0] for v in self.vars() if v.val is not None}), None
         
     | 
| 
      
 22 
     | 
    
         
            +
             
     | 
| 
       24 
23 
     | 
    
         
             
              @functools.cached_property
         
     | 
| 
       25 
24 
     | 
    
         
             
              def key(self) -> str: return self.render(ctx="DEBUG")
         
     | 
| 
       26 
25 
     | 
    
         
             
              @functools.cached_property
         
     | 
| 
       27 
26 
     | 
    
         
             
              def hash(self) -> int: return hash(self.key)
         
     | 
| 
       28 
     | 
    
         
            -
              def __repr__(self): return  
     | 
| 
      
 27 
     | 
    
         
            +
              def __repr__(self): return self.render(ctx="REPR")
         
     | 
| 
      
 28 
     | 
    
         
            +
              def __str__(self): return "<"+self.key+">"
         
     | 
| 
       29 
29 
     | 
    
         
             
              def __hash__(self): return self.hash
         
     | 
| 
       30 
30 
     | 
    
         
             
              def __bool__(self): return not (self.max == self.min == 0)
         
     | 
| 
       31 
31 
     | 
    
         
             
              def __eq__(self, other:object) -> bool:
         
     | 
| 
       32 
32 
     | 
    
         
             
                if not isinstance(other, Node): return NotImplemented
         
     | 
| 
       33 
33 
     | 
    
         
             
                return self.key == other.key
         
     | 
| 
       34 
34 
     | 
    
         
             
              def __neg__(self): return self*-1
         
     | 
| 
       35 
     | 
    
         
            -
              def __add__(self, b:Union[Node,int]): return  
     | 
| 
      
 35 
     | 
    
         
            +
              def __add__(self, b:Union[Node,int]): return Node.sum([self, NumNode(b) if isinstance(b, int) else b])
         
     | 
| 
       36 
36 
     | 
    
         
             
              def __radd__(self, b:int): return self+b
         
     | 
| 
       37 
37 
     | 
    
         
             
              def __sub__(self, b:Union[Node,int]): return self+-b
         
     | 
| 
       38 
38 
     | 
    
         
             
              def __rsub__(self, b:int): return -self+b
         
     | 
| 
       39 
39 
     | 
    
         
             
              def __le__(self, b:Union[Node,int]): return self < (b+1)
         
     | 
| 
       40 
40 
     | 
    
         
             
              def __gt__(self, b:Union[Node,int]): return (-self) < (-b)
         
     | 
| 
       41 
41 
     | 
    
         
             
              def __ge__(self, b:Union[Node,int]): return (-self) < (-b+1)
         
     | 
| 
       42 
     | 
    
         
            -
              def __lt__(self, b:Union[Node,int]):
         
     | 
| 
       43 
     | 
    
         
            -
                lhs = self
         
     | 
| 
       44 
     | 
    
         
            -
                if isinstance(lhs, SumNode) and isinstance(b, int):
         
     | 
| 
       45 
     | 
    
         
            -
                  muls, others = partition(lhs.nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
         
     | 
| 
       46 
     | 
    
         
            -
                  if muls:
         
     | 
| 
       47 
     | 
    
         
            -
                    # NOTE: gcd in python 3.8 takes exactly 2 args
         
     | 
| 
       48 
     | 
    
         
            -
                    mul_gcd = muls[0].b
         
     | 
| 
       49 
     | 
    
         
            -
                    for x in muls[1:]: mul_gcd = gcd(mul_gcd, x.b)
         
     | 
| 
       50 
     | 
    
         
            -
                    if b%mul_gcd == 0:
         
     | 
| 
       51 
     | 
    
         
            -
                      all_others = Variable.sum(others)
         
     | 
| 
       52 
     | 
    
         
            -
                      #print(mul_gcd, muls, all_others)
         
     | 
| 
       53 
     | 
    
         
            -
                      if all_others.min >= 0 and all_others.max < mul_gcd:
         
     | 
| 
       54 
     | 
    
         
            -
                        # TODO: should we divide both by mul_gcd here?
         
     | 
| 
       55 
     | 
    
         
            -
                        lhs = Variable.sum(muls)
         
     | 
| 
       56 
     | 
    
         
            -
                return create_node(LtNode(lhs, b))
         
     | 
| 
      
 42 
     | 
    
         
            +
              def __lt__(self, b:Union[Node,int]): return create_node(LtNode(self, b))
         
     | 
| 
       57 
43 
     | 
    
         
             
              def __mul__(self, b:Union[Node, int]):
         
     | 
| 
       58 
44 
     | 
    
         
             
                if b == 0: return NumNode(0)
         
     | 
| 
       59 
45 
     | 
    
         
             
                if b == 1: return self
         
     | 
| 
       60 
     | 
    
         
            -
                if self.__class__ is NumNode: return NumNode(self.b*b) if isinstance(b, int) else b*self.b
         
     | 
| 
       61 
46 
     | 
    
         
             
                return create_node(MulNode(self, b.b)) if isinstance(b, NumNode) else create_node(MulNode(self, b))
         
     | 
| 
       62 
47 
     | 
    
         
             
              def __rmul__(self, b:int): return self*b
         
     | 
| 
       63 
48 
     | 
    
         | 
| 
       64 
49 
     | 
    
         
             
              # *** complex ops ***
         
     | 
| 
       65 
50 
     | 
    
         | 
| 
       66 
     | 
    
         
            -
              def __rfloordiv__(self, b:int):
         
     | 
| 
       67 
     | 
    
         
            -
                if self.min > b >= 0: return NumNode(0)
         
     | 
| 
       68 
     | 
    
         
            -
                if isinstance(self, NumNode): return NumNode(b // self.b)
         
     | 
| 
       69 
     | 
    
         
            -
                raise RuntimeError(f"not supported: {b} // {self}")
         
     | 
| 
      
 51 
     | 
    
         
            +
              def __rfloordiv__(self, b:int): return NumNode(b) // self
         
     | 
| 
       70 
52 
     | 
    
         
             
              def __floordiv__(self, b:Union[Node,int], factoring_allowed=True):
         
     | 
| 
       71 
53 
     | 
    
         
             
                if isinstance(b, Node):
         
     | 
| 
       72 
     | 
    
         
            -
                  if b.__class__ is NumNode: return self 
     | 
| 
      
 54 
     | 
    
         
            +
                  if b.__class__ is NumNode: return self.__floordiv__(b.b, factoring_allowed)
         
     | 
| 
       73 
55 
     | 
    
         
             
                  if self == b: return NumNode(1)
         
     | 
| 
       74 
56 
     | 
    
         
             
                  if (b - self).min > 0 and self.min >= 0: return NumNode(0) # b - self simplifies the node
         
     | 
| 
       75 
57 
     | 
    
         
             
                  raise RuntimeError(f"not supported: {self} // {b}")
         
     | 
| 
       76 
58 
     | 
    
         
             
                assert b != 0
         
     | 
| 
       77 
     | 
    
         
            -
                if b < 0: return (self 
     | 
| 
      
 59 
     | 
    
         
            +
                if b < 0: return (self*-1).__floordiv__(-b, factoring_allowed)
         
     | 
| 
       78 
60 
     | 
    
         
             
                if b == 1: return self
         
     | 
| 
       79 
61 
     | 
    
         | 
| 
       80 
62 
     | 
    
         
             
                # the numerator of div is not allowed to be negative
         
     | 
| 
         @@ -84,10 +66,7 @@ class Node: 
     | 
|
| 
       84 
66 
     | 
    
         
             
                  return (self + -offset*b).__floordiv__(b, factoring_allowed=False) + offset
         
     | 
| 
       85 
67 
     | 
    
         
             
                return create_node(DivNode(self, b))
         
     | 
| 
       86 
68 
     | 
    
         | 
| 
       87 
     | 
    
         
            -
              def __rmod__(self, b:int):
         
     | 
| 
       88 
     | 
    
         
            -
                if self.min > b >= 0: return NumNode(b)
         
     | 
| 
       89 
     | 
    
         
            -
                if isinstance(self, NumNode): return NumNode(b % self.b)
         
     | 
| 
       90 
     | 
    
         
            -
                raise RuntimeError(f"not supported: {b} % {self}")
         
     | 
| 
      
 69 
     | 
    
         
            +
              def __rmod__(self, b:int): return NumNode(b) % self
         
     | 
| 
       91 
70 
     | 
    
         
             
              def __mod__(self, b:Union[Node,int]):
         
     | 
| 
       92 
71 
     | 
    
         
             
                if isinstance(b, Node):
         
     | 
| 
       93 
72 
     | 
    
         
             
                  if b.__class__ is NumNode: return self % b.b
         
     | 
| 
         @@ -96,37 +75,27 @@ class Node: 
     | 
|
| 
       96 
75 
     | 
    
         
             
                  raise RuntimeError(f"not supported: {self} % {b}")
         
     | 
| 
       97 
76 
     | 
    
         
             
                assert b > 0
         
     | 
| 
       98 
77 
     | 
    
         
             
                if b == 1: return NumNode(0)
         
     | 
| 
       99 
     | 
    
         
            -
                if self. 
     | 
| 
       100 
     | 
    
         
            -
             
     | 
| 
      
 78 
     | 
    
         
            +
                if isinstance(self.max, int) and isinstance(self.min, int):
         
     | 
| 
      
 79 
     | 
    
         
            +
                  if self.min >= 0 and self.max < b: return self
         
     | 
| 
      
 80 
     | 
    
         
            +
                  if (self.min//b) == (self.max//b): return self - (b*(self.min//b))
         
     | 
| 
      
 81 
     | 
    
         
            +
                  if self.min < 0: return (self - ((self.min//b)*b)) % b
         
     | 
| 
       101 
82 
     | 
    
         
             
                return create_node(ModNode(self, b))
         
     | 
| 
       102 
83 
     | 
    
         | 
| 
       103 
     | 
    
         
            -
              @staticmethod
         
     | 
| 
       104 
     | 
    
         
            -
              def num(num:int) -> NumNode: return NumNode(num)
         
     | 
| 
       105 
     | 
    
         
            -
             
     | 
| 
       106 
     | 
    
         
            -
              @staticmethod
         
     | 
| 
       107 
     | 
    
         
            -
              def factorize(nodes:List[Node]) -> List[Node]:
         
     | 
| 
       108 
     | 
    
         
            -
                mul_groups: Dict[Node, int] = {}
         
     | 
| 
       109 
     | 
    
         
            -
                for x in nodes:
         
     | 
| 
       110 
     | 
    
         
            -
                  a,b = (x.a,x.b) if isinstance(x, MulNode) else (x,1)
         
     | 
| 
       111 
     | 
    
         
            -
                  mul_groups[a] = mul_groups.get(a, 0) + b
         
     | 
| 
       112 
     | 
    
         
            -
                return [MulNode(a, b_sum) if b_sum != 1 else a for a, b_sum in mul_groups.items() if b_sum != 0]
         
     | 
| 
       113 
     | 
    
         
            -
             
     | 
| 
       114 
84 
     | 
    
         
             
              @staticmethod
         
     | 
| 
       115 
85 
     | 
    
         
             
              def sum(nodes:List[Node]) -> Node:
         
     | 
| 
       116 
86 
     | 
    
         
             
                nodes = [x for x in nodes if x.max or x.min]
         
     | 
| 
       117 
87 
     | 
    
         
             
                if not nodes: return NumNode(0)
         
     | 
| 
       118 
88 
     | 
    
         
             
                if len(nodes) == 1: return nodes[0]
         
     | 
| 
       119 
89 
     | 
    
         | 
| 
       120 
     | 
    
         
            -
                 
     | 
| 
      
 90 
     | 
    
         
            +
                mul_groups: Dict[Node, int] = {}
         
     | 
| 
       121 
91 
     | 
    
         
             
                num_node_sum = 0
         
     | 
| 
       122 
92 
     | 
    
         
             
                for node in SumNode(nodes).flat_components:
         
     | 
| 
       123 
93 
     | 
    
         
             
                  if node.__class__ is NumNode: num_node_sum += node.b
         
     | 
| 
       124 
     | 
    
         
            -
                   
     | 
| 
       125 
     | 
    
         
            -
             
     | 
| 
       126 
     | 
    
         
            -
                 
     | 
| 
       127 
     | 
    
         
            -
                  new_nodes = Node.factorize(new_nodes)
         
     | 
| 
      
 94 
     | 
    
         
            +
                  elif node.__class__ is MulNode: mul_groups[node.a] = mul_groups.get(node.a, 0) + node.b
         
     | 
| 
      
 95 
     | 
    
         
            +
                  else: mul_groups[node] = mul_groups.get(node, 0) + 1
         
     | 
| 
      
 96 
     | 
    
         
            +
                new_nodes = [MulNode(a, b_sum) if b_sum != 1 else a for a, b_sum in mul_groups.items() if b_sum != 0]
         
     | 
| 
       128 
97 
     | 
    
         
             
                if num_node_sum: new_nodes.append(NumNode(num_node_sum))
         
     | 
| 
       129 
     | 
    
         
            -
                return  
     | 
| 
      
 98 
     | 
    
         
            +
                return create_node(SumNode(new_nodes)) if len(new_nodes) > 1 else new_nodes[0] if len(new_nodes) == 1 else NumNode(0)
         
     | 
| 
       130 
99 
     | 
    
         | 
| 
       131 
100 
     | 
    
         
             
              @staticmethod
         
     | 
| 
       132 
101 
     | 
    
         
             
              def ands(nodes:List[Node]) -> Node:
         
     | 
| 
         @@ -136,48 +105,96 @@ class Node: 
     | 
|
| 
       136 
105 
     | 
    
         | 
| 
       137 
106 
     | 
    
         
             
                # filter 1s
         
     | 
| 
       138 
107 
     | 
    
         
             
                nodes = [x for x in nodes if x.min != x.max]
         
     | 
| 
       139 
     | 
    
         
            -
                return  
     | 
| 
      
 108 
     | 
    
         
            +
                return create_node(AndNode(nodes)) if len(nodes) > 1 else (nodes[0] if len(nodes) == 1 else NumNode(1))
         
     | 
| 
       140 
109 
     | 
    
         | 
| 
       141 
110 
     | 
    
         
             
            # 4 basic node types
         
     | 
| 
       142 
111 
     | 
    
         | 
| 
       143 
112 
     | 
    
         
             
            class Variable(Node):
         
     | 
| 
       144 
     | 
    
         
            -
              def __new__(cls,  
     | 
| 
       145 
     | 
    
         
            -
                 
     | 
| 
      
 113 
     | 
    
         
            +
              def __new__(cls, *args):
         
     | 
| 
      
 114 
     | 
    
         
            +
                expr, nmin, nmax = args
         
     | 
| 
      
 115 
     | 
    
         
            +
                assert nmin >= 0 and nmin <= nmax, f"invalid Variable {expr=} {nmin=} {nmax=}"
         
     | 
| 
       146 
116 
     | 
    
         
             
                if nmin == nmax: return NumNode(nmin)
         
     | 
| 
       147 
117 
     | 
    
         
             
                return super().__new__(cls)
         
     | 
| 
       148 
118 
     | 
    
         | 
| 
       149 
     | 
    
         
            -
              def  
     | 
| 
      
 119 
     | 
    
         
            +
              def __getnewargs__(self): return (self.expr, self.min, self.max)  # args passed to __new__ when unpickling
         
     | 
| 
      
 120 
     | 
    
         
            +
             
     | 
| 
      
 121 
     | 
    
         
            +
              def __init__(self, expr:str, nmin:int, nmax:sint):
         
     | 
| 
       150 
122 
     | 
    
         
             
                self.expr, self.min, self.max = expr, nmin, nmax
         
     | 
| 
       151 
     | 
    
         
            -
             
     | 
| 
      
 123 
     | 
    
         
            +
                self._val: Optional[int] = None
         
     | 
| 
      
 124 
     | 
    
         
            +
              @property
         
     | 
| 
      
 125 
     | 
    
         
            +
              def val(self):
         
     | 
| 
      
 126 
     | 
    
         
            +
                assert self._val is not None, f"Variable isn't bound, can't access val of {self}"
         
     | 
| 
      
 127 
     | 
    
         
            +
                return self._val
         
     | 
| 
      
 128 
     | 
    
         
            +
              def bind(self, val):
         
     | 
| 
      
 129 
     | 
    
         
            +
                assert self._val is None and self.min<=val<=self.max, f"cannot bind {val} to {self}"
         
     | 
| 
      
 130 
     | 
    
         
            +
                self._val = val
         
     | 
| 
      
 131 
     | 
    
         
            +
                return self
         
     | 
| 
      
 132 
     | 
    
         
            +
              def unbind(self) -> Tuple[Variable, int]:
         
     | 
| 
      
 133 
     | 
    
         
            +
                assert self.val is not None, f"cannot unbind {self}"
         
     | 
| 
      
 134 
     | 
    
         
            +
                return Variable(self.expr, self.min, self.max), self.val
         
     | 
| 
      
 135 
     | 
    
         
            +
              def vars(self): return {self}
         
     | 
| 
      
 136 
     | 
    
         
            +
              def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return var_vals.get(self, self)
         
     | 
| 
       152 
137 
     | 
    
         | 
| 
       153 
138 
     | 
    
         
             
            class NumNode(Node):
         
     | 
| 
       154 
139 
     | 
    
         
             
              def __init__(self, num:int):
         
     | 
| 
      
 140 
     | 
    
         
            +
                assert isinstance(num, int), f"{num} is not an int"
         
     | 
| 
       155 
141 
     | 
    
         
             
                self.b:int = num
         
     | 
| 
       156 
142 
     | 
    
         
             
                self.min, self.max = num, num
         
     | 
| 
       157 
     | 
    
         
            -
              def  
     | 
| 
       158 
     | 
    
         
            -
             
     | 
| 
      
 143 
     | 
    
         
            +
              def bind(self, val):
         
     | 
| 
      
 144 
     | 
    
         
            +
                assert self.b == val, f"cannot bind {val} to {self}"
         
     | 
| 
      
 145 
     | 
    
         
            +
                return self
         
     | 
| 
      
 146 
     | 
    
         
            +
              def __mul__(self, b:Union[Node,int]): return NumNode(self.b*b) if isinstance(b, int) else b*self.b
         
     | 
| 
       159 
147 
     | 
    
         
             
              def __eq__(self, other): return self.b == other
         
     | 
| 
       160 
     | 
    
         
            -
              def __hash__(self): return self. 
     | 
| 
      
 148 
     | 
    
         
            +
              def __hash__(self): return hash(self.b)  # needed with __eq__ override
         
     | 
| 
      
 149 
     | 
    
         
            +
              def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self
         
     | 
| 
       161 
150 
     | 
    
         | 
| 
       162 
151 
     | 
    
         
             
            def create_node(ret:Node):
         
     | 
| 
       163 
152 
     | 
    
         
             
              assert ret.min <= ret.max, f"min greater than max! {ret.min} {ret.max} when creating {type(ret)} {ret}"
         
     | 
| 
       164 
153 
     | 
    
         
             
              if ret.min == ret.max: return NumNode(ret.min)
         
     | 
| 
       165 
154 
     | 
    
         
             
              return ret
         
     | 
| 
       166 
155 
     | 
    
         | 
| 
      
 156 
     | 
    
         
            +
            def create_lt_node(lhs:Node, b:Union[Node, int]):
         
     | 
| 
      
 157 
     | 
    
         
            +
              if isinstance(lhs, SumNode):
         
     | 
| 
      
 158 
     | 
    
         
            +
                if isinstance(b, int):
         
     | 
| 
      
 159 
     | 
    
         
            +
                  new_sum = []
         
     | 
| 
      
 160 
     | 
    
         
            +
                  for x in lhs.nodes:
         
     | 
| 
      
 161 
     | 
    
         
            +
                    # TODO: should we just force the last one to always be the number
         
     | 
| 
      
 162 
     | 
    
         
            +
                    if isinstance(x, NumNode): b -= x.b
         
     | 
| 
      
 163 
     | 
    
         
            +
                    else: new_sum.append(x)
         
     | 
| 
      
 164 
     | 
    
         
            +
                  lhs = Node.sum(new_sum)
         
     | 
| 
      
 165 
     | 
    
         
            +
                  nodes = lhs.nodes if isinstance(lhs, SumNode) else [lhs]
         
     | 
| 
      
 166 
     | 
    
         
            +
                  assert all(not isinstance(node, MulNode) or isinstance(node.b, int) for node in nodes), "not supported"
         
     | 
| 
      
 167 
     | 
    
         
            +
                  muls, others = partition(nodes, lambda x: isinstance(x, MulNode) and x.b > 0 and x.max >= b)
         
     | 
| 
      
 168 
     | 
    
         
            +
                  if muls:
         
     | 
| 
      
 169 
     | 
    
         
            +
                    # NOTE: gcd in python 3.8 takes exactly 2 args
         
     | 
| 
      
 170 
     | 
    
         
            +
                    mul_gcd = b
         
     | 
| 
      
 171 
     | 
    
         
            +
                    for x in muls: mul_gcd = gcd(mul_gcd, x.b)  # type: ignore  # mypy cannot tell that x.b is int here due to assert above
         
     | 
| 
      
 172 
     | 
    
         
            +
                    all_others = Node.sum(others)
         
     | 
| 
      
 173 
     | 
    
         
            +
                    if all_others.min >= 0 and all_others.max < mul_gcd:
         
     | 
| 
      
 174 
     | 
    
         
            +
                      lhs, b = Node.sum([mul//mul_gcd for mul in muls]), b//mul_gcd
         
     | 
| 
      
 175 
     | 
    
         
            +
                return create_node(LtNode(lhs, b)) if isinstance(lhs, SumNode) else create_lt_node(lhs, b)
         
     | 
| 
      
 176 
     | 
    
         
            +
              if isinstance(lhs, MulNode):
         
     | 
| 
      
 177 
     | 
    
         
            +
                if isinstance(b, Node) or isinstance(lhs.b, Node) or lhs.b == -1: return create_node(LtNode(lhs, b))
         
     | 
| 
      
 178 
     | 
    
         
            +
                sgn = 1 if lhs.b > 0 else -1
         
     | 
| 
      
 179 
     | 
    
         
            +
                return create_node(LtNode(lhs.a*sgn, (b + abs(lhs.b) - 1)//abs(lhs.b)))
         
     | 
| 
      
 180 
     | 
    
         
            +
              return create_node(LtNode(lhs, b))
         
     | 
| 
      
 181 
     | 
    
         
            +
             
     | 
| 
      
 182 
     | 
    
         
            +
            def create_ge_node(lhs:Node, b:Union[Node, int]): return create_lt_node(-lhs, -b+1)
         
     | 
| 
      
 183 
     | 
    
         
            +
             
     | 
| 
       167 
184 
     | 
    
         
             
            class OpNode(Node):
         
     | 
| 
       168 
185 
     | 
    
         
             
              def __init__(self, a:Node, b:Union[Node, int]):
         
     | 
| 
       169 
186 
     | 
    
         
             
                self.a, self.b = a, b
         
     | 
| 
       170 
187 
     | 
    
         
             
                self.min, self.max = self.get_bounds()
         
     | 
| 
       171 
     | 
    
         
            -
              def vars(self): return self.a.vars()  
     | 
| 
       172 
     | 
    
         
            -
               
     | 
| 
       173 
     | 
    
         
            -
              def get_bounds(self) -> Tuple[int, int]: pass
         
     | 
| 
      
 188 
     | 
    
         
            +
              def vars(self): return self.a.vars() | (self.b.vars() if isinstance(self.b, Node) else set())
         
     | 
| 
      
 189 
     | 
    
         
            +
              def get_bounds(self) -> Tuple[int, sint]: raise NotImplementedError("must be implemented")
         
     | 
| 
       174 
190 
     | 
    
         | 
| 
       175 
191 
     | 
    
         
             
            class LtNode(OpNode):
         
     | 
| 
       176 
     | 
    
         
            -
              def __mul__(self, b: Union[Node, int]): return (self.a*b) < (self.b*b)
         
     | 
| 
       177 
     | 
    
         
            -
              def __floordiv__(self, b: Union[Node, int], _=False): return (self.a//b) < (self.b//b)
         
     | 
| 
       178 
192 
     | 
    
         
             
              def get_bounds(self) -> Tuple[int, int]:
         
     | 
| 
       179 
     | 
    
         
            -
                if  
     | 
| 
       180 
     | 
    
         
            -
                return (1, 1) if self.a.max < self.b 
     | 
| 
      
 193 
     | 
    
         
            +
                if self.a == self.b: return (0, 0)
         
     | 
| 
      
 194 
     | 
    
         
            +
                if isinstance(self.b, int): return (1, 1) if self.a.max < self.b else (0, 0) if self.a.min >= self.b else (0, 1)
         
     | 
| 
      
 195 
     | 
    
         
            +
                return (1, 1) if self.a.max < self.b.min else (0, 0) if self.a.min >= self.b.max else (0, 1)
         
     | 
| 
      
 196 
     | 
    
         
            +
              def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
         
     | 
| 
      
 197 
     | 
    
         
            +
                return create_lt_node(self.a.substitute(var_vals), (self.b if isinstance(self.b, int) else self.b.substitute(var_vals)))
         
     | 
| 
       181 
198 
     | 
    
         | 
| 
       182 
199 
     | 
    
         
             
            class MulNode(OpNode):
         
     | 
| 
       183 
200 
     | 
    
         
             
              def __mul__(self, b: Union[Node, int]): return self.a*(self.b*b) # two muls in one mul
         
     | 
| 
         @@ -185,57 +202,72 @@ class MulNode(OpNode): 
     | 
|
| 
       185 
202 
     | 
    
         
             
                if self.b % b == 0: return self.a*(self.b//b)
         
     | 
| 
       186 
203 
     | 
    
         
             
                if b % self.b == 0 and self.b > 0: return self.a//(b//self.b)
         
     | 
| 
       187 
204 
     | 
    
         
             
                return Node.__floordiv__(self, b, factoring_allowed)
         
     | 
| 
       188 
     | 
    
         
            -
              def __mod__(self, b: Union[Node, int]):
         
     | 
| 
       189 
     | 
    
         
            -
             
     | 
| 
       190 
     | 
    
         
            -
                 
     | 
| 
       191 
     | 
    
         
            -
             
     | 
| 
       192 
     | 
    
         
            -
                return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
         
     | 
| 
      
 205 
     | 
    
         
            +
              def __mod__(self, b: Union[Node, int]): return Node.__mod__(self.a * (self.b%b), b)
         
     | 
| 
      
 206 
     | 
    
         
            +
              def get_bounds(self) -> Tuple[int, sint]:
         
     | 
| 
      
 207 
     | 
    
         
            +
                assert self.a.min >= 0
         
     | 
| 
      
 208 
     | 
    
         
            +
                if isinstance(self.b, int): return (self.a.min*self.b, self.a.max*self.b) if self.b >= 0 else (self.a.max*self.b, self.a.min*self.b)
         
     | 
| 
      
 209 
     | 
    
         
            +
                return (self.a.min*self.b.min, self.a.max*self.b.max) if self.b.min >= 0 else (self.a.max*self.b.min, self.a.min*self.b.max)
         
     | 
| 
      
 210 
     | 
    
         
            +
              def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
         
     | 
| 
      
 211 
     | 
    
         
            +
                return self.a.substitute(var_vals) * (self.b if isinstance(self.b, int) else self.b.substitute(var_vals))
         
     | 
| 
       193 
212 
     | 
    
         | 
| 
       194 
213 
     | 
    
         
             
            class DivNode(OpNode):
         
     | 
| 
       195 
214 
     | 
    
         
             
              def __floordiv__(self, b: Union[Node, int], _=False): return self.a//(self.b*b) # two divs is one div
         
     | 
| 
       196 
     | 
    
         
            -
              def get_bounds(self) -> Tuple[int,  
     | 
| 
      
 215 
     | 
    
         
            +
              def get_bounds(self) -> Tuple[int, sint]:
         
     | 
| 
       197 
216 
     | 
    
         
             
                assert self.a.min >= 0 and isinstance(self.b, int)
         
     | 
| 
       198 
217 
     | 
    
         
             
                return self.a.min//self.b, self.a.max//self.b
         
     | 
| 
      
 218 
     | 
    
         
            +
              def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self.a.substitute(var_vals) // self.b
         
     | 
| 
       199 
219 
     | 
    
         | 
| 
       200 
220 
     | 
    
         
             
            class ModNode(OpNode):
         
     | 
| 
      
 221 
     | 
    
         
            +
              def __mod__(self, b: Union[Node, int]):
         
     | 
| 
      
 222 
     | 
    
         
            +
                if isinstance(b, int) and isinstance(self.b, int) and self.b % b == 0: return self.a % b
         
     | 
| 
      
 223 
     | 
    
         
            +
                return Node.__mod__(self, b)
         
     | 
| 
       201 
224 
     | 
    
         
             
              def __floordiv__(self, b: Union[Node, int], factoring_allowed=True):
         
     | 
| 
       202 
     | 
    
         
            -
                 
     | 
| 
       203 
     | 
    
         
            -
             
     | 
| 
       204 
     | 
    
         
            -
              def get_bounds(self) -> Tuple[int, int]:
         
     | 
| 
      
 225 
     | 
    
         
            +
                return (self.a//b) % (self.b//b) if self.b % b == 0 else Node.__floordiv__(self, b, factoring_allowed)
         
     | 
| 
      
 226 
     | 
    
         
            +
              def get_bounds(self) -> Tuple[int, sint]:
         
     | 
| 
       205 
227 
     | 
    
         
             
                assert self.a.min >= 0 and isinstance(self.b, int)
         
     | 
| 
       206 
     | 
    
         
            -
                 
     | 
| 
      
 228 
     | 
    
         
            +
                if self.a.max - self.a.min >= self.b or (self.a.min != self.a.max and self.a.min%self.b >= self.a.max%self.b): return (0, self.b-1)
         
     | 
| 
      
 229 
     | 
    
         
            +
                return (self.a.min%self.b, self.a.max%self.b)
         
     | 
| 
      
 230 
     | 
    
         
            +
              def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node: return self.a.substitute(var_vals) % self.b
         
     | 
| 
       207 
231 
     | 
    
         | 
| 
       208 
232 
     | 
    
         
             
            class RedNode(Node):
         
     | 
| 
       209 
     | 
    
         
            -
              def __init__(self, nodes:List[Node]): 
     | 
| 
       210 
     | 
    
         
            -
             
     | 
| 
      
 233 
     | 
    
         
            +
              def __init__(self, nodes:List[Node]):
         
     | 
| 
      
 234 
     | 
    
         
            +
                self.nodes = nodes
         
     | 
| 
      
 235 
     | 
    
         
            +
                self.min, self.max = self.get_bounds()
         
     | 
| 
      
 236 
     | 
    
         
            +
              def vars(self) -> Set[Variable]: return set.union(*[x.vars() for x in self.nodes], set())
         
     | 
| 
      
 237 
     | 
    
         
            +
              def get_bounds(self) -> Tuple[int, sint]: raise NotImplementedError("must be implemented")
         
     | 
| 
       211 
238 
     | 
    
         | 
| 
       212 
239 
     | 
    
         
             
            class SumNode(RedNode):
         
     | 
| 
      
 240 
     | 
    
         
            +
              def get_bounds(self) -> Tuple[int, sint]: return sum([x.min for x in self.nodes]), sum([x.max for x in self.nodes])
         
     | 
| 
      
 241 
     | 
    
         
            +
              @functools.lru_cache(maxsize=None)  # pylint: disable=method-cache-max-size-none
         
     | 
| 
       213 
242 
     | 
    
         
             
              def __mul__(self, b: Union[Node, int]): return Node.sum([x*b for x in self.nodes]) # distribute mul into sum
         
     | 
| 
       214 
     | 
    
         
            -
               
     | 
| 
      
 243 
     | 
    
         
            +
              @functools.lru_cache(maxsize=None)  # pylint: disable=method-cache-max-size-none
         
     | 
| 
      
 244 
     | 
    
         
            +
              def __floordiv__(self, b: Union[Node, sint], factoring_allowed=True):
         
     | 
| 
      
 245 
     | 
    
         
            +
                if self == b: return NumNode(1)
         
     | 
| 
       215 
246 
     | 
    
         
             
                fully_divided: List[Node] = []
         
     | 
| 
       216 
247 
     | 
    
         
             
                rest: List[Node] = []
         
     | 
| 
       217 
     | 
    
         
            -
                if isinstance(b, SumNode):
         
     | 
| 
       218 
     | 
    
         
            -
                  nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode)
         
     | 
| 
       219 
     | 
    
         
            -
                  de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode)
         
     | 
| 
       220 
     | 
    
         
            -
                  if nu_num > 0 and de_num and (d:=nu_num//de_num) > 0: return NumNode(d) + (self-b*d) // b
         
     | 
| 
       221 
248 
     | 
    
         
             
                if isinstance(b, Node):
         
     | 
| 
       222 
249 
     | 
    
         
             
                  for x in self.flat_components:
         
     | 
| 
       223 
250 
     | 
    
         
             
                    if x % b == 0: fully_divided.append(x // b)
         
     | 
| 
       224 
251 
     | 
    
         
             
                    else: rest.append(x)
         
     | 
| 
       225 
     | 
    
         
            -
                  if (sum_fully_divided:= 
     | 
| 
      
 252 
     | 
    
         
            +
                  if (sum_fully_divided:=create_node(SumNode(fully_divided))) != 0: return sum_fully_divided + create_node(SumNode(rest)) // b
         
     | 
| 
       226 
253 
     | 
    
         
             
                  return Node.__floordiv__(self, b, False)
         
     | 
| 
       227 
254 
     | 
    
         
             
                if b == 1: return self
         
     | 
| 
       228 
255 
     | 
    
         
             
                if not factoring_allowed: return Node.__floordiv__(self, b, factoring_allowed)
         
     | 
| 
       229 
     | 
    
         
            -
                fully_divided, rest = [], []
         
     | 
| 
       230 
256 
     | 
    
         
             
                _gcd = b
         
     | 
| 
       231 
257 
     | 
    
         
             
                divisor = 1
         
     | 
| 
       232 
258 
     | 
    
         
             
                for x in self.flat_components:
         
     | 
| 
       233 
259 
     | 
    
         
             
                  if x.__class__ in (NumNode, MulNode):
         
     | 
| 
       234 
     | 
    
         
            -
                    if x.b%b == 0: fully_divided.append(x//b)
         
     | 
| 
      
 260 
     | 
    
         
            +
                    if x.b % b == 0: fully_divided.append(x // b)
         
     | 
| 
       235 
261 
     | 
    
         
             
                    else:
         
     | 
| 
      
 262 
     | 
    
         
            +
                      if x.__class__ is NumNode and (div := x.b // b):
         
     | 
| 
      
 263 
     | 
    
         
            +
                        fully_divided.append(NumNode(div))
         
     | 
| 
      
 264 
     | 
    
         
            +
                        x = NumNode(x.b - b * div)
         
     | 
| 
       236 
265 
     | 
    
         
             
                      rest.append(x)
         
     | 
| 
       237 
     | 
    
         
            -
                       
     | 
| 
       238 
     | 
    
         
            -
             
     | 
| 
      
 266 
     | 
    
         
            +
                      if isinstance(x.b, int):
         
     | 
| 
      
 267 
     | 
    
         
            +
                        _gcd = gcd(_gcd, x.b)
         
     | 
| 
      
 268 
     | 
    
         
            +
                        if x.__class__ == MulNode and divisor == 1 and b % x.b == 0: divisor = x.b
         
     | 
| 
      
 269 
     | 
    
         
            +
                      else:
         
     | 
| 
      
 270 
     | 
    
         
            +
                        _gcd = 1
         
     | 
| 
       239 
271 
     | 
    
         
             
                  else:
         
     | 
| 
       240 
272 
     | 
    
         
             
                    rest.append(x)
         
     | 
| 
       241 
273 
     | 
    
         
             
                    _gcd = 1
         
     | 
| 
         @@ -243,62 +275,55 @@ class SumNode(RedNode): 
     | 
|
| 
       243 
275 
     | 
    
         
             
                if divisor > 1: return Node.sum(fully_divided) + Node.sum(rest).__floordiv__(divisor) // (b//divisor)
         
     | 
| 
       244 
276 
     | 
    
         
             
                return Node.sum(fully_divided) + Node.__floordiv__(Node.sum(rest), b)
         
     | 
| 
       245 
277 
     | 
    
         | 
| 
      
 278 
     | 
    
         
            +
              @functools.lru_cache(maxsize=None)  # pylint: disable=method-cache-max-size-none
         
     | 
| 
       246 
279 
     | 
    
         
             
              def __mod__(self, b: Union[Node, int]):
         
     | 
| 
       247 
     | 
    
         
            -
                if  
     | 
| 
       248 
     | 
    
         
            -
                  nu_num = sum(node.b for node in self.flat_components if node.__class__ is NumNode)
         
     | 
| 
       249 
     | 
    
         
            -
                  de_num = sum(node.b for node in b.flat_components if node.__class__ is NumNode)
         
     | 
| 
       250 
     | 
    
         
            -
                  if nu_num > 0 and de_num and (d:=nu_num//de_num) > 0: return (self-b*d) % b
         
     | 
| 
      
 280 
     | 
    
         
            +
                if self == b: return NumNode(0)
         
     | 
| 
       251 
281 
     | 
    
         
             
                if isinstance(b, Node) and (b - self).min > 0: return self # b - self simplifies the node
         
     | 
| 
       252 
     | 
    
         
            -
                 
     | 
| 
       253 
     | 
    
         
            -
                 
     | 
| 
       254 
     | 
    
         
            -
             
     | 
| 
       255 
     | 
    
         
            -
             
     | 
| 
       256 
     | 
    
         
            -
             
     | 
| 
       257 
     | 
    
         
            -
                return Node.__mod__(Node.sum(new_nodes), b)
         
     | 
| 
       258 
     | 
    
         
            -
             
     | 
| 
       259 
     | 
    
         
            -
              def __lt__(self, b:Union[Node,int]):
         
     | 
| 
       260 
     | 
    
         
            -
                if isinstance(b, int):
         
     | 
| 
       261 
     | 
    
         
            -
                  new_sum = []
         
     | 
| 
       262 
     | 
    
         
            -
                  for x in self.nodes:
         
     | 
| 
       263 
     | 
    
         
            -
                    # TODO: should we just force the last one to always be the number
         
     | 
| 
       264 
     | 
    
         
            -
                    if isinstance(x, NumNode): b -= x.b
         
     | 
| 
       265 
     | 
    
         
            -
                    else: new_sum.append(x)
         
     | 
| 
       266 
     | 
    
         
            -
                  return Node.__lt__(Node.sum(new_sum), b)
         
     | 
| 
       267 
     | 
    
         
            -
                return Node.__lt__(self, b)
         
     | 
| 
      
 282 
     | 
    
         
            +
                new_sum = Node.sum([node%b if node.__class__ in (NumNode, MulNode) else node for node in self.nodes])
         
     | 
| 
      
 283 
     | 
    
         
            +
                return Node.__mod__(new_sum, b)
         
     | 
| 
      
 284 
     | 
    
         
            +
             
     | 
| 
      
 285 
     | 
    
         
            +
              def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
         
     | 
| 
      
 286 
     | 
    
         
            +
                return Node.sum([node.substitute(var_vals) for node in self.nodes])
         
     | 
| 
       268 
287 
     | 
    
         | 
| 
      
 288 
     | 
    
         
            +
              # recursively expand sumnode components
         
     | 
| 
      
 289 
     | 
    
         
            +
              # TODO: can remove this if there's no SumNode inside SumNode
         
     | 
| 
       269 
290 
     | 
    
         
             
              @property
         
     | 
| 
       270 
     | 
    
         
            -
              def flat_components(self):  
     | 
| 
       271 
     | 
    
         
            -
                new_nodes = []
         
     | 
| 
       272 
     | 
    
         
            -
                for x in self.nodes: new_nodes += (x.flat_components if isinstance(x, SumNode) else [x])
         
     | 
| 
       273 
     | 
    
         
            -
                return new_nodes
         
     | 
| 
      
 291 
     | 
    
         
            +
              def flat_components(self): return [y for x in self.nodes for y in (x.flat_components if isinstance(x, SumNode) else [x])]
         
     | 
| 
       274 
292 
     | 
    
         | 
| 
       275 
293 
     | 
    
         
             
            class AndNode(RedNode):
         
     | 
| 
       276 
     | 
    
         
            -
              def  
     | 
| 
       277 
     | 
    
         
            -
              def  
     | 
| 
       278 
     | 
    
         
            -
             
     | 
| 
       279 
     | 
    
         
            -
             
     | 
| 
       280 
     | 
    
         
            -
             
     | 
| 
       281 
     | 
    
         
            -
             
     | 
| 
       282 
     | 
    
         
            -
             
     | 
| 
       283 
     | 
    
         
            -
              return create_node(ret)
         
     | 
| 
       284 
     | 
    
         
            -
             
     | 
| 
       285 
     | 
    
         
            -
            def sym_infer(n:Union[Node,int], var_vals: Dict[Variable, int]) -> int:
         
     | 
| 
       286 
     | 
    
         
            -
              if isinstance(n, (int, NumNode)): return int(n)
         
     | 
| 
       287 
     | 
    
         
            -
              if isinstance(n, Variable): return var_vals[n]
         
     | 
| 
       288 
     | 
    
         
            -
              if isinstance(n, MulNode): return sym_infer(n.a, var_vals) * sym_infer(n.b, var_vals)
         
     | 
| 
       289 
     | 
    
         
            -
              if isinstance(n, SumNode): return sum(sym_infer(s, var_vals) for s in n.nodes)
         
     | 
| 
       290 
     | 
    
         
            -
              raise NotImplementedError(n)
         
     | 
| 
       291 
     | 
    
         
            -
            @functools.lru_cache(maxsize=None)
         
     | 
| 
       292 
     | 
    
         
            -
            def sym_rename(s) -> str: return f"s{sym_rename.cache_info().currsize}"
         
     | 
| 
       293 
     | 
    
         
            -
            def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx)
         
     | 
| 
      
 294 
     | 
    
         
            +
              def get_bounds(self) -> Tuple[int, sint]: return min([x.min for x in self.nodes]), max([x.max for x in self.nodes])
         
     | 
| 
      
 295 
     | 
    
         
            +
              def substitute(self, var_vals: Mapping[Variable, Union[NumNode, Variable]]) -> Node:
         
     | 
| 
      
 296 
     | 
    
         
            +
                subed = []
         
     | 
| 
      
 297 
     | 
    
         
            +
                for node in self.nodes:
         
     | 
| 
      
 298 
     | 
    
         
            +
                  if not (sub:=node.substitute(var_vals)): return NumNode(0)
         
     | 
| 
      
 299 
     | 
    
         
            +
                  subed.append(sub)
         
     | 
| 
      
 300 
     | 
    
         
            +
                return Node.ands(subed)
         
     | 
| 
       294 
301 
     | 
    
         | 
| 
       295 
     | 
    
         
            -
             
     | 
| 
       296 
     | 
    
         
            -
             
     | 
| 
       297 
     | 
    
         
            -
               
     | 
| 
       298 
     | 
    
         
            -
               
     | 
| 
      
 302 
     | 
    
         
            +
            def sym_render(a: Union[Node, int], ops=None, ctx=None) -> str: return str(a) if isinstance(a, int) else a.render(ops, ctx)
         
     | 
| 
      
 303 
     | 
    
         
            +
            def sym_infer(a: Union[Node, int], var_vals: Optional[Dict[Variable, int]]) -> int:
         
     | 
| 
      
 304 
     | 
    
         
            +
              if isinstance(a, (int, float)): return a
         
     | 
| 
      
 305 
     | 
    
         
            +
              ret = a.substitute({k:NumNode(v) for k, v in var_vals.items()}) if var_vals is not None else a
         
     | 
| 
      
 306 
     | 
    
         
            +
              assert isinstance(ret, NumNode), f"sym_infer didn't produce NumNode from {a} with {var_vals}"
         
     | 
| 
      
 307 
     | 
    
         
            +
              return ret.b
         
     | 
| 
      
 308 
     | 
    
         
            +
             
     | 
| 
      
 309 
     | 
    
         
            +
            # symbolic int, these are allowed in a Tensor shape
         
     | 
| 
      
 310 
     | 
    
         
            +
            sint = Union[int, Variable, MulNode, SumNode]
         
     | 
| 
      
 311 
     | 
    
         
            +
             
     | 
| 
      
 312 
     | 
    
         
            +
            def render_mulnode(node:MulNode, ops, ctx):
         
     | 
| 
      
 313 
     | 
    
         
            +
              # TODO: add ProdNode and remove this case
         
     | 
| 
      
 314 
     | 
    
         
            +
              if isinstance(node.a,Variable) and isinstance(node.b,Variable) and node.a.expr and node.b.expr and node.b.expr < node.a.expr:
         
     | 
| 
      
 315 
     | 
    
         
            +
                return f"({sym_render(node.b,ops,ctx)}*{node.a.render(ops,ctx)})"
         
     | 
| 
      
 316 
     | 
    
         
            +
              return f"({node.a.render(ops,ctx)}*{sym_render(node.b,ops,ctx)})"
         
     | 
| 
      
 317 
     | 
    
         
            +
             
     | 
| 
      
 318 
     | 
    
         
            +
            render_python: Dict[Type, Callable[..., str]] = {
         
     | 
| 
      
 319 
     | 
    
         
            +
              Variable: lambda self,ops,ctx: f"{self.expr}[{self.min}-{self.max}{'='+str(self.val) if self._val is not None else ''}]" if ctx == "DEBUG" \
         
     | 
| 
      
 320 
     | 
    
         
            +
                else (f"Variable('{self.expr}', {self.min}, {self.max})"+(f".bind({self.val})" if self._val is not None else '') if ctx == "REPR" \
         
     | 
| 
      
 321 
     | 
    
         
            +
                else f"{self.expr}"),
         
     | 
| 
      
 322 
     | 
    
         
            +
              NumNode: lambda self,ops,ctx: f"NumNode({self.b})" if ctx == "REPR" else f"{self.b}",
         
     | 
| 
      
 323 
     | 
    
         
            +
              MulNode: render_mulnode,
         
     | 
| 
       299 
324 
     | 
    
         
             
              DivNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}//{self.b})",
         
     | 
| 
       300 
325 
     | 
    
         
             
              ModNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}%{self.b})",
         
     | 
| 
       301 
326 
     | 
    
         
             
              LtNode: lambda self,ops,ctx: f"({self.a.render(ops,ctx)}<{sym_render(self.b,ops,ctx)})",
         
     | 
| 
       302 
327 
     | 
    
         
             
              SumNode: lambda self,ops,ctx: f"({'+'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
         
     | 
| 
       303 
     | 
    
         
            -
              AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})"
         
     | 
| 
      
 328 
     | 
    
         
            +
              AndNode: lambda self,ops,ctx: f"({' and '.join(sorted([x.render(ops,ctx) for x in self.nodes]))})",
         
     | 
| 
       304 
329 
     | 
    
         
             
            }
         
     |