pycograd 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. pycograd/__init__.py +390 -0
  2. pycograd/_constraints.py +148 -0
  3. pycograd/_dims.py +453 -0
  4. pycograd/_typing.py +97 -0
  5. pycograd/_version.py +669 -0
  6. pycograd/ad_graph.py +372 -0
  7. pycograd/backends/__init__.py +302 -0
  8. pycograd/backends/abstract_backend.py +73 -0
  9. pycograd/backends/cupy_backend.py +47 -0
  10. pycograd/backends/jax_backend.py +207 -0
  11. pycograd/backends/mps_backend.py +47 -0
  12. pycograd/backends/numpy_backend.py +66 -0
  13. pycograd/backends/tf_backend.py +407 -0
  14. pycograd/backends/torch_backend.py +482 -0
  15. pycograd/batching.py +638 -0
  16. pycograd/capture.py +527 -0
  17. pycograd/checkpoint.py +420 -0
  18. pycograd/compile.py +199 -0
  19. pycograd/cost.py +548 -0
  20. pycograd/data.py +115 -0
  21. pycograd/dtypes.py +152 -0
  22. pycograd/examples/__init__.py +12 -0
  23. pycograd/examples/__main__.py +242 -0
  24. pycograd/examples/models.py +953 -0
  25. pycograd/export.py +121 -0
  26. pycograd/extension.py +137 -0
  27. pycograd/forward.py +683 -0
  28. pycograd/functional.py +808 -0
  29. pycograd/ops.py +1575 -0
  30. pycograd/optimizers.py +284 -0
  31. pycograd/params.py +882 -0
  32. pycograd/passes.py +580 -0
  33. pycograd/random.py +92 -0
  34. pycograd/remat.py +779 -0
  35. pycograd/shapes.py +1174 -0
  36. pycograd/tensor.py +650 -0
  37. pycograd/trace.py +420 -0
  38. pycograd/tracer.py +531 -0
  39. pycograd/training.py +136 -0
  40. pycograd/transforms.py +1078 -0
  41. pycograd/transpose.py +167 -0
  42. pycograd/tree.py +109 -0
  43. pycograd/version.py +18 -0
  44. pycograd-0.0.1.dist-info/METADATA +324 -0
  45. pycograd-0.0.1.dist-info/RECORD +48 -0
  46. pycograd-0.0.1.dist-info/WHEEL +5 -0
  47. pycograd-0.0.1.dist-info/licenses/docs/LICENSE.txt +11 -0
  48. pycograd-0.0.1.dist-info/top_level.txt +1 -0
pycograd/__init__.py ADDED
@@ -0,0 +1,390 @@
1
+ # -*- coding: utf-8 -*-
2
+ """pycograd: a small, readable reverse-mode autograd built on numpy and pyccolo.
3
+
4
+ Write ordinary numeric Python -- including ``numpy`` calls like ``np.exp``,
5
+ ``np.dot``, ``np.sum`` and operators like ``@`` -- and get correct gradients.
6
+ ``Var`` is the reverse-mode tape node; ``value_and_grad`` / ``grad`` wrap a
7
+ function to return gradients with the same pytree structure as its arguments.
8
+ """
9
+ from importlib.metadata import PackageNotFoundError, version
10
+
11
+ from pycograd import random
12
+ from pycograd._typing import Operand, Tensor
13
+ from pycograd.ad_graph import grad_graph, jit
14
+ from pycograd.backends import activate, device, get_backend
15
+ from pycograd.capture import Graph, capture, eval_graph
16
+ from pycograd.checkpoint import checkpoint
17
+ from pycograd.compile import compile_to
18
+ from pycograd.cost import (
19
+ DEFAULT_COST_MODEL,
20
+ CostModel,
21
+ GraphCost,
22
+ NodeCost,
23
+ calibrate,
24
+ cost_report,
25
+ )
26
+ from pycograd.data import DataLoader, batches
27
+ from pycograd.dtypes import current_dtype, dtype, resolve_dtype
28
+ from pycograd.export import export_onnx, export_torchscript, to_torch_module
29
+ from pycograd.extension import load_ipython_extension, unload_ipython_extension
30
+ from pycograd.functional import (
31
+ avg_pool2d,
32
+ batch_norm,
33
+ batch_norm_init,
34
+ causal_conv1d,
35
+ conv1d,
36
+ conv2d,
37
+ conv_transpose1d,
38
+ conv_transpose2d,
39
+ cross_entropy,
40
+ dropout,
41
+ elu,
42
+ embedding,
43
+ gelu,
44
+ group_norm,
45
+ hardsigmoid,
46
+ hardswish,
47
+ instance_norm,
48
+ layer_norm,
49
+ leaky_relu,
50
+ linear,
51
+ log_softmax,
52
+ logsumexp,
53
+ max_pool2d,
54
+ mish,
55
+ multi_head_attention,
56
+ one_hot,
57
+ relu,
58
+ rms_norm,
59
+ scaled_dot_product_attention,
60
+ selu,
61
+ sigmoid,
62
+ silu,
63
+ softmax,
64
+ softplus,
65
+ softsign,
66
+ streaming_conv1d,
67
+ streaming_conv1d_init,
68
+ streaming_conv2d,
69
+ streaming_conv2d_init,
70
+ streaming_conv_transpose1d,
71
+ streaming_conv_transpose1d_init,
72
+ streaming_conv_transpose2d,
73
+ streaming_conv_transpose2d_init,
74
+ swish,
75
+ tanh,
76
+ upsample_nearest2d,
77
+ )
78
+ from pycograd.ops import (
79
+ AutodiffWarning,
80
+ d_abs,
81
+ d_arctan,
82
+ d_clip,
83
+ d_column_stack,
84
+ d_concatenate,
85
+ d_cos,
86
+ d_cosh,
87
+ d_cumsum,
88
+ d_dstack,
89
+ d_einsum,
90
+ d_exp,
91
+ d_expand_dims,
92
+ d_expm1,
93
+ d_gated_act,
94
+ d_hstack,
95
+ d_log,
96
+ d_log1p,
97
+ d_logsumexp,
98
+ d_max,
99
+ d_maximum,
100
+ d_mean,
101
+ d_min,
102
+ d_minimum,
103
+ d_reciprocal,
104
+ d_reshape,
105
+ d_sigmoid,
106
+ d_sin,
107
+ d_sinh,
108
+ d_softmax,
109
+ d_sqrt,
110
+ d_square,
111
+ d_stack,
112
+ d_std,
113
+ d_sum,
114
+ d_tanh,
115
+ d_transpose,
116
+ d_var,
117
+ d_vstack,
118
+ d_where,
119
+ )
120
+ from pycograd.optimizers import (
121
+ SGD,
122
+ Adam,
123
+ AdamW,
124
+ Optimizer,
125
+ clip_grad_norm,
126
+ constant_lr,
127
+ cosine_decay,
128
+ step_decay,
129
+ )
130
+ from pycograd.params import (
131
+ Param,
132
+ ParamDict,
133
+ Weight,
134
+ buffer,
135
+ frozen,
136
+ on_cpu,
137
+ on_device,
138
+ param_values,
139
+ params,
140
+ register_pipescript_params_macro,
141
+ tied,
142
+ )
143
+ from pycograd.passes import optimize
144
+ from pycograd.remat import (
145
+ Decision,
146
+ RematPlan,
147
+ apply_remat_plan,
148
+ eval_scheduled,
149
+ plan_remat,
150
+ )
151
+ from pycograd.shapes import (
152
+ Dim,
153
+ ShapedArray,
154
+ ShapeDtypeStruct,
155
+ ShapeError,
156
+ Summary,
157
+ bind,
158
+ eval_shape,
159
+ infer_shapes,
160
+ substitute,
161
+ summary,
162
+ )
163
+ from pycograd.tensor import Var, detach
164
+ from pycograd.tracer import AutodiffTracer, resolve_call
165
+ from pycograd.training import accuracy, fit, train
166
+ from pycograd.transforms import (
167
+ grad,
168
+ gradient_descent,
169
+ jacfwd,
170
+ jacrev,
171
+ jvp,
172
+ value_and_grad,
173
+ vmap,
174
+ )
175
+ from pycograd.tree import (
176
+ sgd_update,
177
+ tree_flatten,
178
+ tree_leaves,
179
+ tree_map,
180
+ tree_structure,
181
+ tree_unflatten,
182
+ )
183
+
184
+ # Friendly aliases for the fused primitives (also reached via ``np.einsum`` / ``np.cumsum``
185
+ # interception); these read more naturally than the ``d_`` names at a call site.
186
+ einsum = d_einsum
187
+ cumsum = d_cumsum
188
+ gated_act = d_gated_act # tanh(f) * sigmoid(s), the WaveNet / GLU gate
189
+
190
+ try:
191
+ __version__ = version("pycograd")
192
+ except PackageNotFoundError: # not installed (e.g. running from a source checkout)
193
+ __version__ = "0.0.0+unknown"
194
+
195
+ __all__ = [
196
+ "__version__",
197
+ # core
198
+ "Var",
199
+ "detach",
200
+ "Tensor",
201
+ "Operand",
202
+ # parameters
203
+ "Param",
204
+ "ParamDict",
205
+ "Weight",
206
+ "frozen",
207
+ "buffer",
208
+ "tied",
209
+ "on_cpu",
210
+ "on_device",
211
+ "params",
212
+ "param_values",
213
+ # transforms / training
214
+ "value_and_grad",
215
+ "grad",
216
+ "checkpoint",
217
+ "vmap",
218
+ "jvp",
219
+ "jacfwd",
220
+ "jacrev",
221
+ "gradient_descent",
222
+ "sgd_update",
223
+ "train",
224
+ "fit",
225
+ "accuracy",
226
+ # shape inference
227
+ "eval_shape",
228
+ "infer_shapes",
229
+ "substitute",
230
+ "bind",
231
+ "summary",
232
+ "Summary",
233
+ "ShapeDtypeStruct",
234
+ "ShapedArray",
235
+ "ShapeError",
236
+ "Dim",
237
+ # compile to other frameworks (torch / tf / jax)
238
+ "compile_to",
239
+ "get_backend",
240
+ # graph-capture IR + optimization passes
241
+ "capture",
242
+ "eval_graph",
243
+ "optimize",
244
+ "grad_graph",
245
+ "jit",
246
+ "Graph",
247
+ # static cost model over the capture IR (CPU / memory / disk)
248
+ "cost_report",
249
+ "CostModel",
250
+ "GraphCost",
251
+ "NodeCost",
252
+ "DEFAULT_COST_MODEL",
253
+ "calibrate",
254
+ # rematerialization / spill planning + memory-managed execution
255
+ "plan_remat",
256
+ "RematPlan",
257
+ "Decision",
258
+ "apply_remat_plan",
259
+ "eval_scheduled",
260
+ # device / array backend seam (numpy default, cupy for GPU)
261
+ "device",
262
+ "activate",
263
+ # working-dtype seam (float64 default; float32 / float16 / bfloat16)
264
+ "dtype",
265
+ "current_dtype",
266
+ "resolve_dtype",
267
+ # static export (standalone artifacts)
268
+ "to_torch_module",
269
+ "export_torchscript",
270
+ "export_onnx",
271
+ # optimizers
272
+ "Optimizer",
273
+ "SGD",
274
+ "Adam",
275
+ "AdamW",
276
+ "clip_grad_norm",
277
+ "constant_lr",
278
+ "step_decay",
279
+ "cosine_decay",
280
+ # neural-net ops (stable softmax family, cross-entropy, activations)
281
+ "softmax",
282
+ "log_softmax",
283
+ "logsumexp",
284
+ "cross_entropy",
285
+ "relu",
286
+ "sigmoid",
287
+ "silu",
288
+ "swish",
289
+ "gelu",
290
+ "tanh",
291
+ "leaky_relu",
292
+ "elu",
293
+ "softplus",
294
+ "mish",
295
+ "hardswish",
296
+ "hardsigmoid",
297
+ "softsign",
298
+ "selu",
299
+ "conv1d",
300
+ "conv2d",
301
+ "causal_conv1d",
302
+ "conv_transpose1d",
303
+ "conv_transpose2d",
304
+ "streaming_conv1d",
305
+ "streaming_conv1d_init",
306
+ "streaming_conv2d",
307
+ "streaming_conv2d_init",
308
+ "streaming_conv_transpose1d",
309
+ "streaming_conv_transpose1d_init",
310
+ "streaming_conv_transpose2d",
311
+ "streaming_conv_transpose2d_init",
312
+ "upsample_nearest2d",
313
+ "max_pool2d",
314
+ "avg_pool2d",
315
+ "one_hot",
316
+ # neural-net layers (normalization, attention, embedding, linear, dropout)
317
+ "layer_norm",
318
+ "rms_norm",
319
+ "batch_norm",
320
+ "batch_norm_init",
321
+ "group_norm",
322
+ "instance_norm",
323
+ "scaled_dot_product_attention",
324
+ "multi_head_attention",
325
+ "embedding",
326
+ "linear",
327
+ "dropout",
328
+ # splittable PRNG keys (pycograd.random: key / split / fold_in + samplers)
329
+ "random",
330
+ # data / batching
331
+ "batches",
332
+ "DataLoader",
333
+ # pytrees
334
+ "tree_flatten",
335
+ "tree_unflatten",
336
+ "tree_leaves",
337
+ "tree_structure",
338
+ "tree_map",
339
+ # tracer / interception
340
+ "AutodiffTracer",
341
+ "resolve_call",
342
+ "AutodiffWarning",
343
+ "register_pipescript_params_macro",
344
+ # ipython / jupyter extension
345
+ "load_ipython_extension",
346
+ "unload_ipython_extension",
347
+ # differentiable primitives
348
+ "d_exp",
349
+ "d_log",
350
+ "d_sin",
351
+ "d_cos",
352
+ "d_tanh",
353
+ "d_sqrt",
354
+ "d_sigmoid",
355
+ "d_abs",
356
+ "d_square",
357
+ "d_sinh",
358
+ "d_cosh",
359
+ "d_arctan",
360
+ "d_log1p",
361
+ "d_expm1",
362
+ "d_reciprocal",
363
+ "d_maximum",
364
+ "d_minimum",
365
+ "d_clip",
366
+ "d_where",
367
+ "d_sum",
368
+ "d_mean",
369
+ "d_var",
370
+ "d_std",
371
+ "d_max",
372
+ "d_min",
373
+ "d_softmax",
374
+ "d_logsumexp",
375
+ "d_concatenate",
376
+ "d_transpose",
377
+ "d_reshape",
378
+ "d_expand_dims",
379
+ "d_stack",
380
+ "d_vstack",
381
+ "d_hstack",
382
+ "d_column_stack",
383
+ "d_dstack",
384
+ "d_einsum",
385
+ "einsum",
386
+ "d_cumsum",
387
+ "cumsum",
388
+ "d_gated_act",
389
+ "gated_act",
390
+ ]
@@ -0,0 +1,148 @@
1
+ # -*- coding: utf-8 -*-
2
+ """Dimension-equality constraints for shape polymorphism.
3
+
4
+ When shape inference runs over *symbolic* input dims (e.g. a batch ``B`` declared via
5
+ ``ShapeDtypeStruct(("B", 768))``), each contraction registers an equality: a matmul
6
+ asserts its inner dims equal, concatenate asserts its non-axis dims equal, broadcasting
7
+ asserts compatible dims equal. :class:`ConstraintEnv` is the union-find that records
8
+ those equalities, refines a symbol pinned to a concrete (``K`` forced to ``4``), and
9
+ reports a contradiction (two concretes forced equal) as a shape error.
10
+
11
+ Only *solvable* symbols -- caller-declared input dims, whose key is a ``str`` -- get
12
+ bound to concretes or merged. *Data-dependent* symbols (a mask count, a broadcast;
13
+ their key is a tuple) are runtime facts, not statically known, so they are left opaque,
14
+ preserving the optimistic "carry it forward" behavior of plain symbolic inference.
15
+ """
16
+ from __future__ import annotations
17
+
18
+ import contextlib
19
+ import contextvars
20
+ from typing import Any, Hashable, Iterator, cast
21
+
22
+ from pycograd import _dims
23
+ from pycograd._dims import Dim
24
+
25
+
26
+ def _as_atom(x: int | Dim) -> tuple:
27
+ """Classify a dim as ``("int", v)``, ``("sym", key, name)``, or ``("expr",)``."""
28
+ if isinstance(x, Dim):
29
+ s = x.as_symbol()
30
+ return ("sym", s[0], s[1]) if s is not None else ("expr",)
31
+ if _dims._is_int(x):
32
+ return ("int", int(cast(Any, x)))
33
+ return ("expr",)
34
+
35
+
36
+ class ConstraintEnv:
37
+ """Union-find over symbol keys with at most one concrete value per class."""
38
+
39
+ def __init__(self) -> None:
40
+ self.parent: dict[Hashable, Hashable] = {} # key -> parent key
41
+ self.value: dict[Hashable, int] = {} # root key -> concrete int
42
+ self.name: dict[Hashable, str] = {} # key -> rendered name
43
+
44
+ def _add(self, key: Hashable, name: str) -> None:
45
+ if key not in self.parent:
46
+ self.parent[key] = key
47
+ self.name[key] = name
48
+
49
+ def _find(self, key: Hashable) -> Hashable:
50
+ root = key
51
+ while self.parent[root] != root:
52
+ root = self.parent[root]
53
+ while self.parent[key] != root: # path compression
54
+ self.parent[key], key = root, self.parent[key]
55
+ return root
56
+
57
+ @staticmethod
58
+ def _solvable(key: Hashable) -> bool:
59
+ # Caller-declared input dims have string keys; data-dependent symbols (nonzero,
60
+ # bcast, slice) have tuple keys and are never bound/merged.
61
+ return isinstance(key, str)
62
+
63
+ def assert_eq(self, a: int | Dim, b: int | Dim) -> bool:
64
+ """Record ``a == b``; return ``False`` if that is a provable contradiction."""
65
+ ta, tb = _as_atom(a), _as_atom(b)
66
+ if ta[0] == "int" and tb[0] == "int":
67
+ return ta[1] == tb[1]
68
+ if ta[0] == "int" and tb[0] == "sym":
69
+ return self._bind(tb, ta[1])
70
+ if ta[0] == "sym" and tb[0] == "int":
71
+ return self._bind(ta, tb[1])
72
+ if ta[0] == "sym" and tb[0] == "sym":
73
+ return self._union(ta, tb)
74
+ return True # an expression is involved -- can't reason, stay optimistic
75
+
76
+ def _bind(self, sym: tuple, val: int) -> bool:
77
+ _, key, name = sym
78
+ if not self._solvable(key):
79
+ return True # data-dependent: a runtime fact, never statically pinned
80
+ self._add(key, name)
81
+ root = self._find(key)
82
+ cur = self.value.get(root)
83
+ if cur is not None and cur != val:
84
+ return False
85
+ self.value[root] = val
86
+ return True
87
+
88
+ def _union(self, s1: tuple, s2: tuple) -> bool:
89
+ _, k1, n1 = s1
90
+ _, k2, n2 = s2
91
+ if not (self._solvable(k1) and self._solvable(k2)):
92
+ return True # leave data-dependent symbols opaque
93
+ self._add(k1, n1)
94
+ self._add(k2, n2)
95
+ r1, r2 = self._find(k1), self._find(k2)
96
+ if r1 == r2:
97
+ return True
98
+ v1, v2 = self.value.get(r1), self.value.get(r2)
99
+ if v1 is not None and v2 is not None and v1 != v2:
100
+ return False
101
+ self.parent[r2] = r1
102
+ if v1 is None and v2 is not None:
103
+ self.value[r1] = v2
104
+ return True
105
+
106
+ def mapping(self) -> dict[Hashable, int | Dim]:
107
+ """A substitution mapping each known symbol key to its concrete value (if its
108
+ class is pinned) or to its class representative symbol (if merged)."""
109
+ m: dict[Hashable, int | Dim] = {}
110
+ for key in self.parent:
111
+ root = self._find(key)
112
+ v = self.value.get(root)
113
+ if v is not None:
114
+ m[key] = v
115
+ elif root != key:
116
+ m[key] = _dims.symbol(root, name=self.name[root])
117
+ return m
118
+
119
+
120
+ # ---------------------------------------------------------------------------
121
+ # Active environment (entered for the duration of an abstract inference run).
122
+ # ---------------------------------------------------------------------------
123
+ _env: "contextvars.ContextVar[ConstraintEnv | None]" = contextvars.ContextVar(
124
+ "dim_env", default=None
125
+ )
126
+
127
+
128
+ @contextlib.contextmanager
129
+ def constraint_scope() -> Iterator[ConstraintEnv]:
130
+ env = ConstraintEnv()
131
+ token = _env.set(env)
132
+ try:
133
+ yield env
134
+ finally:
135
+ _env.reset(token)
136
+
137
+
138
+ def active_env() -> "ConstraintEnv | None":
139
+ return _env.get()
140
+
141
+
142
+ def register_eq(a: int | Dim, b: int | Dim) -> bool:
143
+ """Register ``a == b`` with the active env (if any); ``False`` on a provable
144
+ contradiction. With no active env, falls back to the concrete-only check."""
145
+ env = _env.get()
146
+ if env is None:
147
+ return not _dims.provably_unequal(a, b)
148
+ return env.assert_eq(a, b)