json2vec 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. json2vec/__init__.py +0 -0
  2. json2vec/__main__.py +32 -0
  3. json2vec/architecture/__init__.py +0 -0
  4. json2vec/architecture/attention.py +64 -0
  5. json2vec/architecture/counter.py +37 -0
  6. json2vec/architecture/encoder.py +88 -0
  7. json2vec/architecture/node.py +34 -0
  8. json2vec/architecture/pool.py +61 -0
  9. json2vec/architecture/root.py +338 -0
  10. json2vec/architecture/rotary.py +39 -0
  11. json2vec/data/__init__.py +0 -0
  12. json2vec/data/datasets.py +539 -0
  13. json2vec/data/processing.py +152 -0
  14. json2vec/entrypoints/__init__.py +3 -0
  15. json2vec/entrypoints/pipeline.py +174 -0
  16. json2vec/inference/__init__.py +0 -0
  17. json2vec/inference/callback.py +98 -0
  18. json2vec/inference/deployment.py +175 -0
  19. json2vec/logging/__init__.py +0 -0
  20. json2vec/logging/config.py +27 -0
  21. json2vec/logging/epoch.py +42 -0
  22. json2vec/logging/throughput.py +39 -0
  23. json2vec/logging/tracking.py +152 -0
  24. json2vec/processors/__init__.py +8 -0
  25. json2vec/processors/base.py +102 -0
  26. json2vec/processors/extensions/__init__.py +0 -0
  27. json2vec/processors/extensions/example.py +6 -0
  28. json2vec/processors/spec.py +8 -0
  29. json2vec/structs/__init__.py +0 -0
  30. json2vec/structs/enums.py +84 -0
  31. json2vec/structs/environment.py +138 -0
  32. json2vec/structs/experiment.py +330 -0
  33. json2vec/structs/packages.py +117 -0
  34. json2vec/structs/structure.py +70 -0
  35. json2vec/structs/tree.py +92 -0
  36. json2vec/tensorfields/__init__.py +8 -0
  37. json2vec/tensorfields/base.py +210 -0
  38. json2vec/tensorfields/extensions/__init__.py +0 -0
  39. json2vec/tensorfields/extensions/category.py +484 -0
  40. json2vec/tensorfields/extensions/dateparts.py +410 -0
  41. json2vec/tensorfields/extensions/entity.py +336 -0
  42. json2vec/tensorfields/extensions/number.py +400 -0
  43. json2vec/tensorfields/extensions/vector.py +279 -0
  44. json2vec/tensorfields/spec.py +8 -0
  45. json2vec-0.1.0.dist-info/METADATA +227 -0
  46. json2vec-0.1.0.dist-info/RECORD +51 -0
  47. json2vec-0.1.0.dist-info/WHEEL +5 -0
  48. json2vec-0.1.0.dist-info/entry_points.txt +2 -0
  49. json2vec-0.1.0.dist-info/licenses/LICENSE +178 -0
  50. json2vec-0.1.0.dist-info/licenses/NOTICE +8 -0
  51. json2vec-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,152 @@
1
+ import inspect
2
+ from collections.abc import Callable
3
+ from functools import partial
4
+ from typing import Any
5
+
6
+ import numpy as np
7
+ from beartype import beartype
8
+
9
+ from json2vec.structs.enums import Tokens
10
+
11
+
12
+ def apply(
13
+ values: Any,
14
+ function: Callable[..., Any],
15
+ /,
16
+ *args: Any,
17
+ leaf_depth: int | None = None,
18
+ **kwargs: Any,
19
+ ) -> Any:
20
+ """Apply a function recursively to nested list leaves.
21
+
22
+ When ``leaf_depth`` is set, the function is applied exactly at that depth;
23
+ higher-level non-list values are preserved so downstream padding can mark
24
+ them as incomplete.
25
+ """
26
+ if leaf_depth is not None and leaf_depth < 0:
27
+ raise ValueError("leaf_depth must be >= 0")
28
+
29
+ def walk(node: Any, depth: int) -> Any:
30
+ if leaf_depth is None:
31
+ if isinstance(node, list):
32
+ return [walk(item, depth + 1) for item in node]
33
+
34
+ if node is None:
35
+ return None
36
+
37
+ return function(node, *args, **kwargs)
38
+
39
+ if depth == leaf_depth:
40
+ if node is None:
41
+ return None
42
+
43
+ return function(node, *args, **kwargs)
44
+
45
+ if isinstance(node, list):
46
+ return [walk(item, depth + 1) for item in node]
47
+
48
+ return node
49
+
50
+ return walk(values, depth=0)
51
+
52
+
53
+ def _iter_leaf_nodes(
54
+ nested: Any,
55
+ shape: tuple[int, ...],
56
+ strides: tuple[int, ...],
57
+ ):
58
+ ndim = len(shape)
59
+ stack: list[tuple[Any, int, int]] = [(nested, 0, 0)]
60
+
61
+ while stack:
62
+ node, depth, base = stack.pop()
63
+
64
+ if depth == ndim:
65
+ yield base, node
66
+ continue
67
+
68
+ if not isinstance(node, list):
69
+ continue
70
+
71
+ limit = min(len(node), shape[depth])
72
+ step = strides[depth]
73
+ for index in range(limit - 1, -1, -1):
74
+ stack.append((node[index], depth + 1, base + (index * step)))
75
+
76
+
77
+ def _fill_python(
78
+ nested: Any,
79
+ flat_values: np.ndarray,
80
+ flat_flags: np.ndarray,
81
+ shape: tuple[int, ...],
82
+ strides: tuple[int, ...],
83
+ ) -> None:
84
+ for flat_index, node in _iter_leaf_nodes(nested=nested, shape=shape, strides=strides):
85
+ if node is None:
86
+ flat_flags[flat_index] = Tokens.null.value
87
+ else:
88
+ flat_values[flat_index] = node
89
+ flat_flags[flat_index] = Tokens.valued.value
90
+
91
+
92
+ @beartype
93
+ def pad(
94
+ nested: Any, shape: tuple[int, ...], dtype: type | str = object, pad_value: Any = None
95
+ ) -> tuple[np.ndarray, np.ndarray]:
96
+ resolved_dtype = np.dtype(dtype)
97
+ values = np.full(shape, pad_value, dtype=resolved_dtype)
98
+ flags = np.full(shape, Tokens.padded.value, dtype=np.int8)
99
+
100
+ ndim = len(shape)
101
+ if ndim == 0:
102
+ if nested is None:
103
+ flags[...] = Tokens.null.value
104
+ else:
105
+ values[...] = nested
106
+ flags[...] = Tokens.valued.value
107
+ return values, flags
108
+
109
+ strides = [1] * ndim
110
+ for depth in range(ndim - 2, -1, -1):
111
+ strides[depth] = strides[depth + 1] * shape[depth + 1]
112
+ stride_tuple = tuple(strides)
113
+
114
+ flat_values = values.reshape(-1)
115
+ flat_flags = flags.reshape(-1)
116
+
117
+ _fill_python(
118
+ nested=nested,
119
+ flat_values=flat_values,
120
+ flat_flags=flat_flags,
121
+ shape=shape,
122
+ strides=stride_tuple,
123
+ )
124
+
125
+ return values, flags
126
+
127
+
128
+ @beartype
129
+ class Pipeline:
130
+ def __init__(self, **arguments):
131
+ self.arguments: dict[str, Any] = arguments
132
+ self.steps: list[Callable] = []
133
+
134
+ def __or__(self, function: Callable) -> "Pipeline":
135
+ required = [name for name in inspect.signature(function).parameters.keys()]
136
+
137
+ available = set(required) & set(self.arguments.keys())
138
+
139
+ self.steps.append(partial(function, **{arg: self.arguments[arg] for arg in available}))
140
+
141
+ return self
142
+
143
+ def __repr__(self):
144
+ return f"Pipeline({repr(self.source)}, {repr(self.arguments)})"
145
+
146
+ def __iter__(self):
147
+ stream = self.steps[0]()
148
+
149
+ for step in self.steps[1:]:
150
+ stream = step(stream)
151
+
152
+ return iter(stream)
@@ -0,0 +1,3 @@
1
+ from json2vec.entrypoints.pipeline import build, execute, fit, predict, test, validate
2
+
3
+ __all__ = ["build", "fit", "validate", "test", "predict", "execute"]
@@ -0,0 +1,174 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import torch
8
+ from lightning.pytorch.callbacks import Callback, EarlyStopping, ModelCheckpoint
9
+ from lightning.pytorch.trainer.trainer import Trainer
10
+ from loguru import logger
11
+
12
+ from json2vec.architecture.root import JSON2Vec
13
+ from json2vec.inference.callback import Writer
14
+ from json2vec.logging.epoch import EpochLifecycleLogger
15
+ from json2vec.logging.throughput import ThroughputLogger
16
+ from json2vec.logging.tracking import LoggerFactory
17
+ from json2vec.structs.enums import Metric, Stage, Strata
18
+ from json2vec.structs.experiment import Experiment, PatchOp, Session
19
+
20
+
21
+ def build(model: JSON2Vec, callbacks: list[Callback], names: list[str] | None = None) -> Trainer:
22
+ active_callbacks: list[Callback] = list(callbacks)
23
+ if not any(isinstance(callback, EpochLifecycleLogger) for callback in active_callbacks):
24
+ active_callbacks.append(EpochLifecycleLogger())
25
+
26
+ logger.bind(
27
+ component="trainer",
28
+ session=model.session.name,
29
+ stage=model.session.task,
30
+ callbacks=[type(callback).__name__ for callback in active_callbacks],
31
+ ).info("building lightning trainer")
32
+
33
+ return Trainer(
34
+ accelerator="auto" if torch.cuda.is_available() else "cpu",
35
+ precision="bf16-mixed" if torch.cuda.is_available() else None,
36
+ logger=LoggerFactory.create(*names) if names is not None else False,
37
+ enable_model_summary=False,
38
+ enable_progress_bar=False,
39
+ callbacks=active_callbacks,
40
+ **model.session.trainer,
41
+ )
42
+
43
+
44
+ def fit(
45
+ names: list[str],
46
+ session: Session | None = None,
47
+ checkpoint: str | os.PathLike[str] | None = None,
48
+ patches: list[PatchOp] | None = None,
49
+ ) -> Path:
50
+ logger.bind(component="task", task="fit", session=session.name if session else None).info("starting fit task")
51
+
52
+ checkpoint_path = str(checkpoint) if checkpoint is not None else None
53
+ model: JSON2Vec = JSON2Vec.get_or_create(session=session, checkpoint=checkpoint_path)
54
+ model.session = model.session.patch(patches)
55
+
56
+ monitor = f"{Metric.loss}/{Strata.validate}"
57
+ filename: str = f"{model.session.structure.name}-{model.session.name}-" + "{epoch}-{step}-{val_loss:.2f}"
58
+
59
+ checkpoint_dir = Path("models")
60
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
61
+ checkpointer: ModelCheckpoint = ModelCheckpoint(dirpath=checkpoint_dir, filename=filename, monitor=monitor)
62
+ callbacks: list[Callback] = [ThroughputLogger(), checkpointer]
63
+
64
+ if (patience := model.session.patience) is not None:
65
+ callbacks.append(EarlyStopping(patience=patience, monitor=monitor))
66
+
67
+ trainer: Trainer = build(model=model, callbacks=callbacks, names=names)
68
+ trainer.fit(model=model)
69
+
70
+ best_path = Path(str(checkpointer.best_model_path))
71
+ logger.bind(component="task", task="fit", session=model.session.name, checkpoint=str(best_path)).info(
72
+ "finished fit task"
73
+ )
74
+ return best_path
75
+
76
+
77
+ def validate(
78
+ names: list[str],
79
+ checkpoint: str | os.PathLike[str],
80
+ session: Session | None = None,
81
+ patches: list[PatchOp] | None = None,
82
+ ) -> None:
83
+ logger.bind(component="task", task="validate", session=session.name if session else None).info(
84
+ "starting validate task"
85
+ )
86
+
87
+ model: JSON2Vec = JSON2Vec.get_or_create(session=session, checkpoint=str(checkpoint))
88
+ model.session = model.session.patch(patches)
89
+
90
+ callbacks: list[Callback] = [ThroughputLogger()]
91
+ trainer: Trainer = build(model=model, callbacks=callbacks, names=names)
92
+ trainer.validate(model=model)
93
+ logger.bind(component="task", task="validate", session=model.session.name).info("finished validate task")
94
+
95
+
96
+ def test(
97
+ names: list[str],
98
+ checkpoint: str | os.PathLike[str],
99
+ session: Session | None = None,
100
+ patches: list[PatchOp] | None = None,
101
+ ) -> None:
102
+ logger.bind(component="task", task="test", session=session.name if session else None).info("starting test task")
103
+
104
+ model: JSON2Vec = JSON2Vec.get_or_create(session=session, checkpoint=str(checkpoint))
105
+ model.session = model.session.patch(patches)
106
+
107
+ callbacks: list[Callback] = [ThroughputLogger()]
108
+ trainer: Trainer = build(model=model, callbacks=callbacks, names=names)
109
+ trainer.test(model=model)
110
+ logger.bind(component="task", task="test", session=model.session.name).info("finished test task")
111
+
112
+
113
+ def predict(
114
+ session: Session | None,
115
+ names: list[str] | None,
116
+ checkpoint: str | os.PathLike[str],
117
+ patches: list[PatchOp] | None = None,
118
+ ) -> Path:
119
+ logger.bind(component="task", task="predict", session=session.name if session else None).info("starting predict task")
120
+
121
+ model: JSON2Vec = JSON2Vec.get_or_create(session=session, checkpoint=str(checkpoint))
122
+ model.session = model.session.patch(patches)
123
+
124
+ os.makedirs(name=(outpath := "tmp/predictions"), exist_ok=True)
125
+ callbacks: list[Callback] = [Writer(outpath)]
126
+ trainer: Trainer = build(model=model, callbacks=callbacks, names=names)
127
+ trainer.predict(model=model, return_predictions=False)
128
+
129
+ output_path = Path(outpath)
130
+ logger.bind(component="task", task="predict", session=model.session.name, output=str(output_path)).info(
131
+ "finished predict task"
132
+ )
133
+ return output_path
134
+
135
+
136
+ def execute(experiment: Experiment) -> dict[str, Any]:
137
+ logger.bind(
138
+ component="pipeline",
139
+ project=experiment.project,
140
+ run=experiment.name,
141
+ sessions=len(experiment.sessions),
142
+ ).info("starting experiment execution")
143
+
144
+ checkpoint: str | os.PathLike[str] | None = experiment.checkpoint
145
+ names: list[str] = [experiment.project, experiment.name, experiment.notes]
146
+
147
+ tasks: dict[Stage, Any] = {
148
+ Stage.fit: fit,
149
+ Stage.validate: validate,
150
+ Stage.test: test,
151
+ Stage.predict: predict,
152
+ }
153
+
154
+ outputs: dict[str, Any] = {}
155
+
156
+ for session in experiment.sessions:
157
+ logger.bind(component="pipeline", session=session.name, stage=session.task).info("dispatching session")
158
+ task = tasks[session.task]
159
+
160
+ output = task(
161
+ session=session,
162
+ checkpoint=checkpoint,
163
+ names=names,
164
+ )
165
+ outputs[session.name] = output
166
+
167
+ if isinstance(output, (str, os.PathLike)) and session.task == Stage.fit:
168
+ checkpoint = output
169
+
170
+ logger.bind(component="pipeline", project=experiment.project, run=experiment.name).info(
171
+ "finished experiment execution"
172
+ )
173
+
174
+ return outputs
File without changes
@@ -0,0 +1,98 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import TYPE_CHECKING, Any
5
+
6
+ import lightning.pytorch as lit
7
+ import polars as pl
8
+ import pyarrow as pa
9
+ import pyarrow.parquet as pq
10
+ from lightning.pytorch import callbacks
11
+ from tensordict import TensorDict
12
+
13
+ from json2vec.structs.packages import Prediction
14
+ from json2vec.structs.tree import Address
15
+ from json2vec.tensorfields.base import TensorFieldBase
16
+
17
+ if TYPE_CHECKING:
18
+ from json2vec.architecture.root import JSON2Vec
19
+
20
+
21
+ class Writer(callbacks.BasePredictionWriter):
22
+
23
+ def __init__(self, path: os.PathLike | str, flush_every_n_batches: int | None = None):
24
+
25
+ super().__init__(write_interval="batch")
26
+
27
+ self.path: os.PathLike = path
28
+ self.flush_every_n_batches: int | None = flush_every_n_batches
29
+ self.schema: pa.schema | None = None
30
+ self.writer: pq.ParquetWriter | None = None
31
+
32
+ @staticmethod
33
+ def _as_struct_frame(
34
+ values_by_address: dict[Address, dict[str, Any]], alias: str, num_rows: int
35
+ ) -> pl.DataFrame:
36
+ if len(values_by_address) == 0:
37
+ return pl.DataFrame({alias: [None] * num_rows})
38
+
39
+ columns: list[pl.DataFrame] = []
40
+ for address, values in values_by_address.items():
41
+ field_frame = pl.DataFrame(data=values)
42
+ columns.append(field_frame.select(pl.struct(pl.all()).alias(name=address)))
43
+
44
+ nested: pl.DataFrame = pl.concat(items=columns, how="horizontal")
45
+ return nested.select(pl.struct(pl.all()).alias(name=alias))
46
+
47
+ def write_on_batch_end(
48
+ self,
49
+ trainer: lit.Trainer,
50
+ pl_module: JSON2Vec,
51
+ output: dict[str, list[Prediction]],
52
+ batch_indices: list[int]|None,
53
+ batch: TensorDict[Address, TensorFieldBase],
54
+ batch_idx: int,
55
+ dataloader_idx: int,
56
+ ) -> None:
57
+ num_rows = len(batch["metadata"])
58
+
59
+ supervised: dict[Address, dict[TensorKey, Any]]
60
+ embeddings: dict[Address, dict[TensorKey, Any]]
61
+
62
+ supervised, embeddings = pl_module.write(predictions=output["predictions"])
63
+
64
+ items = [
65
+ pl.from_records(data=batch["metadata"], schema=["inputs"], orient="row"),
66
+ self._as_struct_frame(values_by_address=supervised, alias="predictions", num_rows=num_rows),
67
+ ]
68
+
69
+ if len(embeddings) > 0:
70
+ items.append(self._as_struct_frame(values_by_address=embeddings, alias="embeddings", num_rows=num_rows))
71
+
72
+ table: pa.Table = pl.concat(
73
+ items=items,
74
+ how="horizontal"
75
+ ).to_arrow()
76
+
77
+ if self.writer is None:
78
+
79
+ os.makedirs(self.path, exist_ok=True)
80
+ self.schema: pa.schema = table.schema
81
+
82
+ self.writer: pq.ParquetWriter = pq.ParquetWriter(
83
+ where=os.path.join(self.path, f"rank-{trainer.local_rank}.parquet"),
84
+ schema=self.schema
85
+ )
86
+
87
+ if table.schema != self.schema:
88
+ table = table.cast(self.schema)
89
+
90
+ self.writer.write_table(table)
91
+
92
+ if self.flush_every_n_batches and (batch_idx + 1) % self.flush_every_n_batches == 0 and hasattr(self.writer, "flush"):
93
+ self.writer.flush()
94
+
95
+ def on_predict_end(self, trainer: lit.Trainer, pl_module: lit.LightningModule) -> None:
96
+ if self.writer:
97
+ self.writer.close()
98
+ self.writer: None = None
@@ -0,0 +1,175 @@
1
+ from dataclasses import dataclass
2
+ from typing import Any, Type, TypeAlias
3
+
4
+ import litserve as ls
5
+ import pydantic
6
+ import torch
7
+ from beartype import beartype
8
+ from tensordict import TensorDict
9
+
10
+ from json2vec.architecture.root import JSON2Vec
11
+ from json2vec.data.datasets import encode, process
12
+ from json2vec.structs.enums import Strata
13
+ from json2vec.structs.environment import DeploymentEnvironment
14
+ from json2vec.structs.packages import Prediction
15
+ from json2vec.structs.tree import Address
16
+ from json2vec.tensorfields.base import TensorFieldBase
17
+
18
+ Input: TypeAlias = TensorDict[Address, TensorFieldBase]
19
+
20
+
21
+ @dataclass
22
+ class ErrorItem:
23
+ status_code: int
24
+ message: str
25
+
26
+
27
+ @dataclass
28
+ class BatchItem:
29
+ data: Input | None
30
+ valid_indices: list[int]
31
+ items: list[Input | ErrorItem]
32
+
33
+
34
+ class Deployment(ls.LitAPI):
35
+ def __init__(self, checkpoint: str, *args, **kwargs) -> None:
36
+ super().__init__(*args, **kwargs)
37
+ self.checkpoint = checkpoint
38
+
39
+ def setup(self, device: str) -> None:
40
+ self.model: JSON2Vec = JSON2Vec.get_or_create(checkpoint=self.checkpoint).to(device)
41
+ self.model.eval()
42
+ self.state = self.model.state
43
+
44
+ @beartype
45
+ def decode_request(self, request: dict[str, Any] | pydantic.BaseModel) -> Input | ErrorItem:
46
+
47
+ if isinstance(request, pydantic.BaseModel):
48
+ request = request.model_dump()
49
+
50
+ try:
51
+ observations: list[Any] = list(
52
+ process(
53
+ pipe=[request],
54
+ session=self.model.session,
55
+ strata=Strata.predict,
56
+ state=self.state,
57
+ )
58
+ )
59
+
60
+ except Exception as exception:
61
+ return ErrorItem(status_code=422, message=str(exception))
62
+
63
+
64
+ if len(observations) == 0 or any(x is None for x in observations):
65
+ return ErrorItem(status_code=422, message="processor returned no observations for request")
66
+
67
+
68
+ encoded = encode(
69
+ batch=observations,
70
+ session=self.model.session,
71
+ strata=Strata.predict,
72
+ state=self.state,
73
+ )
74
+
75
+ if encoded is None:
76
+ return ErrorItem(status_code=422, message="processor eliminated observation (filter)")
77
+
78
+ return encoded
79
+
80
+ @beartype
81
+ def batch(self, inputs: list[Input | ErrorItem]) -> BatchItem:
82
+ valid_indices: list[int] = []
83
+ valid_inputs: list[Input] = []
84
+
85
+ for index, item in enumerate(inputs):
86
+ if isinstance(item, ErrorItem):
87
+ continue
88
+
89
+ valid_indices.append(index)
90
+ valid_inputs.append(item)
91
+
92
+ data = torch.stack(valid_inputs, dim=0) if len(valid_inputs) > 0 else None
93
+ return BatchItem(data=data, valid_indices=valid_indices, items=inputs)
94
+
95
+ @beartype
96
+ def unbatch(self, outputs: list[Any]) -> list[Any]:
97
+ return list(outputs)
98
+
99
+ @beartype
100
+ def predict(self, data: BatchItem | Input | ErrorItem) -> list[list[Prediction] | ErrorItem] | list[Prediction] | ErrorItem:
101
+ if isinstance(data, ErrorItem):
102
+ return data
103
+
104
+ if isinstance(data, TensorDict):
105
+ with torch.inference_mode():
106
+ return self.model(data.to(self.device))
107
+
108
+ outputs: list[Any] = list(data.items)
109
+
110
+ if data.data is None:
111
+ return outputs
112
+
113
+ with torch.inference_mode():
114
+ predictions = self.model(data.data.to(self.device))
115
+
116
+ unbatched = Prediction.unbatch(predictions=predictions)
117
+
118
+ for index, item_predictions in zip(data.valid_indices, unbatched):
119
+ outputs[index] = item_predictions
120
+
121
+ return outputs
122
+
123
+ @beartype
124
+ def encode_response(self, response: list[Prediction] | ErrorItem) -> dict[str, Any] | pydantic.BaseModel:
125
+ if isinstance(response, ErrorItem):
126
+ return {
127
+ "predictions": {},
128
+ "error": {
129
+ "status_code": response.status_code,
130
+ "message": response.message,
131
+ },
132
+ }
133
+
134
+ predictions, embeddings = self.model.write(predictions=response)
135
+
136
+ payload = dict(predictions = predictions)
137
+
138
+ if len(embeddings) > 0:
139
+ payload["embeddings"] = embeddings
140
+
141
+ return Prediction.denest(payload)
142
+
143
+ @classmethod
144
+ @beartype
145
+ def forge(
146
+ cls,
147
+ request: Type[pydantic.BaseModel]|None=None,
148
+ response: Type[pydantic.BaseModel]|None=None,
149
+ ) -> Type["Deployment"]:
150
+
151
+ if request is not None:
152
+ cls.decode_request.__annotations__["request"] = request
153
+
154
+ if response is not None:
155
+ cls.encode_response.__annotations__["return"] = response
156
+
157
+ return cls
158
+
159
+ @classmethod
160
+ def serve(cls):
161
+
162
+ environment = DeploymentEnvironment()
163
+
164
+ server: ls.LitServer = ls.LitServer(
165
+ lit_api=Deployment(
166
+ checkpoint=environment.checkpoint,
167
+ max_batch_size=environment.max_batch_size,
168
+ batch_timeout=environment.batch_timeout,
169
+ ),
170
+ accelerator=environment.accelerator,
171
+ track_requests=environment.track_requests,
172
+ workers_per_device=environment.workers_per_device,
173
+ )
174
+
175
+ server.run(generate_client_file=False)
File without changes
@@ -0,0 +1,27 @@
1
+ import json
2
+ import os
3
+ import sys
4
+
5
+ from loguru import logger
6
+ from rich.console import Console
7
+ from rich.json import JSON
8
+
9
+ console = Console(file=sys.stdout)
10
+ LOG_LEVEL: str = os.getenv("JSON2VEC_LOG_LEVEL", "DEBUG").upper()
11
+
12
+
13
+ def sink(message):
14
+ record = message.record
15
+ extras = {k: str(v) for k, v in record["extra"].items()}
16
+ payload = {
17
+ "timestamp": record["time"].strftime("%Y-%m-%d %H:%M:%S"),
18
+ "level": record["level"].name,
19
+ **extras,
20
+ "message": record["message"],
21
+ }
22
+ console.print(JSON(json.dumps(payload), indent=None))
23
+
24
+
25
+ logger.remove()
26
+ logger.add(sink=sink, level=LOG_LEVEL, enqueue=True, backtrace=True, diagnose=False)
27
+ logger.bind(component="logging", level=LOG_LEVEL).info("configured loguru sink")
@@ -0,0 +1,42 @@
1
+ from __future__ import annotations
2
+
3
+ from functools import partialmethod
4
+ from typing import TYPE_CHECKING, Literal
5
+
6
+ from lightning import Callback, Trainer
7
+ from loguru import logger
8
+
9
+ from json2vec.structs.enums import Strata
10
+
11
+ if TYPE_CHECKING:
12
+ from json2vec.architecture.root import JSON2Vec
13
+
14
+
15
+ class EpochLifecycleLogger(Callback):
16
+ def info(
17
+ self,
18
+ trainer: Trainer,
19
+ pl_module: JSON2Vec,
20
+ strata: Strata,
21
+ hook: Literal["start", "end"],
22
+ ):
23
+ logger.bind(
24
+ source="lightning",
25
+ rank=pl_module.global_rank,
26
+ epoch=pl_module.current_epoch,
27
+ step=pl_module.global_step,
28
+ hook=hook,
29
+ strata=str(strata),
30
+ ).info(f"{hook}ing {strata} epoch {pl_module.current_epoch}")
31
+
32
+ on_train_epoch_start = partialmethod(info, strata=Strata.train, hook="start")
33
+ on_train_epoch_end = partialmethod(info, strata=Strata.train, hook="end")
34
+
35
+ on_validation_epoch_start = partialmethod(info, strata=Strata.validate, hook="start")
36
+ on_validation_epoch_end = partialmethod(info, strata=Strata.validate, hook="end")
37
+
38
+ on_test_epoch_start = partialmethod(info, strata=Strata.test, hook="start")
39
+ on_test_epoch_end = partialmethod(info, strata=Strata.test, hook="end")
40
+
41
+ on_predict_epoch_start = partialmethod(info, strata=Strata.predict, hook="start")
42
+ on_predict_epoch_end = partialmethod(info, strata=Strata.predict, hook="end")