ai-pipeline-core 0.1.14__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. ai_pipeline_core/__init__.py +21 -13
  2. ai_pipeline_core/documents/document.py +202 -51
  3. ai_pipeline_core/documents/document_list.py +148 -24
  4. ai_pipeline_core/documents/flow_document.py +2 -6
  5. ai_pipeline_core/documents/task_document.py +0 -4
  6. ai_pipeline_core/documents/temporary_document.py +1 -8
  7. ai_pipeline_core/flow/config.py +174 -5
  8. ai_pipeline_core/llm/__init__.py +1 -6
  9. ai_pipeline_core/llm/ai_messages.py +137 -4
  10. ai_pipeline_core/llm/client.py +118 -65
  11. ai_pipeline_core/llm/model_options.py +6 -7
  12. ai_pipeline_core/llm/model_response.py +17 -16
  13. ai_pipeline_core/llm/model_types.py +3 -7
  14. ai_pipeline_core/logging/__init__.py +0 -2
  15. ai_pipeline_core/logging/logging_config.py +0 -6
  16. ai_pipeline_core/logging/logging_mixin.py +2 -10
  17. ai_pipeline_core/pipeline.py +54 -68
  18. ai_pipeline_core/prefect.py +12 -3
  19. ai_pipeline_core/prompt_manager.py +14 -7
  20. ai_pipeline_core/settings.py +13 -5
  21. ai_pipeline_core/simple_runner/__init__.py +1 -11
  22. ai_pipeline_core/simple_runner/cli.py +13 -12
  23. ai_pipeline_core/simple_runner/simple_runner.py +34 -189
  24. ai_pipeline_core/storage/__init__.py +8 -0
  25. ai_pipeline_core/storage/storage.py +628 -0
  26. ai_pipeline_core/tracing.py +234 -30
  27. {ai_pipeline_core-0.1.14.dist-info → ai_pipeline_core-0.2.1.dist-info}/METADATA +35 -20
  28. ai_pipeline_core-0.2.1.dist-info/RECORD +38 -0
  29. ai_pipeline_core-0.1.14.dist-info/RECORD +0 -36
  30. {ai_pipeline_core-0.1.14.dist-info → ai_pipeline_core-0.2.1.dist-info}/WHEEL +0 -0
  31. {ai_pipeline_core-0.1.14.dist-info → ai_pipeline_core-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -1,7 +1,5 @@
1
1
  """Logging infrastructure for AI Pipeline Core.
2
2
 
3
- @public
4
-
5
3
  Provides a Prefect-integrated logging facade for unified logging across pipelines.
6
4
  Prefer get_pipeline_logger instead of logging.getLogger to ensure proper integration.
7
5
 
@@ -1,7 +1,5 @@
1
1
  """Centralized logging configuration for AI Pipeline Core.
2
2
 
3
- @public
4
-
5
3
  Provides logging configuration management that integrates with Prefect's logging system.
6
4
  """
7
5
 
@@ -26,8 +24,6 @@ DEFAULT_LOG_LEVELS = {
26
24
  class LoggingConfig:
27
25
  """Manages logging configuration for the pipeline.
28
26
 
29
- @public
30
-
31
27
  Provides centralized logging configuration with Prefect integration.
32
28
 
33
29
  Configuration precedence:
@@ -144,8 +140,6 @@ _logging_config: Optional[LoggingConfig] = None
144
140
  def setup_logging(config_path: Optional[Path] = None, level: Optional[str] = None):
145
141
  """Setup logging for the AI Pipeline Core library.
146
142
 
147
- @public
148
-
149
143
  Initializes logging configuration for the pipeline system.
150
144
 
151
145
  IMPORTANT: Call setup_logging exactly once in your application entry point
@@ -1,7 +1,4 @@
1
- """Logging mixin for consistent logging across components using Prefect logging.
2
-
3
- @public
4
- """
1
+ """Logging mixin for consistent logging across components using Prefect logging."""
5
2
 
6
3
  import contextlib
7
4
  import time
@@ -17,8 +14,6 @@ from prefect.logging import get_logger
17
14
  class LoggerMixin:
18
15
  """Mixin class that provides consistent logging functionality using Prefect's logging system.
19
16
 
20
- @public
21
-
22
17
  Note for users: In your code, always obtain loggers via get_pipeline_logger(__name__).
23
18
  The mixin's internal behavior routes to the appropriate backend; you should not call
24
19
  logging.getLogger directly.
@@ -94,10 +89,7 @@ class LoggerMixin:
94
89
 
95
90
 
96
91
  class StructuredLoggerMixin(LoggerMixin):
97
- """Extended mixin for structured logging with Prefect.
98
-
99
- @public
100
- """
92
+ """Extended mixin for structured logging with Prefect."""
101
93
 
102
94
  def log_event(self, event: str, **kwargs: Any) -> None:
103
95
  """Log a structured event.
@@ -36,6 +36,7 @@ from prefect.utilities.annotations import NotSet
36
36
  from typing_extensions import TypeAlias
37
37
 
38
38
  from ai_pipeline_core.documents import DocumentList
39
+ from ai_pipeline_core.flow.config import FlowConfig
39
40
  from ai_pipeline_core.flow.options import FlowOptions
40
41
  from ai_pipeline_core.tracing import TraceLevel, set_trace_cost, trace
41
42
 
@@ -100,7 +101,6 @@ class _DocumentsFlowCallable(Protocol[FO_contra]):
100
101
  project_name: Name of the project/pipeline.
101
102
  documents: Input DocumentList to process.
102
103
  flow_options: Configuration options (FlowOptions or subclass).
103
- *args, **kwargs: Additional flow-specific parameters.
104
104
 
105
105
  Returns:
106
106
  DocumentList: Processed documents.
@@ -114,8 +114,6 @@ class _DocumentsFlowCallable(Protocol[FO_contra]):
114
114
  project_name: str,
115
115
  documents: DocumentList,
116
116
  flow_options: FO_contra,
117
- *args: Any,
118
- **kwargs: Any,
119
117
  ) -> Coroutine[Any, Any, DocumentList]: ...
120
118
 
121
119
 
@@ -146,8 +144,6 @@ class _FlowLike(Protocol[FO_contra]):
146
144
  project_name: str,
147
145
  documents: DocumentList,
148
146
  flow_options: FO_contra,
149
- *args: Any,
150
- **kwargs: Any,
151
147
  ) -> Coroutine[Any, Any, DocumentList]: ...
152
148
 
153
149
  name: str | None
@@ -226,6 +222,7 @@ def pipeline_task(
226
222
  trace_input_formatter: Callable[..., str] | None = None,
227
223
  trace_output_formatter: Callable[..., str] | None = None,
228
224
  trace_cost: float | None = None,
225
+ trace_trim_documents: bool = True,
229
226
  # prefect passthrough
230
227
  name: str | None = None,
231
228
  description: str | None = None,
@@ -266,6 +263,7 @@ def pipeline_task(
266
263
  trace_input_formatter: Callable[..., str] | None = None,
267
264
  trace_output_formatter: Callable[..., str] | None = None,
268
265
  trace_cost: float | None = None,
266
+ trace_trim_documents: bool = True,
269
267
  # prefect passthrough
270
268
  name: str | None = None,
271
269
  description: str | None = None,
@@ -322,6 +320,8 @@ def pipeline_task(
322
320
  trace_cost: Optional cost value to track in metadata. When provided and > 0,
323
321
  sets gen_ai.usage.output_cost, gen_ai.usage.cost, and cost metadata.
324
322
  Also forces trace level to "always" if not already set.
323
+ trace_trim_documents: Trim document content in traces to first 100 chars (default True).
324
+ Reduces trace size with large documents.
325
325
 
326
326
  Prefect task parameters:
327
327
  name: Task name (defaults to function name).
@@ -420,10 +420,6 @@ def pipeline_task(
420
420
  set_trace_cost(trace_cost)
421
421
  return result
422
422
 
423
- # Preserve the original function name for Prefect
424
- _wrapper.__name__ = fname
425
- _wrapper.__qualname__ = getattr(fn, "__qualname__", fname)
426
-
427
423
  traced_fn = trace(
428
424
  level=trace_level,
429
425
  name=name or fname,
@@ -432,6 +428,7 @@ def pipeline_task(
432
428
  ignore_inputs=trace_ignore_inputs,
433
429
  input_formatter=trace_input_formatter,
434
430
  output_formatter=trace_output_formatter,
431
+ trim_documents=trace_trim_documents,
435
432
  )(_wrapper)
436
433
 
437
434
  return cast(
@@ -470,11 +467,10 @@ def pipeline_task(
470
467
  # --------------------------------------------------------------------------- #
471
468
  # @pipeline_flow — async-only, traced, returns Prefect's flow wrapper
472
469
  # --------------------------------------------------------------------------- #
473
- @overload
474
- def pipeline_flow(__fn: _DocumentsFlowCallable[FO_contra], /) -> _FlowLike[FO_contra]: ...
475
- @overload
476
470
  def pipeline_flow(
477
471
  *,
472
+ # config
473
+ config: type[FlowConfig],
478
474
  # tracing
479
475
  trace_level: TraceLevel = "always",
480
476
  trace_ignore_input: bool = False,
@@ -483,6 +479,7 @@ def pipeline_flow(
483
479
  trace_input_formatter: Callable[..., str] | None = None,
484
480
  trace_output_formatter: Callable[..., str] | None = None,
485
481
  trace_cost: float | None = None,
482
+ trace_trim_documents: bool = True,
486
483
  # prefect passthrough
487
484
  name: str | None = None,
488
485
  version: str | None = None,
@@ -503,42 +500,7 @@ def pipeline_flow(
503
500
  on_cancellation: list[FlowStateHook[Any, Any]] | None = None,
504
501
  on_crashed: list[FlowStateHook[Any, Any]] | None = None,
505
502
  on_running: list[FlowStateHook[Any, Any]] | None = None,
506
- ) -> Callable[[_DocumentsFlowCallable[FO_contra]], _FlowLike[FO_contra]]: ...
507
-
508
-
509
- def pipeline_flow(
510
- __fn: _DocumentsFlowCallable[FO_contra] | None = None,
511
- /,
512
- *,
513
- # tracing
514
- trace_level: TraceLevel = "always",
515
- trace_ignore_input: bool = False,
516
- trace_ignore_output: bool = False,
517
- trace_ignore_inputs: list[str] | None = None,
518
- trace_input_formatter: Callable[..., str] | None = None,
519
- trace_output_formatter: Callable[..., str] | None = None,
520
- trace_cost: float | None = None,
521
- # prefect passthrough
522
- name: str | None = None,
523
- version: str | None = None,
524
- flow_run_name: Union[Callable[[], str], str] | None = None,
525
- retries: int | None = None,
526
- retry_delay_seconds: int | float | None = None,
527
- task_runner: TaskRunner[PrefectFuture[Any]] | None = None,
528
- description: str | None = None,
529
- timeout_seconds: int | float | None = None,
530
- validate_parameters: bool = True,
531
- persist_result: bool | None = None,
532
- result_storage: ResultStorage | str | None = None,
533
- result_serializer: ResultSerializer | str | None = None,
534
- cache_result_in_memory: bool = True,
535
- log_prints: bool | None = None,
536
- on_completion: list[FlowStateHook[Any, Any]] | None = None,
537
- on_failure: list[FlowStateHook[Any, Any]] | None = None,
538
- on_cancellation: list[FlowStateHook[Any, Any]] | None = None,
539
- on_crashed: list[FlowStateHook[Any, Any]] | None = None,
540
- on_running: list[FlowStateHook[Any, Any]] | None = None,
541
- ) -> _FlowLike[FO_contra] | Callable[[_DocumentsFlowCallable[FO_contra]], _FlowLike[FO_contra]]:
503
+ ) -> Callable[[_DocumentsFlowCallable[FO_contra]], _FlowLike[FO_contra]]:
542
504
  """Decorate an async flow for document processing.
543
505
 
544
506
  @public
@@ -558,16 +520,15 @@ def pipeline_flow(
558
520
  project_name: str, # Project/pipeline identifier
559
521
  documents: DocumentList, # Input documents to process
560
522
  flow_options: FlowOptions, # Configuration (or subclass)
561
- *args, # Additional positional args for custom parameters
562
- **kwargs # Additional keyword args for custom parameters
563
523
  ) -> DocumentList # Must return DocumentList
564
524
 
565
- Note: *args and **kwargs allow for defining custom parameters on your flow
566
- function, which can be passed during execution for flow-specific needs.
567
-
568
525
  Args:
569
526
  __fn: Function to decorate (when used without parentheses).
570
527
 
528
+ Config parameter:
529
+ config: Required FlowConfig class for document loading/saving. Enables
530
+ automatic loading from string paths and saving outputs.
531
+
571
532
  Tracing parameters:
572
533
  trace_level: When to trace ("always", "debug", "off").
573
534
  - "always": Always trace (default)
@@ -581,6 +542,8 @@ def pipeline_flow(
581
542
  trace_cost: Optional cost value to track in metadata. When provided and > 0,
582
543
  sets gen_ai.usage.output_cost, gen_ai.usage.cost, and cost metadata.
583
544
  Also forces trace level to "always" if not already set.
545
+ trace_trim_documents: Trim document content in traces to first 100 chars (default True).
546
+ Reduces trace size with large documents.
584
547
 
585
548
  Prefect flow parameters:
586
549
  name: Flow name (defaults to function name).
@@ -608,10 +571,14 @@ def pipeline_flow(
608
571
  while enforcing document processing conventions.
609
572
 
610
573
  Example:
611
- >>> from ai_pipeline_core import FlowOptions
574
+ >>> from ai_pipeline_core import FlowOptions, FlowConfig
612
575
  >>>
613
- >>> # RECOMMENDED - No parameters needed!
614
- >>> @pipeline_flow
576
+ >>> class MyFlowConfig(FlowConfig):
577
+ ... INPUT_DOCUMENT_TYPES = [InputDoc]
578
+ ... OUTPUT_DOCUMENT_TYPE = OutputDoc
579
+ >>>
580
+ >>> # Standard usage with config
581
+ >>> @pipeline_flow(config=MyFlowConfig)
615
582
  >>> async def analyze_documents(
616
583
  ... project_name: str,
617
584
  ... documents: DocumentList,
@@ -624,8 +591,8 @@ def pipeline_flow(
624
591
  ... results.append(result)
625
592
  ... return DocumentList(results)
626
593
  >>>
627
- >>> # With parameters (only when necessary):
628
- >>> @pipeline_flow(retries=2) # Only for flows that need retry logic
594
+ >>> # With additional parameters:
595
+ >>> @pipeline_flow(config=MyFlowConfig, retries=2)
629
596
  >>> async def critical_flow(
630
597
  ... project_name: str,
631
598
  ... documents: DocumentList,
@@ -682,14 +649,19 @@ def pipeline_flow(
682
649
  "'project_name, documents, flow_options' as its first three parameters"
683
650
  )
684
651
 
652
+ @wraps(fn)
685
653
  async def _wrapper(
686
654
  project_name: str,
687
- documents: DocumentList,
655
+ documents: str | DocumentList,
688
656
  flow_options: FO_contra,
689
- *args: Any,
690
- **kwargs: Any,
691
657
  ) -> DocumentList:
692
- result = await fn(project_name, documents, flow_options, *args, **kwargs)
658
+ save_path: str | None = None
659
+ if isinstance(documents, str):
660
+ save_path = documents
661
+ documents = await config.load_documents(documents)
662
+ result = await fn(project_name, documents, flow_options)
663
+ if save_path:
664
+ await config.save_documents(save_path, result)
693
665
  if trace_cost is not None and trace_cost > 0:
694
666
  set_trace_cost(trace_cost)
695
667
  if not isinstance(result, DocumentList): # pyright: ignore[reportUnnecessaryIsInstance]
@@ -698,10 +670,6 @@ def pipeline_flow(
698
670
  )
699
671
  return result
700
672
 
701
- # Preserve the original function name for Prefect
702
- _wrapper.__name__ = fname
703
- _wrapper.__qualname__ = getattr(fn, "__qualname__", fname)
704
-
705
673
  traced = trace(
706
674
  level=trace_level,
707
675
  name=name or fname,
@@ -710,9 +678,24 @@ def pipeline_flow(
710
678
  ignore_inputs=trace_ignore_inputs,
711
679
  input_formatter=trace_input_formatter,
712
680
  output_formatter=trace_output_formatter,
681
+ trim_documents=trace_trim_documents,
713
682
  )(_wrapper)
714
683
 
715
- return cast(
684
+ # --- Publish a schema where `documents` accepts str (path) OR DocumentList ---
685
+ _sig = inspect.signature(fn)
686
+ _params = [
687
+ p.replace(annotation=(str | DocumentList)) if p.name == "documents" else p
688
+ for p in _sig.parameters.values()
689
+ ]
690
+ if hasattr(traced, "__signature__"):
691
+ setattr(traced, "__signature__", _sig.replace(parameters=_params))
692
+ if hasattr(traced, "__annotations__"):
693
+ traced.__annotations__ = {
694
+ **getattr(traced, "__annotations__", {}),
695
+ "documents": str | DocumentList,
696
+ }
697
+
698
+ flow_obj = cast(
716
699
  _FlowLike[FO_contra],
717
700
  flow_decorator(
718
701
  name=name or fname,
@@ -736,8 +719,11 @@ def pipeline_flow(
736
719
  on_running=on_running,
737
720
  )(traced),
738
721
  )
722
+ # Attach config to the flow object for later access
723
+ flow_obj.config = config # type: ignore[attr-defined]
724
+ return flow_obj
739
725
 
740
- return _apply(__fn) if __fn else _apply
726
+ return _apply
741
727
 
742
728
 
743
729
  __all__ = ["pipeline_task", "pipeline_flow"]
@@ -47,8 +47,17 @@ Note:
47
47
  integrated LMNR tracing and are the standard for this library.
48
48
  """
49
49
 
50
- from prefect import flow, task
50
+ from prefect import deploy, flow, serve, task
51
51
  from prefect.logging import disable_run_logger
52
52
  from prefect.testing.utilities import prefect_test_harness
53
-
54
- __all__ = ["task", "flow", "disable_run_logger", "prefect_test_harness"]
53
+ from prefect.types.entrypoint import EntrypointType
54
+
55
+ __all__ = [
56
+ "task",
57
+ "flow",
58
+ "disable_run_logger",
59
+ "prefect_test_harness",
60
+ "serve",
61
+ "deploy",
62
+ "EntrypointType",
63
+ ]
@@ -10,13 +10,16 @@ directories.
10
10
  Search strategy:
11
11
  1. Local directory (same as calling module)
12
12
  2. Local 'prompts' subdirectory
13
- 3. Parent 'prompts' directories (up to package boundary)
13
+ 3. Parent 'prompts' directories (search ascends parent packages up to the package
14
+ boundary or after 4 parent levels, whichever comes first)
14
15
 
15
16
  Key features:
16
17
  - Automatic template discovery
17
18
  - Jinja2 template rendering with context
18
19
  - Smart path resolution (.jinja2/.jinja extension handling)
19
20
  - Clear error messages for missing templates
21
+ - Built-in global variables:
22
+ - current_date: Current date in format "03 January 2025" (string)
20
23
 
21
24
  Example:
22
25
  >>> from ai_pipeline_core import PromptManager
@@ -44,6 +47,7 @@ Note:
44
47
  The extension can be omitted when calling get().
45
48
  """
46
49
 
50
+ from datetime import datetime
47
51
  from pathlib import Path
48
52
  from typing import Any
49
53
 
@@ -69,7 +73,8 @@ class PromptManager:
69
73
  Search hierarchy:
70
74
  1. Same directory as the calling module (for local templates)
71
75
  2. 'prompts' subdirectory in the calling module's directory
72
- 3. 'prompts' directories in parent packages (up to package boundary)
76
+ 3. 'prompts' directories in parent packages (search ascends parent packages up to the
77
+ package boundary or after 4 parent levels, whichever comes first)
73
78
 
74
79
  Attributes:
75
80
  search_paths: List of directories where templates are searched.
@@ -101,6 +106,8 @@ class PromptManager:
101
106
  {% if instructions %}
102
107
  Instructions: {{ instructions }}
103
108
  {% endif %}
109
+
110
+ Date: {{ current_date }} # Current date in format "03 January 2025"
104
111
  ```
105
112
 
106
113
  Note:
@@ -144,7 +151,8 @@ class PromptManager:
144
151
  2. /project/flows/prompts/ (if exists)
145
152
  3. /project/prompts/ (if /project has __init__.py)
146
153
 
147
- Search stops when no __init__.py is found (package boundary).
154
+ Search ascends parent packages up to the package boundary or after 4 parent
155
+ levels, whichever comes first.
148
156
 
149
157
  Example:
150
158
  >>> # Correct usage
@@ -155,10 +163,6 @@ class PromptManager:
155
163
  >>>
156
164
  >>> # Common mistake (will raise PromptError)
157
165
  >>> pm = PromptManager(__name__) # Wrong!
158
-
159
- Note:
160
- The search is limited to 4 parent levels to prevent
161
- excessive filesystem traversal.
162
166
  """
163
167
  search_paths: list[Path] = []
164
168
 
@@ -215,6 +219,9 @@ class PromptManager:
215
219
  autoescape=False, # Important for prompt engineering
216
220
  )
217
221
 
222
+ # Add current_date as a global string (format: "03 January 2025")
223
+ self.env.globals["current_date"] = datetime.now().strftime("%d %B %Y") # type: ignore[assignment]
224
+
218
225
  def get(self, prompt_path: str, **kwargs: Any) -> str:
219
226
  """Load and render a Jinja2 template with the given context.
220
227
 
@@ -12,6 +12,7 @@ Environment variables:
12
12
  PREFECT_API_URL: Prefect server endpoint for flow orchestration
13
13
  PREFECT_API_KEY: Prefect API authentication key
14
14
  LMNR_PROJECT_API_KEY: Laminar project key for observability
15
+ GCS_SERVICE_ACCOUNT_FILE: Path to GCS service account JSON file
15
16
 
16
17
  Configuration precedence:
17
18
  1. Environment variables (highest priority)
@@ -39,6 +40,7 @@ Example:
39
40
  PREFECT_API_URL=http://localhost:4200/api
40
41
  PREFECT_API_KEY=pnu_abc123
41
42
  LMNR_PROJECT_API_KEY=lmnr_proj_xyz
43
+ GCS_SERVICE_ACCOUNT_FILE=/path/to/service-account.json
42
44
  APP_NAME=production-app
43
45
  DEBUG_MODE=false
44
46
 
@@ -90,12 +92,15 @@ class Settings(BaseSettings):
90
92
  prefect_api_key: Prefect API authentication key. Required only
91
93
  when connecting to Prefect Cloud or secured server.
92
94
 
93
- lmnr_project_api_key: Laminar (LMNR) project API key for tracing
94
- and observability. Optional but recommended
95
- for production monitoring.
95
+ lmnr_project_api_key: Laminar (LMNR) project API key for observability.
96
+ Optional but recommended for production monitoring.
96
97
 
97
- lmnr_debug: Debug mode flag for Laminar tracing. Set to "true" to
98
- enable debug-level traces. Empty string by default.
98
+ lmnr_debug: Debug mode flag for Laminar. Set to "true" to
99
+ enable debug-level logging. Empty string by default.
100
+
101
+ gcs_service_account_file: Path to GCS service account JSON file.
102
+ Used for authenticating with Google Cloud Storage.
103
+ Optional - if not set, default credentials will be used.
99
104
 
100
105
  Configuration sources:
101
106
  - Environment variables (highest priority)
@@ -126,6 +131,9 @@ class Settings(BaseSettings):
126
131
  lmnr_project_api_key: str = ""
127
132
  lmnr_debug: str = ""
128
133
 
134
+ # Storage Configuration
135
+ gcs_service_account_file: str = "" # Path to GCS service account JSON file
136
+
129
137
 
130
138
  # Legacy: Module-level instance for backwards compatibility
131
139
  # Applications should create their own settings instance
@@ -4,21 +4,11 @@ Utilities for running AI pipelines locally without full Prefect orchestration.
4
4
  """
5
5
 
6
6
  from .cli import run_cli
7
- from .simple_runner import (
8
- ConfigSequence,
9
- FlowSequence,
10
- load_documents_from_directory,
11
- run_pipeline,
12
- run_pipelines,
13
- save_documents_to_directory,
14
- )
7
+ from .simple_runner import FlowSequence, run_pipeline, run_pipelines
15
8
 
16
9
  __all__ = [
17
10
  "run_cli",
18
11
  "run_pipeline",
19
12
  "run_pipelines",
20
- "load_documents_from_directory",
21
- "save_documents_to_directory",
22
13
  "FlowSequence",
23
- "ConfigSequence",
24
14
  ]
@@ -19,7 +19,7 @@ from ai_pipeline_core.logging import get_pipeline_logger, setup_logging
19
19
  from ai_pipeline_core.prefect import disable_run_logger, prefect_test_harness
20
20
  from ai_pipeline_core.settings import settings
21
21
 
22
- from .simple_runner import ConfigSequence, FlowSequence, run_pipelines, save_documents_to_directory
22
+ from .simple_runner import FlowSequence, run_pipelines
23
23
 
24
24
  logger = get_pipeline_logger(__name__)
25
25
 
@@ -87,7 +87,6 @@ def _running_under_pytest() -> bool:
87
87
  def run_cli(
88
88
  *,
89
89
  flows: FlowSequence,
90
- flow_configs: ConfigSequence,
91
90
  options_cls: Type[TOptions],
92
91
  initializer: InitializerFunc = None,
93
92
  trace_name: str | None = None,
@@ -105,17 +104,13 @@ def run_cli(
105
104
 
106
105
  Example:
107
106
  >>> # In __main__.py
108
- >>> from ai_pipeline_core.simple_runner import run_cli
107
+ >>> from ai_pipeline_core import simple_runner
109
108
  >>> from .flows import AnalysisFlow, SummaryFlow
110
- >>> from .config import AnalysisConfig, AnalysisOptions
109
+ >>> from .config import AnalysisOptions
111
110
  >>>
112
111
  >>> if __name__ == "__main__":
113
- ... run_cli(
112
+ ... simple_runner.run_cli(
114
113
  ... flows=[AnalysisFlow, SummaryFlow],
115
- ... flow_configs=[
116
- ... (AnalysisConfig, AnalysisOptions),
117
- ... (AnalysisConfig, AnalysisOptions)
118
- ... ],
119
114
  ... options_cls=AnalysisOptions,
120
115
  ... trace_name="document-analysis"
121
116
  ... )
@@ -226,8 +221,15 @@ def run_cli(
226
221
  _, initial_documents = init_result # Ignore project name from initializer
227
222
 
228
223
  # Save initial documents if starting from first step
229
- if getattr(opts, "start", 1) == 1 and initial_documents:
230
- save_documents_to_directory(wd, initial_documents)
224
+ if getattr(opts, "start", 1) == 1 and initial_documents and flows:
225
+ # Get config from the first flow
226
+ first_flow_config = getattr(flows[0], "config", None)
227
+ if first_flow_config:
228
+ asyncio.run(
229
+ first_flow_config.save_documents(
230
+ str(wd), initial_documents, validate_output_type=False
231
+ )
232
+ )
231
233
 
232
234
  # Setup context stack with optional test harness and tracing
233
235
  with ExitStack() as stack:
@@ -247,7 +249,6 @@ def run_cli(
247
249
  project_name=project_name,
248
250
  output_dir=wd,
249
251
  flows=flows,
250
- flow_configs=flow_configs,
251
252
  flow_options=opts,
252
253
  start_step=getattr(opts, "start", 1),
253
254
  end_step=getattr(opts, "end", None),