ai-pipeline-core 0.1.7__py3-none-any.whl → 0.1.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_pipeline_core/__init__.py +7 -5
- ai_pipeline_core/documents/__init__.py +2 -0
- ai_pipeline_core/documents/document.py +131 -23
- ai_pipeline_core/documents/temporary_document.py +16 -0
- ai_pipeline_core/flow/config.py +40 -1
- ai_pipeline_core/llm/model_options.py +4 -0
- ai_pipeline_core/pipeline.py +313 -317
- ai_pipeline_core/prompt_manager.py +7 -1
- ai_pipeline_core/simple_runner/cli.py +87 -12
- ai_pipeline_core/simple_runner/simple_runner.py +7 -2
- {ai_pipeline_core-0.1.7.dist-info → ai_pipeline_core-0.1.10.dist-info}/METADATA +35 -38
- {ai_pipeline_core-0.1.7.dist-info → ai_pipeline_core-0.1.10.dist-info}/RECORD +14 -13
- {ai_pipeline_core-0.1.7.dist-info → ai_pipeline_core-0.1.10.dist-info}/WHEEL +0 -0
- {ai_pipeline_core-0.1.7.dist-info → ai_pipeline_core-0.1.10.dist-info}/licenses/LICENSE +0 -0
ai_pipeline_core/__init__.py
CHANGED
|
@@ -6,6 +6,7 @@ from .documents import (
|
|
|
6
6
|
DocumentList,
|
|
7
7
|
FlowDocument,
|
|
8
8
|
TaskDocument,
|
|
9
|
+
TemporaryDocument,
|
|
9
10
|
canonical_name_key,
|
|
10
11
|
sanitize_url,
|
|
11
12
|
)
|
|
@@ -27,12 +28,12 @@ from .logging import (
|
|
|
27
28
|
)
|
|
28
29
|
from .logging import get_pipeline_logger as get_logger
|
|
29
30
|
from .pipeline import pipeline_flow, pipeline_task
|
|
30
|
-
from .prefect import
|
|
31
|
+
from .prefect import disable_run_logger, prefect_test_harness
|
|
31
32
|
from .prompt_manager import PromptManager
|
|
32
33
|
from .settings import settings
|
|
33
34
|
from .tracing import TraceInfo, TraceLevel, trace
|
|
34
35
|
|
|
35
|
-
__version__ = "0.1.
|
|
36
|
+
__version__ = "0.1.10"
|
|
36
37
|
|
|
37
38
|
__all__ = [
|
|
38
39
|
# Config/Settings
|
|
@@ -49,17 +50,18 @@ __all__ = [
|
|
|
49
50
|
"DocumentList",
|
|
50
51
|
"FlowDocument",
|
|
51
52
|
"TaskDocument",
|
|
53
|
+
"TemporaryDocument",
|
|
52
54
|
"canonical_name_key",
|
|
53
55
|
"sanitize_url",
|
|
54
56
|
# Flow/Task
|
|
55
57
|
"FlowConfig",
|
|
56
58
|
"FlowOptions",
|
|
57
|
-
# Prefect decorators (clean, no tracing)
|
|
58
|
-
"task",
|
|
59
|
-
"flow",
|
|
60
59
|
# Pipeline decorators (with tracing)
|
|
61
60
|
"pipeline_task",
|
|
62
61
|
"pipeline_flow",
|
|
62
|
+
# Prefect decorators (clean, no tracing)
|
|
63
|
+
"prefect_test_harness",
|
|
64
|
+
"disable_run_logger",
|
|
63
65
|
# LLM
|
|
64
66
|
"llm",
|
|
65
67
|
"ModelName",
|
|
@@ -2,6 +2,7 @@ from .document import Document
|
|
|
2
2
|
from .document_list import DocumentList
|
|
3
3
|
from .flow_document import FlowDocument
|
|
4
4
|
from .task_document import TaskDocument
|
|
5
|
+
from .temporary_document import TemporaryDocument
|
|
5
6
|
from .utils import canonical_name_key, sanitize_url
|
|
6
7
|
|
|
7
8
|
__all__ = [
|
|
@@ -9,6 +10,7 @@ __all__ = [
|
|
|
9
10
|
"DocumentList",
|
|
10
11
|
"FlowDocument",
|
|
11
12
|
"TaskDocument",
|
|
13
|
+
"TemporaryDocument",
|
|
12
14
|
"canonical_name_key",
|
|
13
15
|
"sanitize_url",
|
|
14
16
|
]
|
|
@@ -6,7 +6,19 @@ from abc import ABC, abstractmethod
|
|
|
6
6
|
from base64 import b32encode
|
|
7
7
|
from enum import StrEnum
|
|
8
8
|
from functools import cached_property
|
|
9
|
-
from
|
|
9
|
+
from io import BytesIO
|
|
10
|
+
from typing import (
|
|
11
|
+
Any,
|
|
12
|
+
ClassVar,
|
|
13
|
+
Literal,
|
|
14
|
+
Self,
|
|
15
|
+
TypeVar,
|
|
16
|
+
cast,
|
|
17
|
+
final,
|
|
18
|
+
get_args,
|
|
19
|
+
get_origin,
|
|
20
|
+
overload,
|
|
21
|
+
)
|
|
10
22
|
|
|
11
23
|
from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
|
|
12
24
|
from ruamel.yaml import YAML
|
|
@@ -23,64 +35,107 @@ from .mime_type import (
|
|
|
23
35
|
)
|
|
24
36
|
|
|
25
37
|
TModel = TypeVar("TModel", bound=BaseModel)
|
|
38
|
+
ContentInput = bytes | str | BaseModel | list[str] | Any
|
|
26
39
|
|
|
27
40
|
|
|
28
41
|
class Document(BaseModel, ABC):
|
|
29
|
-
"""Abstract base class for all documents
|
|
42
|
+
"""Abstract base class for all documents.
|
|
43
|
+
|
|
44
|
+
Warning: Document subclasses should NOT start with 'Test' prefix as this
|
|
45
|
+
causes conflicts with pytest test discovery. Classes with 'Test' prefix
|
|
46
|
+
will be rejected at definition time.
|
|
47
|
+
"""
|
|
30
48
|
|
|
31
49
|
MAX_CONTENT_SIZE: ClassVar[int] = 25 * 1024 * 1024 # 25MB default
|
|
32
50
|
DESCRIPTION_EXTENSION: ClassVar[str] = ".description.md"
|
|
33
51
|
MARKDOWN_LIST_SEPARATOR: ClassVar[str] = "\n\n---\n\n"
|
|
34
52
|
|
|
53
|
+
def __init_subclass__(cls, **kwargs: Any) -> None:
|
|
54
|
+
"""Validate subclass names to prevent pytest conflicts."""
|
|
55
|
+
super().__init_subclass__(**kwargs)
|
|
56
|
+
if cls.__name__.startswith("Test"):
|
|
57
|
+
raise TypeError(
|
|
58
|
+
f"Document subclass '{cls.__name__}' cannot start with 'Test' prefix. "
|
|
59
|
+
"This causes conflicts with pytest test discovery. "
|
|
60
|
+
"Please use a different name (e.g., 'SampleDocument', 'ExampleDocument')."
|
|
61
|
+
)
|
|
62
|
+
if hasattr(cls, "FILES"):
|
|
63
|
+
files = getattr(cls, "FILES")
|
|
64
|
+
if not issubclass(files, StrEnum):
|
|
65
|
+
raise TypeError(
|
|
66
|
+
f"Document subclass '{cls.__name__}'.FILES must be an Enum of string values"
|
|
67
|
+
)
|
|
68
|
+
# Check that the Document's model_fields only contain the allowed fields
|
|
69
|
+
# It prevents AI models from adding additional fields to documents
|
|
70
|
+
allowed = {"name", "description", "content"}
|
|
71
|
+
current = set(getattr(cls, "model_fields", {}).keys())
|
|
72
|
+
extras = current - allowed
|
|
73
|
+
if extras:
|
|
74
|
+
raise TypeError(
|
|
75
|
+
f"Document subclass '{cls.__name__}' cannot declare additional fields: "
|
|
76
|
+
f"{', '.join(sorted(extras))}. Only {', '.join(sorted(allowed))} are allowed."
|
|
77
|
+
)
|
|
78
|
+
|
|
35
79
|
def __init__(self, **data: Any) -> None:
|
|
36
80
|
"""Prevent direct instantiation of abstract Document class."""
|
|
37
81
|
if type(self) is Document:
|
|
38
82
|
raise TypeError("Cannot instantiate abstract Document class directly")
|
|
39
83
|
super().__init__(**data)
|
|
40
84
|
|
|
41
|
-
# Optional enum of allowed file names. Subclasses may set this.
|
|
42
|
-
# This is used to validate the document name.
|
|
43
|
-
FILES: ClassVar[type[StrEnum] | None] = None
|
|
44
|
-
|
|
45
85
|
name: str
|
|
46
86
|
description: str | None = None
|
|
47
87
|
content: bytes
|
|
48
88
|
|
|
49
89
|
# Pydantic configuration
|
|
50
90
|
model_config = ConfigDict(
|
|
51
|
-
frozen=True,
|
|
91
|
+
frozen=True,
|
|
52
92
|
arbitrary_types_allowed=True,
|
|
93
|
+
extra="forbid",
|
|
53
94
|
)
|
|
54
95
|
|
|
55
96
|
@abstractmethod
|
|
56
|
-
def get_base_type(self) -> Literal["flow", "task"]:
|
|
97
|
+
def get_base_type(self) -> Literal["flow", "task", "temporary"]:
|
|
57
98
|
"""Get the type of the document - must be implemented by subclasses"""
|
|
58
99
|
raise NotImplementedError("Subclasses must implement this method")
|
|
59
100
|
|
|
101
|
+
@final
|
|
60
102
|
@property
|
|
61
|
-
def base_type(self) -> Literal["flow", "task"]:
|
|
103
|
+
def base_type(self) -> Literal["flow", "task", "temporary"]:
|
|
62
104
|
"""Alias for document_type for backward compatibility"""
|
|
63
105
|
return self.get_base_type()
|
|
64
106
|
|
|
107
|
+
@final
|
|
65
108
|
@property
|
|
66
109
|
def is_flow(self) -> bool:
|
|
67
110
|
"""Check if document is a flow document"""
|
|
68
111
|
return self.get_base_type() == "flow"
|
|
69
112
|
|
|
113
|
+
@final
|
|
70
114
|
@property
|
|
71
115
|
def is_task(self) -> bool:
|
|
72
116
|
"""Check if document is a task document"""
|
|
73
117
|
return self.get_base_type() == "task"
|
|
74
118
|
|
|
119
|
+
@final
|
|
120
|
+
@property
|
|
121
|
+
def is_temporary(self) -> bool:
|
|
122
|
+
"""Check if document is a temporary document"""
|
|
123
|
+
return self.get_base_type() == "temporary"
|
|
124
|
+
|
|
125
|
+
@final
|
|
75
126
|
@classmethod
|
|
76
127
|
def get_expected_files(cls) -> list[str] | None:
|
|
77
128
|
"""
|
|
78
129
|
Return the list of allowed file names for this document class, or None if unrestricted.
|
|
79
130
|
"""
|
|
80
|
-
if cls
|
|
131
|
+
if not hasattr(cls, "FILES"):
|
|
132
|
+
return None
|
|
133
|
+
files = getattr(cls, "FILES")
|
|
134
|
+
if not files:
|
|
81
135
|
return None
|
|
136
|
+
assert issubclass(files, StrEnum)
|
|
82
137
|
try:
|
|
83
|
-
values = [member.value for member in
|
|
138
|
+
values = [member.value for member in files]
|
|
84
139
|
except TypeError:
|
|
85
140
|
raise DocumentNameError(f"{cls.__name__}.FILES must be an Enum of string values")
|
|
86
141
|
if len(values) == 0:
|
|
@@ -100,14 +155,10 @@ class Document(BaseModel, ABC):
|
|
|
100
155
|
Override this method in subclasses for custom conventions (regex, prefixes, etc.).
|
|
101
156
|
Raise DocumentNameError when invalid.
|
|
102
157
|
"""
|
|
103
|
-
|
|
158
|
+
allowed = cls.get_expected_files()
|
|
159
|
+
if not allowed:
|
|
104
160
|
return
|
|
105
161
|
|
|
106
|
-
try:
|
|
107
|
-
allowed = {str(member.value) for member in cls.FILES} # type: ignore[arg-type]
|
|
108
|
-
except TypeError:
|
|
109
|
-
raise DocumentNameError(f"{cls.__name__}.FILES must be an Enum of string values")
|
|
110
|
-
|
|
111
162
|
if len(allowed) > 0 and name not in allowed:
|
|
112
163
|
allowed_str = ", ".join(sorted(allowed))
|
|
113
164
|
raise DocumentNameError(f"Invalid filename '{name}'. Allowed names: {allowed_str}")
|
|
@@ -151,16 +202,19 @@ class Document(BaseModel, ABC):
|
|
|
151
202
|
# Fall back to base64 for binary content
|
|
152
203
|
return base64.b64encode(v).decode("ascii")
|
|
153
204
|
|
|
205
|
+
@final
|
|
154
206
|
@property
|
|
155
207
|
def id(self) -> str:
|
|
156
208
|
"""Return the first 6 characters of the SHA256 hash of the content, encoded in base32"""
|
|
157
209
|
return self.sha256[:6]
|
|
158
210
|
|
|
211
|
+
@final
|
|
159
212
|
@cached_property
|
|
160
213
|
def sha256(self) -> str:
|
|
161
214
|
"""Full SHA256 hash of content, encoded in base32"""
|
|
162
215
|
return b32encode(hashlib.sha256(self.content).digest()).decode("ascii").upper()
|
|
163
216
|
|
|
217
|
+
@final
|
|
164
218
|
@property
|
|
165
219
|
def size(self) -> int:
|
|
166
220
|
"""Size of content in bytes"""
|
|
@@ -210,23 +264,61 @@ class Document(BaseModel, ABC):
|
|
|
210
264
|
"""Parse document as JSON"""
|
|
211
265
|
return json.loads(self.as_text())
|
|
212
266
|
|
|
213
|
-
|
|
267
|
+
@overload
|
|
268
|
+
def as_pydantic_model(self, model_type: type[TModel]) -> TModel: ...
|
|
269
|
+
|
|
270
|
+
@overload
|
|
271
|
+
def as_pydantic_model(self, model_type: type[list[TModel]]) -> list[TModel]: ...
|
|
272
|
+
|
|
273
|
+
def as_pydantic_model(
|
|
274
|
+
self, model_type: type[TModel] | type[list[TModel]]
|
|
275
|
+
) -> TModel | list[TModel]:
|
|
214
276
|
"""Parse document as a pydantic model and return the validated instance"""
|
|
215
277
|
data = self.as_yaml() if is_yaml_mime_type(self.mime_type) else self.as_json()
|
|
216
|
-
|
|
278
|
+
|
|
279
|
+
if get_origin(model_type) is list:
|
|
280
|
+
if not isinstance(data, list):
|
|
281
|
+
raise ValueError(f"Expected list data for {model_type}, got {type(data)}")
|
|
282
|
+
item_type = get_args(model_type)[0]
|
|
283
|
+
return [item_type.model_validate(item) for item in data]
|
|
284
|
+
|
|
285
|
+
# At this point model_type must be type[TModel], not type[list[TModel]]
|
|
286
|
+
single_model = cast(type[TModel], model_type)
|
|
287
|
+
return single_model.model_validate(data)
|
|
217
288
|
|
|
218
289
|
def as_markdown_list(self) -> list[str]:
|
|
219
290
|
"""Parse document as a markdown list"""
|
|
220
291
|
return self.as_text().split(self.MARKDOWN_LIST_SEPARATOR)
|
|
221
292
|
|
|
293
|
+
@overload
|
|
294
|
+
@classmethod
|
|
295
|
+
def create(cls, name: str, content: ContentInput, /) -> Self: ...
|
|
296
|
+
@overload
|
|
297
|
+
@classmethod
|
|
298
|
+
def create(cls, name: str, *, content: ContentInput) -> Self: ...
|
|
299
|
+
@overload
|
|
300
|
+
@classmethod
|
|
301
|
+
def create(cls, name: str, description: str | None, content: ContentInput, /) -> Self: ...
|
|
302
|
+
@overload
|
|
303
|
+
@classmethod
|
|
304
|
+
def create(cls, name: str, description: str | None, *, content: ContentInput) -> Self: ...
|
|
305
|
+
|
|
222
306
|
@classmethod
|
|
223
307
|
def create(
|
|
224
308
|
cls,
|
|
225
309
|
name: str,
|
|
226
|
-
description:
|
|
227
|
-
content:
|
|
310
|
+
description: ContentInput = None,
|
|
311
|
+
content: ContentInput = None,
|
|
228
312
|
) -> Self:
|
|
229
313
|
"""Create a document from a name, description, and content"""
|
|
314
|
+
if content is None:
|
|
315
|
+
if description is None:
|
|
316
|
+
raise ValueError(f"Unsupported content type: {type(content)} for {name}")
|
|
317
|
+
content = description
|
|
318
|
+
description = None
|
|
319
|
+
else:
|
|
320
|
+
assert description is None or isinstance(description, str)
|
|
321
|
+
|
|
230
322
|
is_yaml_extension = name.endswith(".yaml") or name.endswith(".yml")
|
|
231
323
|
is_json_extension = name.endswith(".json")
|
|
232
324
|
is_markdown_extension = name.endswith(".md")
|
|
@@ -237,6 +329,14 @@ class Document(BaseModel, ABC):
|
|
|
237
329
|
content = content.encode("utf-8")
|
|
238
330
|
elif is_str_list and is_markdown_extension:
|
|
239
331
|
return cls.create_as_markdown_list(name, description, content) # type: ignore[arg-type]
|
|
332
|
+
elif isinstance(content, list) and all(isinstance(item, BaseModel) for item in content):
|
|
333
|
+
# Handle list[BaseModel] for JSON/YAML files
|
|
334
|
+
if is_yaml_extension:
|
|
335
|
+
return cls.create_as_yaml(name, description, content)
|
|
336
|
+
elif is_json_extension:
|
|
337
|
+
return cls.create_as_json(name, description, content)
|
|
338
|
+
else:
|
|
339
|
+
raise ValueError(f"list[BaseModel] requires .json or .yaml extension, got {name}")
|
|
240
340
|
elif is_yaml_extension:
|
|
241
341
|
return cls.create_as_yaml(name, description, content)
|
|
242
342
|
elif is_json_extension:
|
|
@@ -246,6 +346,7 @@ class Document(BaseModel, ABC):
|
|
|
246
346
|
|
|
247
347
|
return cls(name=name, description=description, content=content)
|
|
248
348
|
|
|
349
|
+
@final
|
|
249
350
|
@classmethod
|
|
250
351
|
def create_as_markdown_list(cls, name: str, description: str | None, items: list[str]) -> Self:
|
|
251
352
|
"""Create a document from a name, description, and list of strings"""
|
|
@@ -258,15 +359,19 @@ class Document(BaseModel, ABC):
|
|
|
258
359
|
content = Document.MARKDOWN_LIST_SEPARATOR.join(cleaned_items)
|
|
259
360
|
return cls.create(name, description, content)
|
|
260
361
|
|
|
362
|
+
@final
|
|
261
363
|
@classmethod
|
|
262
364
|
def create_as_json(cls, name: str, description: str | None, data: Any) -> Self:
|
|
263
365
|
"""Create a document from a name, description, and JSON data"""
|
|
264
366
|
assert name.endswith(".json"), f"Document name must end with .json: {name}"
|
|
265
367
|
if isinstance(data, BaseModel):
|
|
266
368
|
data = data.model_dump(mode="json")
|
|
369
|
+
elif isinstance(data, list) and all(isinstance(item, BaseModel) for item in data):
|
|
370
|
+
data = [item.model_dump(mode="json") for item in data]
|
|
267
371
|
content = json.dumps(data, indent=2).encode("utf-8")
|
|
268
372
|
return cls.create(name, description, content)
|
|
269
373
|
|
|
374
|
+
@final
|
|
270
375
|
@classmethod
|
|
271
376
|
def create_as_yaml(cls, name: str, description: str | None, data: Any) -> Self:
|
|
272
377
|
"""Create a document from a name, description, and YAML data"""
|
|
@@ -274,16 +379,18 @@ class Document(BaseModel, ABC):
|
|
|
274
379
|
f"Document name must end with .yaml or .yml: {name}"
|
|
275
380
|
)
|
|
276
381
|
if isinstance(data, BaseModel):
|
|
277
|
-
data = data.model_dump()
|
|
382
|
+
data = data.model_dump(mode="json")
|
|
383
|
+
elif isinstance(data, list) and all(isinstance(item, BaseModel) for item in data):
|
|
384
|
+
data = [item.model_dump(mode="json") for item in data]
|
|
278
385
|
yaml = YAML()
|
|
279
386
|
yaml.indent(mapping=2, sequence=4, offset=2)
|
|
280
|
-
from io import BytesIO
|
|
281
387
|
|
|
282
388
|
stream = BytesIO()
|
|
283
389
|
yaml.dump(data, stream)
|
|
284
390
|
content = stream.getvalue()
|
|
285
391
|
return cls.create(name, description, content)
|
|
286
392
|
|
|
393
|
+
@final
|
|
287
394
|
def serialize_model(self) -> dict[str, Any]:
|
|
288
395
|
"""Serialize document to a dictionary with proper encoding."""
|
|
289
396
|
result = {
|
|
@@ -312,6 +419,7 @@ class Document(BaseModel, ABC):
|
|
|
312
419
|
|
|
313
420
|
return result
|
|
314
421
|
|
|
422
|
+
@final
|
|
315
423
|
@classmethod
|
|
316
424
|
def from_dict(cls, data: dict[str, Any]) -> Self:
|
|
317
425
|
"""Deserialize document from dictionary."""
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""Task-specific document base class."""
|
|
2
|
+
|
|
3
|
+
from typing import Literal, final
|
|
4
|
+
|
|
5
|
+
from .document import Document
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@final
|
|
9
|
+
class TemporaryDocument(Document):
|
|
10
|
+
"""
|
|
11
|
+
Temporary document is a document that is not persisted in any case.
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
def get_base_type(self) -> Literal["temporary"]:
|
|
15
|
+
"""Get the document type."""
|
|
16
|
+
return "temporary"
|
ai_pipeline_core/flow/config.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
"""Flow configuration base class."""
|
|
2
2
|
|
|
3
3
|
from abc import ABC
|
|
4
|
-
from typing import ClassVar
|
|
4
|
+
from typing import Any, ClassVar
|
|
5
5
|
|
|
6
6
|
from ai_pipeline_core.documents import DocumentList, FlowDocument
|
|
7
7
|
|
|
@@ -14,6 +14,27 @@ class FlowConfig(ABC):
|
|
|
14
14
|
INPUT_DOCUMENT_TYPES: ClassVar[list[type[FlowDocument]]]
|
|
15
15
|
OUTPUT_DOCUMENT_TYPE: ClassVar[type[FlowDocument]]
|
|
16
16
|
|
|
17
|
+
def __init_subclass__(cls, **kwargs: Any):
|
|
18
|
+
"""Validate that OUTPUT_DOCUMENT_TYPE is not in INPUT_DOCUMENT_TYPES."""
|
|
19
|
+
super().__init_subclass__(**kwargs)
|
|
20
|
+
|
|
21
|
+
# Skip validation for the abstract base class itself
|
|
22
|
+
if cls.__name__ == "FlowConfig":
|
|
23
|
+
return
|
|
24
|
+
|
|
25
|
+
# Ensure required attributes are defined
|
|
26
|
+
if not hasattr(cls, "INPUT_DOCUMENT_TYPES"):
|
|
27
|
+
raise TypeError(f"FlowConfig {cls.__name__} must define INPUT_DOCUMENT_TYPES")
|
|
28
|
+
if not hasattr(cls, "OUTPUT_DOCUMENT_TYPE"):
|
|
29
|
+
raise TypeError(f"FlowConfig {cls.__name__} must define OUTPUT_DOCUMENT_TYPE")
|
|
30
|
+
|
|
31
|
+
# Validate that output type is not in input types
|
|
32
|
+
if cls.OUTPUT_DOCUMENT_TYPE in cls.INPUT_DOCUMENT_TYPES:
|
|
33
|
+
raise TypeError(
|
|
34
|
+
f"FlowConfig {cls.__name__}: OUTPUT_DOCUMENT_TYPE "
|
|
35
|
+
f"({cls.OUTPUT_DOCUMENT_TYPE.__name__}) cannot be in INPUT_DOCUMENT_TYPES"
|
|
36
|
+
)
|
|
37
|
+
|
|
17
38
|
@classmethod
|
|
18
39
|
def get_input_document_types(cls) -> list[type[FlowDocument]]:
|
|
19
40
|
"""
|
|
@@ -64,3 +85,21 @@ class FlowConfig(ABC):
|
|
|
64
85
|
"Documents must be of the correct type. "
|
|
65
86
|
f"Expected: {output_document_class.__name__}, Got invalid: {invalid}"
|
|
66
87
|
)
|
|
88
|
+
|
|
89
|
+
@classmethod
|
|
90
|
+
def create_and_validate_output(
|
|
91
|
+
cls, output: FlowDocument | list[FlowDocument] | DocumentList
|
|
92
|
+
) -> DocumentList:
|
|
93
|
+
"""
|
|
94
|
+
Create the output documents for the flow.
|
|
95
|
+
"""
|
|
96
|
+
documents: DocumentList
|
|
97
|
+
if isinstance(output, FlowDocument):
|
|
98
|
+
documents = DocumentList([output])
|
|
99
|
+
elif isinstance(output, DocumentList):
|
|
100
|
+
documents = output
|
|
101
|
+
else:
|
|
102
|
+
assert isinstance(output, list)
|
|
103
|
+
documents = DocumentList(output) # type: ignore[arg-type]
|
|
104
|
+
cls.validate_output_documents(documents)
|
|
105
|
+
return documents
|
|
@@ -4,6 +4,7 @@ from pydantic import BaseModel
|
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
class ModelOptions(BaseModel):
|
|
7
|
+
temperature: float | None = None
|
|
7
8
|
system_prompt: str | None = None
|
|
8
9
|
search_context_size: Literal["low", "medium", "high"] | None = None
|
|
9
10
|
reasoning_effort: Literal["low", "medium", "high"] | None = None
|
|
@@ -21,6 +22,9 @@ class ModelOptions(BaseModel):
|
|
|
21
22
|
"extra_body": {},
|
|
22
23
|
}
|
|
23
24
|
|
|
25
|
+
if self.temperature:
|
|
26
|
+
kwargs["temperature"] = self.temperature
|
|
27
|
+
|
|
24
28
|
if self.max_completion_tokens:
|
|
25
29
|
kwargs["max_completion_tokens"] = self.max_completion_tokens
|
|
26
30
|
|