jerry-thomas 0.3.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (164) hide show
  1. datapipeline/analysis/vector/collector.py +120 -17
  2. datapipeline/analysis/vector/matrix.py +33 -8
  3. datapipeline/analysis/vector/report.py +162 -32
  4. datapipeline/build/tasks/__init__.py +11 -0
  5. datapipeline/build/tasks/config.py +74 -0
  6. datapipeline/build/tasks/metadata.py +170 -0
  7. datapipeline/build/tasks/scaler.py +73 -0
  8. datapipeline/build/tasks/schema.py +60 -0
  9. datapipeline/build/tasks/utils.py +169 -0
  10. datapipeline/cli/app.py +304 -127
  11. datapipeline/cli/commands/build.py +240 -16
  12. datapipeline/cli/commands/contract.py +367 -0
  13. datapipeline/cli/commands/domain.py +8 -3
  14. datapipeline/cli/commands/inspect.py +401 -149
  15. datapipeline/cli/commands/list_.py +30 -7
  16. datapipeline/cli/commands/plugin.py +1 -1
  17. datapipeline/cli/commands/run.py +227 -241
  18. datapipeline/cli/commands/run_config.py +101 -0
  19. datapipeline/cli/commands/serve_pipeline.py +156 -0
  20. datapipeline/cli/commands/source.py +44 -8
  21. datapipeline/cli/visuals/__init__.py +4 -2
  22. datapipeline/cli/visuals/common.py +239 -0
  23. datapipeline/cli/visuals/labels.py +15 -15
  24. datapipeline/cli/visuals/runner.py +66 -0
  25. datapipeline/cli/visuals/sections.py +20 -0
  26. datapipeline/cli/visuals/sources.py +132 -119
  27. datapipeline/cli/visuals/sources_basic.py +260 -0
  28. datapipeline/cli/visuals/sources_off.py +76 -0
  29. datapipeline/cli/visuals/sources_rich.py +414 -0
  30. datapipeline/config/catalog.py +37 -3
  31. datapipeline/config/context.py +214 -0
  32. datapipeline/config/dataset/loader.py +21 -4
  33. datapipeline/config/dataset/normalize.py +4 -4
  34. datapipeline/config/metadata.py +43 -0
  35. datapipeline/config/postprocess.py +2 -2
  36. datapipeline/config/project.py +3 -2
  37. datapipeline/config/resolution.py +129 -0
  38. datapipeline/config/tasks.py +309 -0
  39. datapipeline/config/workspace.py +155 -0
  40. datapipeline/domain/__init__.py +12 -0
  41. datapipeline/domain/record.py +11 -0
  42. datapipeline/domain/sample.py +54 -0
  43. datapipeline/integrations/ml/adapter.py +34 -20
  44. datapipeline/integrations/ml/pandas_support.py +0 -2
  45. datapipeline/integrations/ml/rows.py +1 -6
  46. datapipeline/integrations/ml/torch_support.py +1 -3
  47. datapipeline/io/factory.py +112 -0
  48. datapipeline/io/output.py +132 -0
  49. datapipeline/io/protocols.py +21 -0
  50. datapipeline/io/serializers.py +219 -0
  51. datapipeline/io/sinks/__init__.py +23 -0
  52. datapipeline/io/sinks/base.py +2 -0
  53. datapipeline/io/sinks/files.py +79 -0
  54. datapipeline/io/sinks/rich.py +57 -0
  55. datapipeline/io/sinks/stdout.py +18 -0
  56. datapipeline/io/writers/__init__.py +14 -0
  57. datapipeline/io/writers/base.py +28 -0
  58. datapipeline/io/writers/csv_writer.py +25 -0
  59. datapipeline/io/writers/jsonl.py +52 -0
  60. datapipeline/io/writers/pickle_writer.py +30 -0
  61. datapipeline/pipeline/artifacts.py +58 -0
  62. datapipeline/pipeline/context.py +66 -7
  63. datapipeline/pipeline/observability.py +65 -0
  64. datapipeline/pipeline/pipelines.py +65 -13
  65. datapipeline/pipeline/split.py +11 -10
  66. datapipeline/pipeline/stages.py +127 -16
  67. datapipeline/pipeline/utils/keygen.py +20 -7
  68. datapipeline/pipeline/utils/memory_sort.py +22 -10
  69. datapipeline/pipeline/utils/transform_utils.py +22 -0
  70. datapipeline/runtime.py +5 -2
  71. datapipeline/services/artifacts.py +12 -6
  72. datapipeline/services/bootstrap/config.py +25 -0
  73. datapipeline/services/bootstrap/core.py +52 -37
  74. datapipeline/services/constants.py +6 -5
  75. datapipeline/services/factories.py +123 -1
  76. datapipeline/services/project_paths.py +43 -16
  77. datapipeline/services/runs.py +208 -0
  78. datapipeline/services/scaffold/domain.py +3 -2
  79. datapipeline/services/scaffold/filter.py +3 -2
  80. datapipeline/services/scaffold/mappers.py +9 -6
  81. datapipeline/services/scaffold/plugin.py +3 -3
  82. datapipeline/services/scaffold/source.py +93 -56
  83. datapipeline/sources/{composed_loader.py → data_loader.py} +9 -9
  84. datapipeline/sources/decoders.py +83 -18
  85. datapipeline/sources/factory.py +26 -16
  86. datapipeline/sources/models/__init__.py +2 -2
  87. datapipeline/sources/models/generator.py +0 -7
  88. datapipeline/sources/models/loader.py +3 -3
  89. datapipeline/sources/models/parsing_error.py +24 -0
  90. datapipeline/sources/models/source.py +6 -6
  91. datapipeline/sources/synthetic/time/loader.py +14 -2
  92. datapipeline/sources/transports.py +74 -37
  93. datapipeline/templates/plugin_skeleton/README.md +74 -30
  94. datapipeline/templates/plugin_skeleton/example/contracts/time.ticks.hour_sin.yaml +31 -0
  95. datapipeline/templates/plugin_skeleton/example/contracts/time.ticks.linear.yaml +30 -0
  96. datapipeline/templates/plugin_skeleton/example/dataset.yaml +18 -0
  97. datapipeline/templates/plugin_skeleton/example/postprocess.yaml +29 -0
  98. datapipeline/templates/plugin_skeleton/{config/datasets/default → example}/project.yaml +11 -8
  99. datapipeline/templates/plugin_skeleton/example/sources/synthetic.ticks.yaml +12 -0
  100. datapipeline/templates/plugin_skeleton/example/tasks/metadata.yaml +3 -0
  101. datapipeline/templates/plugin_skeleton/example/tasks/scaler.yaml +9 -0
  102. datapipeline/templates/plugin_skeleton/example/tasks/schema.yaml +2 -0
  103. datapipeline/templates/plugin_skeleton/example/tasks/serve.test.yaml +4 -0
  104. datapipeline/templates/plugin_skeleton/example/tasks/serve.train.yaml +28 -0
  105. datapipeline/templates/plugin_skeleton/example/tasks/serve.val.yaml +4 -0
  106. datapipeline/templates/plugin_skeleton/jerry.yaml +28 -0
  107. datapipeline/templates/plugin_skeleton/your-dataset/contracts/time.ticks.hour_sin.yaml +31 -0
  108. datapipeline/templates/plugin_skeleton/your-dataset/contracts/time.ticks.linear.yaml +30 -0
  109. datapipeline/templates/plugin_skeleton/your-dataset/dataset.yaml +18 -0
  110. datapipeline/templates/plugin_skeleton/your-dataset/postprocess.yaml +29 -0
  111. datapipeline/templates/plugin_skeleton/your-dataset/project.yaml +22 -0
  112. datapipeline/templates/plugin_skeleton/your-dataset/sources/synthetic.ticks.yaml +12 -0
  113. datapipeline/templates/plugin_skeleton/your-dataset/tasks/metadata.yaml +3 -0
  114. datapipeline/templates/plugin_skeleton/your-dataset/tasks/scaler.yaml +9 -0
  115. datapipeline/templates/plugin_skeleton/your-dataset/tasks/schema.yaml +2 -0
  116. datapipeline/templates/plugin_skeleton/your-dataset/tasks/serve.test.yaml +4 -0
  117. datapipeline/templates/plugin_skeleton/your-dataset/tasks/serve.train.yaml +28 -0
  118. datapipeline/templates/plugin_skeleton/your-dataset/tasks/serve.val.yaml +4 -0
  119. datapipeline/templates/stubs/dto.py.j2 +2 -0
  120. datapipeline/templates/stubs/mapper.py.j2 +5 -4
  121. datapipeline/templates/stubs/parser.py.j2 +2 -0
  122. datapipeline/templates/stubs/record.py.j2 +2 -0
  123. datapipeline/templates/stubs/source.yaml.j2 +2 -3
  124. datapipeline/transforms/debug/lint.py +26 -41
  125. datapipeline/transforms/feature/scaler.py +89 -13
  126. datapipeline/transforms/record/floor_time.py +4 -4
  127. datapipeline/transforms/sequence.py +2 -35
  128. datapipeline/transforms/stream/dedupe.py +24 -0
  129. datapipeline/transforms/stream/ensure_ticks.py +7 -6
  130. datapipeline/transforms/vector/__init__.py +5 -0
  131. datapipeline/transforms/vector/common.py +98 -0
  132. datapipeline/transforms/vector/drop/__init__.py +4 -0
  133. datapipeline/transforms/vector/drop/horizontal.py +79 -0
  134. datapipeline/transforms/vector/drop/orchestrator.py +59 -0
  135. datapipeline/transforms/vector/drop/vertical.py +182 -0
  136. datapipeline/transforms/vector/ensure_schema.py +184 -0
  137. datapipeline/transforms/vector/fill.py +87 -0
  138. datapipeline/transforms/vector/replace.py +62 -0
  139. datapipeline/utils/load.py +24 -3
  140. datapipeline/utils/rich_compat.py +38 -0
  141. datapipeline/utils/window.py +76 -0
  142. jerry_thomas-1.0.0.dist-info/METADATA +825 -0
  143. jerry_thomas-1.0.0.dist-info/RECORD +199 -0
  144. {jerry_thomas-0.3.0.dist-info → jerry_thomas-1.0.0.dist-info}/entry_points.txt +9 -8
  145. datapipeline/build/tasks.py +0 -186
  146. datapipeline/cli/commands/link.py +0 -128
  147. datapipeline/cli/commands/writers.py +0 -138
  148. datapipeline/config/build.py +0 -64
  149. datapipeline/config/run.py +0 -116
  150. datapipeline/templates/plugin_skeleton/config/contracts/time_hour_sin.synthetic.yaml +0 -24
  151. datapipeline/templates/plugin_skeleton/config/contracts/time_linear.synthetic.yaml +0 -23
  152. datapipeline/templates/plugin_skeleton/config/datasets/default/build.yaml +0 -9
  153. datapipeline/templates/plugin_skeleton/config/datasets/default/dataset.yaml +0 -14
  154. datapipeline/templates/plugin_skeleton/config/datasets/default/postprocess.yaml +0 -13
  155. datapipeline/templates/plugin_skeleton/config/datasets/default/runs/run_test.yaml +0 -10
  156. datapipeline/templates/plugin_skeleton/config/datasets/default/runs/run_train.yaml +0 -10
  157. datapipeline/templates/plugin_skeleton/config/datasets/default/runs/run_val.yaml +0 -10
  158. datapipeline/templates/plugin_skeleton/config/sources/time_ticks.yaml +0 -11
  159. datapipeline/transforms/vector.py +0 -210
  160. jerry_thomas-0.3.0.dist-info/METADATA +0 -502
  161. jerry_thomas-0.3.0.dist-info/RECORD +0 -139
  162. {jerry_thomas-0.3.0.dist-info → jerry_thomas-1.0.0.dist-info}/WHEEL +0 -0
  163. {jerry_thomas-0.3.0.dist-info → jerry_thomas-1.0.0.dist-info}/licenses/LICENSE +0 -0
  164. {jerry_thomas-0.3.0.dist-info → jerry_thomas-1.0.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,219 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from dataclasses import asdict, is_dataclass
5
+ from typing import Any, Dict, Type
6
+
7
+ from datapipeline.domain.sample import Sample
8
+
9
+
10
+ class BaseSerializer:
11
+ payload_mode = "sample"
12
+
13
+ def serialize_payload(self, sample: Sample) -> Any: # pragma: no cover - abstract
14
+ raise NotImplementedError
15
+
16
+
17
+ class BaseJsonLineSerializer(BaseSerializer):
18
+ def __call__(self, sample: Sample) -> str:
19
+ data = self.serialize_payload(sample)
20
+ return json.dumps(data, ensure_ascii=False, default=str) + "\n"
21
+
22
+
23
+ class SampleJsonLineSerializer(BaseJsonLineSerializer):
24
+ payload_mode = "sample"
25
+
26
+ def serialize_payload(self, sample: Sample) -> Any:
27
+ return sample.as_full_payload()
28
+
29
+
30
+ class VectorJsonLineSerializer(BaseJsonLineSerializer):
31
+ payload_mode = "vector"
32
+
33
+ def serialize_payload(self, sample: Sample) -> Any:
34
+ return sample.as_vector_payload()
35
+
36
+
37
+ class BasePrintSerializer(BaseSerializer):
38
+ def __call__(self, sample: Sample) -> str:
39
+ value = self.serialize_payload(sample)
40
+ return f"{value}\n"
41
+
42
+
43
+ class SamplePrintSerializer(BasePrintSerializer):
44
+ payload_mode = "sample"
45
+
46
+ def serialize_payload(self, sample: Sample) -> Any:
47
+ return sample.as_full_payload()
48
+
49
+
50
+ class VectorPrintSerializer(BasePrintSerializer):
51
+ payload_mode = "vector"
52
+
53
+ def serialize_payload(self, sample: Sample) -> Any:
54
+ return sample.as_vector_payload()
55
+
56
+
57
+ class BaseCsvRowSerializer(BaseSerializer):
58
+ def __call__(self, sample: Sample) -> list[str | Any]:
59
+ key_value = sample.key
60
+ if isinstance(key_value, tuple):
61
+ key_struct = list(key_value)
62
+ else:
63
+ key_struct = key_value
64
+ if isinstance(key_struct, (list, dict)):
65
+ key_text = json.dumps(key_struct, ensure_ascii=False, default=str)
66
+ else:
67
+ key_text = "" if key_struct is None else str(key_struct)
68
+
69
+ payload_data = self.serialize_payload(sample)
70
+ if isinstance(payload_data, dict):
71
+ payload_data.pop("key", None)
72
+ payload_text = json.dumps(payload_data, ensure_ascii=False, default=str)
73
+ return [key_text, payload_text]
74
+
75
+
76
+ class SampleCsvRowSerializer(BaseCsvRowSerializer):
77
+ payload_mode = "sample"
78
+
79
+ def serialize_payload(self, sample: Sample) -> Any:
80
+ return sample.as_full_payload()
81
+
82
+
83
+ class VectorCsvRowSerializer(BaseCsvRowSerializer):
84
+ payload_mode = "vector"
85
+
86
+ def serialize_payload(self, sample: Sample) -> Any:
87
+ return sample.as_vector_payload()
88
+
89
+
90
+ class BasePickleSerializer(BaseSerializer):
91
+ def __call__(self, sample: Sample) -> Any:
92
+ return self.serialize_payload(sample)
93
+
94
+
95
+ class SamplePickleSerializer(BasePickleSerializer):
96
+ payload_mode = "sample"
97
+
98
+ def serialize_payload(self, sample: Sample) -> Any:
99
+ return sample
100
+
101
+
102
+ class VectorPickleSerializer(BasePickleSerializer):
103
+ payload_mode = "vector"
104
+
105
+ def serialize_payload(self, sample: Sample) -> Any:
106
+ return sample.as_vector_payload()
107
+
108
+
109
+ def _record_payload(value: Any) -> Any:
110
+ if value is None:
111
+ return None
112
+ if is_dataclass(value):
113
+ return asdict(value)
114
+ if isinstance(value, dict):
115
+ return value
116
+ attrs = getattr(value, "__dict__", None)
117
+ if attrs:
118
+ return {
119
+ k: v
120
+ for k, v in attrs.items()
121
+ if not k.startswith("_")
122
+ }
123
+ return value
124
+
125
+
126
+ def _record_key(value: Any) -> Any:
127
+ direct = getattr(value, "time", None)
128
+ if direct is not None:
129
+ return direct
130
+ record = getattr(value, "record", None)
131
+ if record is not None:
132
+ return getattr(record, "time", None)
133
+ return None
134
+
135
+
136
+ class RecordJsonLineSerializer:
137
+ def __call__(self, record: Any) -> str:
138
+ payload = _record_payload(record)
139
+ return json.dumps(payload, ensure_ascii=False, default=str) + "\n"
140
+
141
+
142
+ class RecordPrintSerializer:
143
+ def __call__(self, record: Any) -> str:
144
+ return f"{_record_payload(record)}\n"
145
+
146
+
147
+ class RecordCsvRowSerializer:
148
+ def __call__(self, record: Any) -> list[str | Any]:
149
+ key_value = _record_key(record)
150
+ key_text = "" if key_value is None else str(key_value)
151
+ payload = json.dumps(_record_payload(record), ensure_ascii=False, default=str)
152
+ return [key_text, payload]
153
+
154
+
155
+ class RecordPickleSerializer:
156
+ def __call__(self, record: Any) -> Any:
157
+ return record
158
+
159
+
160
+ def _serializer_factory(
161
+ registry: Dict[str, Type[BaseSerializer]],
162
+ payload: str,
163
+ default_cls: Type[BaseSerializer],
164
+ ) -> BaseSerializer:
165
+ cls = registry.get(payload, default_cls)
166
+ return cls()
167
+
168
+
169
+ JSON_SERIALIZERS: Dict[str, Type[BaseJsonLineSerializer]] = {
170
+ SampleJsonLineSerializer.payload_mode: SampleJsonLineSerializer,
171
+ VectorJsonLineSerializer.payload_mode: VectorJsonLineSerializer,
172
+ }
173
+
174
+ PRINT_SERIALIZERS: Dict[str, Type[BasePrintSerializer]] = {
175
+ SamplePrintSerializer.payload_mode: SamplePrintSerializer,
176
+ VectorPrintSerializer.payload_mode: VectorPrintSerializer,
177
+ }
178
+
179
+ CSV_SERIALIZERS: Dict[str, Type[BaseCsvRowSerializer]] = {
180
+ SampleCsvRowSerializer.payload_mode: SampleCsvRowSerializer,
181
+ VectorCsvRowSerializer.payload_mode: VectorCsvRowSerializer,
182
+ }
183
+
184
+ PICKLE_SERIALIZERS: Dict[str, Type[BasePickleSerializer]] = {
185
+ SamplePickleSerializer.payload_mode: SamplePickleSerializer,
186
+ VectorPickleSerializer.payload_mode: VectorPickleSerializer,
187
+ }
188
+
189
+
190
+ def json_line_serializer(payload: str) -> BaseJsonLineSerializer:
191
+ return _serializer_factory(JSON_SERIALIZERS, payload, SampleJsonLineSerializer)
192
+
193
+
194
+ def print_serializer(payload: str) -> BasePrintSerializer:
195
+ return _serializer_factory(PRINT_SERIALIZERS, payload, SamplePrintSerializer)
196
+
197
+
198
+ def csv_row_serializer(payload: str) -> BaseCsvRowSerializer:
199
+ return _serializer_factory(CSV_SERIALIZERS, payload, SampleCsvRowSerializer)
200
+
201
+
202
+ def pickle_serializer(payload: str) -> BasePickleSerializer:
203
+ return _serializer_factory(PICKLE_SERIALIZERS, payload, SamplePickleSerializer)
204
+
205
+
206
+ def record_json_line_serializer() -> RecordJsonLineSerializer:
207
+ return RecordJsonLineSerializer()
208
+
209
+
210
+ def record_print_serializer() -> RecordPrintSerializer:
211
+ return RecordPrintSerializer()
212
+
213
+
214
+ def record_csv_row_serializer() -> RecordCsvRowSerializer:
215
+ return RecordCsvRowSerializer()
216
+
217
+
218
+ def record_pickle_serializer() -> RecordPickleSerializer:
219
+ return RecordPickleSerializer()
@@ -0,0 +1,23 @@
1
+ from .base import BaseSink
2
+ from .stdout import StdoutTextSink
3
+ from .rich import (
4
+ RichFormatter,
5
+ ReprRichFormatter,
6
+ JsonRichFormatter,
7
+ PlainRichFormatter,
8
+ RichStdoutSink,
9
+ )
10
+ from .files import AtomicTextFileSink, AtomicBinaryFileSink, GzipBinarySink
11
+
12
+ __all__ = [
13
+ "BaseSink",
14
+ "StdoutTextSink",
15
+ "RichFormatter",
16
+ "ReprRichFormatter",
17
+ "JsonRichFormatter",
18
+ "PlainRichFormatter",
19
+ "RichStdoutSink",
20
+ "AtomicTextFileSink",
21
+ "AtomicBinaryFileSink",
22
+ "GzipBinarySink",
23
+ ]
@@ -0,0 +1,2 @@
1
+ class BaseSink:
2
+ def close(self) -> None: ...
@@ -0,0 +1,79 @@
1
+ from pathlib import Path
2
+ import os
3
+ import tempfile
4
+ import gzip
5
+
6
+ from .base import BaseSink
7
+
8
+
9
+ class AtomicTextFileSink(BaseSink):
10
+ def __init__(self, dest: Path):
11
+ self._dest = dest
12
+ dest.parent.mkdir(parents=True, exist_ok=True)
13
+ self._tmp = Path(
14
+ tempfile.NamedTemporaryFile(dir=str(dest.parent), delete=False).name
15
+ )
16
+ self._fh = open(self._tmp, "w", encoding="utf-8")
17
+
18
+ @property
19
+ def file_path(self) -> Path:
20
+ return self._dest
21
+
22
+ @property
23
+ def fh(self):
24
+ return self._fh
25
+
26
+ def write_text(self, s: str) -> None:
27
+ self._fh.write(s)
28
+
29
+ def close(self) -> None:
30
+ self._fh.close()
31
+ os.replace(self._tmp, self._dest)
32
+
33
+
34
+ class AtomicBinaryFileSink(BaseSink):
35
+ def __init__(self, dest: Path):
36
+ self._dest = dest
37
+ dest.parent.mkdir(parents=True, exist_ok=True)
38
+ self._tmp = Path(
39
+ tempfile.NamedTemporaryFile(dir=str(dest.parent), delete=False).name
40
+ )
41
+ self._fh = open(self._tmp, "wb")
42
+
43
+ @property
44
+ def file_path(self) -> Path:
45
+ return self._dest
46
+
47
+ @property
48
+ def fh(self):
49
+ return self._fh
50
+
51
+ def write_bytes(self, b: bytes) -> None:
52
+ self._fh.write(b)
53
+
54
+ def close(self) -> None:
55
+ self._fh.close()
56
+ os.replace(self._tmp, self._dest)
57
+
58
+
59
+ class GzipBinarySink(BaseSink):
60
+ def __init__(self, dest: Path):
61
+ self._dest = dest
62
+ dest.parent.mkdir(parents=True, exist_ok=True)
63
+ self._tmp = Path(
64
+ tempfile.NamedTemporaryFile(dir=str(dest.parent), delete=False).name
65
+ )
66
+ self._raw = open(self._tmp, "wb")
67
+ self._fh = gzip.GzipFile(fileobj=self._raw, mode="wb")
68
+
69
+ @property
70
+ def file_path(self) -> Path:
71
+ return self._dest
72
+
73
+ def write_bytes(self, b: bytes) -> None:
74
+ self._fh.write(b)
75
+
76
+ def close(self) -> None:
77
+ self._fh.close()
78
+ self._raw.close()
79
+ os.replace(self._tmp, self._dest)
@@ -0,0 +1,57 @@
1
+ from typing import Protocol
2
+
3
+ from .stdout import StdoutTextSink
4
+
5
+
6
+ class RichFormatter(Protocol):
7
+ def render(self, console, text: str) -> None: ...
8
+
9
+
10
+ class ReprRichFormatter:
11
+ def __init__(self):
12
+ from rich.highlighter import ReprHighlighter
13
+
14
+ self._highlighter = ReprHighlighter()
15
+
16
+ def render(self, console, text: str) -> None:
17
+ console.print(self._highlighter(text))
18
+
19
+
20
+ class JsonRichFormatter:
21
+ def render(self, console, text: str) -> None:
22
+ import json as _json
23
+
24
+ stripped = text.strip()
25
+ if not stripped:
26
+ return
27
+ try:
28
+ data = _json.loads(stripped)
29
+ console.print_json(data=data)
30
+ except Exception:
31
+ console.print(stripped)
32
+
33
+
34
+ class PlainRichFormatter:
35
+ def render(self, console, text: str) -> None:
36
+ console.print(text)
37
+
38
+
39
+ class RichStdoutSink(StdoutTextSink):
40
+ def __init__(self, formatter: RichFormatter):
41
+ super().__init__()
42
+ try:
43
+ from rich.console import Console
44
+ except Exception: # pragma: no cover
45
+ self.console = None
46
+ else:
47
+ self.console = Console(
48
+ file=self.stream, markup=False, highlight=False, soft_wrap=True
49
+ )
50
+ self._formatter = formatter
51
+
52
+ def write_text(self, s: str) -> None:
53
+ if not self.console:
54
+ super().write_text(s)
55
+ return
56
+ text = s.rstrip("\n")
57
+ self._formatter.render(self.console, text)
@@ -0,0 +1,18 @@
1
+ import sys
2
+ from typing import Optional
3
+
4
+ from .base import BaseSink
5
+
6
+
7
+ class StdoutTextSink(BaseSink):
8
+ def __init__(self, stream: Optional[object] = None):
9
+ self.stream = stream or sys.stdout
10
+
11
+ def write_text(self, s: str) -> None:
12
+ self.stream.write(s)
13
+
14
+ def flush(self) -> None:
15
+ self.stream.flush()
16
+
17
+ def close(self) -> None:
18
+ self.flush()
@@ -0,0 +1,14 @@
1
+ from .base import LineWriter, HeaderJsonlMixin
2
+ from .jsonl import JsonLinesStdoutWriter, JsonLinesFileWriter, GzipJsonLinesWriter
3
+ from .csv_writer import CsvFileWriter
4
+ from .pickle_writer import PickleFileWriter
5
+
6
+ __all__ = [
7
+ "LineWriter",
8
+ "HeaderJsonlMixin",
9
+ "JsonLinesStdoutWriter",
10
+ "JsonLinesFileWriter",
11
+ "GzipJsonLinesWriter",
12
+ "CsvFileWriter",
13
+ "PickleFileWriter",
14
+ ]
@@ -0,0 +1,28 @@
1
+ import json
2
+ from typing import Optional
3
+
4
+ from datapipeline.io.protocols import HeaderCapable, Writer
5
+ from datapipeline.io.sinks import StdoutTextSink, AtomicTextFileSink
6
+
7
+
8
+ class LineWriter(Writer):
9
+ """Text line writer (uses a text sink + serializer)."""
10
+
11
+ def __init__(self, sink: StdoutTextSink | AtomicTextFileSink, serializer):
12
+ self.sink = sink
13
+ self.serializer = serializer
14
+
15
+ def write(self, item) -> None:
16
+ self.sink.write_text(self.serializer(item))
17
+
18
+ def close(self) -> None:
19
+ self.sink.close()
20
+
21
+
22
+ class HeaderJsonlMixin(HeaderCapable):
23
+ """Provide a header write by emitting one JSON line."""
24
+
25
+ def write_header(self, header: dict) -> None:
26
+ self.sink.write_text(
27
+ json.dumps({"__checkpoint__": header}, ensure_ascii=False) + "\n"
28
+ )
@@ -0,0 +1,25 @@
1
+ import csv
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ from datapipeline.io.serializers import csv_row_serializer, BaseCsvRowSerializer
6
+ from datapipeline.io.protocols import HasFilePath, Writer
7
+ from datapipeline.io.sinks import AtomicTextFileSink
8
+
9
+
10
+ class CsvFileWriter(Writer, HasFilePath):
11
+ def __init__(self, dest: Path, serializer: BaseCsvRowSerializer | None = None):
12
+ self.sink = AtomicTextFileSink(dest)
13
+ self.writer = csv.writer(self.sink.fh)
14
+ self.writer.writerow(["key", "values"])
15
+ self._serializer = serializer or csv_row_serializer("sample")
16
+
17
+ @property
18
+ def file_path(self) -> Optional[Path]:
19
+ return self.sink.file_path
20
+
21
+ def write(self, item) -> None:
22
+ self.writer.writerow(self._serializer(item))
23
+
24
+ def close(self) -> None:
25
+ self.sink.close()
@@ -0,0 +1,52 @@
1
+ from pathlib import Path
2
+ from typing import Optional
3
+ import json
4
+
5
+ from datapipeline.io.serializers import json_line_serializer, BaseJsonLineSerializer
6
+ from datapipeline.io.protocols import HeaderCapable, HasFilePath, Writer
7
+ from datapipeline.io.sinks import (
8
+ AtomicTextFileSink,
9
+ GzipBinarySink,
10
+ StdoutTextSink,
11
+ )
12
+
13
+ from .base import HeaderJsonlMixin, LineWriter
14
+
15
+
16
+ class JsonLinesStdoutWriter(LineWriter, HeaderJsonlMixin):
17
+ def __init__(self, serializer: BaseJsonLineSerializer | None = None):
18
+ super().__init__(StdoutTextSink(), serializer or json_line_serializer("sample"))
19
+
20
+
21
+ class JsonLinesFileWriter(LineWriter, HeaderJsonlMixin, HasFilePath):
22
+ def __init__(self, dest: Path, serializer: BaseJsonLineSerializer | None = None):
23
+ self._sink = AtomicTextFileSink(dest)
24
+ super().__init__(self._sink, serializer or json_line_serializer("sample"))
25
+
26
+ @property
27
+ def file_path(self) -> Optional[Path]:
28
+ return self._sink.file_path
29
+
30
+
31
+ class GzipJsonLinesWriter(Writer, HeaderCapable, HasFilePath):
32
+ def __init__(self, dest: Path, serializer: BaseJsonLineSerializer | None = None):
33
+ self.sink = GzipBinarySink(dest)
34
+ self._serializer = serializer or json_line_serializer("sample")
35
+
36
+ @property
37
+ def file_path(self) -> Optional[Path]:
38
+ return self.sink.file_path
39
+
40
+ def write_header(self, header: dict) -> None:
41
+ self.sink.write_bytes(
42
+ (json.dumps({"__checkpoint__": header}, ensure_ascii=False) + "\n").encode(
43
+ "utf-8"
44
+ )
45
+ )
46
+
47
+ def write(self, item) -> None:
48
+ line = self._serializer(item)
49
+ self.sink.write_bytes(line.encode("utf-8"))
50
+
51
+ def close(self) -> None:
52
+ self.sink.close()
@@ -0,0 +1,30 @@
1
+ import pickle
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ from datapipeline.io.serializers import pickle_serializer, BasePickleSerializer
6
+ from datapipeline.io.protocols import HasFilePath, Writer
7
+ from datapipeline.io.sinks import AtomicBinaryFileSink
8
+
9
+
10
+ class PickleFileWriter(Writer, HasFilePath):
11
+ def __init__(
12
+ self,
13
+ dest: Path,
14
+ serializer: BasePickleSerializer | None = None,
15
+ protocol: int = pickle.HIGHEST_PROTOCOL,
16
+ ):
17
+ self.sink = AtomicBinaryFileSink(dest)
18
+ self.pickler = pickle.Pickler(self.sink.fh, protocol=protocol)
19
+ self._serializer = serializer or pickle_serializer("sample")
20
+
21
+ @property
22
+ def file_path(self) -> Optional[Path]:
23
+ return self.sink.file_path
24
+
25
+ def write(self, item) -> None:
26
+ self.pickler.dump(self._serializer(item))
27
+ self.pickler.clear_memo()
28
+
29
+ def close(self) -> None:
30
+ self.sink.close()
@@ -0,0 +1,58 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Iterable
5
+
6
+ from datapipeline.config.dataset.dataset import FeatureDatasetConfig
7
+ from datapipeline.config.dataset.feature import FeatureRecordConfig
8
+ from datapipeline.services.constants import (
9
+ SCALER_STATISTICS,
10
+ VECTOR_SCHEMA,
11
+ VECTOR_SCHEMA_METADATA,
12
+ )
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class StageDemand:
17
+ stage: int | None
18
+
19
+
20
+ def _needs_scaler(configs: Iterable[FeatureRecordConfig]) -> bool:
21
+ for cfg in configs:
22
+ scale = getattr(cfg, "scale", False)
23
+ if isinstance(scale, dict):
24
+ return True
25
+ if bool(scale):
26
+ return True
27
+ return False
28
+
29
+
30
+ def _requires_scaler(dataset: FeatureDatasetConfig) -> bool:
31
+ if _needs_scaler(dataset.features or []):
32
+ return True
33
+ if dataset.targets:
34
+ return _needs_scaler(dataset.targets)
35
+ return False
36
+
37
+
38
+ def required_artifacts_for(
39
+ dataset: FeatureDatasetConfig,
40
+ demands: Iterable[StageDemand],
41
+ ) -> set[str]:
42
+ required: set[str] = set()
43
+ needs_metadata = False
44
+ for demand in demands:
45
+ stage = demand.stage
46
+ effective_stage = 7 if stage is None else stage
47
+
48
+ if effective_stage >= 5 and _requires_scaler(dataset):
49
+ required.add(SCALER_STATISTICS)
50
+
51
+ if effective_stage >= 6:
52
+ required.add(VECTOR_SCHEMA)
53
+ needs_metadata = True
54
+
55
+ if needs_metadata:
56
+ required.add(VECTOR_SCHEMA_METADATA)
57
+
58
+ return required