json2vec 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- json2vec/__init__.py +0 -0
- json2vec/__main__.py +32 -0
- json2vec/architecture/__init__.py +0 -0
- json2vec/architecture/attention.py +64 -0
- json2vec/architecture/counter.py +37 -0
- json2vec/architecture/encoder.py +88 -0
- json2vec/architecture/node.py +34 -0
- json2vec/architecture/pool.py +61 -0
- json2vec/architecture/root.py +338 -0
- json2vec/architecture/rotary.py +39 -0
- json2vec/data/__init__.py +0 -0
- json2vec/data/datasets.py +539 -0
- json2vec/data/processing.py +152 -0
- json2vec/entrypoints/__init__.py +3 -0
- json2vec/entrypoints/pipeline.py +174 -0
- json2vec/inference/__init__.py +0 -0
- json2vec/inference/callback.py +98 -0
- json2vec/inference/deployment.py +175 -0
- json2vec/logging/__init__.py +0 -0
- json2vec/logging/config.py +27 -0
- json2vec/logging/epoch.py +42 -0
- json2vec/logging/throughput.py +39 -0
- json2vec/logging/tracking.py +152 -0
- json2vec/processors/__init__.py +8 -0
- json2vec/processors/base.py +102 -0
- json2vec/processors/extensions/__init__.py +0 -0
- json2vec/processors/extensions/example.py +6 -0
- json2vec/processors/spec.py +8 -0
- json2vec/structs/__init__.py +0 -0
- json2vec/structs/enums.py +84 -0
- json2vec/structs/environment.py +138 -0
- json2vec/structs/experiment.py +330 -0
- json2vec/structs/packages.py +117 -0
- json2vec/structs/structure.py +70 -0
- json2vec/structs/tree.py +92 -0
- json2vec/tensorfields/__init__.py +8 -0
- json2vec/tensorfields/base.py +210 -0
- json2vec/tensorfields/extensions/__init__.py +0 -0
- json2vec/tensorfields/extensions/category.py +484 -0
- json2vec/tensorfields/extensions/dateparts.py +410 -0
- json2vec/tensorfields/extensions/entity.py +336 -0
- json2vec/tensorfields/extensions/number.py +400 -0
- json2vec/tensorfields/extensions/vector.py +279 -0
- json2vec/tensorfields/spec.py +8 -0
- json2vec-0.1.0.dist-info/METADATA +227 -0
- json2vec-0.1.0.dist-info/RECORD +51 -0
- json2vec-0.1.0.dist-info/WHEEL +5 -0
- json2vec-0.1.0.dist-info/entry_points.txt +2 -0
- json2vec-0.1.0.dist-info/licenses/LICENSE +178 -0
- json2vec-0.1.0.dist-info/licenses/NOTICE +8 -0
- json2vec-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
from collections.abc import Callable
|
|
3
|
+
from functools import partial
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from beartype import beartype
|
|
8
|
+
|
|
9
|
+
from json2vec.structs.enums import Tokens
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def apply(
|
|
13
|
+
values: Any,
|
|
14
|
+
function: Callable[..., Any],
|
|
15
|
+
/,
|
|
16
|
+
*args: Any,
|
|
17
|
+
leaf_depth: int | None = None,
|
|
18
|
+
**kwargs: Any,
|
|
19
|
+
) -> Any:
|
|
20
|
+
"""Apply a function recursively to nested list leaves.
|
|
21
|
+
|
|
22
|
+
When ``leaf_depth`` is set, the function is applied exactly at that depth;
|
|
23
|
+
higher-level non-list values are preserved so downstream padding can mark
|
|
24
|
+
them as incomplete.
|
|
25
|
+
"""
|
|
26
|
+
if leaf_depth is not None and leaf_depth < 0:
|
|
27
|
+
raise ValueError("leaf_depth must be >= 0")
|
|
28
|
+
|
|
29
|
+
def walk(node: Any, depth: int) -> Any:
|
|
30
|
+
if leaf_depth is None:
|
|
31
|
+
if isinstance(node, list):
|
|
32
|
+
return [walk(item, depth + 1) for item in node]
|
|
33
|
+
|
|
34
|
+
if node is None:
|
|
35
|
+
return None
|
|
36
|
+
|
|
37
|
+
return function(node, *args, **kwargs)
|
|
38
|
+
|
|
39
|
+
if depth == leaf_depth:
|
|
40
|
+
if node is None:
|
|
41
|
+
return None
|
|
42
|
+
|
|
43
|
+
return function(node, *args, **kwargs)
|
|
44
|
+
|
|
45
|
+
if isinstance(node, list):
|
|
46
|
+
return [walk(item, depth + 1) for item in node]
|
|
47
|
+
|
|
48
|
+
return node
|
|
49
|
+
|
|
50
|
+
return walk(values, depth=0)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _iter_leaf_nodes(
|
|
54
|
+
nested: Any,
|
|
55
|
+
shape: tuple[int, ...],
|
|
56
|
+
strides: tuple[int, ...],
|
|
57
|
+
):
|
|
58
|
+
ndim = len(shape)
|
|
59
|
+
stack: list[tuple[Any, int, int]] = [(nested, 0, 0)]
|
|
60
|
+
|
|
61
|
+
while stack:
|
|
62
|
+
node, depth, base = stack.pop()
|
|
63
|
+
|
|
64
|
+
if depth == ndim:
|
|
65
|
+
yield base, node
|
|
66
|
+
continue
|
|
67
|
+
|
|
68
|
+
if not isinstance(node, list):
|
|
69
|
+
continue
|
|
70
|
+
|
|
71
|
+
limit = min(len(node), shape[depth])
|
|
72
|
+
step = strides[depth]
|
|
73
|
+
for index in range(limit - 1, -1, -1):
|
|
74
|
+
stack.append((node[index], depth + 1, base + (index * step)))
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _fill_python(
|
|
78
|
+
nested: Any,
|
|
79
|
+
flat_values: np.ndarray,
|
|
80
|
+
flat_flags: np.ndarray,
|
|
81
|
+
shape: tuple[int, ...],
|
|
82
|
+
strides: tuple[int, ...],
|
|
83
|
+
) -> None:
|
|
84
|
+
for flat_index, node in _iter_leaf_nodes(nested=nested, shape=shape, strides=strides):
|
|
85
|
+
if node is None:
|
|
86
|
+
flat_flags[flat_index] = Tokens.null.value
|
|
87
|
+
else:
|
|
88
|
+
flat_values[flat_index] = node
|
|
89
|
+
flat_flags[flat_index] = Tokens.valued.value
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@beartype
|
|
93
|
+
def pad(
|
|
94
|
+
nested: Any, shape: tuple[int, ...], dtype: type | str = object, pad_value: Any = None
|
|
95
|
+
) -> tuple[np.ndarray, np.ndarray]:
|
|
96
|
+
resolved_dtype = np.dtype(dtype)
|
|
97
|
+
values = np.full(shape, pad_value, dtype=resolved_dtype)
|
|
98
|
+
flags = np.full(shape, Tokens.padded.value, dtype=np.int8)
|
|
99
|
+
|
|
100
|
+
ndim = len(shape)
|
|
101
|
+
if ndim == 0:
|
|
102
|
+
if nested is None:
|
|
103
|
+
flags[...] = Tokens.null.value
|
|
104
|
+
else:
|
|
105
|
+
values[...] = nested
|
|
106
|
+
flags[...] = Tokens.valued.value
|
|
107
|
+
return values, flags
|
|
108
|
+
|
|
109
|
+
strides = [1] * ndim
|
|
110
|
+
for depth in range(ndim - 2, -1, -1):
|
|
111
|
+
strides[depth] = strides[depth + 1] * shape[depth + 1]
|
|
112
|
+
stride_tuple = tuple(strides)
|
|
113
|
+
|
|
114
|
+
flat_values = values.reshape(-1)
|
|
115
|
+
flat_flags = flags.reshape(-1)
|
|
116
|
+
|
|
117
|
+
_fill_python(
|
|
118
|
+
nested=nested,
|
|
119
|
+
flat_values=flat_values,
|
|
120
|
+
flat_flags=flat_flags,
|
|
121
|
+
shape=shape,
|
|
122
|
+
strides=stride_tuple,
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
return values, flags
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
@beartype
|
|
129
|
+
class Pipeline:
|
|
130
|
+
def __init__(self, **arguments):
|
|
131
|
+
self.arguments: dict[str, Any] = arguments
|
|
132
|
+
self.steps: list[Callable] = []
|
|
133
|
+
|
|
134
|
+
def __or__(self, function: Callable) -> "Pipeline":
|
|
135
|
+
required = [name for name in inspect.signature(function).parameters.keys()]
|
|
136
|
+
|
|
137
|
+
available = set(required) & set(self.arguments.keys())
|
|
138
|
+
|
|
139
|
+
self.steps.append(partial(function, **{arg: self.arguments[arg] for arg in available}))
|
|
140
|
+
|
|
141
|
+
return self
|
|
142
|
+
|
|
143
|
+
def __repr__(self):
|
|
144
|
+
return f"Pipeline({repr(self.source)}, {repr(self.arguments)})"
|
|
145
|
+
|
|
146
|
+
def __iter__(self):
|
|
147
|
+
stream = self.steps[0]()
|
|
148
|
+
|
|
149
|
+
for step in self.steps[1:]:
|
|
150
|
+
stream = step(stream)
|
|
151
|
+
|
|
152
|
+
return iter(stream)
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from lightning.pytorch.callbacks import Callback, EarlyStopping, ModelCheckpoint
|
|
9
|
+
from lightning.pytorch.trainer.trainer import Trainer
|
|
10
|
+
from loguru import logger
|
|
11
|
+
|
|
12
|
+
from json2vec.architecture.root import JSON2Vec
|
|
13
|
+
from json2vec.inference.callback import Writer
|
|
14
|
+
from json2vec.logging.epoch import EpochLifecycleLogger
|
|
15
|
+
from json2vec.logging.throughput import ThroughputLogger
|
|
16
|
+
from json2vec.logging.tracking import LoggerFactory
|
|
17
|
+
from json2vec.structs.enums import Metric, Stage, Strata
|
|
18
|
+
from json2vec.structs.experiment import Experiment, PatchOp, Session
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def build(model: JSON2Vec, callbacks: list[Callback], names: list[str] | None = None) -> Trainer:
|
|
22
|
+
active_callbacks: list[Callback] = list(callbacks)
|
|
23
|
+
if not any(isinstance(callback, EpochLifecycleLogger) for callback in active_callbacks):
|
|
24
|
+
active_callbacks.append(EpochLifecycleLogger())
|
|
25
|
+
|
|
26
|
+
logger.bind(
|
|
27
|
+
component="trainer",
|
|
28
|
+
session=model.session.name,
|
|
29
|
+
stage=model.session.task,
|
|
30
|
+
callbacks=[type(callback).__name__ for callback in active_callbacks],
|
|
31
|
+
).info("building lightning trainer")
|
|
32
|
+
|
|
33
|
+
return Trainer(
|
|
34
|
+
accelerator="auto" if torch.cuda.is_available() else "cpu",
|
|
35
|
+
precision="bf16-mixed" if torch.cuda.is_available() else None,
|
|
36
|
+
logger=LoggerFactory.create(*names) if names is not None else False,
|
|
37
|
+
enable_model_summary=False,
|
|
38
|
+
enable_progress_bar=False,
|
|
39
|
+
callbacks=active_callbacks,
|
|
40
|
+
**model.session.trainer,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def fit(
|
|
45
|
+
names: list[str],
|
|
46
|
+
session: Session | None = None,
|
|
47
|
+
checkpoint: str | os.PathLike[str] | None = None,
|
|
48
|
+
patches: list[PatchOp] | None = None,
|
|
49
|
+
) -> Path:
|
|
50
|
+
logger.bind(component="task", task="fit", session=session.name if session else None).info("starting fit task")
|
|
51
|
+
|
|
52
|
+
checkpoint_path = str(checkpoint) if checkpoint is not None else None
|
|
53
|
+
model: JSON2Vec = JSON2Vec.get_or_create(session=session, checkpoint=checkpoint_path)
|
|
54
|
+
model.session = model.session.patch(patches)
|
|
55
|
+
|
|
56
|
+
monitor = f"{Metric.loss}/{Strata.validate}"
|
|
57
|
+
filename: str = f"{model.session.structure.name}-{model.session.name}-" + "{epoch}-{step}-{val_loss:.2f}"
|
|
58
|
+
|
|
59
|
+
checkpoint_dir = Path("models")
|
|
60
|
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
61
|
+
checkpointer: ModelCheckpoint = ModelCheckpoint(dirpath=checkpoint_dir, filename=filename, monitor=monitor)
|
|
62
|
+
callbacks: list[Callback] = [ThroughputLogger(), checkpointer]
|
|
63
|
+
|
|
64
|
+
if (patience := model.session.patience) is not None:
|
|
65
|
+
callbacks.append(EarlyStopping(patience=patience, monitor=monitor))
|
|
66
|
+
|
|
67
|
+
trainer: Trainer = build(model=model, callbacks=callbacks, names=names)
|
|
68
|
+
trainer.fit(model=model)
|
|
69
|
+
|
|
70
|
+
best_path = Path(str(checkpointer.best_model_path))
|
|
71
|
+
logger.bind(component="task", task="fit", session=model.session.name, checkpoint=str(best_path)).info(
|
|
72
|
+
"finished fit task"
|
|
73
|
+
)
|
|
74
|
+
return best_path
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def validate(
|
|
78
|
+
names: list[str],
|
|
79
|
+
checkpoint: str | os.PathLike[str],
|
|
80
|
+
session: Session | None = None,
|
|
81
|
+
patches: list[PatchOp] | None = None,
|
|
82
|
+
) -> None:
|
|
83
|
+
logger.bind(component="task", task="validate", session=session.name if session else None).info(
|
|
84
|
+
"starting validate task"
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
model: JSON2Vec = JSON2Vec.get_or_create(session=session, checkpoint=str(checkpoint))
|
|
88
|
+
model.session = model.session.patch(patches)
|
|
89
|
+
|
|
90
|
+
callbacks: list[Callback] = [ThroughputLogger()]
|
|
91
|
+
trainer: Trainer = build(model=model, callbacks=callbacks, names=names)
|
|
92
|
+
trainer.validate(model=model)
|
|
93
|
+
logger.bind(component="task", task="validate", session=model.session.name).info("finished validate task")
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def test(
|
|
97
|
+
names: list[str],
|
|
98
|
+
checkpoint: str | os.PathLike[str],
|
|
99
|
+
session: Session | None = None,
|
|
100
|
+
patches: list[PatchOp] | None = None,
|
|
101
|
+
) -> None:
|
|
102
|
+
logger.bind(component="task", task="test", session=session.name if session else None).info("starting test task")
|
|
103
|
+
|
|
104
|
+
model: JSON2Vec = JSON2Vec.get_or_create(session=session, checkpoint=str(checkpoint))
|
|
105
|
+
model.session = model.session.patch(patches)
|
|
106
|
+
|
|
107
|
+
callbacks: list[Callback] = [ThroughputLogger()]
|
|
108
|
+
trainer: Trainer = build(model=model, callbacks=callbacks, names=names)
|
|
109
|
+
trainer.test(model=model)
|
|
110
|
+
logger.bind(component="task", task="test", session=model.session.name).info("finished test task")
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def predict(
|
|
114
|
+
session: Session | None,
|
|
115
|
+
names: list[str] | None,
|
|
116
|
+
checkpoint: str | os.PathLike[str],
|
|
117
|
+
patches: list[PatchOp] | None = None,
|
|
118
|
+
) -> Path:
|
|
119
|
+
logger.bind(component="task", task="predict", session=session.name if session else None).info("starting predict task")
|
|
120
|
+
|
|
121
|
+
model: JSON2Vec = JSON2Vec.get_or_create(session=session, checkpoint=str(checkpoint))
|
|
122
|
+
model.session = model.session.patch(patches)
|
|
123
|
+
|
|
124
|
+
os.makedirs(name=(outpath := "tmp/predictions"), exist_ok=True)
|
|
125
|
+
callbacks: list[Callback] = [Writer(outpath)]
|
|
126
|
+
trainer: Trainer = build(model=model, callbacks=callbacks, names=names)
|
|
127
|
+
trainer.predict(model=model, return_predictions=False)
|
|
128
|
+
|
|
129
|
+
output_path = Path(outpath)
|
|
130
|
+
logger.bind(component="task", task="predict", session=model.session.name, output=str(output_path)).info(
|
|
131
|
+
"finished predict task"
|
|
132
|
+
)
|
|
133
|
+
return output_path
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def execute(experiment: Experiment) -> dict[str, Any]:
|
|
137
|
+
logger.bind(
|
|
138
|
+
component="pipeline",
|
|
139
|
+
project=experiment.project,
|
|
140
|
+
run=experiment.name,
|
|
141
|
+
sessions=len(experiment.sessions),
|
|
142
|
+
).info("starting experiment execution")
|
|
143
|
+
|
|
144
|
+
checkpoint: str | os.PathLike[str] | None = experiment.checkpoint
|
|
145
|
+
names: list[str] = [experiment.project, experiment.name, experiment.notes]
|
|
146
|
+
|
|
147
|
+
tasks: dict[Stage, Any] = {
|
|
148
|
+
Stage.fit: fit,
|
|
149
|
+
Stage.validate: validate,
|
|
150
|
+
Stage.test: test,
|
|
151
|
+
Stage.predict: predict,
|
|
152
|
+
}
|
|
153
|
+
|
|
154
|
+
outputs: dict[str, Any] = {}
|
|
155
|
+
|
|
156
|
+
for session in experiment.sessions:
|
|
157
|
+
logger.bind(component="pipeline", session=session.name, stage=session.task).info("dispatching session")
|
|
158
|
+
task = tasks[session.task]
|
|
159
|
+
|
|
160
|
+
output = task(
|
|
161
|
+
session=session,
|
|
162
|
+
checkpoint=checkpoint,
|
|
163
|
+
names=names,
|
|
164
|
+
)
|
|
165
|
+
outputs[session.name] = output
|
|
166
|
+
|
|
167
|
+
if isinstance(output, (str, os.PathLike)) and session.task == Stage.fit:
|
|
168
|
+
checkpoint = output
|
|
169
|
+
|
|
170
|
+
logger.bind(component="pipeline", project=experiment.project, run=experiment.name).info(
|
|
171
|
+
"finished experiment execution"
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
return outputs
|
|
File without changes
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import TYPE_CHECKING, Any
|
|
5
|
+
|
|
6
|
+
import lightning.pytorch as lit
|
|
7
|
+
import polars as pl
|
|
8
|
+
import pyarrow as pa
|
|
9
|
+
import pyarrow.parquet as pq
|
|
10
|
+
from lightning.pytorch import callbacks
|
|
11
|
+
from tensordict import TensorDict
|
|
12
|
+
|
|
13
|
+
from json2vec.structs.packages import Prediction
|
|
14
|
+
from json2vec.structs.tree import Address
|
|
15
|
+
from json2vec.tensorfields.base import TensorFieldBase
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from json2vec.architecture.root import JSON2Vec
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class Writer(callbacks.BasePredictionWriter):
|
|
22
|
+
|
|
23
|
+
def __init__(self, path: os.PathLike | str, flush_every_n_batches: int | None = None):
|
|
24
|
+
|
|
25
|
+
super().__init__(write_interval="batch")
|
|
26
|
+
|
|
27
|
+
self.path: os.PathLike = path
|
|
28
|
+
self.flush_every_n_batches: int | None = flush_every_n_batches
|
|
29
|
+
self.schema: pa.schema | None = None
|
|
30
|
+
self.writer: pq.ParquetWriter | None = None
|
|
31
|
+
|
|
32
|
+
@staticmethod
|
|
33
|
+
def _as_struct_frame(
|
|
34
|
+
values_by_address: dict[Address, dict[str, Any]], alias: str, num_rows: int
|
|
35
|
+
) -> pl.DataFrame:
|
|
36
|
+
if len(values_by_address) == 0:
|
|
37
|
+
return pl.DataFrame({alias: [None] * num_rows})
|
|
38
|
+
|
|
39
|
+
columns: list[pl.DataFrame] = []
|
|
40
|
+
for address, values in values_by_address.items():
|
|
41
|
+
field_frame = pl.DataFrame(data=values)
|
|
42
|
+
columns.append(field_frame.select(pl.struct(pl.all()).alias(name=address)))
|
|
43
|
+
|
|
44
|
+
nested: pl.DataFrame = pl.concat(items=columns, how="horizontal")
|
|
45
|
+
return nested.select(pl.struct(pl.all()).alias(name=alias))
|
|
46
|
+
|
|
47
|
+
def write_on_batch_end(
|
|
48
|
+
self,
|
|
49
|
+
trainer: lit.Trainer,
|
|
50
|
+
pl_module: JSON2Vec,
|
|
51
|
+
output: dict[str, list[Prediction]],
|
|
52
|
+
batch_indices: list[int]|None,
|
|
53
|
+
batch: TensorDict[Address, TensorFieldBase],
|
|
54
|
+
batch_idx: int,
|
|
55
|
+
dataloader_idx: int,
|
|
56
|
+
) -> None:
|
|
57
|
+
num_rows = len(batch["metadata"])
|
|
58
|
+
|
|
59
|
+
supervised: dict[Address, dict[TensorKey, Any]]
|
|
60
|
+
embeddings: dict[Address, dict[TensorKey, Any]]
|
|
61
|
+
|
|
62
|
+
supervised, embeddings = pl_module.write(predictions=output["predictions"])
|
|
63
|
+
|
|
64
|
+
items = [
|
|
65
|
+
pl.from_records(data=batch["metadata"], schema=["inputs"], orient="row"),
|
|
66
|
+
self._as_struct_frame(values_by_address=supervised, alias="predictions", num_rows=num_rows),
|
|
67
|
+
]
|
|
68
|
+
|
|
69
|
+
if len(embeddings) > 0:
|
|
70
|
+
items.append(self._as_struct_frame(values_by_address=embeddings, alias="embeddings", num_rows=num_rows))
|
|
71
|
+
|
|
72
|
+
table: pa.Table = pl.concat(
|
|
73
|
+
items=items,
|
|
74
|
+
how="horizontal"
|
|
75
|
+
).to_arrow()
|
|
76
|
+
|
|
77
|
+
if self.writer is None:
|
|
78
|
+
|
|
79
|
+
os.makedirs(self.path, exist_ok=True)
|
|
80
|
+
self.schema: pa.schema = table.schema
|
|
81
|
+
|
|
82
|
+
self.writer: pq.ParquetWriter = pq.ParquetWriter(
|
|
83
|
+
where=os.path.join(self.path, f"rank-{trainer.local_rank}.parquet"),
|
|
84
|
+
schema=self.schema
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
if table.schema != self.schema:
|
|
88
|
+
table = table.cast(self.schema)
|
|
89
|
+
|
|
90
|
+
self.writer.write_table(table)
|
|
91
|
+
|
|
92
|
+
if self.flush_every_n_batches and (batch_idx + 1) % self.flush_every_n_batches == 0 and hasattr(self.writer, "flush"):
|
|
93
|
+
self.writer.flush()
|
|
94
|
+
|
|
95
|
+
def on_predict_end(self, trainer: lit.Trainer, pl_module: lit.LightningModule) -> None:
|
|
96
|
+
if self.writer:
|
|
97
|
+
self.writer.close()
|
|
98
|
+
self.writer: None = None
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import Any, Type, TypeAlias
|
|
3
|
+
|
|
4
|
+
import litserve as ls
|
|
5
|
+
import pydantic
|
|
6
|
+
import torch
|
|
7
|
+
from beartype import beartype
|
|
8
|
+
from tensordict import TensorDict
|
|
9
|
+
|
|
10
|
+
from json2vec.architecture.root import JSON2Vec
|
|
11
|
+
from json2vec.data.datasets import encode, process
|
|
12
|
+
from json2vec.structs.enums import Strata
|
|
13
|
+
from json2vec.structs.environment import DeploymentEnvironment
|
|
14
|
+
from json2vec.structs.packages import Prediction
|
|
15
|
+
from json2vec.structs.tree import Address
|
|
16
|
+
from json2vec.tensorfields.base import TensorFieldBase
|
|
17
|
+
|
|
18
|
+
Input: TypeAlias = TensorDict[Address, TensorFieldBase]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class ErrorItem:
|
|
23
|
+
status_code: int
|
|
24
|
+
message: str
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class BatchItem:
|
|
29
|
+
data: Input | None
|
|
30
|
+
valid_indices: list[int]
|
|
31
|
+
items: list[Input | ErrorItem]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class Deployment(ls.LitAPI):
|
|
35
|
+
def __init__(self, checkpoint: str, *args, **kwargs) -> None:
|
|
36
|
+
super().__init__(*args, **kwargs)
|
|
37
|
+
self.checkpoint = checkpoint
|
|
38
|
+
|
|
39
|
+
def setup(self, device: str) -> None:
|
|
40
|
+
self.model: JSON2Vec = JSON2Vec.get_or_create(checkpoint=self.checkpoint).to(device)
|
|
41
|
+
self.model.eval()
|
|
42
|
+
self.state = self.model.state
|
|
43
|
+
|
|
44
|
+
@beartype
|
|
45
|
+
def decode_request(self, request: dict[str, Any] | pydantic.BaseModel) -> Input | ErrorItem:
|
|
46
|
+
|
|
47
|
+
if isinstance(request, pydantic.BaseModel):
|
|
48
|
+
request = request.model_dump()
|
|
49
|
+
|
|
50
|
+
try:
|
|
51
|
+
observations: list[Any] = list(
|
|
52
|
+
process(
|
|
53
|
+
pipe=[request],
|
|
54
|
+
session=self.model.session,
|
|
55
|
+
strata=Strata.predict,
|
|
56
|
+
state=self.state,
|
|
57
|
+
)
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
except Exception as exception:
|
|
61
|
+
return ErrorItem(status_code=422, message=str(exception))
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
if len(observations) == 0 or any(x is None for x in observations):
|
|
65
|
+
return ErrorItem(status_code=422, message="processor returned no observations for request")
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
encoded = encode(
|
|
69
|
+
batch=observations,
|
|
70
|
+
session=self.model.session,
|
|
71
|
+
strata=Strata.predict,
|
|
72
|
+
state=self.state,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
if encoded is None:
|
|
76
|
+
return ErrorItem(status_code=422, message="processor eliminated observation (filter)")
|
|
77
|
+
|
|
78
|
+
return encoded
|
|
79
|
+
|
|
80
|
+
@beartype
|
|
81
|
+
def batch(self, inputs: list[Input | ErrorItem]) -> BatchItem:
|
|
82
|
+
valid_indices: list[int] = []
|
|
83
|
+
valid_inputs: list[Input] = []
|
|
84
|
+
|
|
85
|
+
for index, item in enumerate(inputs):
|
|
86
|
+
if isinstance(item, ErrorItem):
|
|
87
|
+
continue
|
|
88
|
+
|
|
89
|
+
valid_indices.append(index)
|
|
90
|
+
valid_inputs.append(item)
|
|
91
|
+
|
|
92
|
+
data = torch.stack(valid_inputs, dim=0) if len(valid_inputs) > 0 else None
|
|
93
|
+
return BatchItem(data=data, valid_indices=valid_indices, items=inputs)
|
|
94
|
+
|
|
95
|
+
@beartype
|
|
96
|
+
def unbatch(self, outputs: list[Any]) -> list[Any]:
|
|
97
|
+
return list(outputs)
|
|
98
|
+
|
|
99
|
+
@beartype
|
|
100
|
+
def predict(self, data: BatchItem | Input | ErrorItem) -> list[list[Prediction] | ErrorItem] | list[Prediction] | ErrorItem:
|
|
101
|
+
if isinstance(data, ErrorItem):
|
|
102
|
+
return data
|
|
103
|
+
|
|
104
|
+
if isinstance(data, TensorDict):
|
|
105
|
+
with torch.inference_mode():
|
|
106
|
+
return self.model(data.to(self.device))
|
|
107
|
+
|
|
108
|
+
outputs: list[Any] = list(data.items)
|
|
109
|
+
|
|
110
|
+
if data.data is None:
|
|
111
|
+
return outputs
|
|
112
|
+
|
|
113
|
+
with torch.inference_mode():
|
|
114
|
+
predictions = self.model(data.data.to(self.device))
|
|
115
|
+
|
|
116
|
+
unbatched = Prediction.unbatch(predictions=predictions)
|
|
117
|
+
|
|
118
|
+
for index, item_predictions in zip(data.valid_indices, unbatched):
|
|
119
|
+
outputs[index] = item_predictions
|
|
120
|
+
|
|
121
|
+
return outputs
|
|
122
|
+
|
|
123
|
+
@beartype
|
|
124
|
+
def encode_response(self, response: list[Prediction] | ErrorItem) -> dict[str, Any] | pydantic.BaseModel:
|
|
125
|
+
if isinstance(response, ErrorItem):
|
|
126
|
+
return {
|
|
127
|
+
"predictions": {},
|
|
128
|
+
"error": {
|
|
129
|
+
"status_code": response.status_code,
|
|
130
|
+
"message": response.message,
|
|
131
|
+
},
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
predictions, embeddings = self.model.write(predictions=response)
|
|
135
|
+
|
|
136
|
+
payload = dict(predictions = predictions)
|
|
137
|
+
|
|
138
|
+
if len(embeddings) > 0:
|
|
139
|
+
payload["embeddings"] = embeddings
|
|
140
|
+
|
|
141
|
+
return Prediction.denest(payload)
|
|
142
|
+
|
|
143
|
+
@classmethod
|
|
144
|
+
@beartype
|
|
145
|
+
def forge(
|
|
146
|
+
cls,
|
|
147
|
+
request: Type[pydantic.BaseModel]|None=None,
|
|
148
|
+
response: Type[pydantic.BaseModel]|None=None,
|
|
149
|
+
) -> Type["Deployment"]:
|
|
150
|
+
|
|
151
|
+
if request is not None:
|
|
152
|
+
cls.decode_request.__annotations__["request"] = request
|
|
153
|
+
|
|
154
|
+
if response is not None:
|
|
155
|
+
cls.encode_response.__annotations__["return"] = response
|
|
156
|
+
|
|
157
|
+
return cls
|
|
158
|
+
|
|
159
|
+
@classmethod
|
|
160
|
+
def serve(cls):
|
|
161
|
+
|
|
162
|
+
environment = DeploymentEnvironment()
|
|
163
|
+
|
|
164
|
+
server: ls.LitServer = ls.LitServer(
|
|
165
|
+
lit_api=Deployment(
|
|
166
|
+
checkpoint=environment.checkpoint,
|
|
167
|
+
max_batch_size=environment.max_batch_size,
|
|
168
|
+
batch_timeout=environment.batch_timeout,
|
|
169
|
+
),
|
|
170
|
+
accelerator=environment.accelerator,
|
|
171
|
+
track_requests=environment.track_requests,
|
|
172
|
+
workers_per_device=environment.workers_per_device,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
server.run(generate_client_file=False)
|
|
File without changes
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
import sys
|
|
4
|
+
|
|
5
|
+
from loguru import logger
|
|
6
|
+
from rich.console import Console
|
|
7
|
+
from rich.json import JSON
|
|
8
|
+
|
|
9
|
+
console = Console(file=sys.stdout)
|
|
10
|
+
LOG_LEVEL: str = os.getenv("JSON2VEC_LOG_LEVEL", "DEBUG").upper()
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def sink(message):
|
|
14
|
+
record = message.record
|
|
15
|
+
extras = {k: str(v) for k, v in record["extra"].items()}
|
|
16
|
+
payload = {
|
|
17
|
+
"timestamp": record["time"].strftime("%Y-%m-%d %H:%M:%S"),
|
|
18
|
+
"level": record["level"].name,
|
|
19
|
+
**extras,
|
|
20
|
+
"message": record["message"],
|
|
21
|
+
}
|
|
22
|
+
console.print(JSON(json.dumps(payload), indent=None))
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
logger.remove()
|
|
26
|
+
logger.add(sink=sink, level=LOG_LEVEL, enqueue=True, backtrace=True, diagnose=False)
|
|
27
|
+
logger.bind(component="logging", level=LOG_LEVEL).info("configured loguru sink")
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from functools import partialmethod
|
|
4
|
+
from typing import TYPE_CHECKING, Literal
|
|
5
|
+
|
|
6
|
+
from lightning import Callback, Trainer
|
|
7
|
+
from loguru import logger
|
|
8
|
+
|
|
9
|
+
from json2vec.structs.enums import Strata
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from json2vec.architecture.root import JSON2Vec
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class EpochLifecycleLogger(Callback):
|
|
16
|
+
def info(
|
|
17
|
+
self,
|
|
18
|
+
trainer: Trainer,
|
|
19
|
+
pl_module: JSON2Vec,
|
|
20
|
+
strata: Strata,
|
|
21
|
+
hook: Literal["start", "end"],
|
|
22
|
+
):
|
|
23
|
+
logger.bind(
|
|
24
|
+
source="lightning",
|
|
25
|
+
rank=pl_module.global_rank,
|
|
26
|
+
epoch=pl_module.current_epoch,
|
|
27
|
+
step=pl_module.global_step,
|
|
28
|
+
hook=hook,
|
|
29
|
+
strata=str(strata),
|
|
30
|
+
).info(f"{hook}ing {strata} epoch {pl_module.current_epoch}")
|
|
31
|
+
|
|
32
|
+
on_train_epoch_start = partialmethod(info, strata=Strata.train, hook="start")
|
|
33
|
+
on_train_epoch_end = partialmethod(info, strata=Strata.train, hook="end")
|
|
34
|
+
|
|
35
|
+
on_validation_epoch_start = partialmethod(info, strata=Strata.validate, hook="start")
|
|
36
|
+
on_validation_epoch_end = partialmethod(info, strata=Strata.validate, hook="end")
|
|
37
|
+
|
|
38
|
+
on_test_epoch_start = partialmethod(info, strata=Strata.test, hook="start")
|
|
39
|
+
on_test_epoch_end = partialmethod(info, strata=Strata.test, hook="end")
|
|
40
|
+
|
|
41
|
+
on_predict_epoch_start = partialmethod(info, strata=Strata.predict, hook="start")
|
|
42
|
+
on_predict_epoch_end = partialmethod(info, strata=Strata.predict, hook="end")
|