json2vec 0.4.0__tar.gz → 0.4.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. {json2vec-0.4.0/src/json2vec.egg-info → json2vec-0.4.3}/PKG-INFO +2 -2
  2. {json2vec-0.4.0 → json2vec-0.4.3}/pyproject.toml +2 -2
  3. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/architecture/encoder.py +1 -0
  4. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/architecture/node.py +2 -2
  5. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/architecture/plot.py +14 -11
  6. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/architecture/root.py +25 -8
  7. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/architecture/rotary.py +1 -0
  8. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/data/datasets/polars.py +42 -14
  9. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/data/datasets/streaming.py +39 -12
  10. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/data/processing.py +1 -1
  11. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/inference/callback.py +2 -4
  12. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/inference/deployment.py +6 -7
  13. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/logging/config.py +1 -0
  14. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/preprocessors/base.py +4 -6
  15. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/structs/enums.py +8 -8
  16. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/structs/experiment.py +12 -8
  17. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/structs/packages.py +7 -12
  18. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/structs/structure.py +8 -5
  19. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/structs/tree.py +0 -2
  20. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/tensorfields/base.py +61 -13
  21. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/tensorfields/extensions/category.py +17 -23
  22. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/tensorfields/extensions/dateparts.py +3 -7
  23. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/tensorfields/extensions/entity.py +11 -11
  24. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/tensorfields/extensions/number.py +12 -23
  25. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/tensorfields/extensions/set.py +7 -12
  26. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/tensorfields/extensions/text.py +4 -3
  27. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/tensorfields/extensions/vector.py +1 -0
  28. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/tensorfields/shared/counter.py +17 -8
  29. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/tensorfields/shared/vocabulary.py +15 -18
  30. {json2vec-0.4.0 → json2vec-0.4.3/src/json2vec.egg-info}/PKG-INFO +2 -2
  31. {json2vec-0.4.0 → json2vec-0.4.3}/LICENSE +0 -0
  32. {json2vec-0.4.0 → json2vec-0.4.3}/NOTICE +0 -0
  33. {json2vec-0.4.0 → json2vec-0.4.3}/README.md +0 -0
  34. {json2vec-0.4.0 → json2vec-0.4.3}/setup.cfg +0 -0
  35. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/__init__.py +0 -0
  36. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/architecture/__init__.py +0 -0
  37. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/architecture/attention.py +0 -0
  38. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/architecture/pool.py +0 -0
  39. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/data/__init__.py +0 -0
  40. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/data/datasets/__init__.py +0 -0
  41. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/data/datasets/base.py +0 -0
  42. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/data/iterables.py +0 -0
  43. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/distributed.py +0 -0
  44. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/inference/__init__.py +0 -0
  45. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/logging/__init__.py +0 -0
  46. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/logging/epoch.py +0 -0
  47. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/logging/throughput.py +0 -0
  48. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/preprocessors/__init__.py +0 -0
  49. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/preprocessors/extensions/__init__.py +0 -0
  50. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/preprocessors/spec.py +0 -0
  51. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/structs/__init__.py +0 -0
  52. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/tensorfields/__init__.py +0 -0
  53. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/tensorfields/extensions/__init__.py +0 -0
  54. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/tensorfields/shared/__init__.py +0 -0
  55. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec/tensorfields/spec.py +0 -0
  56. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec.egg-info/SOURCES.txt +0 -0
  57. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec.egg-info/dependency_links.txt +0 -0
  58. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec.egg-info/requires.txt +0 -0
  59. {json2vec-0.4.0 → json2vec-0.4.3}/src/json2vec.egg-info/top_level.txt +0 -0
  60. {json2vec-0.4.0 → json2vec-0.4.3}/tests/test_callbacks.py +0 -0
  61. {json2vec-0.4.0 → json2vec-0.4.3}/tests/test_public_api.py +0 -0
@@ -1,7 +1,7 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: json2vec
3
- Version: 0.4.0
4
- Summary: JSON -> [*]
3
+ Version: 0.4.3
4
+ Summary: {...} -> [*]
5
5
  License-Expression: Apache-2.0
6
6
  Requires-Python: >=3.12
7
7
  Description-Content-Type: text/markdown
@@ -1,7 +1,7 @@
1
1
  [project]
2
2
  name = "json2vec"
3
- version = "0.4.0"
4
- description = "JSON -> [*]"
3
+ version = "0.4.3"
4
+ description = "{...} -> [*]"
5
5
  readme = "README.md"
6
6
  license = "Apache-2.0"
7
7
  requires-python = ">=3.12"
@@ -13,6 +13,7 @@ from json2vec.structs.tree import Address
13
13
  if TYPE_CHECKING:
14
14
  from json2vec.structs.experiment import Hyperparameters
15
15
 
16
+
16
17
  class RotaryTransformerEncoderLayer(torch.nn.Module):
17
18
  def __init__(
18
19
  self,
@@ -1,7 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import inspect
4
- from typing import TYPE_CHECKING
4
+ from typing import TYPE_CHECKING, Any
5
5
 
6
6
  import torch
7
7
 
@@ -25,7 +25,7 @@ class NodeModule(torch.nn.Module):
25
25
  if address in hyperparameters.requests:
26
26
  request: Node = hyperparameters.requests[address]
27
27
  plugin: Plugin = TENSORFIELDS[request.type]
28
- embedder_kwargs = dict(hyperparameters=hyperparameters, address=address)
28
+ embedder_kwargs: dict[str, Any] = dict(hyperparameters=hyperparameters, address=address)
29
29
  if "batch_size" in inspect.signature(plugin.Embedder.__init__).parameters:
30
30
  embedder_kwargs["batch_size"] = batch_size
31
31
 
@@ -49,7 +49,7 @@ def plot(
49
49
  ) -> None:
50
50
  """Print a Rich model visualization and optionally write it as text."""
51
51
  renderable = build_plot(module=module, address=address, detail=detail, mode=mode)
52
- Console(width=PLOT_WIDTH, force_jupyter=False).print(renderable)
52
+ Console(width=PLOT_WIDTH).print(renderable)
53
53
 
54
54
  if out is None:
55
55
  return
@@ -112,7 +112,9 @@ def render_flow_plot(module: "Model", address: Address | str | None) -> Renderab
112
112
  table.add_column("Count", justify="right")
113
113
  table.add_row("JSON", "Raw nested records enter with the shape described by the schema.", "")
114
114
  table.add_row("Tensorfields", "Typed requests read values with JMESPath queries.", str(len(fields)))
115
- table.add_row("Encoders", "Array nodes pool child embeddings into parent contexts.", str(len(hyperparameters.arrays)))
115
+ table.add_row(
116
+ "Encoders", "Array nodes pool child embeddings into parent contexts.", str(len(hyperparameters.arrays))
117
+ )
116
118
  table.add_row("Targets", "Target fields produce supervised predictions.", str(target_count))
117
119
  table.add_row("Embeddings", "Selected nodes expose reusable embeddings.", str(embed_count))
118
120
 
@@ -272,11 +274,7 @@ def node_metadata_keys(node: Node, values: dict[str, Any], state_focus: bool) ->
272
274
 
273
275
 
274
276
  def should_hide_metadata(key: str, value: Any) -> bool:
275
- return (
276
- (key == "active" and value is True)
277
- or (key == "embed" and value is False)
278
- or key == "description"
279
- )
277
+ return (key == "active" and value is True) or (key == "embed" and value is False) or key == "description"
280
278
 
281
279
 
282
280
  def format_metadata_value(value: Any) -> str:
@@ -317,7 +315,9 @@ def format_detail_inline(value: Any) -> str:
317
315
  return truncate(value, width=100)
318
316
 
319
317
  if isinstance(value, list):
320
- return truncate(format_inline_sequence(value) or pformat(value, compact=True, sort_dicts=False, width=88), width=100)
318
+ return truncate(
319
+ format_inline_sequence(value) or pformat(value, compact=True, sort_dicts=False, width=88), width=100
320
+ )
321
321
 
322
322
  return truncate(pformat(value, compact=True, sort_dicts=False, width=88), width=100)
323
323
 
@@ -332,7 +332,9 @@ def summarize_value(value: Any, max_items: int = 8) -> Any:
332
332
  if isinstance(value, list):
333
333
  if len(value) <= max_items:
334
334
  return [summarize_value(item, max_items=max_items) for item in value]
335
- return [summarize_value(item, max_items=max_items) for item in value[:max_items]] + [f"... {len(value) - max_items} more"]
335
+ return [summarize_value(item, max_items=max_items) for item in value[:max_items]] + [
336
+ f"... {len(value) - max_items} more"
337
+ ]
336
338
 
337
339
  return value
338
340
 
@@ -410,8 +412,9 @@ def format_compact_number(value: Any) -> str:
410
412
 
411
413
  def resolve_node(hyperparameters: "Hyperparameters", address: Address | str) -> Node:
412
414
  key = Address(str(address))
413
- leaves = {node.address: node for node in hyperparameters.descendants if isinstance(node, Leaf)}
414
- nodes: dict[Address, Node] = hyperparameters.arrays | leaves
415
+ leaves: dict[Address, Node] = {node.address: node for node in hyperparameters.descendants if isinstance(node, Leaf)}
416
+ nodes: dict[Address, Node] = dict(hyperparameters.arrays)
417
+ nodes.update(leaves)
415
418
 
416
419
  if key not in nodes:
417
420
  raise ValueError(f"address '{address}' was not found in the hyperparameters")
@@ -18,6 +18,7 @@ from tensordict import TensorDict
18
18
 
19
19
  from json2vec.architecture.encoder import ArrayEncoder
20
20
  from json2vec.architecture.node import NodeModule
21
+ from json2vec.architecture.plot import PlotMode
21
22
  from json2vec.data.datasets.base import EncodedBatch
22
23
  from json2vec.data.iterables import encode, mock
23
24
  from json2vec.structs.enums import AttentionMode, Metric, Strata, TensorKey
@@ -296,7 +297,6 @@ class Model(lit.LightningModule):
296
297
  optimizer: OptimizerConfig | None = None,
297
298
  scheduler: SchedulerConfig | None = None,
298
299
  ):
299
-
300
300
  super().__init__()
301
301
  if batch_size <= 0:
302
302
  raise ValueError("batch_size must be > 0")
@@ -330,11 +330,16 @@ class Model(lit.LightningModule):
330
330
  self.example_input_array = mock(hyperparameters=self.hyperparameters, batch_size=self.batch_size)
331
331
 
332
332
  def _rebuild(self) -> None:
333
+ self.hyperparameters._clear_tree_caches()
334
+ was_training = self.training
335
+ device = self.device
333
336
  previous = {
334
337
  name: value.detach().clone() if isinstance(value, torch.Tensor) else deepcopy(value)
335
338
  for name, value in self.state_dict().items()
336
339
  }
337
340
  self._build()
341
+ if isinstance(device, torch.device):
342
+ self.to(device=device)
338
343
  current = self.state_dict()
339
344
  compatible = {}
340
345
  for name, value in previous.items():
@@ -351,6 +356,7 @@ class Model(lit.LightningModule):
351
356
  compatible[name] = value
352
357
 
353
358
  self.load_state_dict(compatible, strict=False)
359
+ self.train(was_training)
354
360
 
355
361
  def select(
356
362
  self,
@@ -372,6 +378,7 @@ class Model(lit.LightningModule):
372
378
  allow_extra: bool = False,
373
379
  include_root: bool = True,
374
380
  validate: bool = True,
381
+ use_cache: bool = False,
375
382
  **values: Any,
376
383
  ) -> None:
377
384
  """Mutate selected schema nodes and rebuild compatible modules.
@@ -386,6 +393,8 @@ class Model(lit.LightningModule):
386
393
  allow unknown fields.
387
394
  include_root: Include the root node in predicate matching.
388
395
  validate: Validate each node after applying candidate values.
396
+ use_cache: Permit cached selector results. Mutations default this to
397
+ `False` so updates always evaluate against current schema state.
389
398
  **values: Schema attributes to update.
390
399
  """
391
400
  self._assert_mutation_allowed("update")
@@ -395,6 +404,7 @@ class Model(lit.LightningModule):
395
404
  allow_extra=allow_extra,
396
405
  include_root=include_root,
397
406
  validate=validate,
407
+ use_cache=use_cache,
398
408
  **values,
399
409
  )
400
410
  self._rebuild()
@@ -468,6 +478,7 @@ class Model(lit.LightningModule):
468
478
  allow_extra: bool = False,
469
479
  include_root: bool = True,
470
480
  validate: bool = True,
481
+ use_cache: bool = False,
471
482
  **values: Any,
472
483
  ) -> Iterator[None]:
473
484
  """Temporarily mutate selected schema nodes and keep runtime modules synchronized."""
@@ -479,6 +490,7 @@ class Model(lit.LightningModule):
479
490
  allow_extra=allow_extra,
480
491
  include_root=include_root,
481
492
  validate=validate,
493
+ use_cache=use_cache,
482
494
  **values,
483
495
  ):
484
496
  self._rebuild()
@@ -563,7 +575,7 @@ class Model(lit.LightningModule):
563
575
  address: Address | str | None = None,
564
576
  detail: bool = False,
565
577
  out: str | Path | None = None,
566
- mode: str = "schema",
578
+ mode: PlotMode = "schema",
567
579
  ) -> None:
568
580
  """Print a Rich model visualization.
569
581
 
@@ -602,7 +614,8 @@ class Model(lit.LightningModule):
602
614
 
603
615
  embedder: EmbedderBase = self.nodes[address].embedder
604
616
  embedding: Parcel = embedder(tensorfield)
605
- processed[embedding.destination].append(embedding)
617
+ if embedding.destination is not None:
618
+ processed[embedding.destination].append(embedding)
606
619
  outgoing[embedding.origin] = embedding
607
620
 
608
621
  if address in self.hyperparameters.embed:
@@ -617,7 +630,8 @@ class Model(lit.LightningModule):
617
630
 
618
631
  encoder: ArrayEncoder = self.nodes[address].encoder
619
632
  encoding: Parcel = encoder(processed[address])
620
- processed[encoding.destination].append(encoding)
633
+ if encoding.destination is not None:
634
+ processed[encoding.destination].append(encoding)
621
635
  outgoing[encoding.origin] = encoding
622
636
 
623
637
  if address in self.hyperparameters.embed:
@@ -696,7 +710,6 @@ class Model(lit.LightningModule):
696
710
  def write(
697
711
  self, predictions: list[Prediction]
698
712
  ) -> tuple[dict[Address, dict[str, Any]], dict[Address, dict[str, Any]]]:
699
-
700
713
  supervised: dict[Address, dict[str, Any]] = {}
701
714
  embeddings: dict[Address, dict[str, Any]] = {}
702
715
 
@@ -738,17 +751,21 @@ class Model(lit.LightningModule):
738
751
 
739
752
  if preprocess is not None:
740
753
  observations: EncodedBatch = []
741
- for request in batch:
754
+ for request in cast(list[dict[str, Any]], batch):
742
755
  observation = preprocess(request)
743
756
  if not isinstance(observation, dict):
744
757
  raise TypeError(f"preprocessor must return a dict object, got {type(observation).__name__}")
745
758
 
746
759
  observations.append([observation])
747
760
 
748
- batch = observations
761
+ encoded_batch = observations
762
+ elif batch and isinstance(batch[0], dict):
763
+ encoded_batch = [[request] for request in cast(list[dict[str, Any]], batch)]
764
+ else:
765
+ encoded_batch = cast(EncodedBatch, batch)
749
766
 
750
767
  inputs = encode(
751
- batch=batch,
768
+ batch=encoded_batch,
752
769
  hyperparameters=self.hyperparameters,
753
770
  strata=Strata.predict,
754
771
  interprocess_encoding_context=self.interprocess_encoding_context,
@@ -13,6 +13,7 @@ class RotaryEmbedding(torch.nn.Module):
13
13
  self.base = base
14
14
 
15
15
  index = torch.arange(0, self.rotary_dim, 2, dtype=torch.float32)
16
+ self.inv_freq: torch.Tensor
16
17
  self.register_buffer("inv_freq", base ** (-index / self.rotary_dim), persistent=False)
17
18
 
18
19
  def forward(self, inputs: torch.Tensor) -> torch.Tensor:
@@ -57,7 +57,7 @@ def _dataframes_by_strata(dataframe: pl.DataFrame | DataFrameMap) -> dict[Strata
57
57
  return {strata: dataframe for strata in Strata}
58
58
 
59
59
  normalized: dict[Strata, pl.DataFrame] = {}
60
- for key, frame in dataframe.items():
60
+ for key, frame in cast(DataFrameMap, dataframe).items():
61
61
  if not isinstance(frame, pl.DataFrame):
62
62
  raise TypeError(f"dataframe for strata '{key}' must be a polars DataFrame")
63
63
  normalized[Strata.normalize(key)] = frame
@@ -79,7 +79,7 @@ def observe_polars(
79
79
  world_size: int | None = None,
80
80
  ) -> Iterator[RawObservation]:
81
81
  if replacement:
82
- rows = cast(list[RawObservation], dataframe.to_dicts())
82
+ rows = dataframe.to_dicts()
83
83
  if not rows:
84
84
  raise ValueError("no dataframe rows available for replacement sampling")
85
85
 
@@ -103,7 +103,7 @@ def observe_polars(
103
103
  worker_id=worker_id,
104
104
  num_workers=num_workers,
105
105
  ):
106
- yield cast(RawObservation, row)
106
+ yield row
107
107
  return
108
108
 
109
109
  for chunk_index, offset in enumerate(range(0, dataframe.height, chunk_batch_size)):
@@ -115,7 +115,7 @@ def observe_polars(
115
115
  ):
116
116
  continue
117
117
 
118
- yield from cast(list[RawObservation], dataframe.slice(offset, chunk_batch_size).to_dicts())
118
+ yield from dataframe.slice(offset, chunk_batch_size).to_dicts()
119
119
 
120
120
 
121
121
  class PolarsBatchDataset(IterableDataset):
@@ -281,7 +281,6 @@ class PolarsDataModule(lit.LightningDataModule):
281
281
  else:
282
282
  dataframes = _dataframes_by_strata(dataframe)
283
283
 
284
- self.hyperparameters = model.hyperparameters
285
284
  self.dataframes = dataframes
286
285
  self.preprocessor = PreprocessorConfig.normalize(preprocessor)
287
286
  self.preprocessor_kwargs = dict(kwargs)
@@ -289,26 +288,55 @@ class PolarsDataModule(lit.LightningDataModule):
289
288
  self._model_ref = weakref.ref(model)
290
289
  except TypeError:
291
290
  self._model_ref = None
291
+ self._hyperparameters = model.hyperparameters
292
292
  self._interprocess_encoding_context = model.interprocess_encoding_context
293
- self.batch_size = model.batch_size
293
+ self._batch_size = model.batch_size
294
294
  self.num_workers = Strata.expand(num_workers, default=None)
295
295
  self.persistent_workers = Strata.expand(persistent_workers, default=True)
296
296
  self.pin_memory = Strata.expand(pin_memory, default=True)
297
297
  self.sharding = ShardingStrategy.expand(sharding, default=ShardingStrategy.chunk)
298
298
  self.chunk_batch_size = Strata.expand(chunk_batch_size, default=4096)
299
299
  self.observation_buffer_size = Strata.expand(observation_buffer_size, default=1)
300
- self.sample_rate = {
301
- strata: float(rate)
302
- for strata, rate in Strata.expand(sample_rate, default=1.0).items()
303
- }
300
+ self.sample_rate = {strata: float(rate) for strata, rate in Strata.expand(sample_rate, default=1.0).items()}
304
301
  self.replacement = Strata.expand(replacement, default=False)
305
302
 
303
+ def _model(self) -> Model | None:
304
+ if self._model_ref is None:
305
+ return None
306
+
307
+ return self._model_ref()
308
+
309
+ @property
310
+ def hyperparameters(self) -> Hyperparameters:
311
+ model = self._model()
312
+ if model is not None:
313
+ return model.hyperparameters
314
+
315
+ return self._hyperparameters
316
+
317
+ @hyperparameters.setter
318
+ def hyperparameters(self, hyperparameters: Hyperparameters) -> None:
319
+ self._model_ref = None
320
+ self._hyperparameters = hyperparameters
321
+
322
+ @property
323
+ def batch_size(self) -> int:
324
+ model = self._model()
325
+ if model is not None:
326
+ return model.batch_size
327
+
328
+ return self._batch_size
329
+
330
+ @batch_size.setter
331
+ def batch_size(self, batch_size: int) -> None:
332
+ self._model_ref = None
333
+ self._batch_size = batch_size
334
+
306
335
  @property
307
336
  def interprocess_encoding_context(self) -> InterprocessEncodingContext:
308
- if self._model_ref is not None:
309
- model = self._model_ref()
310
- if model is not None:
311
- return model.interprocess_encoding_context
337
+ model = self._model()
338
+ if model is not None:
339
+ return model.interprocess_encoding_context
312
340
 
313
341
  return self._interprocess_encoding_context
314
342
 
@@ -118,8 +118,7 @@ def observe(
118
118
  sampled_paths = list(paths)
119
119
  if not sampled_paths:
120
120
  raise ValueError(
121
- "no matching files available for replacement sampling; "
122
- "check the streaming root and split pattern"
121
+ "no matching files available for replacement sampling; check the streaming root and split pattern"
123
122
  )
124
123
 
125
124
  def choices() -> Iterator[str]:
@@ -405,7 +404,6 @@ class StreamingDataModule(lit.LightningDataModule):
405
404
  ):
406
405
  super().__init__()
407
406
 
408
- self.hyperparameters = model.hyperparameters
409
407
  self.root = root
410
408
  self.suffix = Suffix(suffix)
411
409
  self.train = train
@@ -418,8 +416,9 @@ class StreamingDataModule(lit.LightningDataModule):
418
416
  self._model_ref = weakref.ref(model)
419
417
  except TypeError:
420
418
  self._model_ref = None
419
+ self._hyperparameters = model.hyperparameters
421
420
  self._interprocess_encoding_context = model.interprocess_encoding_context
422
- self.batch_size = model.batch_size
421
+ self._batch_size = model.batch_size
423
422
  self.num_workers = Strata.expand(num_workers, default=None)
424
423
  self.persistent_workers = Strata.expand(persistent_workers, default=True)
425
424
  self.pin_memory = Strata.expand(pin_memory, default=True)
@@ -427,22 +426,50 @@ class StreamingDataModule(lit.LightningDataModule):
427
426
  self.chunk_batch_size = Strata.expand(chunk_batch_size, default=4096)
428
427
  self.file_buffer_size = Strata.expand(file_buffer_size, default=1)
429
428
  self.observation_buffer_size = Strata.expand(observation_buffer_size, default=1)
430
- self.sample_rate = {
431
- strata: float(rate)
432
- for strata, rate in Strata.expand(sample_rate, default=1.0).items()
433
- }
429
+ self.sample_rate = {strata: float(rate) for strata, rate in Strata.expand(sample_rate, default=1.0).items()}
434
430
  self.replacement = (
435
431
  {strata: strata == Strata.train for strata in Strata}
436
432
  if replacement is None
437
433
  else Strata.expand(replacement, default=False)
438
434
  )
439
435
 
436
+ def _model(self) -> Model | None:
437
+ if self._model_ref is None:
438
+ return None
439
+
440
+ return self._model_ref()
441
+
442
+ @property
443
+ def hyperparameters(self) -> Hyperparameters:
444
+ model = self._model()
445
+ if model is not None:
446
+ return model.hyperparameters
447
+
448
+ return self._hyperparameters
449
+
450
+ @hyperparameters.setter
451
+ def hyperparameters(self, hyperparameters: Hyperparameters) -> None:
452
+ self._model_ref = None
453
+ self._hyperparameters = hyperparameters
454
+
455
+ @property
456
+ def batch_size(self) -> int:
457
+ model = self._model()
458
+ if model is not None:
459
+ return model.batch_size
460
+
461
+ return self._batch_size
462
+
463
+ @batch_size.setter
464
+ def batch_size(self, batch_size: int) -> None:
465
+ self._model_ref = None
466
+ self._batch_size = batch_size
467
+
440
468
  @property
441
469
  def interprocess_encoding_context(self) -> InterprocessEncodingContext:
442
- if self._model_ref is not None:
443
- model = self._model_ref()
444
- if model is not None:
445
- return model.interprocess_encoding_context
470
+ model = self._model()
471
+ if model is not None:
472
+ return model.interprocess_encoding_context
446
473
 
447
474
  return self._interprocess_encoding_context
448
475
 
@@ -141,7 +141,7 @@ class Pipeline:
141
141
  return self
142
142
 
143
143
  def __repr__(self):
144
- return f"Pipeline({repr(self.source)}, {repr(self.arguments)})"
144
+ return f"Pipeline(steps={len(self.steps)}, arguments={self.arguments!r})"
145
145
 
146
146
  def __iter__(self):
147
147
  stream = self.steps[0]()
@@ -41,9 +41,7 @@ class Writer(callbacks.BasePredictionWriter):
41
41
  self.writer: pq.ParquetWriter | None = None
42
42
 
43
43
  @staticmethod
44
- def _as_struct_frame(
45
- values_by_address: dict[Address, dict[str, Any]], alias: str, num_rows: int
46
- ) -> pl.DataFrame:
44
+ def _as_struct_frame(values_by_address: dict[Address, dict[str, Any]], alias: str, num_rows: int) -> pl.DataFrame:
47
45
  if len(values_by_address) == 0:
48
46
  return pl.DataFrame({alias: [None] * num_rows})
49
47
 
@@ -64,7 +62,7 @@ class Writer(callbacks.BasePredictionWriter):
64
62
  batch: TensorDict[Address, TensorFieldBase],
65
63
  batch_idx: int,
66
64
  dataloader_idx: int,
67
- ) -> None:
65
+ ) -> None: # ty:ignore[invalid-method-override]
68
66
  num_rows = len(batch["metadata"])
69
67
 
70
68
  supervised: dict[Address, dict[str, Any]]
@@ -4,7 +4,7 @@ import functools
4
4
  from collections.abc import Callable
5
5
  from enum import StrEnum
6
6
  from pathlib import Path
7
- from typing import Any, TypeAlias
7
+ from typing import Any, TypeAlias, cast
8
8
 
9
9
  import litserve as ls
10
10
  import pydantic
@@ -42,7 +42,7 @@ class Accelerator(StrEnum):
42
42
  if normalized == "":
43
43
  raise ValueError("accelerator must not be blank")
44
44
 
45
- return cls._value2member_map_.get(normalized)
45
+ return cast(Accelerator | None, cls._value2member_map_.get(normalized))
46
46
 
47
47
 
48
48
  class ErrorItem(pydantic.BaseModel):
@@ -114,8 +114,7 @@ class API(ls.LitAPI):
114
114
  self,
115
115
  request: dict[str, Any] | pydantic.BaseModel,
116
116
  context: dict[str, Any] | None = None,
117
- ) -> Input | ErrorItem:
118
-
117
+ ) -> Input | ErrorItem: # ty:ignore[invalid-method-override]
119
118
  if isinstance(request, pydantic.BaseModel):
120
119
  request = request.model_dump()
121
120
 
@@ -175,13 +174,13 @@ class API(ls.LitAPI):
175
174
  return BatchItem(data=data, valid_indices=valid_indices, items=inputs)
176
175
 
177
176
  @beartype
178
- def unbatch(self, outputs: list[Any]) -> list[Any]:
177
+ def unbatch(self, outputs: list[Any]) -> list[Any]: # ty:ignore[invalid-method-override]
179
178
  return list(outputs)
180
179
 
181
180
  @beartype
182
181
  def predict(
183
182
  self, data: BatchItem | Input | ErrorItem
184
- ) -> list[list[Prediction] | ErrorItem] | list[Prediction] | ErrorItem:
183
+ ) -> list[list[Prediction] | ErrorItem] | list[Prediction] | ErrorItem: # ty:ignore[invalid-method-override]
185
184
  if isinstance(data, ErrorItem):
186
185
  return data
187
186
 
@@ -209,7 +208,7 @@ class API(ls.LitAPI):
209
208
  self,
210
209
  response: list[Prediction] | ErrorItem,
211
210
  context: dict[str, Any] | None = None,
212
- ) -> dict[str, Any] | pydantic.BaseModel:
211
+ ) -> dict[str, Any] | pydantic.BaseModel: # ty:ignore[invalid-method-override]
213
212
  if isinstance(response, ErrorItem):
214
213
  return {
215
214
  "predictions": {},
@@ -10,6 +10,7 @@ console = Console(file=sys.stdout)
10
10
 
11
11
  LOG_LEVEL: str = os.getenv("JSON2VEC_LOG_LEVEL", "DEBUG").upper()
12
12
 
13
+
13
14
  def sink(message):
14
15
  record = message.record
15
16
  extras = {k: str(v) for k, v in record["extra"].items()}
@@ -50,8 +50,7 @@ class Preprocessor(pydantic.BaseModel):
50
50
  def accepted_kwargs(func: Callable[..., Any]) -> tuple[bool, frozenset[str]]:
51
51
  signature = inspect.signature(func)
52
52
  accepts_variadic_kwargs = any(
53
- parameter.kind == inspect.Parameter.VAR_KEYWORD
54
- for parameter in signature.parameters.values()
53
+ parameter.kind == inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values()
55
54
  )
56
55
  accepted = frozenset(signature.parameters.keys())
57
56
  return accepts_variadic_kwargs, accepted
@@ -66,7 +65,8 @@ class Preprocessor(pydantic.BaseModel):
66
65
 
67
66
  @classmethod
68
67
  def register(cls, func: Callable[..., Any], *, mode: PreprocessorMode) -> Callable[..., Any]:
69
- PREPROCESSORS[func.__name__] = cls(name=func.__name__, func=func, mode=mode)
68
+ name = getattr(func, "__name__", type(func).__name__)
69
+ PREPROCESSORS[name] = cls(name=name, func=func, mode=mode)
70
70
  return func
71
71
 
72
72
  def __call__(self, observation: dict, **kwargs) -> Any:
@@ -99,9 +99,7 @@ class Preprocessor(pydantic.BaseModel):
99
99
 
100
100
  def require_object(self, output: Any, *, mode: PreprocessorMode) -> dict[str, Any]:
101
101
  if not isinstance(output, dict):
102
- raise TypeError(
103
- f"{mode} preprocessor '{self.name}' must produce dict objects, got {type(output).__name__}"
104
- )
102
+ raise TypeError(f"{mode} preprocessor '{self.name}' must produce dict objects, got {type(output).__name__}")
105
103
 
106
104
  return output
107
105
 
@@ -1,8 +1,11 @@
1
+ from __future__ import annotations
2
+
1
3
  import enum
2
4
  from collections.abc import Mapping
3
- from typing import TypeVar
5
+ from typing import TypeVar, cast
4
6
 
5
7
  T = TypeVar("T")
8
+ DefaultT = TypeVar("DefaultT")
6
9
 
7
10
 
8
11
  class Tokens(enum.IntEnum):
@@ -27,10 +30,10 @@ class Strata(enum.StrEnum):
27
30
  return cls(str(value).strip().lower())
28
31
 
29
32
  @classmethod
30
- def expand(cls, value: T | Mapping["Strata | str", T], *, default: T) -> dict["Strata", T]:
33
+ def expand(cls, value: T | Mapping[Strata | str, T], *, default: DefaultT) -> dict[Strata, T | DefaultT]:
31
34
  if isinstance(value, Mapping):
32
- normalized = {strata: default for strata in cls}
33
- for key, item in value.items():
35
+ normalized: dict[Strata, T | DefaultT] = {strata: default for strata in cls}
36
+ for key, item in cast(Mapping[Strata | str, T], value).items():
34
37
  normalized[cls.normalize(key)] = item
35
38
  return normalized
36
39
 
@@ -87,10 +90,7 @@ class ShardingStrategy(enum.StrEnum):
87
90
  *,
88
91
  default: "ShardingStrategy",
89
92
  ) -> dict[Strata, "ShardingStrategy"]:
90
- return {
91
- strata: cls.normalize(strategy)
92
- for strata, strategy in Strata.expand(value, default=default).items()
93
- }
93
+ return {strata: cls.normalize(strategy) for strata, strategy in Strata.expand(value, default=default).items()}
94
94
 
95
95
 
96
96
  class AttentionMode(enum.StrEnum):