jerry-thomas 0.3.0__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datapipeline/analysis/vector/collector.py +120 -17
- datapipeline/analysis/vector/matrix.py +33 -8
- datapipeline/analysis/vector/report.py +162 -32
- datapipeline/build/tasks/__init__.py +11 -0
- datapipeline/build/tasks/config.py +74 -0
- datapipeline/build/tasks/metadata.py +170 -0
- datapipeline/build/tasks/scaler.py +73 -0
- datapipeline/build/tasks/schema.py +60 -0
- datapipeline/build/tasks/utils.py +169 -0
- datapipeline/cli/app.py +304 -127
- datapipeline/cli/commands/build.py +240 -16
- datapipeline/cli/commands/contract.py +367 -0
- datapipeline/cli/commands/domain.py +8 -3
- datapipeline/cli/commands/inspect.py +401 -149
- datapipeline/cli/commands/list_.py +30 -7
- datapipeline/cli/commands/plugin.py +5 -1
- datapipeline/cli/commands/run.py +227 -241
- datapipeline/cli/commands/run_config.py +101 -0
- datapipeline/cli/commands/serve_pipeline.py +156 -0
- datapipeline/cli/commands/source.py +44 -8
- datapipeline/cli/visuals/__init__.py +4 -2
- datapipeline/cli/visuals/common.py +239 -0
- datapipeline/cli/visuals/labels.py +15 -15
- datapipeline/cli/visuals/runner.py +66 -0
- datapipeline/cli/visuals/sections.py +20 -0
- datapipeline/cli/visuals/sources.py +132 -119
- datapipeline/cli/visuals/sources_basic.py +260 -0
- datapipeline/cli/visuals/sources_off.py +76 -0
- datapipeline/cli/visuals/sources_rich.py +414 -0
- datapipeline/config/catalog.py +37 -3
- datapipeline/config/context.py +214 -0
- datapipeline/config/dataset/loader.py +21 -4
- datapipeline/config/dataset/normalize.py +4 -4
- datapipeline/config/metadata.py +43 -0
- datapipeline/config/postprocess.py +2 -2
- datapipeline/config/project.py +3 -2
- datapipeline/config/resolution.py +129 -0
- datapipeline/config/tasks.py +309 -0
- datapipeline/config/workspace.py +155 -0
- datapipeline/domain/__init__.py +12 -0
- datapipeline/domain/record.py +11 -0
- datapipeline/domain/sample.py +54 -0
- datapipeline/integrations/ml/adapter.py +34 -20
- datapipeline/integrations/ml/pandas_support.py +0 -2
- datapipeline/integrations/ml/rows.py +1 -6
- datapipeline/integrations/ml/torch_support.py +1 -3
- datapipeline/io/factory.py +112 -0
- datapipeline/io/output.py +132 -0
- datapipeline/io/protocols.py +21 -0
- datapipeline/io/serializers.py +219 -0
- datapipeline/io/sinks/__init__.py +23 -0
- datapipeline/io/sinks/base.py +2 -0
- datapipeline/io/sinks/files.py +79 -0
- datapipeline/io/sinks/rich.py +57 -0
- datapipeline/io/sinks/stdout.py +18 -0
- datapipeline/io/writers/__init__.py +14 -0
- datapipeline/io/writers/base.py +28 -0
- datapipeline/io/writers/csv_writer.py +25 -0
- datapipeline/io/writers/jsonl.py +52 -0
- datapipeline/io/writers/pickle_writer.py +30 -0
- datapipeline/pipeline/artifacts.py +58 -0
- datapipeline/pipeline/context.py +66 -7
- datapipeline/pipeline/observability.py +65 -0
- datapipeline/pipeline/pipelines.py +65 -13
- datapipeline/pipeline/split.py +11 -10
- datapipeline/pipeline/stages.py +127 -16
- datapipeline/pipeline/utils/keygen.py +20 -7
- datapipeline/pipeline/utils/memory_sort.py +22 -10
- datapipeline/pipeline/utils/transform_utils.py +22 -0
- datapipeline/runtime.py +5 -2
- datapipeline/services/artifacts.py +12 -6
- datapipeline/services/bootstrap/config.py +25 -0
- datapipeline/services/bootstrap/core.py +52 -37
- datapipeline/services/constants.py +6 -5
- datapipeline/services/factories.py +123 -1
- datapipeline/services/project_paths.py +43 -16
- datapipeline/services/runs.py +208 -0
- datapipeline/services/scaffold/domain.py +3 -2
- datapipeline/services/scaffold/filter.py +3 -2
- datapipeline/services/scaffold/mappers.py +9 -6
- datapipeline/services/scaffold/plugin.py +54 -10
- datapipeline/services/scaffold/source.py +93 -56
- datapipeline/sources/{composed_loader.py → data_loader.py} +9 -9
- datapipeline/sources/decoders.py +83 -18
- datapipeline/sources/factory.py +26 -16
- datapipeline/sources/models/__init__.py +2 -2
- datapipeline/sources/models/generator.py +0 -7
- datapipeline/sources/models/loader.py +3 -3
- datapipeline/sources/models/parsing_error.py +24 -0
- datapipeline/sources/models/source.py +6 -6
- datapipeline/sources/synthetic/time/loader.py +14 -2
- datapipeline/sources/transports.py +74 -37
- datapipeline/templates/plugin_skeleton/README.md +76 -30
- datapipeline/templates/plugin_skeleton/example/contracts/time.ticks.hour_sin.yaml +31 -0
- datapipeline/templates/plugin_skeleton/example/contracts/time.ticks.linear.yaml +30 -0
- datapipeline/templates/plugin_skeleton/example/dataset.yaml +18 -0
- datapipeline/templates/plugin_skeleton/example/postprocess.yaml +29 -0
- datapipeline/templates/plugin_skeleton/{config/datasets/default → example}/project.yaml +11 -8
- datapipeline/templates/plugin_skeleton/example/sources/synthetic.ticks.yaml +12 -0
- datapipeline/templates/plugin_skeleton/example/tasks/metadata.yaml +3 -0
- datapipeline/templates/plugin_skeleton/example/tasks/scaler.yaml +9 -0
- datapipeline/templates/plugin_skeleton/example/tasks/schema.yaml +2 -0
- datapipeline/templates/plugin_skeleton/example/tasks/serve.test.yaml +4 -0
- datapipeline/templates/plugin_skeleton/example/tasks/serve.train.yaml +28 -0
- datapipeline/templates/plugin_skeleton/example/tasks/serve.val.yaml +4 -0
- datapipeline/templates/plugin_skeleton/jerry.yaml +34 -0
- datapipeline/templates/plugin_skeleton/your-dataset/contracts/time.ticks.hour_sin.yaml +31 -0
- datapipeline/templates/plugin_skeleton/your-dataset/contracts/time.ticks.linear.yaml +30 -0
- datapipeline/templates/plugin_skeleton/your-dataset/dataset.yaml +18 -0
- datapipeline/templates/plugin_skeleton/your-dataset/postprocess.yaml +29 -0
- datapipeline/templates/plugin_skeleton/your-dataset/project.yaml +22 -0
- datapipeline/templates/plugin_skeleton/your-dataset/sources/synthetic.ticks.yaml +12 -0
- datapipeline/templates/plugin_skeleton/your-dataset/tasks/metadata.yaml +3 -0
- datapipeline/templates/plugin_skeleton/your-dataset/tasks/scaler.yaml +9 -0
- datapipeline/templates/plugin_skeleton/your-dataset/tasks/schema.yaml +2 -0
- datapipeline/templates/plugin_skeleton/your-dataset/tasks/serve.test.yaml +4 -0
- datapipeline/templates/plugin_skeleton/your-dataset/tasks/serve.train.yaml +28 -0
- datapipeline/templates/plugin_skeleton/your-dataset/tasks/serve.val.yaml +4 -0
- datapipeline/templates/stubs/dto.py.j2 +2 -0
- datapipeline/templates/stubs/mapper.py.j2 +5 -4
- datapipeline/templates/stubs/parser.py.j2 +2 -0
- datapipeline/templates/stubs/record.py.j2 +2 -0
- datapipeline/templates/stubs/source.yaml.j2 +2 -3
- datapipeline/transforms/debug/lint.py +26 -41
- datapipeline/transforms/feature/scaler.py +89 -13
- datapipeline/transforms/record/floor_time.py +4 -4
- datapipeline/transforms/sequence.py +2 -35
- datapipeline/transforms/stream/dedupe.py +24 -0
- datapipeline/transforms/stream/ensure_ticks.py +7 -6
- datapipeline/transforms/vector/__init__.py +5 -0
- datapipeline/transforms/vector/common.py +98 -0
- datapipeline/transforms/vector/drop/__init__.py +4 -0
- datapipeline/transforms/vector/drop/horizontal.py +79 -0
- datapipeline/transforms/vector/drop/orchestrator.py +59 -0
- datapipeline/transforms/vector/drop/vertical.py +182 -0
- datapipeline/transforms/vector/ensure_schema.py +184 -0
- datapipeline/transforms/vector/fill.py +87 -0
- datapipeline/transforms/vector/replace.py +62 -0
- datapipeline/utils/load.py +24 -3
- datapipeline/utils/rich_compat.py +38 -0
- datapipeline/utils/window.py +76 -0
- jerry_thomas-1.0.1.dist-info/METADATA +825 -0
- jerry_thomas-1.0.1.dist-info/RECORD +199 -0
- {jerry_thomas-0.3.0.dist-info → jerry_thomas-1.0.1.dist-info}/entry_points.txt +9 -8
- datapipeline/build/tasks.py +0 -186
- datapipeline/cli/commands/link.py +0 -128
- datapipeline/cli/commands/writers.py +0 -138
- datapipeline/config/build.py +0 -64
- datapipeline/config/run.py +0 -116
- datapipeline/templates/plugin_skeleton/config/contracts/time_hour_sin.synthetic.yaml +0 -24
- datapipeline/templates/plugin_skeleton/config/contracts/time_linear.synthetic.yaml +0 -23
- datapipeline/templates/plugin_skeleton/config/datasets/default/build.yaml +0 -9
- datapipeline/templates/plugin_skeleton/config/datasets/default/dataset.yaml +0 -14
- datapipeline/templates/plugin_skeleton/config/datasets/default/postprocess.yaml +0 -13
- datapipeline/templates/plugin_skeleton/config/datasets/default/runs/run_test.yaml +0 -10
- datapipeline/templates/plugin_skeleton/config/datasets/default/runs/run_train.yaml +0 -10
- datapipeline/templates/plugin_skeleton/config/datasets/default/runs/run_val.yaml +0 -10
- datapipeline/templates/plugin_skeleton/config/sources/time_ticks.yaml +0 -11
- datapipeline/transforms/vector.py +0 -210
- jerry_thomas-0.3.0.dist-info/METADATA +0 -502
- jerry_thomas-0.3.0.dist-info/RECORD +0 -139
- {jerry_thomas-0.3.0.dist-info → jerry_thomas-1.0.1.dist-info}/WHEEL +0 -0
- {jerry_thomas-0.3.0.dist-info → jerry_thomas-1.0.1.dist-info}/licenses/LICENSE +0 -0
- {jerry_thomas-0.3.0.dist-info → jerry_thomas-1.0.1.dist-info}/top_level.txt +0 -0
|
@@ -1,12 +1,30 @@
|
|
|
1
1
|
from typing import Literal
|
|
2
|
-
|
|
3
|
-
from datapipeline.
|
|
2
|
+
|
|
3
|
+
from datapipeline.config.dataset.dataset import (
|
|
4
|
+
RecordDatasetConfig,
|
|
5
|
+
FeatureDatasetConfig,
|
|
6
|
+
)
|
|
7
|
+
from datapipeline.services.bootstrap import _load_by_key, _globals, _interpolate
|
|
4
8
|
|
|
5
9
|
Stage = Literal["records", "features", "vectors"]
|
|
6
10
|
|
|
7
11
|
|
|
12
|
+
def _normalize_dataset_doc(doc):
|
|
13
|
+
if not isinstance(doc, dict):
|
|
14
|
+
return doc
|
|
15
|
+
normalized = dict(doc)
|
|
16
|
+
for key in ("features", "targets"):
|
|
17
|
+
if normalized.get(key) is None:
|
|
18
|
+
normalized[key] = []
|
|
19
|
+
return normalized
|
|
20
|
+
|
|
21
|
+
|
|
8
22
|
def load_dataset(project_yaml, stage: Stage):
|
|
9
|
-
|
|
23
|
+
raw = _load_by_key(project_yaml, "dataset")
|
|
24
|
+
vars_ = _globals(project_yaml)
|
|
25
|
+
if vars_:
|
|
26
|
+
raw = _interpolate(raw, vars_)
|
|
27
|
+
ds_doc = _normalize_dataset_doc(raw)
|
|
10
28
|
|
|
11
29
|
if stage == "records":
|
|
12
30
|
return RecordDatasetConfig.model_validate(ds_doc)
|
|
@@ -16,4 +34,3 @@ def load_dataset(project_yaml, stage: Stage):
|
|
|
16
34
|
return FeatureDatasetConfig.model_validate(ds_doc)
|
|
17
35
|
else:
|
|
18
36
|
raise ValueError(f"Unknown stage: {stage}")
|
|
19
|
-
|
|
@@ -2,15 +2,15 @@ from datetime import datetime
|
|
|
2
2
|
import re
|
|
3
3
|
|
|
4
4
|
|
|
5
|
-
def
|
|
6
|
-
"""Floor a timestamp to the nearest
|
|
5
|
+
def floor_time_to_bucket(ts: datetime, bucket: str) -> datetime:
|
|
6
|
+
"""Floor a timestamp to the nearest bucket boundary.
|
|
7
7
|
|
|
8
8
|
Supports patterns like '10m', '10min', '1h', '2h'.
|
|
9
9
|
Minutes may be specified as 'm' or 'min'.
|
|
10
10
|
"""
|
|
11
|
-
m = re.fullmatch(r"^(\d+)(m|min|h)$",
|
|
11
|
+
m = re.fullmatch(r"^(\d+)(m|min|h)$", bucket)
|
|
12
12
|
if not m:
|
|
13
|
-
raise ValueError(f"Unsupported
|
|
13
|
+
raise ValueError(f"Unsupported cadence: {bucket}")
|
|
14
14
|
n = int(m.group(1))
|
|
15
15
|
unit = m.group(2)
|
|
16
16
|
if n <= 0:
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from typing import Any, Dict, List, Optional
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
# Shared keys for vector metadata counts
|
|
10
|
+
FEATURE_VECTORS_COUNT_KEY = "feature_vectors"
|
|
11
|
+
TARGET_VECTORS_COUNT_KEY = "target_vectors"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Window(BaseModel):
|
|
15
|
+
"""Typed representation of dataset window bounds."""
|
|
16
|
+
|
|
17
|
+
start: Optional[datetime] = None
|
|
18
|
+
end: Optional[datetime] = None
|
|
19
|
+
mode: Optional[str] = None
|
|
20
|
+
size: Optional[int] = Field(
|
|
21
|
+
default=None,
|
|
22
|
+
description="Count of cadence buckets from start to end (inclusive) when known.",
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class VectorMetadata(BaseModel):
|
|
27
|
+
"""Lightweight typed model for metadata.json.
|
|
28
|
+
|
|
29
|
+
Only window/counts/entries are modeled explicitly; all other fields are
|
|
30
|
+
accepted via extra='allow' for forwards-compatibility.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
model_config = ConfigDict(extra="allow")
|
|
34
|
+
|
|
35
|
+
schema_version: int = 1
|
|
36
|
+
generated_at: Optional[datetime] = None
|
|
37
|
+
window: Optional[Window] = None
|
|
38
|
+
meta: Dict[str, Any] | None = None
|
|
39
|
+
features: List[Dict[str, Any]] = Field(default_factory=list)
|
|
40
|
+
targets: List[Dict[str, Any]] = Field(default_factory=list)
|
|
41
|
+
counts: Dict[str, int] = Field(default_factory=dict)
|
|
42
|
+
|
|
43
|
+
# Window is the single source of truth; no legacy fallbacks.
|
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
from typing import Any, List
|
|
2
|
+
|
|
2
3
|
from pydantic import RootModel, model_validator
|
|
3
4
|
|
|
4
5
|
|
|
5
6
|
class PostprocessConfig(RootModel[List[Any]]):
|
|
6
|
-
"""Schema for
|
|
7
|
+
"""Schema for postprocess.yaml (list of transforms)."""
|
|
7
8
|
|
|
8
9
|
@model_validator(mode="before")
|
|
9
10
|
@classmethod
|
|
10
11
|
def allow_empty(cls, value: Any) -> Any:
|
|
11
|
-
"""Coerce missing or empty mappings into an empty list."""
|
|
12
12
|
if value in (None, {}):
|
|
13
13
|
return []
|
|
14
14
|
return value
|
datapipeline/config/project.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
from datetime import datetime
|
|
2
4
|
from typing import Optional
|
|
3
5
|
from pydantic import BaseModel, Field, ConfigDict
|
|
@@ -11,8 +13,7 @@ class ProjectPaths(BaseModel):
|
|
|
11
13
|
dataset: str
|
|
12
14
|
postprocess: str
|
|
13
15
|
artifacts: str
|
|
14
|
-
|
|
15
|
-
run: str | None = None
|
|
16
|
+
tasks: str | None = None
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
class ProjectGlobals(BaseModel):
|
|
@@ -0,0 +1,129 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Any, Optional
|
|
7
|
+
|
|
8
|
+
from datapipeline.config.tasks import ServeOutputConfig
|
|
9
|
+
from datapipeline.config.workspace import WorkspaceContext
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def cascade(*values, fallback=None):
|
|
13
|
+
"""Return the first non-None value from a list, or fallback."""
|
|
14
|
+
for value in values:
|
|
15
|
+
if value is not None:
|
|
16
|
+
return value
|
|
17
|
+
return fallback
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _normalize_lower(value: Any) -> Optional[str]:
|
|
21
|
+
if value is None:
|
|
22
|
+
return None
|
|
23
|
+
text = str(value).strip()
|
|
24
|
+
return text.lower() if text else None
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _normalize_upper(value: Any) -> Optional[str]:
|
|
28
|
+
if value is None:
|
|
29
|
+
return None
|
|
30
|
+
if isinstance(value, int):
|
|
31
|
+
return logging.getLevelName(value).upper()
|
|
32
|
+
text = str(value).strip()
|
|
33
|
+
return text.upper() if text else None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _level_value(value: Any) -> Optional[int]:
|
|
37
|
+
name = _normalize_upper(value)
|
|
38
|
+
return logging._nameToLevel.get(name) if name else None
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
@dataclass(frozen=True)
|
|
42
|
+
class VisualSettings:
|
|
43
|
+
visuals: str
|
|
44
|
+
progress: str
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def resolve_visuals(
|
|
48
|
+
*,
|
|
49
|
+
cli_visuals: str | None,
|
|
50
|
+
config_visuals: str | None,
|
|
51
|
+
workspace_visuals: str | None,
|
|
52
|
+
cli_progress: str | None,
|
|
53
|
+
config_progress: str | None,
|
|
54
|
+
workspace_progress: str | None,
|
|
55
|
+
default_visuals: str = "auto",
|
|
56
|
+
default_progress: str = "auto",
|
|
57
|
+
) -> VisualSettings:
|
|
58
|
+
visuals = cascade(
|
|
59
|
+
_normalize_lower(cli_visuals),
|
|
60
|
+
_normalize_lower(config_visuals),
|
|
61
|
+
_normalize_lower(workspace_visuals),
|
|
62
|
+
default_visuals,
|
|
63
|
+
) or default_visuals
|
|
64
|
+
progress = cascade(
|
|
65
|
+
_normalize_lower(cli_progress),
|
|
66
|
+
_normalize_lower(config_progress),
|
|
67
|
+
_normalize_lower(workspace_progress),
|
|
68
|
+
default_progress,
|
|
69
|
+
) or default_progress
|
|
70
|
+
return VisualSettings(visuals=visuals, progress=progress)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@dataclass(frozen=True)
|
|
74
|
+
class LogLevelDecision:
|
|
75
|
+
name: str
|
|
76
|
+
value: int
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def resolve_log_level(
|
|
80
|
+
*levels: Any,
|
|
81
|
+
fallback: str = "INFO",
|
|
82
|
+
) -> LogLevelDecision:
|
|
83
|
+
name = None
|
|
84
|
+
for level in levels:
|
|
85
|
+
normalized = _normalize_upper(level)
|
|
86
|
+
if normalized:
|
|
87
|
+
name = normalized
|
|
88
|
+
break
|
|
89
|
+
if not name:
|
|
90
|
+
name = _normalize_upper(fallback) or "INFO"
|
|
91
|
+
value = logging._nameToLevel.get(name, logging.INFO)
|
|
92
|
+
return LogLevelDecision(name=name, value=value)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def minimum_level(*levels: Any, start: int | None = None) -> int | None:
|
|
96
|
+
"""Return the lowest numeric logging level among the provided values."""
|
|
97
|
+
current = start
|
|
98
|
+
for level in levels:
|
|
99
|
+
value = _level_value(level)
|
|
100
|
+
if value is None:
|
|
101
|
+
continue
|
|
102
|
+
if current is None or value < current:
|
|
103
|
+
current = value
|
|
104
|
+
return current
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def workspace_output_defaults(
|
|
108
|
+
workspace: WorkspaceContext | None,
|
|
109
|
+
) -> ServeOutputConfig | None:
|
|
110
|
+
if workspace is None:
|
|
111
|
+
return None
|
|
112
|
+
serve_defaults = getattr(workspace.config, "serve", None)
|
|
113
|
+
if not serve_defaults or not serve_defaults.output:
|
|
114
|
+
return None
|
|
115
|
+
od = serve_defaults.output
|
|
116
|
+
output_dir = None
|
|
117
|
+
if od.directory:
|
|
118
|
+
candidate = Path(od.directory)
|
|
119
|
+
output_dir = (
|
|
120
|
+
candidate
|
|
121
|
+
if candidate.is_absolute()
|
|
122
|
+
else (workspace.root / candidate).resolve()
|
|
123
|
+
)
|
|
124
|
+
return ServeOutputConfig(
|
|
125
|
+
transport=od.transport,
|
|
126
|
+
format=od.format,
|
|
127
|
+
payload=od.payload,
|
|
128
|
+
directory=output_dir,
|
|
129
|
+
)
|
|
@@ -0,0 +1,309 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Annotated, Iterable, List, Literal, Sequence
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, Field, field_validator, model_validator
|
|
7
|
+
from pydantic.type_adapter import TypeAdapter
|
|
8
|
+
|
|
9
|
+
from datapipeline.services.project_paths import tasks_dir
|
|
10
|
+
from datapipeline.utils.load import load_yaml
|
|
11
|
+
|
|
12
|
+
VALID_LOG_LEVELS = ("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG")
|
|
13
|
+
VALID_VISUAL_PROVIDERS = ("AUTO", "TQDM", "RICH", "OFF")
|
|
14
|
+
VALID_PROGRESS_STYLES = ("AUTO", "SPINNER", "BARS", "OFF")
|
|
15
|
+
|
|
16
|
+
Transport = Literal["fs", "stdout"]
|
|
17
|
+
Format = Literal["csv", "json", "json-lines", "print", "pickle"]
|
|
18
|
+
PayloadMode = Literal["sample", "vector"]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class TaskBase(BaseModel):
|
|
22
|
+
version: int = Field(default=1)
|
|
23
|
+
kind: str
|
|
24
|
+
name: str | None = Field(default=None, description="Optional task identifier.")
|
|
25
|
+
enabled: bool = Field(default=True, description="Disable to skip execution.")
|
|
26
|
+
depends_on: list[str] = Field(default_factory=list)
|
|
27
|
+
source_path: Path | None = Field(default=None, exclude=True)
|
|
28
|
+
|
|
29
|
+
def effective_name(self) -> str:
|
|
30
|
+
return self.name or (self.source_path.stem if self.source_path else self.kind)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class ArtifactTask(TaskBase):
|
|
34
|
+
output: str = Field(
|
|
35
|
+
...,
|
|
36
|
+
description="Artifact path relative to project.paths.artifacts.",
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class ScalerTask(ArtifactTask):
|
|
41
|
+
kind: Literal["scaler"]
|
|
42
|
+
output: str = Field(default="scaler.pkl")
|
|
43
|
+
split_label: str = Field(
|
|
44
|
+
default="train",
|
|
45
|
+
description="Split label to use when fitting scaler statistics.",
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class SchemaTask(ArtifactTask):
|
|
50
|
+
kind: Literal["schema"]
|
|
51
|
+
output: str = Field(default="schema.json")
|
|
52
|
+
cadence_strategy: Literal["max"] = Field(
|
|
53
|
+
default="max",
|
|
54
|
+
description="Strategy for selecting cadence targets (currently only 'max').",
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class MetadataTask(ArtifactTask):
|
|
59
|
+
kind: Literal["metadata"]
|
|
60
|
+
output: str = Field(default="metadata.json")
|
|
61
|
+
enabled: bool = Field(
|
|
62
|
+
default=True,
|
|
63
|
+
description="Disable to skip generating the vector metadata artifact.",
|
|
64
|
+
)
|
|
65
|
+
cadence_strategy: Literal["max"] = Field(
|
|
66
|
+
default="max",
|
|
67
|
+
description="Strategy for selecting cadence targets.",
|
|
68
|
+
)
|
|
69
|
+
window_mode: Literal["union", "intersection", "strict", "relaxed"] = Field(
|
|
70
|
+
default="intersection",
|
|
71
|
+
description="Window mode: union (base union), intersection (base intersection), strict (partition intersection), relaxed (partition union).",
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class RuntimeTask(TaskBase):
|
|
76
|
+
"""Base class for runtime-oriented tasks (serve/evaluate/etc.)."""
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class ServeOutputConfig(BaseModel):
|
|
80
|
+
transport: Transport = Field(..., description="fs | stdout")
|
|
81
|
+
format: Format = Field(..., description="csv | json | json-lines | print | pickle")
|
|
82
|
+
payload: PayloadMode = Field(
|
|
83
|
+
default="sample",
|
|
84
|
+
description="sample (key + metadata) or vector payload (features [+targets]).",
|
|
85
|
+
)
|
|
86
|
+
directory: Path | None = Field(
|
|
87
|
+
default=None,
|
|
88
|
+
description="Directory for fs outputs.",
|
|
89
|
+
)
|
|
90
|
+
filename: str | None = Field(
|
|
91
|
+
default=None,
|
|
92
|
+
description="Filename stem (format controls extension) for fs outputs.",
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
@field_validator("filename", mode="before")
|
|
96
|
+
@classmethod
|
|
97
|
+
def _normalize_filename(cls, value):
|
|
98
|
+
if value is None:
|
|
99
|
+
return None
|
|
100
|
+
text = str(value).strip()
|
|
101
|
+
if not text:
|
|
102
|
+
return None
|
|
103
|
+
if any(sep in text for sep in ("/", "\\")):
|
|
104
|
+
raise ValueError("filename must not contain path separators")
|
|
105
|
+
if "." in Path(text).name:
|
|
106
|
+
raise ValueError("filename must not include an extension")
|
|
107
|
+
return text
|
|
108
|
+
|
|
109
|
+
@model_validator(mode="after")
|
|
110
|
+
def _validate(self):
|
|
111
|
+
if self.transport == "stdout":
|
|
112
|
+
if self.directory is not None:
|
|
113
|
+
raise ValueError("stdout cannot define a directory")
|
|
114
|
+
if self.filename is not None:
|
|
115
|
+
raise ValueError("stdout outputs do not support filenames")
|
|
116
|
+
if self.format not in {"print", "json-lines", "json"}:
|
|
117
|
+
raise ValueError(
|
|
118
|
+
"stdout output supports 'print', 'json-lines', or 'json' formats"
|
|
119
|
+
)
|
|
120
|
+
return self
|
|
121
|
+
|
|
122
|
+
if self.format == "print":
|
|
123
|
+
raise ValueError("fs transport cannot use 'print' format")
|
|
124
|
+
if self.directory is None:
|
|
125
|
+
raise ValueError("fs outputs require a directory")
|
|
126
|
+
return self
|
|
127
|
+
|
|
128
|
+
@field_validator("payload", mode="before")
|
|
129
|
+
@classmethod
|
|
130
|
+
def _normalize_payload(cls, value):
|
|
131
|
+
if value is None:
|
|
132
|
+
return "sample"
|
|
133
|
+
name = str(value).lower()
|
|
134
|
+
if name not in {"sample", "vector"}:
|
|
135
|
+
raise ValueError("payload must be 'sample' or 'vector'")
|
|
136
|
+
return name
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class ServeTask(RuntimeTask):
|
|
140
|
+
kind: Literal["serve"]
|
|
141
|
+
output: ServeOutputConfig | None = None
|
|
142
|
+
keep: str | None = Field(
|
|
143
|
+
default=None,
|
|
144
|
+
description="Active split label to serve.",
|
|
145
|
+
min_length=1,
|
|
146
|
+
)
|
|
147
|
+
limit: int | None = Field(
|
|
148
|
+
default=None,
|
|
149
|
+
description="Default max number of vectors to emit.",
|
|
150
|
+
ge=1,
|
|
151
|
+
)
|
|
152
|
+
stage: int | None = Field(
|
|
153
|
+
default=None,
|
|
154
|
+
description="Default pipeline stage preview (0-7).",
|
|
155
|
+
ge=0,
|
|
156
|
+
le=7,
|
|
157
|
+
)
|
|
158
|
+
throttle_ms: float | None = Field(
|
|
159
|
+
default=None,
|
|
160
|
+
description="Milliseconds to sleep between emitted vectors.",
|
|
161
|
+
ge=0.0,
|
|
162
|
+
)
|
|
163
|
+
log_level: str | None = Field(
|
|
164
|
+
default="INFO",
|
|
165
|
+
description="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL).",
|
|
166
|
+
)
|
|
167
|
+
visuals: str | None = Field(
|
|
168
|
+
default="AUTO",
|
|
169
|
+
description="Visuals provider: AUTO, TQDM, RICH, or OFF.",
|
|
170
|
+
)
|
|
171
|
+
progress: str | None = Field(
|
|
172
|
+
default="AUTO",
|
|
173
|
+
description="Progress style: AUTO, SPINNER, BARS, or OFF.",
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
@field_validator("log_level")
|
|
177
|
+
@classmethod
|
|
178
|
+
def _validate_log_level(cls, value: str | None) -> str | None:
|
|
179
|
+
if value is None:
|
|
180
|
+
return None
|
|
181
|
+
name = str(value).upper()
|
|
182
|
+
if name not in VALID_LOG_LEVELS:
|
|
183
|
+
raise ValueError(
|
|
184
|
+
f"log_level must be one of {', '.join(VALID_LOG_LEVELS)}, got {value!r}"
|
|
185
|
+
)
|
|
186
|
+
return name
|
|
187
|
+
|
|
188
|
+
@field_validator("visuals", mode="before")
|
|
189
|
+
@classmethod
|
|
190
|
+
def _validate_visuals_run(cls, value):
|
|
191
|
+
if value is None:
|
|
192
|
+
return None
|
|
193
|
+
if isinstance(value, bool):
|
|
194
|
+
return "OFF" if value is False else "AUTO"
|
|
195
|
+
name = str(value).upper()
|
|
196
|
+
if name not in VALID_VISUAL_PROVIDERS:
|
|
197
|
+
raise ValueError(
|
|
198
|
+
f"visuals must be one of {', '.join(VALID_VISUAL_PROVIDERS)}, got {value!r}"
|
|
199
|
+
)
|
|
200
|
+
return name
|
|
201
|
+
|
|
202
|
+
@field_validator("progress", mode="before")
|
|
203
|
+
@classmethod
|
|
204
|
+
def _validate_progress_run(cls, value):
|
|
205
|
+
if value is None:
|
|
206
|
+
return None
|
|
207
|
+
name = str(value).upper()
|
|
208
|
+
if name not in VALID_PROGRESS_STYLES:
|
|
209
|
+
raise ValueError(
|
|
210
|
+
f"progress must be one of {', '.join(VALID_PROGRESS_STYLES)}, got {value!r}"
|
|
211
|
+
)
|
|
212
|
+
return name
|
|
213
|
+
|
|
214
|
+
@field_validator("name")
|
|
215
|
+
@classmethod
|
|
216
|
+
def _validate_name(cls, value: str | None) -> str | None:
|
|
217
|
+
if value is None:
|
|
218
|
+
return None
|
|
219
|
+
text = str(value).strip()
|
|
220
|
+
if not text:
|
|
221
|
+
raise ValueError("task name cannot be empty")
|
|
222
|
+
return text
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
TaskModel = Annotated[
|
|
226
|
+
ScalerTask | SchemaTask | MetadataTask | ServeTask,
|
|
227
|
+
Field(discriminator="kind"),
|
|
228
|
+
]
|
|
229
|
+
|
|
230
|
+
TASK_ADAPTER = TypeAdapter(TaskModel)
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def _task_files(root: Path) -> Sequence[Path]:
|
|
234
|
+
if not root.exists():
|
|
235
|
+
return []
|
|
236
|
+
if root.is_file():
|
|
237
|
+
return [root]
|
|
238
|
+
return sorted(
|
|
239
|
+
p for p in root.rglob("*.y*ml") if p.is_file()
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def _load_task_docs(path: Path) -> list[TaskBase]:
|
|
244
|
+
doc = load_yaml(path)
|
|
245
|
+
if isinstance(doc, list):
|
|
246
|
+
entries = doc
|
|
247
|
+
else:
|
|
248
|
+
entries = [doc]
|
|
249
|
+
tasks: list[TaskBase] = []
|
|
250
|
+
for entry in entries:
|
|
251
|
+
if not isinstance(entry, dict):
|
|
252
|
+
raise TypeError(f"{path} must define mapping tasks.")
|
|
253
|
+
task = TASK_ADAPTER.validate_python(entry)
|
|
254
|
+
task.source_path = path
|
|
255
|
+
if task.name is None:
|
|
256
|
+
task.name = path.stem
|
|
257
|
+
tasks.append(task)
|
|
258
|
+
return tasks
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def load_all_tasks(project_yaml: Path) -> list[TaskBase]:
|
|
262
|
+
root = tasks_dir(project_yaml)
|
|
263
|
+
tasks: list[TaskBase] = []
|
|
264
|
+
for path in _task_files(root):
|
|
265
|
+
tasks.extend(_load_task_docs(path))
|
|
266
|
+
return tasks
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def artifact_tasks(project_yaml: Path) -> list[ArtifactTask]:
|
|
270
|
+
tasks = [
|
|
271
|
+
task
|
|
272
|
+
for task in load_all_tasks(project_yaml)
|
|
273
|
+
if isinstance(task, ArtifactTask)
|
|
274
|
+
]
|
|
275
|
+
kinds = {task.kind for task in tasks}
|
|
276
|
+
if "schema" not in kinds:
|
|
277
|
+
tasks.append(SchemaTask(kind="schema"))
|
|
278
|
+
if "scaler" not in kinds:
|
|
279
|
+
tasks.append(ScalerTask(kind="scaler"))
|
|
280
|
+
if "metadata" not in kinds:
|
|
281
|
+
tasks.append(MetadataTask(kind="metadata"))
|
|
282
|
+
return tasks
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
def command_tasks(project_yaml: Path, kind: str | None = None) -> list[TaskBase]:
|
|
286
|
+
tasks = [
|
|
287
|
+
task
|
|
288
|
+
for task in load_all_tasks(project_yaml)
|
|
289
|
+
if not isinstance(task, ArtifactTask)
|
|
290
|
+
]
|
|
291
|
+
if kind is None:
|
|
292
|
+
return tasks
|
|
293
|
+
return [task for task in tasks if task.kind == kind]
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def serve_tasks(project_yaml: Path) -> list[ServeTask]:
|
|
297
|
+
"""Load all serve tasks regardless of enabled state."""
|
|
298
|
+
return [
|
|
299
|
+
task
|
|
300
|
+
for task in command_tasks(project_yaml, kind="serve")
|
|
301
|
+
if isinstance(task, ServeTask)
|
|
302
|
+
]
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def default_serve_task(project_yaml: Path) -> ServeTask | None:
|
|
306
|
+
for task in serve_tasks(project_yaml):
|
|
307
|
+
if task.enabled:
|
|
308
|
+
return task
|
|
309
|
+
return None
|