json2vec 0.4.7__tar.gz → 0.4.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {json2vec-0.4.7/src/json2vec.egg-info → json2vec-0.4.8}/PKG-INFO +8 -5
- {json2vec-0.4.7 → json2vec-0.4.8}/README.md +1 -1
- {json2vec-0.4.7 → json2vec-0.4.8}/pyproject.toml +12 -5
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/__init__.py +4 -10
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/architecture/contracts.py +5 -4
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/architecture/mutations.py +54 -9
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/architecture/root.py +46 -24
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/data/datasets/base.py +8 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/data/datasets/custom.py +11 -2
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/data/datasets/polars.py +11 -2
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/data/datasets/streaming.py +11 -2
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/data/iterables.py +57 -4
- json2vec-0.4.8/src/json2vec/helpers/hyperparameters.py +0 -0
- json2vec-0.4.8/src/json2vec/helpers/optimizers.py +78 -0
- json2vec-0.4.8/src/json2vec/helpers/trainer.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/inference/__init__.py +5 -11
- json2vec-0.4.8/src/json2vec/inference/deployment.py +691 -0
- json2vec-0.4.8/src/json2vec/structs/__init__.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/structs/enums.py +0 -1
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/structs/experiment.py +40 -10
- json2vec-0.4.8/src/json2vec/structs/structure.py +110 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/structs/tree.py +147 -2
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/tensorfields/base.py +19 -39
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/tensorfields/extensions/category.py +50 -57
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/tensorfields/extensions/number.py +42 -41
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/tensorfields/extensions/set.py +30 -64
- json2vec-0.4.8/src/json2vec/tensorfields/shared/__init__.py +80 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/tensorfields/shared/counter.py +3 -1
- json2vec-0.4.8/src/json2vec/tensorfields/shared/vocabulary.py +436 -0
- {json2vec-0.4.7 → json2vec-0.4.8/src/json2vec.egg-info}/PKG-INFO +8 -5
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec.egg-info/SOURCES.txt +5 -1
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec.egg-info/requires.txt +6 -3
- {json2vec-0.4.7 → json2vec-0.4.8}/tests/test_callbacks.py +6 -30
- json2vec-0.4.8/tests/test_optimizers.py +78 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/tests/test_public_api.py +1 -3
- json2vec-0.4.7/src/json2vec/architecture/plot.py +0 -562
- json2vec-0.4.7/src/json2vec/inference/deployment.py +0 -422
- json2vec-0.4.7/src/json2vec/structs/structure.py +0 -59
- json2vec-0.4.7/src/json2vec/tensorfields/shared/__init__.py +0 -12
- json2vec-0.4.7/src/json2vec/tensorfields/shared/vocabulary.py +0 -283
- {json2vec-0.4.7 → json2vec-0.4.8}/LICENSE +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/NOTICE +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/setup.cfg +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/architecture/__init__.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/architecture/attention.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/architecture/checkpoint.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/architecture/encoder.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/architecture/graph.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/architecture/node.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/architecture/pool.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/architecture/rotary.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/architecture/runtime.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/data/__init__.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/data/datasets/__init__.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/data/processing.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/distributed.py +0 -0
- {json2vec-0.4.7/src/json2vec/structs → json2vec-0.4.8/src/json2vec/helpers}/__init__.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/inference/callback.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/logging/__init__.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/logging/config.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/logging/epoch.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/logging/throughput.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/preprocessors/__init__.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/preprocessors/base.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/preprocessors/extensions/__init__.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/preprocessors/spec.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/structs/packages.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/structs/selectors.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/tensorfields/__init__.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/tensorfields/extensions/__init__.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/tensorfields/extensions/dateparts.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/tensorfields/extensions/entity.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/tensorfields/extensions/text.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/tensorfields/extensions/vector.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/tensorfields/spec.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec.egg-info/dependency_links.txt +0 -0
- {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: json2vec
|
|
3
|
-
Version: 0.4.
|
|
3
|
+
Version: 0.4.8
|
|
4
4
|
Summary: Schema-first PyTorch models for hierarchical / nested / sequence data structures
|
|
5
5
|
License-Expression: Apache-2.0
|
|
6
6
|
Requires-Python: >=3.12
|
|
@@ -14,7 +14,6 @@ Requires-Dist: pydantic>=2.11.7
|
|
|
14
14
|
Requires-Dist: jmespath>=1.0.1
|
|
15
15
|
Requires-Dist: loguru>=0.7.3
|
|
16
16
|
Requires-Dist: anytree>=2.13.0
|
|
17
|
-
Requires-Dist: ordered-set>=4.1.0
|
|
18
17
|
Requires-Dist: pyarrow>=21.0.0
|
|
19
18
|
Requires-Dist: polars>=1.35.2
|
|
20
19
|
Requires-Dist: numpy>=2.2.6
|
|
@@ -22,16 +21,20 @@ Requires-Dist: lightning>=2.6.4
|
|
|
22
21
|
Requires-Dist: tensordict>=0.10.0
|
|
23
22
|
Requires-Dist: torch>=2.7.1
|
|
24
23
|
Provides-Extra: serving
|
|
25
|
-
Requires-Dist:
|
|
24
|
+
Requires-Dist: fastapi>=0.124.0; extra == "serving"
|
|
25
|
+
Requires-Dist: orjson>=3.10.0; extra == "serving"
|
|
26
26
|
Requires-Dist: pydantic-settings>=2.10.1; extra == "serving"
|
|
27
|
+
Requires-Dist: uvicorn>=0.38.0; extra == "serving"
|
|
27
28
|
Provides-Extra: text
|
|
28
29
|
Requires-Dist: transformers>=4.55.0; extra == "text"
|
|
29
30
|
Provides-Extra: docs
|
|
30
|
-
Requires-Dist:
|
|
31
|
+
Requires-Dist: fastapi>=0.124.0; extra == "docs"
|
|
31
32
|
Requires-Dist: mkdocs-material>=9.6; extra == "docs"
|
|
32
33
|
Requires-Dist: mkdocs-jupyter>=0.26.3; extra == "docs"
|
|
33
34
|
Requires-Dist: mkdocstrings[python]>=0.27; extra == "docs"
|
|
35
|
+
Requires-Dist: orjson>=3.10.0; extra == "docs"
|
|
34
36
|
Requires-Dist: pydantic-settings>=2.10.1; extra == "docs"
|
|
37
|
+
Requires-Dist: uvicorn>=0.38.0; extra == "docs"
|
|
35
38
|
Dynamic: license-file
|
|
36
39
|
|
|
37
40
|
<h1 align="center"><code>json2vec</code></h1>
|
|
@@ -314,7 +317,7 @@ uv sync --extra docs
|
|
|
314
317
|
```
|
|
315
318
|
|
|
316
319
|
The `text` extra installs Hugging Face `transformers`. The `serving` extra
|
|
317
|
-
installs
|
|
320
|
+
installs FastAPI-backed deployment dependencies. The `docs` extra installs the
|
|
318
321
|
MkDocs toolchain.
|
|
319
322
|
|
|
320
323
|
## Documentation Map
|
|
@@ -278,7 +278,7 @@ uv sync --extra docs
|
|
|
278
278
|
```
|
|
279
279
|
|
|
280
280
|
The `text` extra installs Hugging Face `transformers`. The `serving` extra
|
|
281
|
-
installs
|
|
281
|
+
installs FastAPI-backed deployment dependencies. The `docs` extra installs the
|
|
282
282
|
MkDocs toolchain.
|
|
283
283
|
|
|
284
284
|
## Documentation Map
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "json2vec"
|
|
3
|
-
version = "0.4.
|
|
3
|
+
version = "0.4.8"
|
|
4
4
|
description = "Schema-first PyTorch models for hierarchical / nested / sequence data structures"
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
license = "Apache-2.0"
|
|
@@ -13,7 +13,6 @@ dependencies = [
|
|
|
13
13
|
"jmespath>=1.0.1",
|
|
14
14
|
"loguru>=0.7.3",
|
|
15
15
|
"anytree>=2.13.0",
|
|
16
|
-
"ordered-set>=4.1.0",
|
|
17
16
|
"pyarrow>=21.0.0",
|
|
18
17
|
"polars>=1.35.2",
|
|
19
18
|
"numpy>=2.2.6",
|
|
@@ -24,30 +23,37 @@ dependencies = [
|
|
|
24
23
|
|
|
25
24
|
[project.optional-dependencies]
|
|
26
25
|
serving = [
|
|
27
|
-
"
|
|
26
|
+
"fastapi>=0.124.0",
|
|
27
|
+
"orjson>=3.10.0",
|
|
28
28
|
"pydantic-settings>=2.10.1",
|
|
29
|
+
"uvicorn>=0.38.0",
|
|
29
30
|
]
|
|
30
31
|
text = [
|
|
31
32
|
"transformers>=4.55.0",
|
|
32
33
|
]
|
|
33
34
|
docs = [
|
|
34
|
-
"
|
|
35
|
+
"fastapi>=0.124.0",
|
|
35
36
|
"mkdocs-material>=9.6",
|
|
36
37
|
"mkdocs-jupyter>=0.26.3",
|
|
37
38
|
"mkdocstrings[python]>=0.27",
|
|
39
|
+
"orjson>=3.10.0",
|
|
38
40
|
"pydantic-settings>=2.10.1",
|
|
41
|
+
"uvicorn>=0.38.0",
|
|
39
42
|
]
|
|
40
43
|
|
|
41
44
|
[dependency-groups]
|
|
42
45
|
dev = [
|
|
43
46
|
"ruff>=0.12.12",
|
|
44
47
|
"pytest>=8.4.1",
|
|
48
|
+
"pytest-xdist>=3.8.0",
|
|
45
49
|
"ipython>=9.9.0",
|
|
46
50
|
"ipykernel>=6.29.5",
|
|
47
51
|
"nbclient>=0.10.2",
|
|
48
52
|
"nbformat>=5.10.4",
|
|
49
|
-
"
|
|
53
|
+
"fastapi>=0.124.0",
|
|
54
|
+
"orjson>=3.10.0",
|
|
50
55
|
"pydantic-settings>=2.10.1",
|
|
56
|
+
"uvicorn>=0.38.0",
|
|
51
57
|
"ty>=0.0.1a20",
|
|
52
58
|
"pre-commit>=4.3.0",
|
|
53
59
|
]
|
|
@@ -66,6 +72,7 @@ include = ["json2vec*"]
|
|
|
66
72
|
[tool.pytest.ini_options]
|
|
67
73
|
testpaths = ["tests"]
|
|
68
74
|
python_files = ["test_*.py"]
|
|
75
|
+
addopts = ["-n", "auto"]
|
|
69
76
|
|
|
70
77
|
[tool.ruff]
|
|
71
78
|
line-length = 120
|
|
@@ -52,23 +52,19 @@ from json2vec.tensorfields.shared.vocabulary import VocabularySyncCallback
|
|
|
52
52
|
|
|
53
53
|
if TYPE_CHECKING:
|
|
54
54
|
from json2vec.inference.deployment import (
|
|
55
|
-
API,
|
|
56
55
|
Accelerator,
|
|
57
|
-
BatchItem,
|
|
58
56
|
Deployment,
|
|
59
|
-
ErrorItem,
|
|
60
57
|
Input,
|
|
58
|
+
JSONBackend,
|
|
61
59
|
ModelSource,
|
|
62
60
|
UpdateOperation,
|
|
63
61
|
)
|
|
64
62
|
|
|
65
63
|
_SERVING_EXPORTS = {
|
|
66
|
-
"API",
|
|
67
64
|
"Accelerator",
|
|
68
|
-
"BatchItem",
|
|
69
65
|
"Deployment",
|
|
70
|
-
"ErrorItem",
|
|
71
66
|
"Input",
|
|
67
|
+
"JSONBackend",
|
|
72
68
|
"ModelSource",
|
|
73
69
|
"UpdateOperation",
|
|
74
70
|
}
|
|
@@ -81,7 +77,7 @@ def __getattr__(name: str) -> Any:
|
|
|
81
77
|
try:
|
|
82
78
|
from json2vec.inference import deployment
|
|
83
79
|
except ModuleNotFoundError as error:
|
|
84
|
-
if error.name in {"
|
|
80
|
+
if error.name in {"fastapi", "orjson", "pydantic_settings", "uvicorn"}:
|
|
85
81
|
raise ModuleNotFoundError(
|
|
86
82
|
f"json2vec.{name} requires the serving extra; install with `pip install json2vec[serving]`."
|
|
87
83
|
) from error
|
|
@@ -98,11 +94,9 @@ def __dir__() -> list[str]:
|
|
|
98
94
|
|
|
99
95
|
__all__ = [
|
|
100
96
|
"Address",
|
|
101
|
-
"API",
|
|
102
97
|
"Accelerator",
|
|
103
98
|
"Array",
|
|
104
99
|
"AttentionMode",
|
|
105
|
-
"BatchItem",
|
|
106
100
|
"Category",
|
|
107
101
|
"Component",
|
|
108
102
|
"CustomDataModule",
|
|
@@ -111,9 +105,9 @@ __all__ = [
|
|
|
111
105
|
"Deployment",
|
|
112
106
|
"EmbedderBase",
|
|
113
107
|
"Entity",
|
|
114
|
-
"ErrorItem",
|
|
115
108
|
"Hyperparameters",
|
|
116
109
|
"Input",
|
|
110
|
+
"JSONBackend",
|
|
117
111
|
"Leaf",
|
|
118
112
|
"Metric",
|
|
119
113
|
"Model",
|
|
@@ -3,9 +3,9 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
from collections.abc import Iterator, Mapping
|
|
6
|
-
from dataclasses import dataclass, field
|
|
7
6
|
from typing import TYPE_CHECKING, Any
|
|
8
7
|
|
|
8
|
+
import pydantic
|
|
9
9
|
import torch
|
|
10
10
|
from tensordict import TensorDict
|
|
11
11
|
|
|
@@ -34,12 +34,13 @@ ContractSignature = tuple[Any, ...]
|
|
|
34
34
|
ContractScope = tuple[str, int, int, ContractSignature]
|
|
35
35
|
|
|
36
36
|
|
|
37
|
-
|
|
38
|
-
class ContractScheduler:
|
|
37
|
+
class ContractScheduler(pydantic.BaseModel):
|
|
39
38
|
"""Deterministic backoff scheduler for expensive forward contract checks."""
|
|
40
39
|
|
|
40
|
+
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
|
|
41
|
+
|
|
41
42
|
periodic_interval: int = 1024
|
|
42
|
-
_counts: dict[ContractScope, int] =
|
|
43
|
+
_counts: dict[ContractScope, int] = pydantic.PrivateAttr(default_factory=dict)
|
|
43
44
|
|
|
44
45
|
def reset(self) -> None:
|
|
45
46
|
self._counts.clear()
|
|
@@ -4,9 +4,9 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
from collections.abc import Callable, Iterator
|
|
6
6
|
from contextlib import contextmanager
|
|
7
|
-
from dataclasses import dataclass
|
|
8
7
|
from typing import TYPE_CHECKING, Any
|
|
9
8
|
|
|
9
|
+
import pydantic
|
|
10
10
|
from loguru import logger
|
|
11
11
|
|
|
12
12
|
from json2vec.architecture.graph import ModelGraph
|
|
@@ -20,12 +20,18 @@ if TYPE_CHECKING:
|
|
|
20
20
|
_MISSING = object()
|
|
21
21
|
|
|
22
22
|
|
|
23
|
-
|
|
24
|
-
|
|
23
|
+
class AttributeChange(pydantic.BaseModel):
|
|
24
|
+
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
|
|
25
|
+
|
|
25
26
|
node: Node
|
|
26
27
|
name: str
|
|
27
28
|
original: Any
|
|
28
29
|
definition_attribute: bool
|
|
30
|
+
address: str
|
|
31
|
+
node_name: str
|
|
32
|
+
node_type: str
|
|
33
|
+
changed: Any = _MISSING
|
|
34
|
+
changed_address: Any = _MISSING
|
|
29
35
|
|
|
30
36
|
|
|
31
37
|
class SchemaEditor:
|
|
@@ -208,6 +214,9 @@ class SchemaEditor:
|
|
|
208
214
|
name=name,
|
|
209
215
|
original=getattr(node, name, _MISSING),
|
|
210
216
|
definition_attribute=_is_definition_attribute(node, name),
|
|
217
|
+
address=str(node.address),
|
|
218
|
+
node_name=node.name,
|
|
219
|
+
node_type=node.type,
|
|
211
220
|
)
|
|
212
221
|
)
|
|
213
222
|
|
|
@@ -280,29 +289,55 @@ class SchemaEditor:
|
|
|
280
289
|
|
|
281
290
|
def _log_attribute_changes(self, action: str, changes: list[AttributeChange], *, restored: bool = False) -> None:
|
|
282
291
|
for change in changes:
|
|
292
|
+
current_address = str(change.node.address)
|
|
283
293
|
value = change.original if restored else getattr(change.node, change.name, _MISSING)
|
|
294
|
+
if not restored:
|
|
295
|
+
change.changed = value
|
|
296
|
+
change.changed_address = current_address
|
|
297
|
+
previous_value = change.changed if restored else change.original
|
|
298
|
+
previous_address = change.changed_address if restored else change.address
|
|
299
|
+
if previous_address is _MISSING:
|
|
300
|
+
previous_address = change.address
|
|
301
|
+
address_context = (
|
|
302
|
+
current_address if previous_address == current_address else f"{previous_address} -> {current_address}"
|
|
303
|
+
)
|
|
304
|
+
value_text = _format_log_value(value)
|
|
305
|
+
previous_value_text = _format_log_value(previous_value)
|
|
284
306
|
logger.bind(
|
|
285
307
|
component="schema_mutation",
|
|
286
308
|
action=action,
|
|
287
|
-
address=
|
|
288
|
-
|
|
309
|
+
address=current_address,
|
|
310
|
+
previous_address=previous_address,
|
|
311
|
+
node_name=change.node.name,
|
|
312
|
+
previous_node_name=change.node_name,
|
|
313
|
+
node_type=change.node_type,
|
|
289
314
|
attribute=change.name,
|
|
290
315
|
definition_attribute=change.definition_attribute,
|
|
291
|
-
value=
|
|
292
|
-
previous_value=
|
|
293
|
-
|
|
316
|
+
value=value_text,
|
|
317
|
+
previous_value=previous_value_text,
|
|
318
|
+
change=f"{change.name}: {previous_value_text} -> {value_text}",
|
|
319
|
+
).info(
|
|
320
|
+
"{} {}: {} {} -> {}",
|
|
321
|
+
"restored" if restored else "mutated",
|
|
322
|
+
address_context,
|
|
323
|
+
change.name,
|
|
324
|
+
previous_value_text,
|
|
325
|
+
value_text,
|
|
326
|
+
)
|
|
294
327
|
|
|
295
328
|
def _log_node_mutation(self, *, action: str, message: str, node: Node, **kwargs: Any) -> None:
|
|
296
329
|
extra = {key: str(value.address) if isinstance(value, Node) else value for key, value in kwargs.items()}
|
|
330
|
+
context = _format_node_log_context(node, extra)
|
|
297
331
|
logger.bind(
|
|
298
332
|
component="schema_mutation",
|
|
299
333
|
action=action,
|
|
300
334
|
address=str(node.address),
|
|
301
335
|
node_type=node.type,
|
|
336
|
+
node_name=node.name,
|
|
302
337
|
attribute=None,
|
|
303
338
|
definition_attribute=None,
|
|
304
339
|
**extra,
|
|
305
|
-
).info(message)
|
|
340
|
+
).info("{} {}", message, context)
|
|
306
341
|
|
|
307
342
|
|
|
308
343
|
def _has_node_attribute(node: Node, name: str) -> bool:
|
|
@@ -321,3 +356,13 @@ def _format_log_value(value: Any) -> str:
|
|
|
321
356
|
|
|
322
357
|
text = repr(value)
|
|
323
358
|
return text if len(text) <= 160 else f"{text[:157]}..."
|
|
359
|
+
|
|
360
|
+
|
|
361
|
+
def _format_node_log_context(node: Node, extra: dict[str, Any]) -> str:
|
|
362
|
+
parts = [str(node.address)]
|
|
363
|
+
if parent := extra.get("parent"):
|
|
364
|
+
parts.append(f"under {parent}")
|
|
365
|
+
if "descendants" in extra:
|
|
366
|
+
parts.append(f"descendants={extra['descendants']}")
|
|
367
|
+
|
|
368
|
+
return " ".join(parts)
|
|
@@ -13,15 +13,16 @@ from beartype import beartype
|
|
|
13
13
|
from lightning.pytorch import Callback
|
|
14
14
|
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
15
15
|
from loguru import logger
|
|
16
|
+
from rich.text import Text
|
|
16
17
|
from tensordict import TensorDict
|
|
17
18
|
|
|
18
19
|
from json2vec.architecture.checkpoint import CheckpointState
|
|
19
20
|
from json2vec.architecture.contracts import ContractScheduler
|
|
20
21
|
from json2vec.architecture.graph import ModelGraph
|
|
21
22
|
from json2vec.architecture.mutations import SchemaEditor
|
|
22
|
-
from json2vec.architecture.plot import PlotMode
|
|
23
23
|
from json2vec.architecture.runtime import ModelRuntime, Postprocessor, Preprocessor, step
|
|
24
24
|
from json2vec.data.datasets.base import EncodedBatch, EncodedInput
|
|
25
|
+
from json2vec.logging.throughput import ThroughputLogger
|
|
25
26
|
from json2vec.structs.enums import AttentionMode, Strata
|
|
26
27
|
from json2vec.structs.experiment import (
|
|
27
28
|
Hyperparameters,
|
|
@@ -30,7 +31,7 @@ from json2vec.structs.experiment import (
|
|
|
30
31
|
SchemaField,
|
|
31
32
|
)
|
|
32
33
|
from json2vec.structs.packages import Prediction
|
|
33
|
-
from json2vec.structs.tree import Address, Node, Rate
|
|
34
|
+
from json2vec.structs.tree import Address, Node, Rate, Renderable
|
|
34
35
|
from json2vec.tensorfields.base import TENSORFIELDS, Plugin, TensorFieldBase
|
|
35
36
|
|
|
36
37
|
OptimizerConfig = torch.optim.Optimizer | Callable[["Model"], torch.optim.Optimizer]
|
|
@@ -137,12 +138,12 @@ class RollbackCheckpoint(ModelCheckpoint):
|
|
|
137
138
|
).info("rolled back Model to best checkpoint")
|
|
138
139
|
|
|
139
140
|
|
|
140
|
-
class Model(lit.LightningModule):
|
|
141
|
+
class Model(lit.LightningModule, Renderable):
|
|
141
142
|
"""Neural model generated from a `json2vec` schema tree.
|
|
142
143
|
|
|
143
144
|
`Model` owns the schema hyperparameters, tensorfield embedders, array
|
|
144
145
|
encoders, decoders, and convenience methods for prediction, checkpointing,
|
|
145
|
-
|
|
146
|
+
schema display and mutation.
|
|
146
147
|
|
|
147
148
|
Example:
|
|
148
149
|
```python
|
|
@@ -265,6 +266,45 @@ class Model(lit.LightningModule):
|
|
|
265
266
|
self._contract_generation += 1
|
|
266
267
|
self._contract_scheduler.reset()
|
|
267
268
|
|
|
269
|
+
def __rich_console__(self, console, options):
|
|
270
|
+
parameters = sum(parameter.numel() for parameter in self.parameters())
|
|
271
|
+
heading = Text()
|
|
272
|
+
heading.append(type(self).__name__, style=self.RICH_NAME_STYLE)
|
|
273
|
+
heading.append(" ")
|
|
274
|
+
heading.append("[model]", style=self.RICH_TYPE_STYLE)
|
|
275
|
+
for name, value in (
|
|
276
|
+
("batch_size", self.batch_size),
|
|
277
|
+
("d_model", self.hyperparameters.d_model),
|
|
278
|
+
("parameters", f"{parameters:,}"),
|
|
279
|
+
("arrays", len(self.hyperparameters.arrays)),
|
|
280
|
+
("fields", len(self.hyperparameters.active_requests)),
|
|
281
|
+
("targets", len(self.hyperparameters.target)),
|
|
282
|
+
("embeds", len(self.hyperparameters.embed)),
|
|
283
|
+
):
|
|
284
|
+
heading.append(" ")
|
|
285
|
+
heading.append(f"{name}=", style="dim")
|
|
286
|
+
heading.append(str(value), style="cyan")
|
|
287
|
+
yield heading
|
|
288
|
+
|
|
289
|
+
lines = list(self.hyperparameters.fields.__rich_console__(console, options))
|
|
290
|
+
if not lines:
|
|
291
|
+
return
|
|
292
|
+
first = Text()
|
|
293
|
+
first.append("`-- ", style=self.RICH_TREE_STYLE)
|
|
294
|
+
if isinstance(lines[0], Text):
|
|
295
|
+
first.append_text(lines[0])
|
|
296
|
+
else:
|
|
297
|
+
first.append(str(lines[0]))
|
|
298
|
+
yield first
|
|
299
|
+
for line in lines[1:]:
|
|
300
|
+
nested = Text()
|
|
301
|
+
nested.append(" ", style=self.RICH_TREE_STYLE)
|
|
302
|
+
if isinstance(line, Text):
|
|
303
|
+
nested.append_text(line)
|
|
304
|
+
else:
|
|
305
|
+
nested.append(str(line))
|
|
306
|
+
yield nested
|
|
307
|
+
|
|
268
308
|
def select(
|
|
269
309
|
self,
|
|
270
310
|
*predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
|
|
@@ -382,6 +422,8 @@ class Model(lit.LightningModule):
|
|
|
382
422
|
callbacks.append(RuntimePlacementCallback())
|
|
383
423
|
if MutationLockCallback not in attached_callback_types:
|
|
384
424
|
callbacks.append(MutationLockCallback())
|
|
425
|
+
if ThroughputLogger not in attached_callback_types:
|
|
426
|
+
callbacks.append(ThroughputLogger())
|
|
385
427
|
|
|
386
428
|
for request in self.hyperparameters.active_requests.values():
|
|
387
429
|
plugin: Plugin = TENSORFIELDS[request.type]
|
|
@@ -438,26 +480,6 @@ class Model(lit.LightningModule):
|
|
|
438
480
|
if hasattr(node, "embedder") and hasattr(node.embedder, "interprocess_encoding_context")
|
|
439
481
|
}
|
|
440
482
|
|
|
441
|
-
def plot(
|
|
442
|
-
self,
|
|
443
|
-
address: Address | str | None = None,
|
|
444
|
-
detail: bool = False,
|
|
445
|
-
out: str | Path | None = None,
|
|
446
|
-
mode: PlotMode = "schema",
|
|
447
|
-
) -> None:
|
|
448
|
-
"""Print a Rich model visualization.
|
|
449
|
-
|
|
450
|
-
Args:
|
|
451
|
-
address: Optional subtree address to render.
|
|
452
|
-
detail: Include tensorfield-specific detail sections.
|
|
453
|
-
out: Optional output path for the rendered console text.
|
|
454
|
-
mode: Plot mode. Supported values are `schema`, `state`, `flow`,
|
|
455
|
-
and `debug`.
|
|
456
|
-
"""
|
|
457
|
-
from json2vec.architecture.plot import plot
|
|
458
|
-
|
|
459
|
-
return plot(module=self, address=address, detail=detail, out=out, mode=mode)
|
|
460
|
-
|
|
461
483
|
@beartype
|
|
462
484
|
def save(self, pathname: str | Path) -> str | Path:
|
|
463
485
|
"""Save model weights and schema hyperparameters to a checkpoint."""
|
|
@@ -90,5 +90,13 @@ def _is_assigned_to_worker(shard_key: str, worker_id: int, num_workers: int) ->
|
|
|
90
90
|
return owner == worker_id
|
|
91
91
|
|
|
92
92
|
|
|
93
|
+
def share_interprocess_encoding_context(context: InterprocessEncodingContext) -> None:
|
|
94
|
+
"""Opt encoding context resources into multiprocessing-safe storage."""
|
|
95
|
+
for field_context in context.values():
|
|
96
|
+
share = getattr(field_context, "share", None)
|
|
97
|
+
if callable(share):
|
|
98
|
+
share()
|
|
99
|
+
|
|
100
|
+
|
|
93
101
|
def identity(data: Any) -> Any:
|
|
94
102
|
return data
|
|
@@ -22,6 +22,7 @@ from json2vec.data.datasets.base import (
|
|
|
22
22
|
SampleRate,
|
|
23
23
|
StrataMap,
|
|
24
24
|
identity,
|
|
25
|
+
share_interprocess_encoding_context,
|
|
25
26
|
)
|
|
26
27
|
from json2vec.data.iterables import (
|
|
27
28
|
JMESPathResolutionMonitor,
|
|
@@ -299,15 +300,23 @@ class CustomDataModule(lit.LightningDataModule):
|
|
|
299
300
|
return None
|
|
300
301
|
raise ValueError(f"no dataset configured for strata: {strata}")
|
|
301
302
|
|
|
303
|
+
workers = self.num_workers[strata]
|
|
304
|
+
if workers is None:
|
|
305
|
+
workers = os.cpu_count() or 0
|
|
306
|
+
|
|
307
|
+
interprocess_encoding_context = self.interprocess_encoding_context
|
|
308
|
+
if strata == Strata.train and workers > 0:
|
|
309
|
+
share_interprocess_encoding_context(interprocess_encoding_context)
|
|
310
|
+
|
|
302
311
|
return custom_dataloader(
|
|
303
312
|
hyperparameters=self.hyperparameters,
|
|
304
313
|
dataset=self.datasets[strata],
|
|
305
314
|
preprocessor=self.preprocessor,
|
|
306
315
|
preprocessor_kwargs=self.preprocessor_kwargs,
|
|
307
|
-
interprocess_encoding_context=
|
|
316
|
+
interprocess_encoding_context=interprocess_encoding_context,
|
|
308
317
|
batch_size=self.batch_size,
|
|
309
318
|
strata=strata,
|
|
310
|
-
num_workers=
|
|
319
|
+
num_workers=workers,
|
|
311
320
|
persistent_workers=self.persistent_workers[strata],
|
|
312
321
|
pin_memory=self.pin_memory[strata],
|
|
313
322
|
observation_buffer_size=self.observation_buffer_size[strata],
|
|
@@ -26,6 +26,7 @@ from json2vec.data.datasets.base import (
|
|
|
26
26
|
_is_assigned_to_worker,
|
|
27
27
|
_worker_identity,
|
|
28
28
|
identity,
|
|
29
|
+
share_interprocess_encoding_context,
|
|
29
30
|
)
|
|
30
31
|
from json2vec.data.iterables import (
|
|
31
32
|
JMESPathResolutionMonitor,
|
|
@@ -355,15 +356,23 @@ class PolarsDataModule(lit.LightningDataModule):
|
|
|
355
356
|
return None
|
|
356
357
|
raise ValueError(f"no dataframe configured for strata: {strata}")
|
|
357
358
|
|
|
359
|
+
workers = self.num_workers[strata]
|
|
360
|
+
if workers is None:
|
|
361
|
+
workers = os.cpu_count() or 0
|
|
362
|
+
|
|
363
|
+
interprocess_encoding_context = self.interprocess_encoding_context
|
|
364
|
+
if strata == Strata.train and workers > 0:
|
|
365
|
+
share_interprocess_encoding_context(interprocess_encoding_context)
|
|
366
|
+
|
|
358
367
|
return polars_dataloader(
|
|
359
368
|
hyperparameters=self.hyperparameters,
|
|
360
369
|
dataframe=self.dataframes[strata],
|
|
361
370
|
preprocessor=self.preprocessor,
|
|
362
371
|
preprocessor_kwargs=self.preprocessor_kwargs,
|
|
363
|
-
interprocess_encoding_context=
|
|
372
|
+
interprocess_encoding_context=interprocess_encoding_context,
|
|
364
373
|
batch_size=self.batch_size,
|
|
365
374
|
strata=strata,
|
|
366
|
-
num_workers=
|
|
375
|
+
num_workers=workers,
|
|
367
376
|
persistent_workers=self.persistent_workers[strata],
|
|
368
377
|
pin_memory=self.pin_memory[strata],
|
|
369
378
|
sharding=self.sharding[strata],
|
|
@@ -31,6 +31,7 @@ from json2vec.data.datasets.base import (
|
|
|
31
31
|
_is_assigned_to_worker,
|
|
32
32
|
_worker_identity,
|
|
33
33
|
identity,
|
|
34
|
+
share_interprocess_encoding_context,
|
|
34
35
|
)
|
|
35
36
|
from json2vec.data.iterables import (
|
|
36
37
|
JMESPathResolutionMonitor,
|
|
@@ -500,6 +501,14 @@ class StreamingDataModule(lit.LightningDataModule):
|
|
|
500
501
|
global_rank = getattr(trainer, "global_rank", None)
|
|
501
502
|
world_size = getattr(trainer, "world_size", None)
|
|
502
503
|
|
|
504
|
+
workers = self.num_workers[strata]
|
|
505
|
+
if workers is None:
|
|
506
|
+
workers = os.cpu_count() or 0
|
|
507
|
+
|
|
508
|
+
interprocess_encoding_context = self.interprocess_encoding_context
|
|
509
|
+
if strata == Strata.train and workers > 0:
|
|
510
|
+
share_interprocess_encoding_context(interprocess_encoding_context)
|
|
511
|
+
|
|
503
512
|
return dataloader(
|
|
504
513
|
hyperparameters=self.hyperparameters,
|
|
505
514
|
root=self.root,
|
|
@@ -507,10 +516,10 @@ class StreamingDataModule(lit.LightningDataModule):
|
|
|
507
516
|
pattern=pattern,
|
|
508
517
|
preprocessor=self.preprocessor,
|
|
509
518
|
preprocessor_kwargs=self.preprocessor_kwargs,
|
|
510
|
-
interprocess_encoding_context=
|
|
519
|
+
interprocess_encoding_context=interprocess_encoding_context,
|
|
511
520
|
batch_size=self.batch_size,
|
|
512
521
|
strata=strata,
|
|
513
|
-
num_workers=
|
|
522
|
+
num_workers=workers,
|
|
514
523
|
persistent_workers=self.persistent_workers[strata],
|
|
515
524
|
pin_memory=self.pin_memory[strata],
|
|
516
525
|
sharding=self.sharding[strata],
|
|
@@ -4,8 +4,9 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
import inspect
|
|
6
6
|
import random
|
|
7
|
+
import re
|
|
7
8
|
from collections import Counter
|
|
8
|
-
from collections.abc import Iterable, Iterator
|
|
9
|
+
from collections.abc import Callable, Iterable, Iterator
|
|
9
10
|
from functools import cache
|
|
10
11
|
from typing import Annotated, Any, TypeVar, cast
|
|
11
12
|
|
|
@@ -138,6 +139,56 @@ def query(expression: str) -> jmespath.parser.ParsedResult:
|
|
|
138
139
|
return jmespath.compile(expression=f"[*]{expression}")
|
|
139
140
|
|
|
140
141
|
|
|
142
|
+
@cache
|
|
143
|
+
def compile_query_extractor(expression: str) -> Callable[[Any], Any] | None:
|
|
144
|
+
"""Compile a direct extractor for simple JMESPath projection queries."""
|
|
145
|
+
if not expression.startswith("[*]"):
|
|
146
|
+
return None
|
|
147
|
+
|
|
148
|
+
operations: list[tuple[str, str | None]] = [("projection", None), ("projection", None)]
|
|
149
|
+
index = 3
|
|
150
|
+
while index < len(expression):
|
|
151
|
+
if expression.startswith("[*]", index):
|
|
152
|
+
operations.append(("projection", None))
|
|
153
|
+
index += 3
|
|
154
|
+
continue
|
|
155
|
+
|
|
156
|
+
if not expression.startswith(".", index):
|
|
157
|
+
return None
|
|
158
|
+
|
|
159
|
+
match = re.match(r"\.([A-Za-z_][A-Za-z0-9_]*)", expression[index:])
|
|
160
|
+
if match is None:
|
|
161
|
+
return None
|
|
162
|
+
|
|
163
|
+
operations.append(("field", match.group(1)))
|
|
164
|
+
index += len(match.group(0))
|
|
165
|
+
|
|
166
|
+
def search(value: Any) -> Any:
|
|
167
|
+
def apply(item: Any, operation_index: int) -> Any:
|
|
168
|
+
if operation_index == len(operations):
|
|
169
|
+
return item
|
|
170
|
+
|
|
171
|
+
operation, key = operations[operation_index]
|
|
172
|
+
if operation == "field":
|
|
173
|
+
if not isinstance(item, dict) or key not in item:
|
|
174
|
+
return None
|
|
175
|
+
return apply(item[key], operation_index + 1)
|
|
176
|
+
|
|
177
|
+
if not isinstance(item, list):
|
|
178
|
+
return None
|
|
179
|
+
|
|
180
|
+
out = []
|
|
181
|
+
for child in item:
|
|
182
|
+
result = apply(child, operation_index + 1)
|
|
183
|
+
if result is not None:
|
|
184
|
+
out.append(result)
|
|
185
|
+
return out
|
|
186
|
+
|
|
187
|
+
return apply(value, 0)
|
|
188
|
+
|
|
189
|
+
return search
|
|
190
|
+
|
|
191
|
+
|
|
141
192
|
class JMESPathResolutionMonitor(pydantic.BaseModel):
|
|
142
193
|
every: Annotated[int, pydantic.Field(gt=0)] = 1000
|
|
143
194
|
|
|
@@ -192,9 +243,11 @@ def encode(
|
|
|
192
243
|
if expression is None:
|
|
193
244
|
raise ValueError(f"request '{address}' must define query")
|
|
194
245
|
|
|
195
|
-
# `request.query` is relative to a processed observation.
|
|
196
|
-
#
|
|
197
|
-
|
|
246
|
+
# `request.query` is relative to a processed observation. Direct
|
|
247
|
+
# extractors add the outer batch selector; `query(...)` does the same
|
|
248
|
+
# before JMESPath searches `batch`.
|
|
249
|
+
extractor = compile_query_extractor(expression)
|
|
250
|
+
result = extractor(batch) if extractor is not None else query(expression).search(batch)
|
|
198
251
|
if jmespath_resolution_monitor is not None:
|
|
199
252
|
jmespath_resolution_monitor.observe(address=address, expression=expression, result=result)
|
|
200
253
|
|
|
File without changes
|