json2vec 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- json2vec/__init__.py +0 -0
- json2vec/__main__.py +32 -0
- json2vec/architecture/__init__.py +0 -0
- json2vec/architecture/attention.py +64 -0
- json2vec/architecture/counter.py +37 -0
- json2vec/architecture/encoder.py +88 -0
- json2vec/architecture/node.py +34 -0
- json2vec/architecture/pool.py +61 -0
- json2vec/architecture/root.py +338 -0
- json2vec/architecture/rotary.py +39 -0
- json2vec/data/__init__.py +0 -0
- json2vec/data/datasets.py +539 -0
- json2vec/data/processing.py +152 -0
- json2vec/entrypoints/__init__.py +3 -0
- json2vec/entrypoints/pipeline.py +174 -0
- json2vec/inference/__init__.py +0 -0
- json2vec/inference/callback.py +98 -0
- json2vec/inference/deployment.py +175 -0
- json2vec/logging/__init__.py +0 -0
- json2vec/logging/config.py +27 -0
- json2vec/logging/epoch.py +42 -0
- json2vec/logging/throughput.py +39 -0
- json2vec/logging/tracking.py +152 -0
- json2vec/processors/__init__.py +8 -0
- json2vec/processors/base.py +102 -0
- json2vec/processors/extensions/__init__.py +0 -0
- json2vec/processors/extensions/example.py +6 -0
- json2vec/processors/spec.py +8 -0
- json2vec/structs/__init__.py +0 -0
- json2vec/structs/enums.py +84 -0
- json2vec/structs/environment.py +138 -0
- json2vec/structs/experiment.py +330 -0
- json2vec/structs/packages.py +117 -0
- json2vec/structs/structure.py +70 -0
- json2vec/structs/tree.py +92 -0
- json2vec/tensorfields/__init__.py +8 -0
- json2vec/tensorfields/base.py +210 -0
- json2vec/tensorfields/extensions/__init__.py +0 -0
- json2vec/tensorfields/extensions/category.py +484 -0
- json2vec/tensorfields/extensions/dateparts.py +410 -0
- json2vec/tensorfields/extensions/entity.py +336 -0
- json2vec/tensorfields/extensions/number.py +400 -0
- json2vec/tensorfields/extensions/vector.py +279 -0
- json2vec/tensorfields/spec.py +8 -0
- json2vec-0.1.0.dist-info/METADATA +227 -0
- json2vec-0.1.0.dist-info/RECORD +51 -0
- json2vec-0.1.0.dist-info/WHEEL +5 -0
- json2vec-0.1.0.dist-info/entry_points.txt +2 -0
- json2vec-0.1.0.dist-info/licenses/LICENSE +178 -0
- json2vec-0.1.0.dist-info/licenses/NOTICE +8 -0
- json2vec-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import datetime
|
|
4
|
+
from collections import defaultdict
|
|
5
|
+
from functools import partialmethod
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from lightning import Callback, Trainer
|
|
10
|
+
|
|
11
|
+
from json2vec.structs.enums import Metric, Strata
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from json2vec.architecture.root import JSON2Vec
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class ThroughputLogger(Callback):
|
|
18
|
+
def __init__(self):
|
|
19
|
+
super().__init__()
|
|
20
|
+
|
|
21
|
+
self.timestamp: dict[Strata, datetime.datetime] = defaultdict(lambda: datetime.datetime.now())
|
|
22
|
+
|
|
23
|
+
def start(self, trainer: Trainer, pl_module: JSON2Vec, batch, batch_idx, strata: Strata):
|
|
24
|
+
self.timestamp[strata] = datetime.datetime.now()
|
|
25
|
+
|
|
26
|
+
def end(self, trainer: Trainer, pl_module: JSON2Vec, outputs, batch, batch_idx, strata: Strata):
|
|
27
|
+
now = datetime.datetime.now()
|
|
28
|
+
then = self.timestamp[strata]
|
|
29
|
+
throughput = pl_module.session.structure.batch_size / (now - then).total_seconds()
|
|
30
|
+
|
|
31
|
+
pl_module.track((Metric.throughput, strata), torch.tensor(throughput))
|
|
32
|
+
|
|
33
|
+
on_train_batch_start = partialmethod(start, strata=Strata.train)
|
|
34
|
+
on_validation_batch_start = partialmethod(start, strata=Strata.validate)
|
|
35
|
+
on_test_batch_start = partialmethod(start, strata=Strata.test)
|
|
36
|
+
|
|
37
|
+
on_train_batch_end = partialmethod(end, strata=Strata.train)
|
|
38
|
+
on_validation_batch_end = partialmethod(end, strata=Strata.validate)
|
|
39
|
+
on_test_batch_end = partialmethod(end, strata=Strata.test)
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import enum
|
|
4
|
+
from typing import Callable
|
|
5
|
+
|
|
6
|
+
from lightning.pytorch.loggers import Logger
|
|
7
|
+
from loguru import logger
|
|
8
|
+
|
|
9
|
+
from json2vec.structs.environment import TrackingEnvironment
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class LoggingFramework(enum.StrEnum):
|
|
13
|
+
wandb = "wandb"
|
|
14
|
+
neptune = "neptune"
|
|
15
|
+
comet = "comet"
|
|
16
|
+
mlflow = "mlflow"
|
|
17
|
+
tensorboard = "tensorboard"
|
|
18
|
+
csv = "csv"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class LoggerFactory:
|
|
22
|
+
AUTO_DETECTION_ORDER: tuple[LoggingFramework, ...] = (
|
|
23
|
+
LoggingFramework.wandb,
|
|
24
|
+
LoggingFramework.neptune,
|
|
25
|
+
LoggingFramework.comet,
|
|
26
|
+
LoggingFramework.mlflow,
|
|
27
|
+
LoggingFramework.tensorboard,
|
|
28
|
+
LoggingFramework.csv,
|
|
29
|
+
)
|
|
30
|
+
AUTO_LOGGER_FIELDS: dict[LoggingFramework, tuple[str, ...]] = {
|
|
31
|
+
LoggingFramework.wandb: ("wandb_api_key",),
|
|
32
|
+
LoggingFramework.neptune: ("neptune_api_token",),
|
|
33
|
+
LoggingFramework.comet: ("comet_api_key",),
|
|
34
|
+
LoggingFramework.mlflow: ("mlflow_tracking_uri",),
|
|
35
|
+
LoggingFramework.tensorboard: ("tensorboard_log_dir",),
|
|
36
|
+
LoggingFramework.csv: ("csv_log_dir",),
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
@staticmethod
|
|
40
|
+
def wandb(project: str, run: str, notes: str) -> Logger:
|
|
41
|
+
from lightning.pytorch.loggers import WandbLogger
|
|
42
|
+
|
|
43
|
+
tracker = WandbLogger(project=project, name=run)
|
|
44
|
+
if notes:
|
|
45
|
+
try:
|
|
46
|
+
tracker.experiment.notes = notes
|
|
47
|
+
except Exception:
|
|
48
|
+
logger.bind(component="tracking", backend=LoggingFramework.wandb.value).warning(
|
|
49
|
+
"failed to attach run notes"
|
|
50
|
+
)
|
|
51
|
+
return tracker
|
|
52
|
+
|
|
53
|
+
@staticmethod
|
|
54
|
+
def neptune(project: str, run: str, notes: str) -> Logger:
|
|
55
|
+
from lightning.pytorch.loggers import NeptuneLogger
|
|
56
|
+
|
|
57
|
+
tracker = NeptuneLogger(project=project, name=run)
|
|
58
|
+
if notes:
|
|
59
|
+
try:
|
|
60
|
+
tracker.experiment["sys/notes"] = notes
|
|
61
|
+
except Exception:
|
|
62
|
+
logger.bind(component="tracking", backend=LoggingFramework.neptune.value).warning(
|
|
63
|
+
"failed to attach run notes"
|
|
64
|
+
)
|
|
65
|
+
return tracker
|
|
66
|
+
|
|
67
|
+
@staticmethod
|
|
68
|
+
def comet(project: str, run: str, notes: str) -> Logger:
|
|
69
|
+
from lightning.pytorch.loggers import CometLogger
|
|
70
|
+
|
|
71
|
+
tracker = CometLogger(project_name=project, experiment_name=run)
|
|
72
|
+
if notes:
|
|
73
|
+
try:
|
|
74
|
+
tracker.experiment.log_other("notes", notes)
|
|
75
|
+
except Exception:
|
|
76
|
+
logger.bind(component="tracking", backend=LoggingFramework.comet.value).warning(
|
|
77
|
+
"failed to attach run notes"
|
|
78
|
+
)
|
|
79
|
+
return tracker
|
|
80
|
+
|
|
81
|
+
@staticmethod
|
|
82
|
+
def mlflow(project: str, run: str, notes: str) -> Logger:
|
|
83
|
+
from lightning.pytorch.loggers import MLFlowLogger
|
|
84
|
+
|
|
85
|
+
tags = {"notes": notes} if notes else None
|
|
86
|
+
return MLFlowLogger(experiment_name=project, run_name=run, tags=tags)
|
|
87
|
+
|
|
88
|
+
@staticmethod
|
|
89
|
+
def tensorboard(project: str, run: str, _: str) -> Logger:
|
|
90
|
+
from lightning.pytorch.loggers import TensorBoardLogger
|
|
91
|
+
|
|
92
|
+
save_dir = TrackingEnvironment.from_env().resolved_tensorboard_log_dir
|
|
93
|
+
return TensorBoardLogger(save_dir=save_dir, name=project, version=run)
|
|
94
|
+
|
|
95
|
+
@staticmethod
|
|
96
|
+
def csv(project: str, run: str, _: str) -> Logger:
|
|
97
|
+
from lightning.pytorch.loggers import CSVLogger
|
|
98
|
+
|
|
99
|
+
save_dir = TrackingEnvironment.from_env().resolved_csv_log_dir
|
|
100
|
+
return CSVLogger(save_dir=save_dir, name=project, version=run)
|
|
101
|
+
|
|
102
|
+
@staticmethod
|
|
103
|
+
def _builders() -> dict[LoggingFramework, Callable[[str, str, str], Logger]]:
|
|
104
|
+
return {
|
|
105
|
+
LoggingFramework.wandb: LoggerFactory.wandb,
|
|
106
|
+
LoggingFramework.neptune: LoggerFactory.neptune,
|
|
107
|
+
LoggingFramework.comet: LoggerFactory.comet,
|
|
108
|
+
LoggingFramework.mlflow: LoggerFactory.mlflow,
|
|
109
|
+
LoggingFramework.tensorboard: LoggerFactory.tensorboard,
|
|
110
|
+
LoggingFramework.csv: LoggerFactory.csv,
|
|
111
|
+
}
|
|
112
|
+
|
|
113
|
+
@staticmethod
|
|
114
|
+
def _resolve_framework() -> LoggingFramework | None:
|
|
115
|
+
settings = TrackingEnvironment.from_env()
|
|
116
|
+
forced = settings.logger
|
|
117
|
+
if forced is not None:
|
|
118
|
+
forced = forced.lower()
|
|
119
|
+
if forced in {"none", "false", "off", "disabled"}:
|
|
120
|
+
return None
|
|
121
|
+
|
|
122
|
+
try:
|
|
123
|
+
return LoggingFramework(forced)
|
|
124
|
+
except ValueError:
|
|
125
|
+
logger.bind(component="tracking", backend=forced).warning("unsupported logger backend override")
|
|
126
|
+
return None
|
|
127
|
+
|
|
128
|
+
for backend in LoggerFactory.AUTO_DETECTION_ORDER:
|
|
129
|
+
if any(getattr(settings, field) is not None for field in LoggerFactory.AUTO_LOGGER_FIELDS[backend]):
|
|
130
|
+
return backend
|
|
131
|
+
|
|
132
|
+
return None
|
|
133
|
+
|
|
134
|
+
@staticmethod
|
|
135
|
+
def create(project: str, run: str, notes: str) -> Logger | bool:
|
|
136
|
+
backend = LoggerFactory._resolve_framework()
|
|
137
|
+
if backend is None:
|
|
138
|
+
return False
|
|
139
|
+
|
|
140
|
+
builder = LoggerFactory._builders().get(backend)
|
|
141
|
+
if builder is None:
|
|
142
|
+
logger.bind(component="tracking", backend=backend.value).warning("unsupported logger backend")
|
|
143
|
+
return False
|
|
144
|
+
|
|
145
|
+
try:
|
|
146
|
+
tracker = builder(project, run, notes)
|
|
147
|
+
except Exception:
|
|
148
|
+
logger.bind(component="tracking", backend=backend.value).exception("failed to initialize trainer logger")
|
|
149
|
+
return False
|
|
150
|
+
|
|
151
|
+
logger.bind(component="tracking", backend=backend.value, project=project, run=run).info("enabled trainer logger")
|
|
152
|
+
return tracker
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import ast
|
|
4
|
+
import enum
|
|
5
|
+
import inspect
|
|
6
|
+
import textwrap
|
|
7
|
+
from functools import cache
|
|
8
|
+
from typing import Any, Callable
|
|
9
|
+
|
|
10
|
+
import pluggy
|
|
11
|
+
import pydantic
|
|
12
|
+
|
|
13
|
+
from json2vec.processors.spec import PluginSpec
|
|
14
|
+
|
|
15
|
+
pm: pluggy.PluginManager = pluggy.PluginManager(project_name="processors")
|
|
16
|
+
|
|
17
|
+
pm.add_hookspecs(module_or_class=PluginSpec)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ProcessorMode(enum.StrEnum):
|
|
21
|
+
yielding = "yield"
|
|
22
|
+
returning = "return"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def has_yield_expression(node: ast.AST, root: bool = False) -> bool:
|
|
26
|
+
for child in ast.iter_child_nodes(node):
|
|
27
|
+
if isinstance(child, (ast.Yield, ast.YieldFrom)):
|
|
28
|
+
return True
|
|
29
|
+
|
|
30
|
+
if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef, ast.Lambda)):
|
|
31
|
+
if root and has_yield_expression(child):
|
|
32
|
+
return True
|
|
33
|
+
continue
|
|
34
|
+
|
|
35
|
+
if has_yield_expression(child):
|
|
36
|
+
return True
|
|
37
|
+
|
|
38
|
+
return False
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def is_yielding_processor(func: Callable[..., Any]) -> bool:
|
|
42
|
+
try:
|
|
43
|
+
source: str = textwrap.dedent(inspect.getsource(func))
|
|
44
|
+
except (OSError, TypeError):
|
|
45
|
+
return inspect.isgeneratorfunction(func)
|
|
46
|
+
|
|
47
|
+
module: ast.Module = ast.parse(source)
|
|
48
|
+
candidates: list[ast.FunctionDef | ast.AsyncFunctionDef] = [
|
|
49
|
+
node
|
|
50
|
+
for node in module.body
|
|
51
|
+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
|
|
52
|
+
]
|
|
53
|
+
|
|
54
|
+
target = next((node for node in candidates if node.name == func.__name__), None)
|
|
55
|
+
if target is None:
|
|
56
|
+
return inspect.isgeneratorfunction(func)
|
|
57
|
+
|
|
58
|
+
return has_yield_expression(target, root=True)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class Processor(pydantic.BaseModel):
|
|
62
|
+
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True, frozen=True)
|
|
63
|
+
name: str
|
|
64
|
+
func: Callable[..., Any]
|
|
65
|
+
mode: ProcessorMode
|
|
66
|
+
|
|
67
|
+
def __call__(self, observation: dict, **kwargs) -> Any:
|
|
68
|
+
return self.func(observation, **_filter_supported_kwargs(self.func, kwargs))
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@cache
|
|
72
|
+
def _accepted_kwargs(func: Callable[..., Any]) -> tuple[bool, frozenset[str]]:
|
|
73
|
+
signature = inspect.signature(func)
|
|
74
|
+
accepts_variadic_kwargs = any(
|
|
75
|
+
parameter.kind == inspect.Parameter.VAR_KEYWORD
|
|
76
|
+
for parameter in signature.parameters.values()
|
|
77
|
+
)
|
|
78
|
+
accepted = frozenset(signature.parameters.keys())
|
|
79
|
+
return accepts_variadic_kwargs, accepted
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def _filter_supported_kwargs(func: Callable[..., Any], kwargs: dict[str, Any]) -> dict[str, Any]:
|
|
83
|
+
accepts_variadic_kwargs, accepted = _accepted_kwargs(func)
|
|
84
|
+
if accepts_variadic_kwargs:
|
|
85
|
+
return kwargs
|
|
86
|
+
|
|
87
|
+
return {key: value for key, value in kwargs.items() if key in accepted}
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
PROCESSORS: dict[str, Processor] = {}
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def register(func: Callable[..., Any]) -> Callable[..., Any]:
|
|
94
|
+
name = func.__name__
|
|
95
|
+
|
|
96
|
+
if name in PROCESSORS:
|
|
97
|
+
raise ValueError(f"Processor '{name}' is already registered.")
|
|
98
|
+
|
|
99
|
+
mode: ProcessorMode = ProcessorMode.yielding if is_yielding_processor(func) else ProcessorMode.returning
|
|
100
|
+
PROCESSORS[name] = Processor(name=name, func=func, mode=mode)
|
|
101
|
+
|
|
102
|
+
return func
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
import enum
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class Tokens(enum.IntEnum):
|
|
5
|
+
valued = 0
|
|
6
|
+
null = 1
|
|
7
|
+
padded = 2
|
|
8
|
+
masked = 3
|
|
9
|
+
pruned = 4
|
|
10
|
+
other = 5
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class Stage(enum.StrEnum):
|
|
14
|
+
fit = "fit"
|
|
15
|
+
validate = "validate"
|
|
16
|
+
test = "test"
|
|
17
|
+
predict = "predict"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Strata(enum.StrEnum):
|
|
21
|
+
train = "train"
|
|
22
|
+
validate = "validate"
|
|
23
|
+
test = "test"
|
|
24
|
+
predict = "predict"
|
|
25
|
+
|
|
26
|
+
@classmethod
|
|
27
|
+
def from_stage(cls, stage: Stage | str) -> list["Strata"]:
|
|
28
|
+
match stage:
|
|
29
|
+
case Stage.fit:
|
|
30
|
+
return [cls.train, cls.validate]
|
|
31
|
+
case Stage.validate:
|
|
32
|
+
return [cls.validate]
|
|
33
|
+
case Stage.test:
|
|
34
|
+
return [cls.test]
|
|
35
|
+
case Stage.predict:
|
|
36
|
+
return [cls.predict]
|
|
37
|
+
case _:
|
|
38
|
+
raise ValueError(f"Unknown stage: {stage}")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class Suffix(enum.StrEnum):
|
|
42
|
+
feather = "feather"
|
|
43
|
+
parquet = "parquet"
|
|
44
|
+
ndjson = "ndjson"
|
|
45
|
+
avro = "avro"
|
|
46
|
+
csv = "csv"
|
|
47
|
+
orc = "orc"
|
|
48
|
+
json = "json"
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class TensorKey(enum.StrEnum):
|
|
52
|
+
value = "value"
|
|
53
|
+
content = "content"
|
|
54
|
+
state = "state"
|
|
55
|
+
intervals = "intervals"
|
|
56
|
+
probability = "probability"
|
|
57
|
+
topk = "topk"
|
|
58
|
+
embedding = "embedding"
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class Metric(enum.StrEnum):
|
|
62
|
+
accuracy = "accuracy"
|
|
63
|
+
precision = "precision"
|
|
64
|
+
recall = "recall"
|
|
65
|
+
loss = "loss"
|
|
66
|
+
sigma = "sigma"
|
|
67
|
+
throughput = "throughput"
|
|
68
|
+
mae = "mae"
|
|
69
|
+
rmse = "rmse"
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class ShardingStrategy(enum.StrEnum):
|
|
73
|
+
file = "file"
|
|
74
|
+
chunk = "chunk"
|
|
75
|
+
record = "record"
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class Component(enum.StrEnum):
|
|
79
|
+
Request = "Request"
|
|
80
|
+
Embedder = "Embedder"
|
|
81
|
+
Decoder = "Decoder"
|
|
82
|
+
TensorField = "TensorField"
|
|
83
|
+
loss = "loss"
|
|
84
|
+
write = "write"
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Literal, Self
|
|
5
|
+
from urllib.parse import urlparse
|
|
6
|
+
|
|
7
|
+
from pydantic import AliasChoices, Field, ValidationInfo, field_validator
|
|
8
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
9
|
+
|
|
10
|
+
from json2vec.structs.enums import ShardingStrategy
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class DataLoaderEnvironment(BaseSettings):
|
|
14
|
+
model_config = SettingsConfigDict(extra="ignore", case_sensitive=False)
|
|
15
|
+
|
|
16
|
+
num_workers: int | None = Field(
|
|
17
|
+
default=None,
|
|
18
|
+
ge=0,
|
|
19
|
+
validation_alias=AliasChoices("JSON2VEC_NUM_WORKERS", "NUM_WORKERS"),
|
|
20
|
+
)
|
|
21
|
+
persistent_workers: bool = Field(
|
|
22
|
+
default=True,
|
|
23
|
+
validation_alias=AliasChoices("JSON2VEC_PERSISTENT_WORKERS", "PERSISTENT_WORKERS"),
|
|
24
|
+
)
|
|
25
|
+
pin_memory: bool = Field(
|
|
26
|
+
default=True,
|
|
27
|
+
validation_alias=AliasChoices("JSON2VEC_PIN_MEMORY", "PIN_MEMORY"),
|
|
28
|
+
)
|
|
29
|
+
sharding: ShardingStrategy = Field(
|
|
30
|
+
default=ShardingStrategy.file,
|
|
31
|
+
validation_alias=AliasChoices("JSON2VEC_SHARDING", "JSON2VEC_SHARDING_STRATEGY", "SHARDING_STRATEGY"),
|
|
32
|
+
)
|
|
33
|
+
chunk_batch_size: int = Field(
|
|
34
|
+
default=4096,
|
|
35
|
+
ge=1,
|
|
36
|
+
validation_alias=AliasChoices("JSON2VEC_CHUNK_BATCH_SIZE", "JSON2VEC_PYARROW_BATCH_SIZE", "CHUNK_BATCH_SIZE"),
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
@field_validator("sharding", mode="before")
|
|
40
|
+
@classmethod
|
|
41
|
+
def normalize_sharding(cls, value: ShardingStrategy | str) -> ShardingStrategy | str:
|
|
42
|
+
if isinstance(value, str):
|
|
43
|
+
normalized = value.strip().lower()
|
|
44
|
+
return normalized
|
|
45
|
+
|
|
46
|
+
return value
|
|
47
|
+
|
|
48
|
+
@classmethod
|
|
49
|
+
def from_env(cls) -> Self:
|
|
50
|
+
return cls()
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class TrackingEnvironment(BaseSettings):
|
|
54
|
+
model_config = SettingsConfigDict(extra="ignore", case_sensitive=False)
|
|
55
|
+
|
|
56
|
+
logger: str | None = Field(default=None, validation_alias=AliasChoices("JSON2VEC_LOGGER"))
|
|
57
|
+
wandb_api_key: str | None = Field(default=None, validation_alias=AliasChoices("WANDB_API_KEY"))
|
|
58
|
+
neptune_api_token: str | None = Field(default=None, validation_alias=AliasChoices("NEPTUNE_API_TOKEN"))
|
|
59
|
+
comet_api_key: str | None = Field(default=None, validation_alias=AliasChoices("COMET_API_KEY"))
|
|
60
|
+
mlflow_tracking_uri: str | None = Field(default=None, validation_alias=AliasChoices("MLFLOW_TRACKING_URI"))
|
|
61
|
+
tensorboard_log_dir: str | None = Field(
|
|
62
|
+
default=None,
|
|
63
|
+
validation_alias=AliasChoices("JSON2VEC_TENSORBOARD_LOG_DIR", "TENSORBOARD_LOG_DIR"),
|
|
64
|
+
)
|
|
65
|
+
csv_log_dir: str | None = Field(
|
|
66
|
+
default=None,
|
|
67
|
+
validation_alias=AliasChoices("JSON2VEC_CSV_LOG_DIR", "CSV_LOG_DIR"),
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
@field_validator("*", mode="before")
|
|
71
|
+
@classmethod
|
|
72
|
+
def strip_string_values(cls, value):
|
|
73
|
+
if isinstance(value, str):
|
|
74
|
+
stripped = value.strip()
|
|
75
|
+
if stripped == "":
|
|
76
|
+
return None
|
|
77
|
+
return stripped
|
|
78
|
+
|
|
79
|
+
return value
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def resolved_tensorboard_log_dir(self) -> str:
|
|
83
|
+
return self.tensorboard_log_dir or "logs/tensorboard"
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def resolved_csv_log_dir(self) -> str:
|
|
87
|
+
return self.csv_log_dir or "logs/csv"
|
|
88
|
+
|
|
89
|
+
@classmethod
|
|
90
|
+
def from_env(cls) -> Self:
|
|
91
|
+
return cls()
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
class DeploymentEnvironment(BaseSettings):
|
|
95
|
+
model_config = SettingsConfigDict(extra="ignore", case_sensitive=False)
|
|
96
|
+
|
|
97
|
+
checkpoint: str = Field(
|
|
98
|
+
default="model.ckpt",
|
|
99
|
+
validation_alias=AliasChoices("JSON2VEC_CHECKPOINT", "CHECKPOINT"),
|
|
100
|
+
)
|
|
101
|
+
max_batch_size: int = Field(
|
|
102
|
+
default=128,
|
|
103
|
+
ge=1,
|
|
104
|
+
validation_alias=AliasChoices("JSON2VEC_MAX_BATCH_SIZE", "MAX_BATCH_SIZE"),
|
|
105
|
+
)
|
|
106
|
+
batch_timeout: float = Field(
|
|
107
|
+
default=0.0,
|
|
108
|
+
ge=0.0,
|
|
109
|
+
validation_alias=AliasChoices("JSON2VEC_BATCH_TIMEOUT", "BATCH_TIMEOUT"),
|
|
110
|
+
)
|
|
111
|
+
workers_per_device: int = Field(
|
|
112
|
+
default=1,
|
|
113
|
+
ge=1,
|
|
114
|
+
validation_alias=AliasChoices("JSON2VEC_WORKERS_PER_DEVICE", "JSON2VEC_N_WORKERS", "N_WORKERS"),
|
|
115
|
+
)
|
|
116
|
+
accelerator: Literal["auto", "cpu", "cuda", "mps"] = Field(
|
|
117
|
+
default="auto",
|
|
118
|
+
validation_alias=AliasChoices("JSON2VEC_ACCELERATOR", "ACCELERATOR"),
|
|
119
|
+
)
|
|
120
|
+
track_requests: bool = Field(
|
|
121
|
+
default=False,
|
|
122
|
+
validation_alias=AliasChoices("JSON2VEC_TRACK_REQUESTS", "TRACK_REQUESTS"),
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
@field_validator("checkpoint", "accelerator", mode="before")
|
|
126
|
+
@classmethod
|
|
127
|
+
def strip_required_strings(cls, value: str | None, info: ValidationInfo) -> str | None:
|
|
128
|
+
if isinstance(value, str):
|
|
129
|
+
stripped = value.strip()
|
|
130
|
+
if stripped == "":
|
|
131
|
+
raise ValueError(f"{info.field_name} must not be blank")
|
|
132
|
+
return stripped
|
|
133
|
+
|
|
134
|
+
return value
|
|
135
|
+
|
|
136
|
+
@classmethod
|
|
137
|
+
def from_env(cls) -> Self:
|
|
138
|
+
return cls()
|