json2vec 0.4.2__tar.gz → 0.4.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {json2vec-0.4.2/src/json2vec.egg-info → json2vec-0.4.3}/PKG-INFO +1 -1
- {json2vec-0.4.2 → json2vec-0.4.3}/pyproject.toml +1 -1
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/architecture/encoder.py +1 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/architecture/node.py +2 -2
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/architecture/plot.py +14 -11
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/architecture/root.py +25 -8
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/architecture/rotary.py +1 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/data/datasets/polars.py +42 -14
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/data/datasets/streaming.py +39 -12
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/data/processing.py +1 -1
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/inference/callback.py +2 -4
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/inference/deployment.py +6 -7
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/logging/config.py +1 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/preprocessors/base.py +4 -6
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/structs/enums.py +8 -8
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/structs/experiment.py +12 -8
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/structs/packages.py +7 -12
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/structs/structure.py +8 -5
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/structs/tree.py +0 -2
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/tensorfields/base.py +61 -13
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/tensorfields/extensions/category.py +17 -23
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/tensorfields/extensions/dateparts.py +3 -7
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/tensorfields/extensions/entity.py +11 -11
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/tensorfields/extensions/number.py +12 -23
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/tensorfields/extensions/set.py +7 -12
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/tensorfields/extensions/text.py +4 -3
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/tensorfields/extensions/vector.py +1 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/tensorfields/shared/counter.py +17 -8
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/tensorfields/shared/vocabulary.py +15 -18
- {json2vec-0.4.2 → json2vec-0.4.3/src/json2vec.egg-info}/PKG-INFO +1 -1
- {json2vec-0.4.2 → json2vec-0.4.3}/LICENSE +0 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/NOTICE +0 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/README.md +0 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/setup.cfg +0 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/__init__.py +0 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/architecture/__init__.py +0 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/architecture/attention.py +0 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/architecture/pool.py +0 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/data/__init__.py +0 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/data/datasets/__init__.py +0 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/data/datasets/base.py +0 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/data/iterables.py +0 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/distributed.py +0 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/inference/__init__.py +0 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/logging/__init__.py +0 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/logging/epoch.py +0 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/logging/throughput.py +0 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/preprocessors/__init__.py +0 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/preprocessors/extensions/__init__.py +0 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/preprocessors/spec.py +0 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/structs/__init__.py +0 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/tensorfields/__init__.py +0 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/tensorfields/extensions/__init__.py +0 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/tensorfields/shared/__init__.py +0 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec/tensorfields/spec.py +0 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec.egg-info/SOURCES.txt +0 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec.egg-info/dependency_links.txt +0 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec.egg-info/requires.txt +0 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/src/json2vec.egg-info/top_level.txt +0 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/tests/test_callbacks.py +0 -0
- {json2vec-0.4.2 → json2vec-0.4.3}/tests/test_public_api.py +0 -0
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import inspect
|
|
4
|
-
from typing import TYPE_CHECKING
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
@@ -25,7 +25,7 @@ class NodeModule(torch.nn.Module):
|
|
|
25
25
|
if address in hyperparameters.requests:
|
|
26
26
|
request: Node = hyperparameters.requests[address]
|
|
27
27
|
plugin: Plugin = TENSORFIELDS[request.type]
|
|
28
|
-
embedder_kwargs = dict(hyperparameters=hyperparameters, address=address)
|
|
28
|
+
embedder_kwargs: dict[str, Any] = dict(hyperparameters=hyperparameters, address=address)
|
|
29
29
|
if "batch_size" in inspect.signature(plugin.Embedder.__init__).parameters:
|
|
30
30
|
embedder_kwargs["batch_size"] = batch_size
|
|
31
31
|
|
|
@@ -49,7 +49,7 @@ def plot(
|
|
|
49
49
|
) -> None:
|
|
50
50
|
"""Print a Rich model visualization and optionally write it as text."""
|
|
51
51
|
renderable = build_plot(module=module, address=address, detail=detail, mode=mode)
|
|
52
|
-
Console(width=PLOT_WIDTH
|
|
52
|
+
Console(width=PLOT_WIDTH).print(renderable)
|
|
53
53
|
|
|
54
54
|
if out is None:
|
|
55
55
|
return
|
|
@@ -112,7 +112,9 @@ def render_flow_plot(module: "Model", address: Address | str | None) -> Renderab
|
|
|
112
112
|
table.add_column("Count", justify="right")
|
|
113
113
|
table.add_row("JSON", "Raw nested records enter with the shape described by the schema.", "")
|
|
114
114
|
table.add_row("Tensorfields", "Typed requests read values with JMESPath queries.", str(len(fields)))
|
|
115
|
-
table.add_row(
|
|
115
|
+
table.add_row(
|
|
116
|
+
"Encoders", "Array nodes pool child embeddings into parent contexts.", str(len(hyperparameters.arrays))
|
|
117
|
+
)
|
|
116
118
|
table.add_row("Targets", "Target fields produce supervised predictions.", str(target_count))
|
|
117
119
|
table.add_row("Embeddings", "Selected nodes expose reusable embeddings.", str(embed_count))
|
|
118
120
|
|
|
@@ -272,11 +274,7 @@ def node_metadata_keys(node: Node, values: dict[str, Any], state_focus: bool) ->
|
|
|
272
274
|
|
|
273
275
|
|
|
274
276
|
def should_hide_metadata(key: str, value: Any) -> bool:
|
|
275
|
-
return (
|
|
276
|
-
(key == "active" and value is True)
|
|
277
|
-
or (key == "embed" and value is False)
|
|
278
|
-
or key == "description"
|
|
279
|
-
)
|
|
277
|
+
return (key == "active" and value is True) or (key == "embed" and value is False) or key == "description"
|
|
280
278
|
|
|
281
279
|
|
|
282
280
|
def format_metadata_value(value: Any) -> str:
|
|
@@ -317,7 +315,9 @@ def format_detail_inline(value: Any) -> str:
|
|
|
317
315
|
return truncate(value, width=100)
|
|
318
316
|
|
|
319
317
|
if isinstance(value, list):
|
|
320
|
-
return truncate(
|
|
318
|
+
return truncate(
|
|
319
|
+
format_inline_sequence(value) or pformat(value, compact=True, sort_dicts=False, width=88), width=100
|
|
320
|
+
)
|
|
321
321
|
|
|
322
322
|
return truncate(pformat(value, compact=True, sort_dicts=False, width=88), width=100)
|
|
323
323
|
|
|
@@ -332,7 +332,9 @@ def summarize_value(value: Any, max_items: int = 8) -> Any:
|
|
|
332
332
|
if isinstance(value, list):
|
|
333
333
|
if len(value) <= max_items:
|
|
334
334
|
return [summarize_value(item, max_items=max_items) for item in value]
|
|
335
|
-
return [summarize_value(item, max_items=max_items) for item in value[:max_items]] + [
|
|
335
|
+
return [summarize_value(item, max_items=max_items) for item in value[:max_items]] + [
|
|
336
|
+
f"... {len(value) - max_items} more"
|
|
337
|
+
]
|
|
336
338
|
|
|
337
339
|
return value
|
|
338
340
|
|
|
@@ -410,8 +412,9 @@ def format_compact_number(value: Any) -> str:
|
|
|
410
412
|
|
|
411
413
|
def resolve_node(hyperparameters: "Hyperparameters", address: Address | str) -> Node:
|
|
412
414
|
key = Address(str(address))
|
|
413
|
-
leaves = {node.address: node for node in hyperparameters.descendants if isinstance(node, Leaf)}
|
|
414
|
-
nodes: dict[Address, Node] = hyperparameters.arrays
|
|
415
|
+
leaves: dict[Address, Node] = {node.address: node for node in hyperparameters.descendants if isinstance(node, Leaf)}
|
|
416
|
+
nodes: dict[Address, Node] = dict(hyperparameters.arrays)
|
|
417
|
+
nodes.update(leaves)
|
|
415
418
|
|
|
416
419
|
if key not in nodes:
|
|
417
420
|
raise ValueError(f"address '{address}' was not found in the hyperparameters")
|
|
@@ -18,6 +18,7 @@ from tensordict import TensorDict
|
|
|
18
18
|
|
|
19
19
|
from json2vec.architecture.encoder import ArrayEncoder
|
|
20
20
|
from json2vec.architecture.node import NodeModule
|
|
21
|
+
from json2vec.architecture.plot import PlotMode
|
|
21
22
|
from json2vec.data.datasets.base import EncodedBatch
|
|
22
23
|
from json2vec.data.iterables import encode, mock
|
|
23
24
|
from json2vec.structs.enums import AttentionMode, Metric, Strata, TensorKey
|
|
@@ -296,7 +297,6 @@ class Model(lit.LightningModule):
|
|
|
296
297
|
optimizer: OptimizerConfig | None = None,
|
|
297
298
|
scheduler: SchedulerConfig | None = None,
|
|
298
299
|
):
|
|
299
|
-
|
|
300
300
|
super().__init__()
|
|
301
301
|
if batch_size <= 0:
|
|
302
302
|
raise ValueError("batch_size must be > 0")
|
|
@@ -330,11 +330,16 @@ class Model(lit.LightningModule):
|
|
|
330
330
|
self.example_input_array = mock(hyperparameters=self.hyperparameters, batch_size=self.batch_size)
|
|
331
331
|
|
|
332
332
|
def _rebuild(self) -> None:
|
|
333
|
+
self.hyperparameters._clear_tree_caches()
|
|
334
|
+
was_training = self.training
|
|
335
|
+
device = self.device
|
|
333
336
|
previous = {
|
|
334
337
|
name: value.detach().clone() if isinstance(value, torch.Tensor) else deepcopy(value)
|
|
335
338
|
for name, value in self.state_dict().items()
|
|
336
339
|
}
|
|
337
340
|
self._build()
|
|
341
|
+
if isinstance(device, torch.device):
|
|
342
|
+
self.to(device=device)
|
|
338
343
|
current = self.state_dict()
|
|
339
344
|
compatible = {}
|
|
340
345
|
for name, value in previous.items():
|
|
@@ -351,6 +356,7 @@ class Model(lit.LightningModule):
|
|
|
351
356
|
compatible[name] = value
|
|
352
357
|
|
|
353
358
|
self.load_state_dict(compatible, strict=False)
|
|
359
|
+
self.train(was_training)
|
|
354
360
|
|
|
355
361
|
def select(
|
|
356
362
|
self,
|
|
@@ -372,6 +378,7 @@ class Model(lit.LightningModule):
|
|
|
372
378
|
allow_extra: bool = False,
|
|
373
379
|
include_root: bool = True,
|
|
374
380
|
validate: bool = True,
|
|
381
|
+
use_cache: bool = False,
|
|
375
382
|
**values: Any,
|
|
376
383
|
) -> None:
|
|
377
384
|
"""Mutate selected schema nodes and rebuild compatible modules.
|
|
@@ -386,6 +393,8 @@ class Model(lit.LightningModule):
|
|
|
386
393
|
allow unknown fields.
|
|
387
394
|
include_root: Include the root node in predicate matching.
|
|
388
395
|
validate: Validate each node after applying candidate values.
|
|
396
|
+
use_cache: Permit cached selector results. Mutations default this to
|
|
397
|
+
`False` so updates always evaluate against current schema state.
|
|
389
398
|
**values: Schema attributes to update.
|
|
390
399
|
"""
|
|
391
400
|
self._assert_mutation_allowed("update")
|
|
@@ -395,6 +404,7 @@ class Model(lit.LightningModule):
|
|
|
395
404
|
allow_extra=allow_extra,
|
|
396
405
|
include_root=include_root,
|
|
397
406
|
validate=validate,
|
|
407
|
+
use_cache=use_cache,
|
|
398
408
|
**values,
|
|
399
409
|
)
|
|
400
410
|
self._rebuild()
|
|
@@ -468,6 +478,7 @@ class Model(lit.LightningModule):
|
|
|
468
478
|
allow_extra: bool = False,
|
|
469
479
|
include_root: bool = True,
|
|
470
480
|
validate: bool = True,
|
|
481
|
+
use_cache: bool = False,
|
|
471
482
|
**values: Any,
|
|
472
483
|
) -> Iterator[None]:
|
|
473
484
|
"""Temporarily mutate selected schema nodes and keep runtime modules synchronized."""
|
|
@@ -479,6 +490,7 @@ class Model(lit.LightningModule):
|
|
|
479
490
|
allow_extra=allow_extra,
|
|
480
491
|
include_root=include_root,
|
|
481
492
|
validate=validate,
|
|
493
|
+
use_cache=use_cache,
|
|
482
494
|
**values,
|
|
483
495
|
):
|
|
484
496
|
self._rebuild()
|
|
@@ -563,7 +575,7 @@ class Model(lit.LightningModule):
|
|
|
563
575
|
address: Address | str | None = None,
|
|
564
576
|
detail: bool = False,
|
|
565
577
|
out: str | Path | None = None,
|
|
566
|
-
mode:
|
|
578
|
+
mode: PlotMode = "schema",
|
|
567
579
|
) -> None:
|
|
568
580
|
"""Print a Rich model visualization.
|
|
569
581
|
|
|
@@ -602,7 +614,8 @@ class Model(lit.LightningModule):
|
|
|
602
614
|
|
|
603
615
|
embedder: EmbedderBase = self.nodes[address].embedder
|
|
604
616
|
embedding: Parcel = embedder(tensorfield)
|
|
605
|
-
|
|
617
|
+
if embedding.destination is not None:
|
|
618
|
+
processed[embedding.destination].append(embedding)
|
|
606
619
|
outgoing[embedding.origin] = embedding
|
|
607
620
|
|
|
608
621
|
if address in self.hyperparameters.embed:
|
|
@@ -617,7 +630,8 @@ class Model(lit.LightningModule):
|
|
|
617
630
|
|
|
618
631
|
encoder: ArrayEncoder = self.nodes[address].encoder
|
|
619
632
|
encoding: Parcel = encoder(processed[address])
|
|
620
|
-
|
|
633
|
+
if encoding.destination is not None:
|
|
634
|
+
processed[encoding.destination].append(encoding)
|
|
621
635
|
outgoing[encoding.origin] = encoding
|
|
622
636
|
|
|
623
637
|
if address in self.hyperparameters.embed:
|
|
@@ -696,7 +710,6 @@ class Model(lit.LightningModule):
|
|
|
696
710
|
def write(
|
|
697
711
|
self, predictions: list[Prediction]
|
|
698
712
|
) -> tuple[dict[Address, dict[str, Any]], dict[Address, dict[str, Any]]]:
|
|
699
|
-
|
|
700
713
|
supervised: dict[Address, dict[str, Any]] = {}
|
|
701
714
|
embeddings: dict[Address, dict[str, Any]] = {}
|
|
702
715
|
|
|
@@ -738,17 +751,21 @@ class Model(lit.LightningModule):
|
|
|
738
751
|
|
|
739
752
|
if preprocess is not None:
|
|
740
753
|
observations: EncodedBatch = []
|
|
741
|
-
for request in batch:
|
|
754
|
+
for request in cast(list[dict[str, Any]], batch):
|
|
742
755
|
observation = preprocess(request)
|
|
743
756
|
if not isinstance(observation, dict):
|
|
744
757
|
raise TypeError(f"preprocessor must return a dict object, got {type(observation).__name__}")
|
|
745
758
|
|
|
746
759
|
observations.append([observation])
|
|
747
760
|
|
|
748
|
-
|
|
761
|
+
encoded_batch = observations
|
|
762
|
+
elif batch and isinstance(batch[0], dict):
|
|
763
|
+
encoded_batch = [[request] for request in cast(list[dict[str, Any]], batch)]
|
|
764
|
+
else:
|
|
765
|
+
encoded_batch = cast(EncodedBatch, batch)
|
|
749
766
|
|
|
750
767
|
inputs = encode(
|
|
751
|
-
batch=
|
|
768
|
+
batch=encoded_batch,
|
|
752
769
|
hyperparameters=self.hyperparameters,
|
|
753
770
|
strata=Strata.predict,
|
|
754
771
|
interprocess_encoding_context=self.interprocess_encoding_context,
|
|
@@ -13,6 +13,7 @@ class RotaryEmbedding(torch.nn.Module):
|
|
|
13
13
|
self.base = base
|
|
14
14
|
|
|
15
15
|
index = torch.arange(0, self.rotary_dim, 2, dtype=torch.float32)
|
|
16
|
+
self.inv_freq: torch.Tensor
|
|
16
17
|
self.register_buffer("inv_freq", base ** (-index / self.rotary_dim), persistent=False)
|
|
17
18
|
|
|
18
19
|
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
@@ -57,7 +57,7 @@ def _dataframes_by_strata(dataframe: pl.DataFrame | DataFrameMap) -> dict[Strata
|
|
|
57
57
|
return {strata: dataframe for strata in Strata}
|
|
58
58
|
|
|
59
59
|
normalized: dict[Strata, pl.DataFrame] = {}
|
|
60
|
-
for key, frame in dataframe.items():
|
|
60
|
+
for key, frame in cast(DataFrameMap, dataframe).items():
|
|
61
61
|
if not isinstance(frame, pl.DataFrame):
|
|
62
62
|
raise TypeError(f"dataframe for strata '{key}' must be a polars DataFrame")
|
|
63
63
|
normalized[Strata.normalize(key)] = frame
|
|
@@ -79,7 +79,7 @@ def observe_polars(
|
|
|
79
79
|
world_size: int | None = None,
|
|
80
80
|
) -> Iterator[RawObservation]:
|
|
81
81
|
if replacement:
|
|
82
|
-
rows =
|
|
82
|
+
rows = dataframe.to_dicts()
|
|
83
83
|
if not rows:
|
|
84
84
|
raise ValueError("no dataframe rows available for replacement sampling")
|
|
85
85
|
|
|
@@ -103,7 +103,7 @@ def observe_polars(
|
|
|
103
103
|
worker_id=worker_id,
|
|
104
104
|
num_workers=num_workers,
|
|
105
105
|
):
|
|
106
|
-
yield
|
|
106
|
+
yield row
|
|
107
107
|
return
|
|
108
108
|
|
|
109
109
|
for chunk_index, offset in enumerate(range(0, dataframe.height, chunk_batch_size)):
|
|
@@ -115,7 +115,7 @@ def observe_polars(
|
|
|
115
115
|
):
|
|
116
116
|
continue
|
|
117
117
|
|
|
118
|
-
yield from
|
|
118
|
+
yield from dataframe.slice(offset, chunk_batch_size).to_dicts()
|
|
119
119
|
|
|
120
120
|
|
|
121
121
|
class PolarsBatchDataset(IterableDataset):
|
|
@@ -281,7 +281,6 @@ class PolarsDataModule(lit.LightningDataModule):
|
|
|
281
281
|
else:
|
|
282
282
|
dataframes = _dataframes_by_strata(dataframe)
|
|
283
283
|
|
|
284
|
-
self.hyperparameters = model.hyperparameters
|
|
285
284
|
self.dataframes = dataframes
|
|
286
285
|
self.preprocessor = PreprocessorConfig.normalize(preprocessor)
|
|
287
286
|
self.preprocessor_kwargs = dict(kwargs)
|
|
@@ -289,26 +288,55 @@ class PolarsDataModule(lit.LightningDataModule):
|
|
|
289
288
|
self._model_ref = weakref.ref(model)
|
|
290
289
|
except TypeError:
|
|
291
290
|
self._model_ref = None
|
|
291
|
+
self._hyperparameters = model.hyperparameters
|
|
292
292
|
self._interprocess_encoding_context = model.interprocess_encoding_context
|
|
293
|
-
self.
|
|
293
|
+
self._batch_size = model.batch_size
|
|
294
294
|
self.num_workers = Strata.expand(num_workers, default=None)
|
|
295
295
|
self.persistent_workers = Strata.expand(persistent_workers, default=True)
|
|
296
296
|
self.pin_memory = Strata.expand(pin_memory, default=True)
|
|
297
297
|
self.sharding = ShardingStrategy.expand(sharding, default=ShardingStrategy.chunk)
|
|
298
298
|
self.chunk_batch_size = Strata.expand(chunk_batch_size, default=4096)
|
|
299
299
|
self.observation_buffer_size = Strata.expand(observation_buffer_size, default=1)
|
|
300
|
-
self.sample_rate = {
|
|
301
|
-
strata: float(rate)
|
|
302
|
-
for strata, rate in Strata.expand(sample_rate, default=1.0).items()
|
|
303
|
-
}
|
|
300
|
+
self.sample_rate = {strata: float(rate) for strata, rate in Strata.expand(sample_rate, default=1.0).items()}
|
|
304
301
|
self.replacement = Strata.expand(replacement, default=False)
|
|
305
302
|
|
|
303
|
+
def _model(self) -> Model | None:
|
|
304
|
+
if self._model_ref is None:
|
|
305
|
+
return None
|
|
306
|
+
|
|
307
|
+
return self._model_ref()
|
|
308
|
+
|
|
309
|
+
@property
|
|
310
|
+
def hyperparameters(self) -> Hyperparameters:
|
|
311
|
+
model = self._model()
|
|
312
|
+
if model is not None:
|
|
313
|
+
return model.hyperparameters
|
|
314
|
+
|
|
315
|
+
return self._hyperparameters
|
|
316
|
+
|
|
317
|
+
@hyperparameters.setter
|
|
318
|
+
def hyperparameters(self, hyperparameters: Hyperparameters) -> None:
|
|
319
|
+
self._model_ref = None
|
|
320
|
+
self._hyperparameters = hyperparameters
|
|
321
|
+
|
|
322
|
+
@property
|
|
323
|
+
def batch_size(self) -> int:
|
|
324
|
+
model = self._model()
|
|
325
|
+
if model is not None:
|
|
326
|
+
return model.batch_size
|
|
327
|
+
|
|
328
|
+
return self._batch_size
|
|
329
|
+
|
|
330
|
+
@batch_size.setter
|
|
331
|
+
def batch_size(self, batch_size: int) -> None:
|
|
332
|
+
self._model_ref = None
|
|
333
|
+
self._batch_size = batch_size
|
|
334
|
+
|
|
306
335
|
@property
|
|
307
336
|
def interprocess_encoding_context(self) -> InterprocessEncodingContext:
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
return model.interprocess_encoding_context
|
|
337
|
+
model = self._model()
|
|
338
|
+
if model is not None:
|
|
339
|
+
return model.interprocess_encoding_context
|
|
312
340
|
|
|
313
341
|
return self._interprocess_encoding_context
|
|
314
342
|
|
|
@@ -118,8 +118,7 @@ def observe(
|
|
|
118
118
|
sampled_paths = list(paths)
|
|
119
119
|
if not sampled_paths:
|
|
120
120
|
raise ValueError(
|
|
121
|
-
"no matching files available for replacement sampling; "
|
|
122
|
-
"check the streaming root and split pattern"
|
|
121
|
+
"no matching files available for replacement sampling; check the streaming root and split pattern"
|
|
123
122
|
)
|
|
124
123
|
|
|
125
124
|
def choices() -> Iterator[str]:
|
|
@@ -405,7 +404,6 @@ class StreamingDataModule(lit.LightningDataModule):
|
|
|
405
404
|
):
|
|
406
405
|
super().__init__()
|
|
407
406
|
|
|
408
|
-
self.hyperparameters = model.hyperparameters
|
|
409
407
|
self.root = root
|
|
410
408
|
self.suffix = Suffix(suffix)
|
|
411
409
|
self.train = train
|
|
@@ -418,8 +416,9 @@ class StreamingDataModule(lit.LightningDataModule):
|
|
|
418
416
|
self._model_ref = weakref.ref(model)
|
|
419
417
|
except TypeError:
|
|
420
418
|
self._model_ref = None
|
|
419
|
+
self._hyperparameters = model.hyperparameters
|
|
421
420
|
self._interprocess_encoding_context = model.interprocess_encoding_context
|
|
422
|
-
self.
|
|
421
|
+
self._batch_size = model.batch_size
|
|
423
422
|
self.num_workers = Strata.expand(num_workers, default=None)
|
|
424
423
|
self.persistent_workers = Strata.expand(persistent_workers, default=True)
|
|
425
424
|
self.pin_memory = Strata.expand(pin_memory, default=True)
|
|
@@ -427,22 +426,50 @@ class StreamingDataModule(lit.LightningDataModule):
|
|
|
427
426
|
self.chunk_batch_size = Strata.expand(chunk_batch_size, default=4096)
|
|
428
427
|
self.file_buffer_size = Strata.expand(file_buffer_size, default=1)
|
|
429
428
|
self.observation_buffer_size = Strata.expand(observation_buffer_size, default=1)
|
|
430
|
-
self.sample_rate = {
|
|
431
|
-
strata: float(rate)
|
|
432
|
-
for strata, rate in Strata.expand(sample_rate, default=1.0).items()
|
|
433
|
-
}
|
|
429
|
+
self.sample_rate = {strata: float(rate) for strata, rate in Strata.expand(sample_rate, default=1.0).items()}
|
|
434
430
|
self.replacement = (
|
|
435
431
|
{strata: strata == Strata.train for strata in Strata}
|
|
436
432
|
if replacement is None
|
|
437
433
|
else Strata.expand(replacement, default=False)
|
|
438
434
|
)
|
|
439
435
|
|
|
436
|
+
def _model(self) -> Model | None:
|
|
437
|
+
if self._model_ref is None:
|
|
438
|
+
return None
|
|
439
|
+
|
|
440
|
+
return self._model_ref()
|
|
441
|
+
|
|
442
|
+
@property
|
|
443
|
+
def hyperparameters(self) -> Hyperparameters:
|
|
444
|
+
model = self._model()
|
|
445
|
+
if model is not None:
|
|
446
|
+
return model.hyperparameters
|
|
447
|
+
|
|
448
|
+
return self._hyperparameters
|
|
449
|
+
|
|
450
|
+
@hyperparameters.setter
|
|
451
|
+
def hyperparameters(self, hyperparameters: Hyperparameters) -> None:
|
|
452
|
+
self._model_ref = None
|
|
453
|
+
self._hyperparameters = hyperparameters
|
|
454
|
+
|
|
455
|
+
@property
|
|
456
|
+
def batch_size(self) -> int:
|
|
457
|
+
model = self._model()
|
|
458
|
+
if model is not None:
|
|
459
|
+
return model.batch_size
|
|
460
|
+
|
|
461
|
+
return self._batch_size
|
|
462
|
+
|
|
463
|
+
@batch_size.setter
|
|
464
|
+
def batch_size(self, batch_size: int) -> None:
|
|
465
|
+
self._model_ref = None
|
|
466
|
+
self._batch_size = batch_size
|
|
467
|
+
|
|
440
468
|
@property
|
|
441
469
|
def interprocess_encoding_context(self) -> InterprocessEncodingContext:
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
return model.interprocess_encoding_context
|
|
470
|
+
model = self._model()
|
|
471
|
+
if model is not None:
|
|
472
|
+
return model.interprocess_encoding_context
|
|
446
473
|
|
|
447
474
|
return self._interprocess_encoding_context
|
|
448
475
|
|
|
@@ -141,7 +141,7 @@ class Pipeline:
|
|
|
141
141
|
return self
|
|
142
142
|
|
|
143
143
|
def __repr__(self):
|
|
144
|
-
return f"Pipeline({
|
|
144
|
+
return f"Pipeline(steps={len(self.steps)}, arguments={self.arguments!r})"
|
|
145
145
|
|
|
146
146
|
def __iter__(self):
|
|
147
147
|
stream = self.steps[0]()
|
|
@@ -41,9 +41,7 @@ class Writer(callbacks.BasePredictionWriter):
|
|
|
41
41
|
self.writer: pq.ParquetWriter | None = None
|
|
42
42
|
|
|
43
43
|
@staticmethod
|
|
44
|
-
def _as_struct_frame(
|
|
45
|
-
values_by_address: dict[Address, dict[str, Any]], alias: str, num_rows: int
|
|
46
|
-
) -> pl.DataFrame:
|
|
44
|
+
def _as_struct_frame(values_by_address: dict[Address, dict[str, Any]], alias: str, num_rows: int) -> pl.DataFrame:
|
|
47
45
|
if len(values_by_address) == 0:
|
|
48
46
|
return pl.DataFrame({alias: [None] * num_rows})
|
|
49
47
|
|
|
@@ -64,7 +62,7 @@ class Writer(callbacks.BasePredictionWriter):
|
|
|
64
62
|
batch: TensorDict[Address, TensorFieldBase],
|
|
65
63
|
batch_idx: int,
|
|
66
64
|
dataloader_idx: int,
|
|
67
|
-
) -> None:
|
|
65
|
+
) -> None: # ty:ignore[invalid-method-override]
|
|
68
66
|
num_rows = len(batch["metadata"])
|
|
69
67
|
|
|
70
68
|
supervised: dict[Address, dict[str, Any]]
|
|
@@ -4,7 +4,7 @@ import functools
|
|
|
4
4
|
from collections.abc import Callable
|
|
5
5
|
from enum import StrEnum
|
|
6
6
|
from pathlib import Path
|
|
7
|
-
from typing import Any, TypeAlias
|
|
7
|
+
from typing import Any, TypeAlias, cast
|
|
8
8
|
|
|
9
9
|
import litserve as ls
|
|
10
10
|
import pydantic
|
|
@@ -42,7 +42,7 @@ class Accelerator(StrEnum):
|
|
|
42
42
|
if normalized == "":
|
|
43
43
|
raise ValueError("accelerator must not be blank")
|
|
44
44
|
|
|
45
|
-
return cls._value2member_map_.get(normalized)
|
|
45
|
+
return cast(Accelerator | None, cls._value2member_map_.get(normalized))
|
|
46
46
|
|
|
47
47
|
|
|
48
48
|
class ErrorItem(pydantic.BaseModel):
|
|
@@ -114,8 +114,7 @@ class API(ls.LitAPI):
|
|
|
114
114
|
self,
|
|
115
115
|
request: dict[str, Any] | pydantic.BaseModel,
|
|
116
116
|
context: dict[str, Any] | None = None,
|
|
117
|
-
) -> Input | ErrorItem:
|
|
118
|
-
|
|
117
|
+
) -> Input | ErrorItem: # ty:ignore[invalid-method-override]
|
|
119
118
|
if isinstance(request, pydantic.BaseModel):
|
|
120
119
|
request = request.model_dump()
|
|
121
120
|
|
|
@@ -175,13 +174,13 @@ class API(ls.LitAPI):
|
|
|
175
174
|
return BatchItem(data=data, valid_indices=valid_indices, items=inputs)
|
|
176
175
|
|
|
177
176
|
@beartype
|
|
178
|
-
def unbatch(self, outputs: list[Any]) -> list[Any]:
|
|
177
|
+
def unbatch(self, outputs: list[Any]) -> list[Any]: # ty:ignore[invalid-method-override]
|
|
179
178
|
return list(outputs)
|
|
180
179
|
|
|
181
180
|
@beartype
|
|
182
181
|
def predict(
|
|
183
182
|
self, data: BatchItem | Input | ErrorItem
|
|
184
|
-
) -> list[list[Prediction] | ErrorItem] | list[Prediction] | ErrorItem:
|
|
183
|
+
) -> list[list[Prediction] | ErrorItem] | list[Prediction] | ErrorItem: # ty:ignore[invalid-method-override]
|
|
185
184
|
if isinstance(data, ErrorItem):
|
|
186
185
|
return data
|
|
187
186
|
|
|
@@ -209,7 +208,7 @@ class API(ls.LitAPI):
|
|
|
209
208
|
self,
|
|
210
209
|
response: list[Prediction] | ErrorItem,
|
|
211
210
|
context: dict[str, Any] | None = None,
|
|
212
|
-
) -> dict[str, Any] | pydantic.BaseModel:
|
|
211
|
+
) -> dict[str, Any] | pydantic.BaseModel: # ty:ignore[invalid-method-override]
|
|
213
212
|
if isinstance(response, ErrorItem):
|
|
214
213
|
return {
|
|
215
214
|
"predictions": {},
|
|
@@ -50,8 +50,7 @@ class Preprocessor(pydantic.BaseModel):
|
|
|
50
50
|
def accepted_kwargs(func: Callable[..., Any]) -> tuple[bool, frozenset[str]]:
|
|
51
51
|
signature = inspect.signature(func)
|
|
52
52
|
accepts_variadic_kwargs = any(
|
|
53
|
-
parameter.kind == inspect.Parameter.VAR_KEYWORD
|
|
54
|
-
for parameter in signature.parameters.values()
|
|
53
|
+
parameter.kind == inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values()
|
|
55
54
|
)
|
|
56
55
|
accepted = frozenset(signature.parameters.keys())
|
|
57
56
|
return accepts_variadic_kwargs, accepted
|
|
@@ -66,7 +65,8 @@ class Preprocessor(pydantic.BaseModel):
|
|
|
66
65
|
|
|
67
66
|
@classmethod
|
|
68
67
|
def register(cls, func: Callable[..., Any], *, mode: PreprocessorMode) -> Callable[..., Any]:
|
|
69
|
-
|
|
68
|
+
name = getattr(func, "__name__", type(func).__name__)
|
|
69
|
+
PREPROCESSORS[name] = cls(name=name, func=func, mode=mode)
|
|
70
70
|
return func
|
|
71
71
|
|
|
72
72
|
def __call__(self, observation: dict, **kwargs) -> Any:
|
|
@@ -99,9 +99,7 @@ class Preprocessor(pydantic.BaseModel):
|
|
|
99
99
|
|
|
100
100
|
def require_object(self, output: Any, *, mode: PreprocessorMode) -> dict[str, Any]:
|
|
101
101
|
if not isinstance(output, dict):
|
|
102
|
-
raise TypeError(
|
|
103
|
-
f"{mode} preprocessor '{self.name}' must produce dict objects, got {type(output).__name__}"
|
|
104
|
-
)
|
|
102
|
+
raise TypeError(f"{mode} preprocessor '{self.name}' must produce dict objects, got {type(output).__name__}")
|
|
105
103
|
|
|
106
104
|
return output
|
|
107
105
|
|
|
@@ -1,8 +1,11 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import enum
|
|
2
4
|
from collections.abc import Mapping
|
|
3
|
-
from typing import TypeVar
|
|
5
|
+
from typing import TypeVar, cast
|
|
4
6
|
|
|
5
7
|
T = TypeVar("T")
|
|
8
|
+
DefaultT = TypeVar("DefaultT")
|
|
6
9
|
|
|
7
10
|
|
|
8
11
|
class Tokens(enum.IntEnum):
|
|
@@ -27,10 +30,10 @@ class Strata(enum.StrEnum):
|
|
|
27
30
|
return cls(str(value).strip().lower())
|
|
28
31
|
|
|
29
32
|
@classmethod
|
|
30
|
-
def expand(cls, value: T | Mapping[
|
|
33
|
+
def expand(cls, value: T | Mapping[Strata | str, T], *, default: DefaultT) -> dict[Strata, T | DefaultT]:
|
|
31
34
|
if isinstance(value, Mapping):
|
|
32
|
-
normalized = {strata: default for strata in cls}
|
|
33
|
-
for key, item in value.items():
|
|
35
|
+
normalized: dict[Strata, T | DefaultT] = {strata: default for strata in cls}
|
|
36
|
+
for key, item in cast(Mapping[Strata | str, T], value).items():
|
|
34
37
|
normalized[cls.normalize(key)] = item
|
|
35
38
|
return normalized
|
|
36
39
|
|
|
@@ -87,10 +90,7 @@ class ShardingStrategy(enum.StrEnum):
|
|
|
87
90
|
*,
|
|
88
91
|
default: "ShardingStrategy",
|
|
89
92
|
) -> dict[Strata, "ShardingStrategy"]:
|
|
90
|
-
return {
|
|
91
|
-
strata: cls.normalize(strategy)
|
|
92
|
-
for strata, strategy in Strata.expand(value, default=default).items()
|
|
93
|
-
}
|
|
93
|
+
return {strata: cls.normalize(strategy) for strata, strategy in Strata.expand(value, default=default).items()}
|
|
94
94
|
|
|
95
95
|
|
|
96
96
|
class AttentionMode(enum.StrEnum):
|