json2vec 0.4.7__tar.gz → 0.4.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. {json2vec-0.4.7/src/json2vec.egg-info → json2vec-0.4.8}/PKG-INFO +8 -5
  2. {json2vec-0.4.7 → json2vec-0.4.8}/README.md +1 -1
  3. {json2vec-0.4.7 → json2vec-0.4.8}/pyproject.toml +12 -5
  4. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/__init__.py +4 -10
  5. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/architecture/contracts.py +5 -4
  6. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/architecture/mutations.py +54 -9
  7. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/architecture/root.py +46 -24
  8. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/data/datasets/base.py +8 -0
  9. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/data/datasets/custom.py +11 -2
  10. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/data/datasets/polars.py +11 -2
  11. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/data/datasets/streaming.py +11 -2
  12. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/data/iterables.py +57 -4
  13. json2vec-0.4.8/src/json2vec/helpers/hyperparameters.py +0 -0
  14. json2vec-0.4.8/src/json2vec/helpers/optimizers.py +78 -0
  15. json2vec-0.4.8/src/json2vec/helpers/trainer.py +0 -0
  16. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/inference/__init__.py +5 -11
  17. json2vec-0.4.8/src/json2vec/inference/deployment.py +691 -0
  18. json2vec-0.4.8/src/json2vec/structs/__init__.py +0 -0
  19. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/structs/enums.py +0 -1
  20. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/structs/experiment.py +40 -10
  21. json2vec-0.4.8/src/json2vec/structs/structure.py +110 -0
  22. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/structs/tree.py +147 -2
  23. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/tensorfields/base.py +19 -39
  24. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/tensorfields/extensions/category.py +50 -57
  25. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/tensorfields/extensions/number.py +42 -41
  26. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/tensorfields/extensions/set.py +30 -64
  27. json2vec-0.4.8/src/json2vec/tensorfields/shared/__init__.py +80 -0
  28. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/tensorfields/shared/counter.py +3 -1
  29. json2vec-0.4.8/src/json2vec/tensorfields/shared/vocabulary.py +436 -0
  30. {json2vec-0.4.7 → json2vec-0.4.8/src/json2vec.egg-info}/PKG-INFO +8 -5
  31. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec.egg-info/SOURCES.txt +5 -1
  32. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec.egg-info/requires.txt +6 -3
  33. {json2vec-0.4.7 → json2vec-0.4.8}/tests/test_callbacks.py +6 -30
  34. json2vec-0.4.8/tests/test_optimizers.py +78 -0
  35. {json2vec-0.4.7 → json2vec-0.4.8}/tests/test_public_api.py +1 -3
  36. json2vec-0.4.7/src/json2vec/architecture/plot.py +0 -562
  37. json2vec-0.4.7/src/json2vec/inference/deployment.py +0 -422
  38. json2vec-0.4.7/src/json2vec/structs/structure.py +0 -59
  39. json2vec-0.4.7/src/json2vec/tensorfields/shared/__init__.py +0 -12
  40. json2vec-0.4.7/src/json2vec/tensorfields/shared/vocabulary.py +0 -283
  41. {json2vec-0.4.7 → json2vec-0.4.8}/LICENSE +0 -0
  42. {json2vec-0.4.7 → json2vec-0.4.8}/NOTICE +0 -0
  43. {json2vec-0.4.7 → json2vec-0.4.8}/setup.cfg +0 -0
  44. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/architecture/__init__.py +0 -0
  45. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/architecture/attention.py +0 -0
  46. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/architecture/checkpoint.py +0 -0
  47. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/architecture/encoder.py +0 -0
  48. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/architecture/graph.py +0 -0
  49. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/architecture/node.py +0 -0
  50. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/architecture/pool.py +0 -0
  51. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/architecture/rotary.py +0 -0
  52. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/architecture/runtime.py +0 -0
  53. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/data/__init__.py +0 -0
  54. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/data/datasets/__init__.py +0 -0
  55. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/data/processing.py +0 -0
  56. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/distributed.py +0 -0
  57. {json2vec-0.4.7/src/json2vec/structs → json2vec-0.4.8/src/json2vec/helpers}/__init__.py +0 -0
  58. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/inference/callback.py +0 -0
  59. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/logging/__init__.py +0 -0
  60. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/logging/config.py +0 -0
  61. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/logging/epoch.py +0 -0
  62. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/logging/throughput.py +0 -0
  63. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/preprocessors/__init__.py +0 -0
  64. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/preprocessors/base.py +0 -0
  65. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/preprocessors/extensions/__init__.py +0 -0
  66. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/preprocessors/spec.py +0 -0
  67. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/structs/packages.py +0 -0
  68. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/structs/selectors.py +0 -0
  69. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/tensorfields/__init__.py +0 -0
  70. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/tensorfields/extensions/__init__.py +0 -0
  71. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/tensorfields/extensions/dateparts.py +0 -0
  72. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/tensorfields/extensions/entity.py +0 -0
  73. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/tensorfields/extensions/text.py +0 -0
  74. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/tensorfields/extensions/vector.py +0 -0
  75. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec/tensorfields/spec.py +0 -0
  76. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec.egg-info/dependency_links.txt +0 -0
  77. {json2vec-0.4.7 → json2vec-0.4.8}/src/json2vec.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: json2vec
3
- Version: 0.4.7
3
+ Version: 0.4.8
4
4
  Summary: Schema-first PyTorch models for hierarchical / nested / sequence data structures
5
5
  License-Expression: Apache-2.0
6
6
  Requires-Python: >=3.12
@@ -14,7 +14,6 @@ Requires-Dist: pydantic>=2.11.7
14
14
  Requires-Dist: jmespath>=1.0.1
15
15
  Requires-Dist: loguru>=0.7.3
16
16
  Requires-Dist: anytree>=2.13.0
17
- Requires-Dist: ordered-set>=4.1.0
18
17
  Requires-Dist: pyarrow>=21.0.0
19
18
  Requires-Dist: polars>=1.35.2
20
19
  Requires-Dist: numpy>=2.2.6
@@ -22,16 +21,20 @@ Requires-Dist: lightning>=2.6.4
22
21
  Requires-Dist: tensordict>=0.10.0
23
22
  Requires-Dist: torch>=2.7.1
24
23
  Provides-Extra: serving
25
- Requires-Dist: litserve>=0.2.13; extra == "serving"
24
+ Requires-Dist: fastapi>=0.124.0; extra == "serving"
25
+ Requires-Dist: orjson>=3.10.0; extra == "serving"
26
26
  Requires-Dist: pydantic-settings>=2.10.1; extra == "serving"
27
+ Requires-Dist: uvicorn>=0.38.0; extra == "serving"
27
28
  Provides-Extra: text
28
29
  Requires-Dist: transformers>=4.55.0; extra == "text"
29
30
  Provides-Extra: docs
30
- Requires-Dist: litserve>=0.2.13; extra == "docs"
31
+ Requires-Dist: fastapi>=0.124.0; extra == "docs"
31
32
  Requires-Dist: mkdocs-material>=9.6; extra == "docs"
32
33
  Requires-Dist: mkdocs-jupyter>=0.26.3; extra == "docs"
33
34
  Requires-Dist: mkdocstrings[python]>=0.27; extra == "docs"
35
+ Requires-Dist: orjson>=3.10.0; extra == "docs"
34
36
  Requires-Dist: pydantic-settings>=2.10.1; extra == "docs"
37
+ Requires-Dist: uvicorn>=0.38.0; extra == "docs"
35
38
  Dynamic: license-file
36
39
 
37
40
  <h1 align="center"><code>json2vec</code></h1>
@@ -314,7 +317,7 @@ uv sync --extra docs
314
317
  ```
315
318
 
316
319
  The `text` extra installs Hugging Face `transformers`. The `serving` extra
317
- installs LitServe-backed deployment dependencies. The `docs` extra installs the
320
+ installs FastAPI-backed deployment dependencies. The `docs` extra installs the
318
321
  MkDocs toolchain.
319
322
 
320
323
  ## Documentation Map
@@ -278,7 +278,7 @@ uv sync --extra docs
278
278
  ```
279
279
 
280
280
  The `text` extra installs Hugging Face `transformers`. The `serving` extra
281
- installs LitServe-backed deployment dependencies. The `docs` extra installs the
281
+ installs FastAPI-backed deployment dependencies. The `docs` extra installs the
282
282
  MkDocs toolchain.
283
283
 
284
284
  ## Documentation Map
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "json2vec"
3
- version = "0.4.7"
3
+ version = "0.4.8"
4
4
  description = "Schema-first PyTorch models for hierarchical / nested / sequence data structures"
5
5
  readme = "README.md"
6
6
  license = "Apache-2.0"
@@ -13,7 +13,6 @@ dependencies = [
13
13
  "jmespath>=1.0.1",
14
14
  "loguru>=0.7.3",
15
15
  "anytree>=2.13.0",
16
- "ordered-set>=4.1.0",
17
16
  "pyarrow>=21.0.0",
18
17
  "polars>=1.35.2",
19
18
  "numpy>=2.2.6",
@@ -24,30 +23,37 @@ dependencies = [
24
23
 
25
24
  [project.optional-dependencies]
26
25
  serving = [
27
- "litserve>=0.2.13",
26
+ "fastapi>=0.124.0",
27
+ "orjson>=3.10.0",
28
28
  "pydantic-settings>=2.10.1",
29
+ "uvicorn>=0.38.0",
29
30
  ]
30
31
  text = [
31
32
  "transformers>=4.55.0",
32
33
  ]
33
34
  docs = [
34
- "litserve>=0.2.13",
35
+ "fastapi>=0.124.0",
35
36
  "mkdocs-material>=9.6",
36
37
  "mkdocs-jupyter>=0.26.3",
37
38
  "mkdocstrings[python]>=0.27",
39
+ "orjson>=3.10.0",
38
40
  "pydantic-settings>=2.10.1",
41
+ "uvicorn>=0.38.0",
39
42
  ]
40
43
 
41
44
  [dependency-groups]
42
45
  dev = [
43
46
  "ruff>=0.12.12",
44
47
  "pytest>=8.4.1",
48
+ "pytest-xdist>=3.8.0",
45
49
  "ipython>=9.9.0",
46
50
  "ipykernel>=6.29.5",
47
51
  "nbclient>=0.10.2",
48
52
  "nbformat>=5.10.4",
49
- "litserve>=0.2.13",
53
+ "fastapi>=0.124.0",
54
+ "orjson>=3.10.0",
50
55
  "pydantic-settings>=2.10.1",
56
+ "uvicorn>=0.38.0",
51
57
  "ty>=0.0.1a20",
52
58
  "pre-commit>=4.3.0",
53
59
  ]
@@ -66,6 +72,7 @@ include = ["json2vec*"]
66
72
  [tool.pytest.ini_options]
67
73
  testpaths = ["tests"]
68
74
  python_files = ["test_*.py"]
75
+ addopts = ["-n", "auto"]
69
76
 
70
77
  [tool.ruff]
71
78
  line-length = 120
@@ -52,23 +52,19 @@ from json2vec.tensorfields.shared.vocabulary import VocabularySyncCallback
52
52
 
53
53
  if TYPE_CHECKING:
54
54
  from json2vec.inference.deployment import (
55
- API,
56
55
  Accelerator,
57
- BatchItem,
58
56
  Deployment,
59
- ErrorItem,
60
57
  Input,
58
+ JSONBackend,
61
59
  ModelSource,
62
60
  UpdateOperation,
63
61
  )
64
62
 
65
63
  _SERVING_EXPORTS = {
66
- "API",
67
64
  "Accelerator",
68
- "BatchItem",
69
65
  "Deployment",
70
- "ErrorItem",
71
66
  "Input",
67
+ "JSONBackend",
72
68
  "ModelSource",
73
69
  "UpdateOperation",
74
70
  }
@@ -81,7 +77,7 @@ def __getattr__(name: str) -> Any:
81
77
  try:
82
78
  from json2vec.inference import deployment
83
79
  except ModuleNotFoundError as error:
84
- if error.name in {"litserve", "pydantic_settings"}:
80
+ if error.name in {"fastapi", "orjson", "pydantic_settings", "uvicorn"}:
85
81
  raise ModuleNotFoundError(
86
82
  f"json2vec.{name} requires the serving extra; install with `pip install json2vec[serving]`."
87
83
  ) from error
@@ -98,11 +94,9 @@ def __dir__() -> list[str]:
98
94
 
99
95
  __all__ = [
100
96
  "Address",
101
- "API",
102
97
  "Accelerator",
103
98
  "Array",
104
99
  "AttentionMode",
105
- "BatchItem",
106
100
  "Category",
107
101
  "Component",
108
102
  "CustomDataModule",
@@ -111,9 +105,9 @@ __all__ = [
111
105
  "Deployment",
112
106
  "EmbedderBase",
113
107
  "Entity",
114
- "ErrorItem",
115
108
  "Hyperparameters",
116
109
  "Input",
110
+ "JSONBackend",
117
111
  "Leaf",
118
112
  "Metric",
119
113
  "Model",
@@ -3,9 +3,9 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  from collections.abc import Iterator, Mapping
6
- from dataclasses import dataclass, field
7
6
  from typing import TYPE_CHECKING, Any
8
7
 
8
+ import pydantic
9
9
  import torch
10
10
  from tensordict import TensorDict
11
11
 
@@ -34,12 +34,13 @@ ContractSignature = tuple[Any, ...]
34
34
  ContractScope = tuple[str, int, int, ContractSignature]
35
35
 
36
36
 
37
- @dataclass
38
- class ContractScheduler:
37
+ class ContractScheduler(pydantic.BaseModel):
39
38
  """Deterministic backoff scheduler for expensive forward contract checks."""
40
39
 
40
+ model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
41
+
41
42
  periodic_interval: int = 1024
42
- _counts: dict[ContractScope, int] = field(default_factory=dict)
43
+ _counts: dict[ContractScope, int] = pydantic.PrivateAttr(default_factory=dict)
43
44
 
44
45
  def reset(self) -> None:
45
46
  self._counts.clear()
@@ -4,9 +4,9 @@ from __future__ import annotations
4
4
 
5
5
  from collections.abc import Callable, Iterator
6
6
  from contextlib import contextmanager
7
- from dataclasses import dataclass
8
7
  from typing import TYPE_CHECKING, Any
9
8
 
9
+ import pydantic
10
10
  from loguru import logger
11
11
 
12
12
  from json2vec.architecture.graph import ModelGraph
@@ -20,12 +20,18 @@ if TYPE_CHECKING:
20
20
  _MISSING = object()
21
21
 
22
22
 
23
- @dataclass(frozen=True)
24
- class AttributeChange:
23
+ class AttributeChange(pydantic.BaseModel):
24
+ model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
25
+
25
26
  node: Node
26
27
  name: str
27
28
  original: Any
28
29
  definition_attribute: bool
30
+ address: str
31
+ node_name: str
32
+ node_type: str
33
+ changed: Any = _MISSING
34
+ changed_address: Any = _MISSING
29
35
 
30
36
 
31
37
  class SchemaEditor:
@@ -208,6 +214,9 @@ class SchemaEditor:
208
214
  name=name,
209
215
  original=getattr(node, name, _MISSING),
210
216
  definition_attribute=_is_definition_attribute(node, name),
217
+ address=str(node.address),
218
+ node_name=node.name,
219
+ node_type=node.type,
211
220
  )
212
221
  )
213
222
 
@@ -280,29 +289,55 @@ class SchemaEditor:
280
289
 
281
290
  def _log_attribute_changes(self, action: str, changes: list[AttributeChange], *, restored: bool = False) -> None:
282
291
  for change in changes:
292
+ current_address = str(change.node.address)
283
293
  value = change.original if restored else getattr(change.node, change.name, _MISSING)
294
+ if not restored:
295
+ change.changed = value
296
+ change.changed_address = current_address
297
+ previous_value = change.changed if restored else change.original
298
+ previous_address = change.changed_address if restored else change.address
299
+ if previous_address is _MISSING:
300
+ previous_address = change.address
301
+ address_context = (
302
+ current_address if previous_address == current_address else f"{previous_address} -> {current_address}"
303
+ )
304
+ value_text = _format_log_value(value)
305
+ previous_value_text = _format_log_value(previous_value)
284
306
  logger.bind(
285
307
  component="schema_mutation",
286
308
  action=action,
287
- address=str(change.node.address),
288
- node_type=change.node.type,
309
+ address=current_address,
310
+ previous_address=previous_address,
311
+ node_name=change.node.name,
312
+ previous_node_name=change.node_name,
313
+ node_type=change.node_type,
289
314
  attribute=change.name,
290
315
  definition_attribute=change.definition_attribute,
291
- value=_format_log_value(value),
292
- previous_value=_format_log_value(change.original),
293
- ).info("restored schema node attribute" if restored else "mutated schema node attribute")
316
+ value=value_text,
317
+ previous_value=previous_value_text,
318
+ change=f"{change.name}: {previous_value_text} -> {value_text}",
319
+ ).info(
320
+ "{} {}: {} {} -> {}",
321
+ "restored" if restored else "mutated",
322
+ address_context,
323
+ change.name,
324
+ previous_value_text,
325
+ value_text,
326
+ )
294
327
 
295
328
  def _log_node_mutation(self, *, action: str, message: str, node: Node, **kwargs: Any) -> None:
296
329
  extra = {key: str(value.address) if isinstance(value, Node) else value for key, value in kwargs.items()}
330
+ context = _format_node_log_context(node, extra)
297
331
  logger.bind(
298
332
  component="schema_mutation",
299
333
  action=action,
300
334
  address=str(node.address),
301
335
  node_type=node.type,
336
+ node_name=node.name,
302
337
  attribute=None,
303
338
  definition_attribute=None,
304
339
  **extra,
305
- ).info(message)
340
+ ).info("{} {}", message, context)
306
341
 
307
342
 
308
343
  def _has_node_attribute(node: Node, name: str) -> bool:
@@ -321,3 +356,13 @@ def _format_log_value(value: Any) -> str:
321
356
 
322
357
  text = repr(value)
323
358
  return text if len(text) <= 160 else f"{text[:157]}..."
359
+
360
+
361
+ def _format_node_log_context(node: Node, extra: dict[str, Any]) -> str:
362
+ parts = [str(node.address)]
363
+ if parent := extra.get("parent"):
364
+ parts.append(f"under {parent}")
365
+ if "descendants" in extra:
366
+ parts.append(f"descendants={extra['descendants']}")
367
+
368
+ return " ".join(parts)
@@ -13,15 +13,16 @@ from beartype import beartype
13
13
  from lightning.pytorch import Callback
14
14
  from lightning.pytorch.callbacks import ModelCheckpoint
15
15
  from loguru import logger
16
+ from rich.text import Text
16
17
  from tensordict import TensorDict
17
18
 
18
19
  from json2vec.architecture.checkpoint import CheckpointState
19
20
  from json2vec.architecture.contracts import ContractScheduler
20
21
  from json2vec.architecture.graph import ModelGraph
21
22
  from json2vec.architecture.mutations import SchemaEditor
22
- from json2vec.architecture.plot import PlotMode
23
23
  from json2vec.architecture.runtime import ModelRuntime, Postprocessor, Preprocessor, step
24
24
  from json2vec.data.datasets.base import EncodedBatch, EncodedInput
25
+ from json2vec.logging.throughput import ThroughputLogger
25
26
  from json2vec.structs.enums import AttentionMode, Strata
26
27
  from json2vec.structs.experiment import (
27
28
  Hyperparameters,
@@ -30,7 +31,7 @@ from json2vec.structs.experiment import (
30
31
  SchemaField,
31
32
  )
32
33
  from json2vec.structs.packages import Prediction
33
- from json2vec.structs.tree import Address, Node, Rate
34
+ from json2vec.structs.tree import Address, Node, Rate, Renderable
34
35
  from json2vec.tensorfields.base import TENSORFIELDS, Plugin, TensorFieldBase
35
36
 
36
37
  OptimizerConfig = torch.optim.Optimizer | Callable[["Model"], torch.optim.Optimizer]
@@ -137,12 +138,12 @@ class RollbackCheckpoint(ModelCheckpoint):
137
138
  ).info("rolled back Model to best checkpoint")
138
139
 
139
140
 
140
- class Model(lit.LightningModule):
141
+ class Model(lit.LightningModule, Renderable):
141
142
  """Neural model generated from a `json2vec` schema tree.
142
143
 
143
144
  `Model` owns the schema hyperparameters, tensorfield embedders, array
144
145
  encoders, decoders, and convenience methods for prediction, checkpointing,
145
- plotting, and schema mutation.
146
+ schema display and mutation.
146
147
 
147
148
  Example:
148
149
  ```python
@@ -265,6 +266,45 @@ class Model(lit.LightningModule):
265
266
  self._contract_generation += 1
266
267
  self._contract_scheduler.reset()
267
268
 
269
+ def __rich_console__(self, console, options):
270
+ parameters = sum(parameter.numel() for parameter in self.parameters())
271
+ heading = Text()
272
+ heading.append(type(self).__name__, style=self.RICH_NAME_STYLE)
273
+ heading.append(" ")
274
+ heading.append("[model]", style=self.RICH_TYPE_STYLE)
275
+ for name, value in (
276
+ ("batch_size", self.batch_size),
277
+ ("d_model", self.hyperparameters.d_model),
278
+ ("parameters", f"{parameters:,}"),
279
+ ("arrays", len(self.hyperparameters.arrays)),
280
+ ("fields", len(self.hyperparameters.active_requests)),
281
+ ("targets", len(self.hyperparameters.target)),
282
+ ("embeds", len(self.hyperparameters.embed)),
283
+ ):
284
+ heading.append(" ")
285
+ heading.append(f"{name}=", style="dim")
286
+ heading.append(str(value), style="cyan")
287
+ yield heading
288
+
289
+ lines = list(self.hyperparameters.fields.__rich_console__(console, options))
290
+ if not lines:
291
+ return
292
+ first = Text()
293
+ first.append("`-- ", style=self.RICH_TREE_STYLE)
294
+ if isinstance(lines[0], Text):
295
+ first.append_text(lines[0])
296
+ else:
297
+ first.append(str(lines[0]))
298
+ yield first
299
+ for line in lines[1:]:
300
+ nested = Text()
301
+ nested.append(" ", style=self.RICH_TREE_STYLE)
302
+ if isinstance(line, Text):
303
+ nested.append_text(line)
304
+ else:
305
+ nested.append(str(line))
306
+ yield nested
307
+
268
308
  def select(
269
309
  self,
270
310
  *predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
@@ -382,6 +422,8 @@ class Model(lit.LightningModule):
382
422
  callbacks.append(RuntimePlacementCallback())
383
423
  if MutationLockCallback not in attached_callback_types:
384
424
  callbacks.append(MutationLockCallback())
425
+ if ThroughputLogger not in attached_callback_types:
426
+ callbacks.append(ThroughputLogger())
385
427
 
386
428
  for request in self.hyperparameters.active_requests.values():
387
429
  plugin: Plugin = TENSORFIELDS[request.type]
@@ -438,26 +480,6 @@ class Model(lit.LightningModule):
438
480
  if hasattr(node, "embedder") and hasattr(node.embedder, "interprocess_encoding_context")
439
481
  }
440
482
 
441
- def plot(
442
- self,
443
- address: Address | str | None = None,
444
- detail: bool = False,
445
- out: str | Path | None = None,
446
- mode: PlotMode = "schema",
447
- ) -> None:
448
- """Print a Rich model visualization.
449
-
450
- Args:
451
- address: Optional subtree address to render.
452
- detail: Include tensorfield-specific detail sections.
453
- out: Optional output path for the rendered console text.
454
- mode: Plot mode. Supported values are `schema`, `state`, `flow`,
455
- and `debug`.
456
- """
457
- from json2vec.architecture.plot import plot
458
-
459
- return plot(module=self, address=address, detail=detail, out=out, mode=mode)
460
-
461
483
  @beartype
462
484
  def save(self, pathname: str | Path) -> str | Path:
463
485
  """Save model weights and schema hyperparameters to a checkpoint."""
@@ -90,5 +90,13 @@ def _is_assigned_to_worker(shard_key: str, worker_id: int, num_workers: int) ->
90
90
  return owner == worker_id
91
91
 
92
92
 
93
+ def share_interprocess_encoding_context(context: InterprocessEncodingContext) -> None:
94
+ """Opt encoding context resources into multiprocessing-safe storage."""
95
+ for field_context in context.values():
96
+ share = getattr(field_context, "share", None)
97
+ if callable(share):
98
+ share()
99
+
100
+
93
101
  def identity(data: Any) -> Any:
94
102
  return data
@@ -22,6 +22,7 @@ from json2vec.data.datasets.base import (
22
22
  SampleRate,
23
23
  StrataMap,
24
24
  identity,
25
+ share_interprocess_encoding_context,
25
26
  )
26
27
  from json2vec.data.iterables import (
27
28
  JMESPathResolutionMonitor,
@@ -299,15 +300,23 @@ class CustomDataModule(lit.LightningDataModule):
299
300
  return None
300
301
  raise ValueError(f"no dataset configured for strata: {strata}")
301
302
 
303
+ workers = self.num_workers[strata]
304
+ if workers is None:
305
+ workers = os.cpu_count() or 0
306
+
307
+ interprocess_encoding_context = self.interprocess_encoding_context
308
+ if strata == Strata.train and workers > 0:
309
+ share_interprocess_encoding_context(interprocess_encoding_context)
310
+
302
311
  return custom_dataloader(
303
312
  hyperparameters=self.hyperparameters,
304
313
  dataset=self.datasets[strata],
305
314
  preprocessor=self.preprocessor,
306
315
  preprocessor_kwargs=self.preprocessor_kwargs,
307
- interprocess_encoding_context=self.interprocess_encoding_context,
316
+ interprocess_encoding_context=interprocess_encoding_context,
308
317
  batch_size=self.batch_size,
309
318
  strata=strata,
310
- num_workers=self.num_workers[strata],
319
+ num_workers=workers,
311
320
  persistent_workers=self.persistent_workers[strata],
312
321
  pin_memory=self.pin_memory[strata],
313
322
  observation_buffer_size=self.observation_buffer_size[strata],
@@ -26,6 +26,7 @@ from json2vec.data.datasets.base import (
26
26
  _is_assigned_to_worker,
27
27
  _worker_identity,
28
28
  identity,
29
+ share_interprocess_encoding_context,
29
30
  )
30
31
  from json2vec.data.iterables import (
31
32
  JMESPathResolutionMonitor,
@@ -355,15 +356,23 @@ class PolarsDataModule(lit.LightningDataModule):
355
356
  return None
356
357
  raise ValueError(f"no dataframe configured for strata: {strata}")
357
358
 
359
+ workers = self.num_workers[strata]
360
+ if workers is None:
361
+ workers = os.cpu_count() or 0
362
+
363
+ interprocess_encoding_context = self.interprocess_encoding_context
364
+ if strata == Strata.train and workers > 0:
365
+ share_interprocess_encoding_context(interprocess_encoding_context)
366
+
358
367
  return polars_dataloader(
359
368
  hyperparameters=self.hyperparameters,
360
369
  dataframe=self.dataframes[strata],
361
370
  preprocessor=self.preprocessor,
362
371
  preprocessor_kwargs=self.preprocessor_kwargs,
363
- interprocess_encoding_context=self.interprocess_encoding_context,
372
+ interprocess_encoding_context=interprocess_encoding_context,
364
373
  batch_size=self.batch_size,
365
374
  strata=strata,
366
- num_workers=self.num_workers[strata],
375
+ num_workers=workers,
367
376
  persistent_workers=self.persistent_workers[strata],
368
377
  pin_memory=self.pin_memory[strata],
369
378
  sharding=self.sharding[strata],
@@ -31,6 +31,7 @@ from json2vec.data.datasets.base import (
31
31
  _is_assigned_to_worker,
32
32
  _worker_identity,
33
33
  identity,
34
+ share_interprocess_encoding_context,
34
35
  )
35
36
  from json2vec.data.iterables import (
36
37
  JMESPathResolutionMonitor,
@@ -500,6 +501,14 @@ class StreamingDataModule(lit.LightningDataModule):
500
501
  global_rank = getattr(trainer, "global_rank", None)
501
502
  world_size = getattr(trainer, "world_size", None)
502
503
 
504
+ workers = self.num_workers[strata]
505
+ if workers is None:
506
+ workers = os.cpu_count() or 0
507
+
508
+ interprocess_encoding_context = self.interprocess_encoding_context
509
+ if strata == Strata.train and workers > 0:
510
+ share_interprocess_encoding_context(interprocess_encoding_context)
511
+
503
512
  return dataloader(
504
513
  hyperparameters=self.hyperparameters,
505
514
  root=self.root,
@@ -507,10 +516,10 @@ class StreamingDataModule(lit.LightningDataModule):
507
516
  pattern=pattern,
508
517
  preprocessor=self.preprocessor,
509
518
  preprocessor_kwargs=self.preprocessor_kwargs,
510
- interprocess_encoding_context=self.interprocess_encoding_context,
519
+ interprocess_encoding_context=interprocess_encoding_context,
511
520
  batch_size=self.batch_size,
512
521
  strata=strata,
513
- num_workers=self.num_workers[strata],
522
+ num_workers=workers,
514
523
  persistent_workers=self.persistent_workers[strata],
515
524
  pin_memory=self.pin_memory[strata],
516
525
  sharding=self.sharding[strata],
@@ -4,8 +4,9 @@ from __future__ import annotations
4
4
 
5
5
  import inspect
6
6
  import random
7
+ import re
7
8
  from collections import Counter
8
- from collections.abc import Iterable, Iterator
9
+ from collections.abc import Callable, Iterable, Iterator
9
10
  from functools import cache
10
11
  from typing import Annotated, Any, TypeVar, cast
11
12
 
@@ -138,6 +139,56 @@ def query(expression: str) -> jmespath.parser.ParsedResult:
138
139
  return jmespath.compile(expression=f"[*]{expression}")
139
140
 
140
141
 
142
+ @cache
143
+ def compile_query_extractor(expression: str) -> Callable[[Any], Any] | None:
144
+ """Compile a direct extractor for simple JMESPath projection queries."""
145
+ if not expression.startswith("[*]"):
146
+ return None
147
+
148
+ operations: list[tuple[str, str | None]] = [("projection", None), ("projection", None)]
149
+ index = 3
150
+ while index < len(expression):
151
+ if expression.startswith("[*]", index):
152
+ operations.append(("projection", None))
153
+ index += 3
154
+ continue
155
+
156
+ if not expression.startswith(".", index):
157
+ return None
158
+
159
+ match = re.match(r"\.([A-Za-z_][A-Za-z0-9_]*)", expression[index:])
160
+ if match is None:
161
+ return None
162
+
163
+ operations.append(("field", match.group(1)))
164
+ index += len(match.group(0))
165
+
166
+ def search(value: Any) -> Any:
167
+ def apply(item: Any, operation_index: int) -> Any:
168
+ if operation_index == len(operations):
169
+ return item
170
+
171
+ operation, key = operations[operation_index]
172
+ if operation == "field":
173
+ if not isinstance(item, dict) or key not in item:
174
+ return None
175
+ return apply(item[key], operation_index + 1)
176
+
177
+ if not isinstance(item, list):
178
+ return None
179
+
180
+ out = []
181
+ for child in item:
182
+ result = apply(child, operation_index + 1)
183
+ if result is not None:
184
+ out.append(result)
185
+ return out
186
+
187
+ return apply(value, 0)
188
+
189
+ return search
190
+
191
+
141
192
  class JMESPathResolutionMonitor(pydantic.BaseModel):
142
193
  every: Annotated[int, pydantic.Field(gt=0)] = 1000
143
194
 
@@ -192,9 +243,11 @@ def encode(
192
243
  if expression is None:
193
244
  raise ValueError(f"request '{address}' must define query")
194
245
 
195
- # `request.query` is relative to a processed observation. `query(...)`
196
- # adds the outer batch selector before JMESPath searches `batch`.
197
- result = query(expression).search(batch)
246
+ # `request.query` is relative to a processed observation. Direct
247
+ # extractors add the outer batch selector; `query(...)` does the same
248
+ # before JMESPath searches `batch`.
249
+ extractor = compile_query_extractor(expression)
250
+ result = extractor(batch) if extractor is not None else query(expression).search(batch)
198
251
  if jmespath_resolution_monitor is not None:
199
252
  jmespath_resolution_monitor.observe(address=address, expression=expression, result=result)
200
253
 
File without changes