jerry-thomas 0.3.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. datapipeline/analysis/vector/collector.py +120 -17
  2. datapipeline/analysis/vector/matrix.py +33 -8
  3. datapipeline/analysis/vector/report.py +162 -32
  4. datapipeline/build/tasks/__init__.py +11 -0
  5. datapipeline/build/tasks/config.py +74 -0
  6. datapipeline/build/tasks/metadata.py +170 -0
  7. datapipeline/build/tasks/scaler.py +73 -0
  8. datapipeline/build/tasks/schema.py +60 -0
  9. datapipeline/build/tasks/utils.py +169 -0
  10. datapipeline/cli/app.py +304 -127
  11. datapipeline/cli/commands/build.py +240 -16
  12. datapipeline/cli/commands/contract.py +367 -0
  13. datapipeline/cli/commands/domain.py +8 -3
  14. datapipeline/cli/commands/inspect.py +401 -149
  15. datapipeline/cli/commands/list_.py +30 -7
  16. datapipeline/cli/commands/plugin.py +1 -1
  17. datapipeline/cli/commands/run.py +227 -241
  18. datapipeline/cli/commands/run_config.py +101 -0
  19. datapipeline/cli/commands/serve_pipeline.py +156 -0
  20. datapipeline/cli/commands/source.py +44 -8
  21. datapipeline/cli/visuals/__init__.py +4 -2
  22. datapipeline/cli/visuals/common.py +239 -0
  23. datapipeline/cli/visuals/labels.py +15 -15
  24. datapipeline/cli/visuals/runner.py +66 -0
  25. datapipeline/cli/visuals/sections.py +20 -0
  26. datapipeline/cli/visuals/sources.py +132 -119
  27. datapipeline/cli/visuals/sources_basic.py +260 -0
  28. datapipeline/cli/visuals/sources_off.py +76 -0
  29. datapipeline/cli/visuals/sources_rich.py +414 -0
  30. datapipeline/config/catalog.py +37 -3
  31. datapipeline/config/context.py +214 -0
  32. datapipeline/config/dataset/loader.py +21 -4
  33. datapipeline/config/dataset/normalize.py +4 -4
  34. datapipeline/config/metadata.py +43 -0
  35. datapipeline/config/postprocess.py +2 -2
  36. datapipeline/config/project.py +3 -2
  37. datapipeline/config/resolution.py +129 -0
  38. datapipeline/config/tasks.py +309 -0
  39. datapipeline/config/workspace.py +155 -0
  40. datapipeline/domain/__init__.py +12 -0
  41. datapipeline/domain/record.py +11 -0
  42. datapipeline/domain/sample.py +54 -0
  43. datapipeline/integrations/ml/adapter.py +34 -20
  44. datapipeline/integrations/ml/pandas_support.py +0 -2
  45. datapipeline/integrations/ml/rows.py +1 -6
  46. datapipeline/integrations/ml/torch_support.py +1 -3
  47. datapipeline/io/factory.py +112 -0
  48. datapipeline/io/output.py +132 -0
  49. datapipeline/io/protocols.py +21 -0
  50. datapipeline/io/serializers.py +219 -0
  51. datapipeline/io/sinks/__init__.py +23 -0
  52. datapipeline/io/sinks/base.py +2 -0
  53. datapipeline/io/sinks/files.py +79 -0
  54. datapipeline/io/sinks/rich.py +57 -0
  55. datapipeline/io/sinks/stdout.py +18 -0
  56. datapipeline/io/writers/__init__.py +14 -0
  57. datapipeline/io/writers/base.py +28 -0
  58. datapipeline/io/writers/csv_writer.py +25 -0
  59. datapipeline/io/writers/jsonl.py +52 -0
  60. datapipeline/io/writers/pickle_writer.py +30 -0
  61. datapipeline/pipeline/artifacts.py +58 -0
  62. datapipeline/pipeline/context.py +66 -7
  63. datapipeline/pipeline/observability.py +65 -0
  64. datapipeline/pipeline/pipelines.py +65 -13
  65. datapipeline/pipeline/split.py +11 -10
  66. datapipeline/pipeline/stages.py +127 -16
  67. datapipeline/pipeline/utils/keygen.py +20 -7
  68. datapipeline/pipeline/utils/memory_sort.py +22 -10
  69. datapipeline/pipeline/utils/transform_utils.py +22 -0
  70. datapipeline/runtime.py +5 -2
  71. datapipeline/services/artifacts.py +12 -6
  72. datapipeline/services/bootstrap/config.py +25 -0
  73. datapipeline/services/bootstrap/core.py +52 -37
  74. datapipeline/services/constants.py +6 -5
  75. datapipeline/services/factories.py +123 -1
  76. datapipeline/services/project_paths.py +43 -16
  77. datapipeline/services/runs.py +208 -0
  78. datapipeline/services/scaffold/domain.py +3 -2
  79. datapipeline/services/scaffold/filter.py +3 -2
  80. datapipeline/services/scaffold/mappers.py +9 -6
  81. datapipeline/services/scaffold/plugin.py +3 -3
  82. datapipeline/services/scaffold/source.py +93 -56
  83. datapipeline/sources/{composed_loader.py → data_loader.py} +9 -9
  84. datapipeline/sources/decoders.py +83 -18
  85. datapipeline/sources/factory.py +26 -16
  86. datapipeline/sources/models/__init__.py +2 -2
  87. datapipeline/sources/models/generator.py +0 -7
  88. datapipeline/sources/models/loader.py +3 -3
  89. datapipeline/sources/models/parsing_error.py +24 -0
  90. datapipeline/sources/models/source.py +6 -6
  91. datapipeline/sources/synthetic/time/loader.py +14 -2
  92. datapipeline/sources/transports.py +74 -37
  93. datapipeline/templates/plugin_skeleton/README.md +74 -30
  94. datapipeline/templates/plugin_skeleton/example/contracts/time.ticks.hour_sin.yaml +31 -0
  95. datapipeline/templates/plugin_skeleton/example/contracts/time.ticks.linear.yaml +30 -0
  96. datapipeline/templates/plugin_skeleton/example/dataset.yaml +18 -0
  97. datapipeline/templates/plugin_skeleton/example/postprocess.yaml +29 -0
  98. datapipeline/templates/plugin_skeleton/{config/datasets/default → example}/project.yaml +11 -8
  99. datapipeline/templates/plugin_skeleton/example/sources/synthetic.ticks.yaml +12 -0
  100. datapipeline/templates/plugin_skeleton/example/tasks/metadata.yaml +3 -0
  101. datapipeline/templates/plugin_skeleton/example/tasks/scaler.yaml +9 -0
  102. datapipeline/templates/plugin_skeleton/example/tasks/schema.yaml +2 -0
  103. datapipeline/templates/plugin_skeleton/example/tasks/serve.test.yaml +4 -0
  104. datapipeline/templates/plugin_skeleton/example/tasks/serve.train.yaml +28 -0
  105. datapipeline/templates/plugin_skeleton/example/tasks/serve.val.yaml +4 -0
  106. datapipeline/templates/plugin_skeleton/jerry.yaml +28 -0
  107. datapipeline/templates/plugin_skeleton/your-dataset/contracts/time.ticks.hour_sin.yaml +31 -0
  108. datapipeline/templates/plugin_skeleton/your-dataset/contracts/time.ticks.linear.yaml +30 -0
  109. datapipeline/templates/plugin_skeleton/your-dataset/dataset.yaml +18 -0
  110. datapipeline/templates/plugin_skeleton/your-dataset/postprocess.yaml +29 -0
  111. datapipeline/templates/plugin_skeleton/your-dataset/project.yaml +22 -0
  112. datapipeline/templates/plugin_skeleton/your-dataset/sources/synthetic.ticks.yaml +12 -0
  113. datapipeline/templates/plugin_skeleton/your-dataset/tasks/metadata.yaml +3 -0
  114. datapipeline/templates/plugin_skeleton/your-dataset/tasks/scaler.yaml +9 -0
  115. datapipeline/templates/plugin_skeleton/your-dataset/tasks/schema.yaml +2 -0
  116. datapipeline/templates/plugin_skeleton/your-dataset/tasks/serve.test.yaml +4 -0
  117. datapipeline/templates/plugin_skeleton/your-dataset/tasks/serve.train.yaml +28 -0
  118. datapipeline/templates/plugin_skeleton/your-dataset/tasks/serve.val.yaml +4 -0
  119. datapipeline/templates/stubs/dto.py.j2 +2 -0
  120. datapipeline/templates/stubs/mapper.py.j2 +5 -4
  121. datapipeline/templates/stubs/parser.py.j2 +2 -0
  122. datapipeline/templates/stubs/record.py.j2 +2 -0
  123. datapipeline/templates/stubs/source.yaml.j2 +2 -3
  124. datapipeline/transforms/debug/lint.py +26 -41
  125. datapipeline/transforms/feature/scaler.py +89 -13
  126. datapipeline/transforms/record/floor_time.py +4 -4
  127. datapipeline/transforms/sequence.py +2 -35
  128. datapipeline/transforms/stream/dedupe.py +24 -0
  129. datapipeline/transforms/stream/ensure_ticks.py +7 -6
  130. datapipeline/transforms/vector/__init__.py +5 -0
  131. datapipeline/transforms/vector/common.py +98 -0
  132. datapipeline/transforms/vector/drop/__init__.py +4 -0
  133. datapipeline/transforms/vector/drop/horizontal.py +79 -0
  134. datapipeline/transforms/vector/drop/orchestrator.py +59 -0
  135. datapipeline/transforms/vector/drop/vertical.py +182 -0
  136. datapipeline/transforms/vector/ensure_schema.py +184 -0
  137. datapipeline/transforms/vector/fill.py +87 -0
  138. datapipeline/transforms/vector/replace.py +62 -0
  139. datapipeline/utils/load.py +24 -3
  140. datapipeline/utils/rich_compat.py +38 -0
  141. datapipeline/utils/window.py +76 -0
  142. jerry_thomas-1.0.0.dist-info/METADATA +825 -0
  143. jerry_thomas-1.0.0.dist-info/RECORD +199 -0
  144. {jerry_thomas-0.3.0.dist-info → jerry_thomas-1.0.0.dist-info}/entry_points.txt +9 -8
  145. datapipeline/build/tasks.py +0 -186
  146. datapipeline/cli/commands/link.py +0 -128
  147. datapipeline/cli/commands/writers.py +0 -138
  148. datapipeline/config/build.py +0 -64
  149. datapipeline/config/run.py +0 -116
  150. datapipeline/templates/plugin_skeleton/config/contracts/time_hour_sin.synthetic.yaml +0 -24
  151. datapipeline/templates/plugin_skeleton/config/contracts/time_linear.synthetic.yaml +0 -23
  152. datapipeline/templates/plugin_skeleton/config/datasets/default/build.yaml +0 -9
  153. datapipeline/templates/plugin_skeleton/config/datasets/default/dataset.yaml +0 -14
  154. datapipeline/templates/plugin_skeleton/config/datasets/default/postprocess.yaml +0 -13
  155. datapipeline/templates/plugin_skeleton/config/datasets/default/runs/run_test.yaml +0 -10
  156. datapipeline/templates/plugin_skeleton/config/datasets/default/runs/run_train.yaml +0 -10
  157. datapipeline/templates/plugin_skeleton/config/datasets/default/runs/run_val.yaml +0 -10
  158. datapipeline/templates/plugin_skeleton/config/sources/time_ticks.yaml +0 -11
  159. datapipeline/transforms/vector.py +0 -210
  160. jerry_thomas-0.3.0.dist-info/METADATA +0 -502
  161. jerry_thomas-0.3.0.dist-info/RECORD +0 -139
  162. {jerry_thomas-0.3.0.dist-info → jerry_thomas-1.0.0.dist-info}/WHEEL +0 -0
  163. {jerry_thomas-0.3.0.dist-info → jerry_thomas-1.0.0.dist-info}/licenses/LICENSE +0 -0
  164. {jerry_thomas-0.3.0.dist-info → jerry_thomas-1.0.0.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,30 @@
1
1
  from typing import Literal
2
- from datapipeline.config.dataset.dataset import RecordDatasetConfig, FeatureDatasetConfig
3
- from datapipeline.services.bootstrap import _load_by_key
2
+
3
+ from datapipeline.config.dataset.dataset import (
4
+ RecordDatasetConfig,
5
+ FeatureDatasetConfig,
6
+ )
7
+ from datapipeline.services.bootstrap import _load_by_key, _globals, _interpolate
4
8
 
5
9
  Stage = Literal["records", "features", "vectors"]
6
10
 
7
11
 
12
+ def _normalize_dataset_doc(doc):
13
+ if not isinstance(doc, dict):
14
+ return doc
15
+ normalized = dict(doc)
16
+ for key in ("features", "targets"):
17
+ if normalized.get(key) is None:
18
+ normalized[key] = []
19
+ return normalized
20
+
21
+
8
22
  def load_dataset(project_yaml, stage: Stage):
9
- ds_doc = _load_by_key(project_yaml, "dataset")
23
+ raw = _load_by_key(project_yaml, "dataset")
24
+ vars_ = _globals(project_yaml)
25
+ if vars_:
26
+ raw = _interpolate(raw, vars_)
27
+ ds_doc = _normalize_dataset_doc(raw)
10
28
 
11
29
  if stage == "records":
12
30
  return RecordDatasetConfig.model_validate(ds_doc)
@@ -16,4 +34,3 @@ def load_dataset(project_yaml, stage: Stage):
16
34
  return FeatureDatasetConfig.model_validate(ds_doc)
17
35
  else:
18
36
  raise ValueError(f"Unknown stage: {stage}")
19
-
@@ -2,15 +2,15 @@ from datetime import datetime
2
2
  import re
3
3
 
4
4
 
5
- def floor_time_to_resolution(ts: datetime, resolution: str) -> datetime:
6
- """Floor a timestamp to the nearest resolution bucket.
5
+ def floor_time_to_bucket(ts: datetime, bucket: str) -> datetime:
6
+ """Floor a timestamp to the nearest bucket boundary.
7
7
 
8
8
  Supports patterns like '10m', '10min', '1h', '2h'.
9
9
  Minutes may be specified as 'm' or 'min'.
10
10
  """
11
- m = re.fullmatch(r"^(\d+)(m|min|h)$", resolution)
11
+ m = re.fullmatch(r"^(\d+)(m|min|h)$", bucket)
12
12
  if not m:
13
- raise ValueError(f"Unsupported granularity: {resolution}")
13
+ raise ValueError(f"Unsupported cadence: {bucket}")
14
14
  n = int(m.group(1))
15
15
  unit = m.group(2)
16
16
  if n <= 0:
@@ -0,0 +1,43 @@
1
+ from __future__ import annotations
2
+
3
+ from datetime import datetime
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ from pydantic import BaseModel, ConfigDict, Field
7
+
8
+
9
+ # Shared keys for vector metadata counts
10
+ FEATURE_VECTORS_COUNT_KEY = "feature_vectors"
11
+ TARGET_VECTORS_COUNT_KEY = "target_vectors"
12
+
13
+
14
+ class Window(BaseModel):
15
+ """Typed representation of dataset window bounds."""
16
+
17
+ start: Optional[datetime] = None
18
+ end: Optional[datetime] = None
19
+ mode: Optional[str] = None
20
+ size: Optional[int] = Field(
21
+ default=None,
22
+ description="Count of cadence buckets from start to end (inclusive) when known.",
23
+ )
24
+
25
+
26
+ class VectorMetadata(BaseModel):
27
+ """Lightweight typed model for metadata.json.
28
+
29
+ Only window/counts/entries are modeled explicitly; all other fields are
30
+ accepted via extra='allow' for forwards-compatibility.
31
+ """
32
+
33
+ model_config = ConfigDict(extra="allow")
34
+
35
+ schema_version: int = 1
36
+ generated_at: Optional[datetime] = None
37
+ window: Optional[Window] = None
38
+ meta: Dict[str, Any] | None = None
39
+ features: List[Dict[str, Any]] = Field(default_factory=list)
40
+ targets: List[Dict[str, Any]] = Field(default_factory=list)
41
+ counts: Dict[str, int] = Field(default_factory=dict)
42
+
43
+ # Window is the single source of truth; no legacy fallbacks.
@@ -1,14 +1,14 @@
1
1
  from typing import Any, List
2
+
2
3
  from pydantic import RootModel, model_validator
3
4
 
4
5
 
5
6
  class PostprocessConfig(RootModel[List[Any]]):
6
- """Schema for optional postprocess.yaml."""
7
+ """Schema for postprocess.yaml (list of transforms)."""
7
8
 
8
9
  @model_validator(mode="before")
9
10
  @classmethod
10
11
  def allow_empty(cls, value: Any) -> Any:
11
- """Coerce missing or empty mappings into an empty list."""
12
12
  if value in (None, {}):
13
13
  return []
14
14
  return value
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  from datetime import datetime
2
4
  from typing import Optional
3
5
  from pydantic import BaseModel, Field, ConfigDict
@@ -11,8 +13,7 @@ class ProjectPaths(BaseModel):
11
13
  dataset: str
12
14
  postprocess: str
13
15
  artifacts: str
14
- build: str | None = None
15
- run: str | None = None
16
+ tasks: str | None = None
16
17
 
17
18
 
18
19
  class ProjectGlobals(BaseModel):
@@ -0,0 +1,129 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Any, Optional
7
+
8
+ from datapipeline.config.tasks import ServeOutputConfig
9
+ from datapipeline.config.workspace import WorkspaceContext
10
+
11
+
12
+ def cascade(*values, fallback=None):
13
+ """Return the first non-None value from a list, or fallback."""
14
+ for value in values:
15
+ if value is not None:
16
+ return value
17
+ return fallback
18
+
19
+
20
+ def _normalize_lower(value: Any) -> Optional[str]:
21
+ if value is None:
22
+ return None
23
+ text = str(value).strip()
24
+ return text.lower() if text else None
25
+
26
+
27
+ def _normalize_upper(value: Any) -> Optional[str]:
28
+ if value is None:
29
+ return None
30
+ if isinstance(value, int):
31
+ return logging.getLevelName(value).upper()
32
+ text = str(value).strip()
33
+ return text.upper() if text else None
34
+
35
+
36
+ def _level_value(value: Any) -> Optional[int]:
37
+ name = _normalize_upper(value)
38
+ return logging._nameToLevel.get(name) if name else None
39
+
40
+
41
+ @dataclass(frozen=True)
42
+ class VisualSettings:
43
+ visuals: str
44
+ progress: str
45
+
46
+
47
+ def resolve_visuals(
48
+ *,
49
+ cli_visuals: str | None,
50
+ config_visuals: str | None,
51
+ workspace_visuals: str | None,
52
+ cli_progress: str | None,
53
+ config_progress: str | None,
54
+ workspace_progress: str | None,
55
+ default_visuals: str = "auto",
56
+ default_progress: str = "auto",
57
+ ) -> VisualSettings:
58
+ visuals = cascade(
59
+ _normalize_lower(cli_visuals),
60
+ _normalize_lower(config_visuals),
61
+ _normalize_lower(workspace_visuals),
62
+ default_visuals,
63
+ ) or default_visuals
64
+ progress = cascade(
65
+ _normalize_lower(cli_progress),
66
+ _normalize_lower(config_progress),
67
+ _normalize_lower(workspace_progress),
68
+ default_progress,
69
+ ) or default_progress
70
+ return VisualSettings(visuals=visuals, progress=progress)
71
+
72
+
73
+ @dataclass(frozen=True)
74
+ class LogLevelDecision:
75
+ name: str
76
+ value: int
77
+
78
+
79
+ def resolve_log_level(
80
+ *levels: Any,
81
+ fallback: str = "INFO",
82
+ ) -> LogLevelDecision:
83
+ name = None
84
+ for level in levels:
85
+ normalized = _normalize_upper(level)
86
+ if normalized:
87
+ name = normalized
88
+ break
89
+ if not name:
90
+ name = _normalize_upper(fallback) or "INFO"
91
+ value = logging._nameToLevel.get(name, logging.INFO)
92
+ return LogLevelDecision(name=name, value=value)
93
+
94
+
95
+ def minimum_level(*levels: Any, start: int | None = None) -> int | None:
96
+ """Return the lowest numeric logging level among the provided values."""
97
+ current = start
98
+ for level in levels:
99
+ value = _level_value(level)
100
+ if value is None:
101
+ continue
102
+ if current is None or value < current:
103
+ current = value
104
+ return current
105
+
106
+
107
+ def workspace_output_defaults(
108
+ workspace: WorkspaceContext | None,
109
+ ) -> ServeOutputConfig | None:
110
+ if workspace is None:
111
+ return None
112
+ serve_defaults = getattr(workspace.config, "serve", None)
113
+ if not serve_defaults or not serve_defaults.output:
114
+ return None
115
+ od = serve_defaults.output
116
+ output_dir = None
117
+ if od.directory:
118
+ candidate = Path(od.directory)
119
+ output_dir = (
120
+ candidate
121
+ if candidate.is_absolute()
122
+ else (workspace.root / candidate).resolve()
123
+ )
124
+ return ServeOutputConfig(
125
+ transport=od.transport,
126
+ format=od.format,
127
+ payload=od.payload,
128
+ directory=output_dir,
129
+ )
@@ -0,0 +1,309 @@
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Annotated, Iterable, List, Literal, Sequence
5
+
6
+ from pydantic import BaseModel, Field, field_validator, model_validator
7
+ from pydantic.type_adapter import TypeAdapter
8
+
9
+ from datapipeline.services.project_paths import tasks_dir
10
+ from datapipeline.utils.load import load_yaml
11
+
12
+ VALID_LOG_LEVELS = ("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG")
13
+ VALID_VISUAL_PROVIDERS = ("AUTO", "TQDM", "RICH", "OFF")
14
+ VALID_PROGRESS_STYLES = ("AUTO", "SPINNER", "BARS", "OFF")
15
+
16
+ Transport = Literal["fs", "stdout"]
17
+ Format = Literal["csv", "json", "json-lines", "print", "pickle"]
18
+ PayloadMode = Literal["sample", "vector"]
19
+
20
+
21
+ class TaskBase(BaseModel):
22
+ version: int = Field(default=1)
23
+ kind: str
24
+ name: str | None = Field(default=None, description="Optional task identifier.")
25
+ enabled: bool = Field(default=True, description="Disable to skip execution.")
26
+ depends_on: list[str] = Field(default_factory=list)
27
+ source_path: Path | None = Field(default=None, exclude=True)
28
+
29
+ def effective_name(self) -> str:
30
+ return self.name or (self.source_path.stem if self.source_path else self.kind)
31
+
32
+
33
+ class ArtifactTask(TaskBase):
34
+ output: str = Field(
35
+ ...,
36
+ description="Artifact path relative to project.paths.artifacts.",
37
+ )
38
+
39
+
40
+ class ScalerTask(ArtifactTask):
41
+ kind: Literal["scaler"]
42
+ output: str = Field(default="scaler.pkl")
43
+ split_label: str = Field(
44
+ default="train",
45
+ description="Split label to use when fitting scaler statistics.",
46
+ )
47
+
48
+
49
+ class SchemaTask(ArtifactTask):
50
+ kind: Literal["schema"]
51
+ output: str = Field(default="schema.json")
52
+ cadence_strategy: Literal["max"] = Field(
53
+ default="max",
54
+ description="Strategy for selecting cadence targets (currently only 'max').",
55
+ )
56
+
57
+
58
+ class MetadataTask(ArtifactTask):
59
+ kind: Literal["metadata"]
60
+ output: str = Field(default="metadata.json")
61
+ enabled: bool = Field(
62
+ default=True,
63
+ description="Disable to skip generating the vector metadata artifact.",
64
+ )
65
+ cadence_strategy: Literal["max"] = Field(
66
+ default="max",
67
+ description="Strategy for selecting cadence targets.",
68
+ )
69
+ window_mode: Literal["union", "intersection", "strict", "relaxed"] = Field(
70
+ default="intersection",
71
+ description="Window mode: union (base union), intersection (base intersection), strict (partition intersection), relaxed (partition union).",
72
+ )
73
+
74
+
75
+ class RuntimeTask(TaskBase):
76
+ """Base class for runtime-oriented tasks (serve/evaluate/etc.)."""
77
+
78
+
79
+ class ServeOutputConfig(BaseModel):
80
+ transport: Transport = Field(..., description="fs | stdout")
81
+ format: Format = Field(..., description="csv | json | json-lines | print | pickle")
82
+ payload: PayloadMode = Field(
83
+ default="sample",
84
+ description="sample (key + metadata) or vector payload (features [+targets]).",
85
+ )
86
+ directory: Path | None = Field(
87
+ default=None,
88
+ description="Directory for fs outputs.",
89
+ )
90
+ filename: str | None = Field(
91
+ default=None,
92
+ description="Filename stem (format controls extension) for fs outputs.",
93
+ )
94
+
95
+ @field_validator("filename", mode="before")
96
+ @classmethod
97
+ def _normalize_filename(cls, value):
98
+ if value is None:
99
+ return None
100
+ text = str(value).strip()
101
+ if not text:
102
+ return None
103
+ if any(sep in text for sep in ("/", "\\")):
104
+ raise ValueError("filename must not contain path separators")
105
+ if "." in Path(text).name:
106
+ raise ValueError("filename must not include an extension")
107
+ return text
108
+
109
+ @model_validator(mode="after")
110
+ def _validate(self):
111
+ if self.transport == "stdout":
112
+ if self.directory is not None:
113
+ raise ValueError("stdout cannot define a directory")
114
+ if self.filename is not None:
115
+ raise ValueError("stdout outputs do not support filenames")
116
+ if self.format not in {"print", "json-lines", "json"}:
117
+ raise ValueError(
118
+ "stdout output supports 'print', 'json-lines', or 'json' formats"
119
+ )
120
+ return self
121
+
122
+ if self.format == "print":
123
+ raise ValueError("fs transport cannot use 'print' format")
124
+ if self.directory is None:
125
+ raise ValueError("fs outputs require a directory")
126
+ return self
127
+
128
+ @field_validator("payload", mode="before")
129
+ @classmethod
130
+ def _normalize_payload(cls, value):
131
+ if value is None:
132
+ return "sample"
133
+ name = str(value).lower()
134
+ if name not in {"sample", "vector"}:
135
+ raise ValueError("payload must be 'sample' or 'vector'")
136
+ return name
137
+
138
+
139
+ class ServeTask(RuntimeTask):
140
+ kind: Literal["serve"]
141
+ output: ServeOutputConfig | None = None
142
+ keep: str | None = Field(
143
+ default=None,
144
+ description="Active split label to serve.",
145
+ min_length=1,
146
+ )
147
+ limit: int | None = Field(
148
+ default=None,
149
+ description="Default max number of vectors to emit.",
150
+ ge=1,
151
+ )
152
+ stage: int | None = Field(
153
+ default=None,
154
+ description="Default pipeline stage preview (0-7).",
155
+ ge=0,
156
+ le=7,
157
+ )
158
+ throttle_ms: float | None = Field(
159
+ default=None,
160
+ description="Milliseconds to sleep between emitted vectors.",
161
+ ge=0.0,
162
+ )
163
+ log_level: str | None = Field(
164
+ default="INFO",
165
+ description="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL).",
166
+ )
167
+ visuals: str | None = Field(
168
+ default="AUTO",
169
+ description="Visuals provider: AUTO, TQDM, RICH, or OFF.",
170
+ )
171
+ progress: str | None = Field(
172
+ default="AUTO",
173
+ description="Progress style: AUTO, SPINNER, BARS, or OFF.",
174
+ )
175
+
176
+ @field_validator("log_level")
177
+ @classmethod
178
+ def _validate_log_level(cls, value: str | None) -> str | None:
179
+ if value is None:
180
+ return None
181
+ name = str(value).upper()
182
+ if name not in VALID_LOG_LEVELS:
183
+ raise ValueError(
184
+ f"log_level must be one of {', '.join(VALID_LOG_LEVELS)}, got {value!r}"
185
+ )
186
+ return name
187
+
188
+ @field_validator("visuals", mode="before")
189
+ @classmethod
190
+ def _validate_visuals_run(cls, value):
191
+ if value is None:
192
+ return None
193
+ if isinstance(value, bool):
194
+ return "OFF" if value is False else "AUTO"
195
+ name = str(value).upper()
196
+ if name not in VALID_VISUAL_PROVIDERS:
197
+ raise ValueError(
198
+ f"visuals must be one of {', '.join(VALID_VISUAL_PROVIDERS)}, got {value!r}"
199
+ )
200
+ return name
201
+
202
+ @field_validator("progress", mode="before")
203
+ @classmethod
204
+ def _validate_progress_run(cls, value):
205
+ if value is None:
206
+ return None
207
+ name = str(value).upper()
208
+ if name not in VALID_PROGRESS_STYLES:
209
+ raise ValueError(
210
+ f"progress must be one of {', '.join(VALID_PROGRESS_STYLES)}, got {value!r}"
211
+ )
212
+ return name
213
+
214
+ @field_validator("name")
215
+ @classmethod
216
+ def _validate_name(cls, value: str | None) -> str | None:
217
+ if value is None:
218
+ return None
219
+ text = str(value).strip()
220
+ if not text:
221
+ raise ValueError("task name cannot be empty")
222
+ return text
223
+
224
+
225
+ TaskModel = Annotated[
226
+ ScalerTask | SchemaTask | MetadataTask | ServeTask,
227
+ Field(discriminator="kind"),
228
+ ]
229
+
230
+ TASK_ADAPTER = TypeAdapter(TaskModel)
231
+
232
+
233
+ def _task_files(root: Path) -> Sequence[Path]:
234
+ if not root.exists():
235
+ return []
236
+ if root.is_file():
237
+ return [root]
238
+ return sorted(
239
+ p for p in root.rglob("*.y*ml") if p.is_file()
240
+ )
241
+
242
+
243
+ def _load_task_docs(path: Path) -> list[TaskBase]:
244
+ doc = load_yaml(path)
245
+ if isinstance(doc, list):
246
+ entries = doc
247
+ else:
248
+ entries = [doc]
249
+ tasks: list[TaskBase] = []
250
+ for entry in entries:
251
+ if not isinstance(entry, dict):
252
+ raise TypeError(f"{path} must define mapping tasks.")
253
+ task = TASK_ADAPTER.validate_python(entry)
254
+ task.source_path = path
255
+ if task.name is None:
256
+ task.name = path.stem
257
+ tasks.append(task)
258
+ return tasks
259
+
260
+
261
+ def load_all_tasks(project_yaml: Path) -> list[TaskBase]:
262
+ root = tasks_dir(project_yaml)
263
+ tasks: list[TaskBase] = []
264
+ for path in _task_files(root):
265
+ tasks.extend(_load_task_docs(path))
266
+ return tasks
267
+
268
+
269
+ def artifact_tasks(project_yaml: Path) -> list[ArtifactTask]:
270
+ tasks = [
271
+ task
272
+ for task in load_all_tasks(project_yaml)
273
+ if isinstance(task, ArtifactTask)
274
+ ]
275
+ kinds = {task.kind for task in tasks}
276
+ if "schema" not in kinds:
277
+ tasks.append(SchemaTask(kind="schema"))
278
+ if "scaler" not in kinds:
279
+ tasks.append(ScalerTask(kind="scaler"))
280
+ if "metadata" not in kinds:
281
+ tasks.append(MetadataTask(kind="metadata"))
282
+ return tasks
283
+
284
+
285
+ def command_tasks(project_yaml: Path, kind: str | None = None) -> list[TaskBase]:
286
+ tasks = [
287
+ task
288
+ for task in load_all_tasks(project_yaml)
289
+ if not isinstance(task, ArtifactTask)
290
+ ]
291
+ if kind is None:
292
+ return tasks
293
+ return [task for task in tasks if task.kind == kind]
294
+
295
+
296
+ def serve_tasks(project_yaml: Path) -> list[ServeTask]:
297
+ """Load all serve tasks regardless of enabled state."""
298
+ return [
299
+ task
300
+ for task in command_tasks(project_yaml, kind="serve")
301
+ if isinstance(task, ServeTask)
302
+ ]
303
+
304
+
305
+ def default_serve_task(project_yaml: Path) -> ServeTask | None:
306
+ for task in serve_tasks(project_yaml):
307
+ if task.enabled:
308
+ return task
309
+ return None