pycograd 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pycograd/__init__.py +390 -0
- pycograd/_constraints.py +148 -0
- pycograd/_dims.py +453 -0
- pycograd/_typing.py +97 -0
- pycograd/_version.py +669 -0
- pycograd/ad_graph.py +372 -0
- pycograd/backends/__init__.py +302 -0
- pycograd/backends/abstract_backend.py +73 -0
- pycograd/backends/cupy_backend.py +47 -0
- pycograd/backends/jax_backend.py +207 -0
- pycograd/backends/mps_backend.py +47 -0
- pycograd/backends/numpy_backend.py +66 -0
- pycograd/backends/tf_backend.py +407 -0
- pycograd/backends/torch_backend.py +482 -0
- pycograd/batching.py +638 -0
- pycograd/capture.py +527 -0
- pycograd/checkpoint.py +420 -0
- pycograd/compile.py +199 -0
- pycograd/cost.py +548 -0
- pycograd/data.py +115 -0
- pycograd/dtypes.py +152 -0
- pycograd/examples/__init__.py +12 -0
- pycograd/examples/__main__.py +242 -0
- pycograd/examples/models.py +953 -0
- pycograd/export.py +121 -0
- pycograd/extension.py +137 -0
- pycograd/forward.py +683 -0
- pycograd/functional.py +808 -0
- pycograd/ops.py +1575 -0
- pycograd/optimizers.py +284 -0
- pycograd/params.py +882 -0
- pycograd/passes.py +580 -0
- pycograd/random.py +92 -0
- pycograd/remat.py +779 -0
- pycograd/shapes.py +1174 -0
- pycograd/tensor.py +650 -0
- pycograd/trace.py +420 -0
- pycograd/tracer.py +531 -0
- pycograd/training.py +136 -0
- pycograd/transforms.py +1078 -0
- pycograd/transpose.py +167 -0
- pycograd/tree.py +109 -0
- pycograd/version.py +18 -0
- pycograd-0.0.1.dist-info/METADATA +324 -0
- pycograd-0.0.1.dist-info/RECORD +48 -0
- pycograd-0.0.1.dist-info/WHEEL +5 -0
- pycograd-0.0.1.dist-info/licenses/docs/LICENSE.txt +11 -0
- pycograd-0.0.1.dist-info/top_level.txt +1 -0
pycograd/__init__.py
ADDED
|
@@ -0,0 +1,390 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""pycograd: a small, readable reverse-mode autograd built on numpy and pyccolo.
|
|
3
|
+
|
|
4
|
+
Write ordinary numeric Python -- including ``numpy`` calls like ``np.exp``,
|
|
5
|
+
``np.dot``, ``np.sum`` and operators like ``@`` -- and get correct gradients.
|
|
6
|
+
``Var`` is the reverse-mode tape node; ``value_and_grad`` / ``grad`` wrap a
|
|
7
|
+
function to return gradients with the same pytree structure as its arguments.
|
|
8
|
+
"""
|
|
9
|
+
from importlib.metadata import PackageNotFoundError, version
|
|
10
|
+
|
|
11
|
+
from pycograd import random
|
|
12
|
+
from pycograd._typing import Operand, Tensor
|
|
13
|
+
from pycograd.ad_graph import grad_graph, jit
|
|
14
|
+
from pycograd.backends import activate, device, get_backend
|
|
15
|
+
from pycograd.capture import Graph, capture, eval_graph
|
|
16
|
+
from pycograd.checkpoint import checkpoint
|
|
17
|
+
from pycograd.compile import compile_to
|
|
18
|
+
from pycograd.cost import (
|
|
19
|
+
DEFAULT_COST_MODEL,
|
|
20
|
+
CostModel,
|
|
21
|
+
GraphCost,
|
|
22
|
+
NodeCost,
|
|
23
|
+
calibrate,
|
|
24
|
+
cost_report,
|
|
25
|
+
)
|
|
26
|
+
from pycograd.data import DataLoader, batches
|
|
27
|
+
from pycograd.dtypes import current_dtype, dtype, resolve_dtype
|
|
28
|
+
from pycograd.export import export_onnx, export_torchscript, to_torch_module
|
|
29
|
+
from pycograd.extension import load_ipython_extension, unload_ipython_extension
|
|
30
|
+
from pycograd.functional import (
|
|
31
|
+
avg_pool2d,
|
|
32
|
+
batch_norm,
|
|
33
|
+
batch_norm_init,
|
|
34
|
+
causal_conv1d,
|
|
35
|
+
conv1d,
|
|
36
|
+
conv2d,
|
|
37
|
+
conv_transpose1d,
|
|
38
|
+
conv_transpose2d,
|
|
39
|
+
cross_entropy,
|
|
40
|
+
dropout,
|
|
41
|
+
elu,
|
|
42
|
+
embedding,
|
|
43
|
+
gelu,
|
|
44
|
+
group_norm,
|
|
45
|
+
hardsigmoid,
|
|
46
|
+
hardswish,
|
|
47
|
+
instance_norm,
|
|
48
|
+
layer_norm,
|
|
49
|
+
leaky_relu,
|
|
50
|
+
linear,
|
|
51
|
+
log_softmax,
|
|
52
|
+
logsumexp,
|
|
53
|
+
max_pool2d,
|
|
54
|
+
mish,
|
|
55
|
+
multi_head_attention,
|
|
56
|
+
one_hot,
|
|
57
|
+
relu,
|
|
58
|
+
rms_norm,
|
|
59
|
+
scaled_dot_product_attention,
|
|
60
|
+
selu,
|
|
61
|
+
sigmoid,
|
|
62
|
+
silu,
|
|
63
|
+
softmax,
|
|
64
|
+
softplus,
|
|
65
|
+
softsign,
|
|
66
|
+
streaming_conv1d,
|
|
67
|
+
streaming_conv1d_init,
|
|
68
|
+
streaming_conv2d,
|
|
69
|
+
streaming_conv2d_init,
|
|
70
|
+
streaming_conv_transpose1d,
|
|
71
|
+
streaming_conv_transpose1d_init,
|
|
72
|
+
streaming_conv_transpose2d,
|
|
73
|
+
streaming_conv_transpose2d_init,
|
|
74
|
+
swish,
|
|
75
|
+
tanh,
|
|
76
|
+
upsample_nearest2d,
|
|
77
|
+
)
|
|
78
|
+
from pycograd.ops import (
|
|
79
|
+
AutodiffWarning,
|
|
80
|
+
d_abs,
|
|
81
|
+
d_arctan,
|
|
82
|
+
d_clip,
|
|
83
|
+
d_column_stack,
|
|
84
|
+
d_concatenate,
|
|
85
|
+
d_cos,
|
|
86
|
+
d_cosh,
|
|
87
|
+
d_cumsum,
|
|
88
|
+
d_dstack,
|
|
89
|
+
d_einsum,
|
|
90
|
+
d_exp,
|
|
91
|
+
d_expand_dims,
|
|
92
|
+
d_expm1,
|
|
93
|
+
d_gated_act,
|
|
94
|
+
d_hstack,
|
|
95
|
+
d_log,
|
|
96
|
+
d_log1p,
|
|
97
|
+
d_logsumexp,
|
|
98
|
+
d_max,
|
|
99
|
+
d_maximum,
|
|
100
|
+
d_mean,
|
|
101
|
+
d_min,
|
|
102
|
+
d_minimum,
|
|
103
|
+
d_reciprocal,
|
|
104
|
+
d_reshape,
|
|
105
|
+
d_sigmoid,
|
|
106
|
+
d_sin,
|
|
107
|
+
d_sinh,
|
|
108
|
+
d_softmax,
|
|
109
|
+
d_sqrt,
|
|
110
|
+
d_square,
|
|
111
|
+
d_stack,
|
|
112
|
+
d_std,
|
|
113
|
+
d_sum,
|
|
114
|
+
d_tanh,
|
|
115
|
+
d_transpose,
|
|
116
|
+
d_var,
|
|
117
|
+
d_vstack,
|
|
118
|
+
d_where,
|
|
119
|
+
)
|
|
120
|
+
from pycograd.optimizers import (
|
|
121
|
+
SGD,
|
|
122
|
+
Adam,
|
|
123
|
+
AdamW,
|
|
124
|
+
Optimizer,
|
|
125
|
+
clip_grad_norm,
|
|
126
|
+
constant_lr,
|
|
127
|
+
cosine_decay,
|
|
128
|
+
step_decay,
|
|
129
|
+
)
|
|
130
|
+
from pycograd.params import (
|
|
131
|
+
Param,
|
|
132
|
+
ParamDict,
|
|
133
|
+
Weight,
|
|
134
|
+
buffer,
|
|
135
|
+
frozen,
|
|
136
|
+
on_cpu,
|
|
137
|
+
on_device,
|
|
138
|
+
param_values,
|
|
139
|
+
params,
|
|
140
|
+
register_pipescript_params_macro,
|
|
141
|
+
tied,
|
|
142
|
+
)
|
|
143
|
+
from pycograd.passes import optimize
|
|
144
|
+
from pycograd.remat import (
|
|
145
|
+
Decision,
|
|
146
|
+
RematPlan,
|
|
147
|
+
apply_remat_plan,
|
|
148
|
+
eval_scheduled,
|
|
149
|
+
plan_remat,
|
|
150
|
+
)
|
|
151
|
+
from pycograd.shapes import (
|
|
152
|
+
Dim,
|
|
153
|
+
ShapedArray,
|
|
154
|
+
ShapeDtypeStruct,
|
|
155
|
+
ShapeError,
|
|
156
|
+
Summary,
|
|
157
|
+
bind,
|
|
158
|
+
eval_shape,
|
|
159
|
+
infer_shapes,
|
|
160
|
+
substitute,
|
|
161
|
+
summary,
|
|
162
|
+
)
|
|
163
|
+
from pycograd.tensor import Var, detach
|
|
164
|
+
from pycograd.tracer import AutodiffTracer, resolve_call
|
|
165
|
+
from pycograd.training import accuracy, fit, train
|
|
166
|
+
from pycograd.transforms import (
|
|
167
|
+
grad,
|
|
168
|
+
gradient_descent,
|
|
169
|
+
jacfwd,
|
|
170
|
+
jacrev,
|
|
171
|
+
jvp,
|
|
172
|
+
value_and_grad,
|
|
173
|
+
vmap,
|
|
174
|
+
)
|
|
175
|
+
from pycograd.tree import (
|
|
176
|
+
sgd_update,
|
|
177
|
+
tree_flatten,
|
|
178
|
+
tree_leaves,
|
|
179
|
+
tree_map,
|
|
180
|
+
tree_structure,
|
|
181
|
+
tree_unflatten,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
# Friendly aliases for the fused primitives (also reached via ``np.einsum`` / ``np.cumsum``
|
|
185
|
+
# interception); these read more naturally than the ``d_`` names at a call site.
|
|
186
|
+
einsum = d_einsum
|
|
187
|
+
cumsum = d_cumsum
|
|
188
|
+
gated_act = d_gated_act # tanh(f) * sigmoid(s), the WaveNet / GLU gate
|
|
189
|
+
|
|
190
|
+
try:
|
|
191
|
+
__version__ = version("pycograd")
|
|
192
|
+
except PackageNotFoundError: # not installed (e.g. running from a source checkout)
|
|
193
|
+
__version__ = "0.0.0+unknown"
|
|
194
|
+
|
|
195
|
+
__all__ = [
|
|
196
|
+
"__version__",
|
|
197
|
+
# core
|
|
198
|
+
"Var",
|
|
199
|
+
"detach",
|
|
200
|
+
"Tensor",
|
|
201
|
+
"Operand",
|
|
202
|
+
# parameters
|
|
203
|
+
"Param",
|
|
204
|
+
"ParamDict",
|
|
205
|
+
"Weight",
|
|
206
|
+
"frozen",
|
|
207
|
+
"buffer",
|
|
208
|
+
"tied",
|
|
209
|
+
"on_cpu",
|
|
210
|
+
"on_device",
|
|
211
|
+
"params",
|
|
212
|
+
"param_values",
|
|
213
|
+
# transforms / training
|
|
214
|
+
"value_and_grad",
|
|
215
|
+
"grad",
|
|
216
|
+
"checkpoint",
|
|
217
|
+
"vmap",
|
|
218
|
+
"jvp",
|
|
219
|
+
"jacfwd",
|
|
220
|
+
"jacrev",
|
|
221
|
+
"gradient_descent",
|
|
222
|
+
"sgd_update",
|
|
223
|
+
"train",
|
|
224
|
+
"fit",
|
|
225
|
+
"accuracy",
|
|
226
|
+
# shape inference
|
|
227
|
+
"eval_shape",
|
|
228
|
+
"infer_shapes",
|
|
229
|
+
"substitute",
|
|
230
|
+
"bind",
|
|
231
|
+
"summary",
|
|
232
|
+
"Summary",
|
|
233
|
+
"ShapeDtypeStruct",
|
|
234
|
+
"ShapedArray",
|
|
235
|
+
"ShapeError",
|
|
236
|
+
"Dim",
|
|
237
|
+
# compile to other frameworks (torch / tf / jax)
|
|
238
|
+
"compile_to",
|
|
239
|
+
"get_backend",
|
|
240
|
+
# graph-capture IR + optimization passes
|
|
241
|
+
"capture",
|
|
242
|
+
"eval_graph",
|
|
243
|
+
"optimize",
|
|
244
|
+
"grad_graph",
|
|
245
|
+
"jit",
|
|
246
|
+
"Graph",
|
|
247
|
+
# static cost model over the capture IR (CPU / memory / disk)
|
|
248
|
+
"cost_report",
|
|
249
|
+
"CostModel",
|
|
250
|
+
"GraphCost",
|
|
251
|
+
"NodeCost",
|
|
252
|
+
"DEFAULT_COST_MODEL",
|
|
253
|
+
"calibrate",
|
|
254
|
+
# rematerialization / spill planning + memory-managed execution
|
|
255
|
+
"plan_remat",
|
|
256
|
+
"RematPlan",
|
|
257
|
+
"Decision",
|
|
258
|
+
"apply_remat_plan",
|
|
259
|
+
"eval_scheduled",
|
|
260
|
+
# device / array backend seam (numpy default, cupy for GPU)
|
|
261
|
+
"device",
|
|
262
|
+
"activate",
|
|
263
|
+
# working-dtype seam (float64 default; float32 / float16 / bfloat16)
|
|
264
|
+
"dtype",
|
|
265
|
+
"current_dtype",
|
|
266
|
+
"resolve_dtype",
|
|
267
|
+
# static export (standalone artifacts)
|
|
268
|
+
"to_torch_module",
|
|
269
|
+
"export_torchscript",
|
|
270
|
+
"export_onnx",
|
|
271
|
+
# optimizers
|
|
272
|
+
"Optimizer",
|
|
273
|
+
"SGD",
|
|
274
|
+
"Adam",
|
|
275
|
+
"AdamW",
|
|
276
|
+
"clip_grad_norm",
|
|
277
|
+
"constant_lr",
|
|
278
|
+
"step_decay",
|
|
279
|
+
"cosine_decay",
|
|
280
|
+
# neural-net ops (stable softmax family, cross-entropy, activations)
|
|
281
|
+
"softmax",
|
|
282
|
+
"log_softmax",
|
|
283
|
+
"logsumexp",
|
|
284
|
+
"cross_entropy",
|
|
285
|
+
"relu",
|
|
286
|
+
"sigmoid",
|
|
287
|
+
"silu",
|
|
288
|
+
"swish",
|
|
289
|
+
"gelu",
|
|
290
|
+
"tanh",
|
|
291
|
+
"leaky_relu",
|
|
292
|
+
"elu",
|
|
293
|
+
"softplus",
|
|
294
|
+
"mish",
|
|
295
|
+
"hardswish",
|
|
296
|
+
"hardsigmoid",
|
|
297
|
+
"softsign",
|
|
298
|
+
"selu",
|
|
299
|
+
"conv1d",
|
|
300
|
+
"conv2d",
|
|
301
|
+
"causal_conv1d",
|
|
302
|
+
"conv_transpose1d",
|
|
303
|
+
"conv_transpose2d",
|
|
304
|
+
"streaming_conv1d",
|
|
305
|
+
"streaming_conv1d_init",
|
|
306
|
+
"streaming_conv2d",
|
|
307
|
+
"streaming_conv2d_init",
|
|
308
|
+
"streaming_conv_transpose1d",
|
|
309
|
+
"streaming_conv_transpose1d_init",
|
|
310
|
+
"streaming_conv_transpose2d",
|
|
311
|
+
"streaming_conv_transpose2d_init",
|
|
312
|
+
"upsample_nearest2d",
|
|
313
|
+
"max_pool2d",
|
|
314
|
+
"avg_pool2d",
|
|
315
|
+
"one_hot",
|
|
316
|
+
# neural-net layers (normalization, attention, embedding, linear, dropout)
|
|
317
|
+
"layer_norm",
|
|
318
|
+
"rms_norm",
|
|
319
|
+
"batch_norm",
|
|
320
|
+
"batch_norm_init",
|
|
321
|
+
"group_norm",
|
|
322
|
+
"instance_norm",
|
|
323
|
+
"scaled_dot_product_attention",
|
|
324
|
+
"multi_head_attention",
|
|
325
|
+
"embedding",
|
|
326
|
+
"linear",
|
|
327
|
+
"dropout",
|
|
328
|
+
# splittable PRNG keys (pycograd.random: key / split / fold_in + samplers)
|
|
329
|
+
"random",
|
|
330
|
+
# data / batching
|
|
331
|
+
"batches",
|
|
332
|
+
"DataLoader",
|
|
333
|
+
# pytrees
|
|
334
|
+
"tree_flatten",
|
|
335
|
+
"tree_unflatten",
|
|
336
|
+
"tree_leaves",
|
|
337
|
+
"tree_structure",
|
|
338
|
+
"tree_map",
|
|
339
|
+
# tracer / interception
|
|
340
|
+
"AutodiffTracer",
|
|
341
|
+
"resolve_call",
|
|
342
|
+
"AutodiffWarning",
|
|
343
|
+
"register_pipescript_params_macro",
|
|
344
|
+
# ipython / jupyter extension
|
|
345
|
+
"load_ipython_extension",
|
|
346
|
+
"unload_ipython_extension",
|
|
347
|
+
# differentiable primitives
|
|
348
|
+
"d_exp",
|
|
349
|
+
"d_log",
|
|
350
|
+
"d_sin",
|
|
351
|
+
"d_cos",
|
|
352
|
+
"d_tanh",
|
|
353
|
+
"d_sqrt",
|
|
354
|
+
"d_sigmoid",
|
|
355
|
+
"d_abs",
|
|
356
|
+
"d_square",
|
|
357
|
+
"d_sinh",
|
|
358
|
+
"d_cosh",
|
|
359
|
+
"d_arctan",
|
|
360
|
+
"d_log1p",
|
|
361
|
+
"d_expm1",
|
|
362
|
+
"d_reciprocal",
|
|
363
|
+
"d_maximum",
|
|
364
|
+
"d_minimum",
|
|
365
|
+
"d_clip",
|
|
366
|
+
"d_where",
|
|
367
|
+
"d_sum",
|
|
368
|
+
"d_mean",
|
|
369
|
+
"d_var",
|
|
370
|
+
"d_std",
|
|
371
|
+
"d_max",
|
|
372
|
+
"d_min",
|
|
373
|
+
"d_softmax",
|
|
374
|
+
"d_logsumexp",
|
|
375
|
+
"d_concatenate",
|
|
376
|
+
"d_transpose",
|
|
377
|
+
"d_reshape",
|
|
378
|
+
"d_expand_dims",
|
|
379
|
+
"d_stack",
|
|
380
|
+
"d_vstack",
|
|
381
|
+
"d_hstack",
|
|
382
|
+
"d_column_stack",
|
|
383
|
+
"d_dstack",
|
|
384
|
+
"d_einsum",
|
|
385
|
+
"einsum",
|
|
386
|
+
"d_cumsum",
|
|
387
|
+
"cumsum",
|
|
388
|
+
"d_gated_act",
|
|
389
|
+
"gated_act",
|
|
390
|
+
]
|
pycograd/_constraints.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
"""Dimension-equality constraints for shape polymorphism.
|
|
3
|
+
|
|
4
|
+
When shape inference runs over *symbolic* input dims (e.g. a batch ``B`` declared via
|
|
5
|
+
``ShapeDtypeStruct(("B", 768))``), each contraction registers an equality: a matmul
|
|
6
|
+
asserts its inner dims equal, concatenate asserts its non-axis dims equal, broadcasting
|
|
7
|
+
asserts compatible dims equal. :class:`ConstraintEnv` is the union-find that records
|
|
8
|
+
those equalities, refines a symbol pinned to a concrete (``K`` forced to ``4``), and
|
|
9
|
+
reports a contradiction (two concretes forced equal) as a shape error.
|
|
10
|
+
|
|
11
|
+
Only *solvable* symbols -- caller-declared input dims, whose key is a ``str`` -- get
|
|
12
|
+
bound to concretes or merged. *Data-dependent* symbols (a mask count, a broadcast;
|
|
13
|
+
their key is a tuple) are runtime facts, not statically known, so they are left opaque,
|
|
14
|
+
preserving the optimistic "carry it forward" behavior of plain symbolic inference.
|
|
15
|
+
"""
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
import contextlib
|
|
19
|
+
import contextvars
|
|
20
|
+
from typing import Any, Hashable, Iterator, cast
|
|
21
|
+
|
|
22
|
+
from pycograd import _dims
|
|
23
|
+
from pycograd._dims import Dim
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _as_atom(x: int | Dim) -> tuple:
|
|
27
|
+
"""Classify a dim as ``("int", v)``, ``("sym", key, name)``, or ``("expr",)``."""
|
|
28
|
+
if isinstance(x, Dim):
|
|
29
|
+
s = x.as_symbol()
|
|
30
|
+
return ("sym", s[0], s[1]) if s is not None else ("expr",)
|
|
31
|
+
if _dims._is_int(x):
|
|
32
|
+
return ("int", int(cast(Any, x)))
|
|
33
|
+
return ("expr",)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class ConstraintEnv:
|
|
37
|
+
"""Union-find over symbol keys with at most one concrete value per class."""
|
|
38
|
+
|
|
39
|
+
def __init__(self) -> None:
|
|
40
|
+
self.parent: dict[Hashable, Hashable] = {} # key -> parent key
|
|
41
|
+
self.value: dict[Hashable, int] = {} # root key -> concrete int
|
|
42
|
+
self.name: dict[Hashable, str] = {} # key -> rendered name
|
|
43
|
+
|
|
44
|
+
def _add(self, key: Hashable, name: str) -> None:
|
|
45
|
+
if key not in self.parent:
|
|
46
|
+
self.parent[key] = key
|
|
47
|
+
self.name[key] = name
|
|
48
|
+
|
|
49
|
+
def _find(self, key: Hashable) -> Hashable:
|
|
50
|
+
root = key
|
|
51
|
+
while self.parent[root] != root:
|
|
52
|
+
root = self.parent[root]
|
|
53
|
+
while self.parent[key] != root: # path compression
|
|
54
|
+
self.parent[key], key = root, self.parent[key]
|
|
55
|
+
return root
|
|
56
|
+
|
|
57
|
+
@staticmethod
|
|
58
|
+
def _solvable(key: Hashable) -> bool:
|
|
59
|
+
# Caller-declared input dims have string keys; data-dependent symbols (nonzero,
|
|
60
|
+
# bcast, slice) have tuple keys and are never bound/merged.
|
|
61
|
+
return isinstance(key, str)
|
|
62
|
+
|
|
63
|
+
def assert_eq(self, a: int | Dim, b: int | Dim) -> bool:
|
|
64
|
+
"""Record ``a == b``; return ``False`` if that is a provable contradiction."""
|
|
65
|
+
ta, tb = _as_atom(a), _as_atom(b)
|
|
66
|
+
if ta[0] == "int" and tb[0] == "int":
|
|
67
|
+
return ta[1] == tb[1]
|
|
68
|
+
if ta[0] == "int" and tb[0] == "sym":
|
|
69
|
+
return self._bind(tb, ta[1])
|
|
70
|
+
if ta[0] == "sym" and tb[0] == "int":
|
|
71
|
+
return self._bind(ta, tb[1])
|
|
72
|
+
if ta[0] == "sym" and tb[0] == "sym":
|
|
73
|
+
return self._union(ta, tb)
|
|
74
|
+
return True # an expression is involved -- can't reason, stay optimistic
|
|
75
|
+
|
|
76
|
+
def _bind(self, sym: tuple, val: int) -> bool:
|
|
77
|
+
_, key, name = sym
|
|
78
|
+
if not self._solvable(key):
|
|
79
|
+
return True # data-dependent: a runtime fact, never statically pinned
|
|
80
|
+
self._add(key, name)
|
|
81
|
+
root = self._find(key)
|
|
82
|
+
cur = self.value.get(root)
|
|
83
|
+
if cur is not None and cur != val:
|
|
84
|
+
return False
|
|
85
|
+
self.value[root] = val
|
|
86
|
+
return True
|
|
87
|
+
|
|
88
|
+
def _union(self, s1: tuple, s2: tuple) -> bool:
|
|
89
|
+
_, k1, n1 = s1
|
|
90
|
+
_, k2, n2 = s2
|
|
91
|
+
if not (self._solvable(k1) and self._solvable(k2)):
|
|
92
|
+
return True # leave data-dependent symbols opaque
|
|
93
|
+
self._add(k1, n1)
|
|
94
|
+
self._add(k2, n2)
|
|
95
|
+
r1, r2 = self._find(k1), self._find(k2)
|
|
96
|
+
if r1 == r2:
|
|
97
|
+
return True
|
|
98
|
+
v1, v2 = self.value.get(r1), self.value.get(r2)
|
|
99
|
+
if v1 is not None and v2 is not None and v1 != v2:
|
|
100
|
+
return False
|
|
101
|
+
self.parent[r2] = r1
|
|
102
|
+
if v1 is None and v2 is not None:
|
|
103
|
+
self.value[r1] = v2
|
|
104
|
+
return True
|
|
105
|
+
|
|
106
|
+
def mapping(self) -> dict[Hashable, int | Dim]:
|
|
107
|
+
"""A substitution mapping each known symbol key to its concrete value (if its
|
|
108
|
+
class is pinned) or to its class representative symbol (if merged)."""
|
|
109
|
+
m: dict[Hashable, int | Dim] = {}
|
|
110
|
+
for key in self.parent:
|
|
111
|
+
root = self._find(key)
|
|
112
|
+
v = self.value.get(root)
|
|
113
|
+
if v is not None:
|
|
114
|
+
m[key] = v
|
|
115
|
+
elif root != key:
|
|
116
|
+
m[key] = _dims.symbol(root, name=self.name[root])
|
|
117
|
+
return m
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
# ---------------------------------------------------------------------------
|
|
121
|
+
# Active environment (entered for the duration of an abstract inference run).
|
|
122
|
+
# ---------------------------------------------------------------------------
|
|
123
|
+
_env: "contextvars.ContextVar[ConstraintEnv | None]" = contextvars.ContextVar(
|
|
124
|
+
"dim_env", default=None
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@contextlib.contextmanager
|
|
129
|
+
def constraint_scope() -> Iterator[ConstraintEnv]:
|
|
130
|
+
env = ConstraintEnv()
|
|
131
|
+
token = _env.set(env)
|
|
132
|
+
try:
|
|
133
|
+
yield env
|
|
134
|
+
finally:
|
|
135
|
+
_env.reset(token)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def active_env() -> "ConstraintEnv | None":
|
|
139
|
+
return _env.get()
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def register_eq(a: int | Dim, b: int | Dim) -> bool:
|
|
143
|
+
"""Register ``a == b`` with the active env (if any); ``False`` on a provable
|
|
144
|
+
contradiction. With no active env, falls back to the concrete-only check."""
|
|
145
|
+
env = _env.get()
|
|
146
|
+
if env is None:
|
|
147
|
+
return not _dims.provably_unequal(a, b)
|
|
148
|
+
return env.assert_eq(a, b)
|