ai-pipeline-core 0.2.8__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. ai_pipeline_core/__init__.py +14 -4
  2. ai_pipeline_core/deployment/__init__.py +46 -0
  3. ai_pipeline_core/deployment/base.py +681 -0
  4. ai_pipeline_core/deployment/contract.py +84 -0
  5. ai_pipeline_core/deployment/helpers.py +98 -0
  6. ai_pipeline_core/documents/flow_document.py +1 -1
  7. ai_pipeline_core/documents/task_document.py +1 -1
  8. ai_pipeline_core/documents/temporary_document.py +1 -1
  9. ai_pipeline_core/flow/config.py +13 -2
  10. ai_pipeline_core/flow/options.py +1 -1
  11. ai_pipeline_core/llm/client.py +22 -23
  12. ai_pipeline_core/llm/model_response.py +6 -3
  13. ai_pipeline_core/llm/model_types.py +0 -1
  14. ai_pipeline_core/pipeline.py +1 -1
  15. ai_pipeline_core/progress.py +127 -0
  16. ai_pipeline_core/prompt_builder/__init__.py +5 -0
  17. ai_pipeline_core/prompt_builder/documents_prompt.jinja2 +23 -0
  18. ai_pipeline_core/prompt_builder/global_cache.py +78 -0
  19. ai_pipeline_core/prompt_builder/new_core_documents_prompt.jinja2 +6 -0
  20. ai_pipeline_core/prompt_builder/prompt_builder.py +253 -0
  21. ai_pipeline_core/prompt_builder/system_prompt.jinja2 +41 -0
  22. ai_pipeline_core/tracing.py +1 -1
  23. ai_pipeline_core/utils/remote_deployment.py +37 -187
  24. {ai_pipeline_core-0.2.8.dist-info → ai_pipeline_core-0.3.0.dist-info}/METADATA +23 -20
  25. ai_pipeline_core-0.3.0.dist-info/RECORD +49 -0
  26. {ai_pipeline_core-0.2.8.dist-info → ai_pipeline_core-0.3.0.dist-info}/WHEEL +1 -1
  27. ai_pipeline_core/simple_runner/__init__.py +0 -14
  28. ai_pipeline_core/simple_runner/cli.py +0 -254
  29. ai_pipeline_core/simple_runner/simple_runner.py +0 -247
  30. ai_pipeline_core-0.2.8.dist-info/RECORD +0 -41
  31. {ai_pipeline_core-0.2.8.dist-info → ai_pipeline_core-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,84 @@
1
+ """Unified pipeline run response contract.
2
+
3
+ @public
4
+
5
+ Single source of truth for the response shape used by both
6
+ webhook push (ai-pipeline-core) and polling pull (unified-middleware).
7
+ """
8
+
9
+ from datetime import datetime
10
+ from typing import Annotated, Literal
11
+ from uuid import UUID
12
+
13
+ from pydantic import BaseModel, ConfigDict, Discriminator
14
+
15
+
16
+ class _RunBase(BaseModel):
17
+ """Common fields on every run response variant."""
18
+
19
+ type: str
20
+ flow_run_id: UUID
21
+ project_name: str
22
+ state: str # PENDING, RUNNING, COMPLETED, FAILED, CRASHED, CANCELLED
23
+ timestamp: datetime
24
+ storage_uri: str = ""
25
+
26
+ model_config = ConfigDict(frozen=True)
27
+
28
+
29
+ class PendingRun(_RunBase):
30
+ """Pipeline queued or running but no progress reported yet."""
31
+
32
+ type: Literal["pending"] = "pending" # pyright: ignore[reportIncompatibleVariableOverride]
33
+
34
+
35
+ class ProgressRun(_RunBase):
36
+ """Pipeline running with step-level progress data."""
37
+
38
+ type: Literal["progress"] = "progress" # pyright: ignore[reportIncompatibleVariableOverride]
39
+ step: int
40
+ total_steps: int
41
+ flow_name: str
42
+ status: str # "started", "completed", "cached"
43
+ progress: float # overall 0.0–1.0
44
+ step_progress: float # within step 0.0–1.0
45
+ message: str
46
+
47
+
48
+ class DeploymentResultData(BaseModel):
49
+ """Typed result payload — always has success + optional error."""
50
+
51
+ success: bool
52
+ error: str | None = None
53
+
54
+ model_config = ConfigDict(frozen=True, extra="allow")
55
+
56
+
57
+ class CompletedRun(_RunBase):
58
+ """Pipeline finished (Prefect COMPLETED). Check result.success for business outcome."""
59
+
60
+ type: Literal["completed"] = "completed" # pyright: ignore[reportIncompatibleVariableOverride]
61
+ result: DeploymentResultData
62
+
63
+
64
+ class FailedRun(_RunBase):
65
+ """Pipeline crashed — execution error, not business logic."""
66
+
67
+ type: Literal["failed"] = "failed" # pyright: ignore[reportIncompatibleVariableOverride]
68
+ error: str
69
+ result: DeploymentResultData | None = None
70
+
71
+
72
+ RunResponse = Annotated[
73
+ PendingRun | ProgressRun | CompletedRun | FailedRun,
74
+ Discriminator("type"),
75
+ ]
76
+
77
+ __all__ = [
78
+ "CompletedRun",
79
+ "DeploymentResultData",
80
+ "FailedRun",
81
+ "PendingRun",
82
+ "ProgressRun",
83
+ "RunResponse",
84
+ ]
@@ -0,0 +1,98 @@
1
+ """Helper functions for pipeline deployments."""
2
+
3
+ import asyncio
4
+ import re
5
+ from typing import Any, Literal, TypedDict
6
+
7
+ import httpx
8
+
9
+ from ai_pipeline_core.deployment.contract import CompletedRun, FailedRun, ProgressRun
10
+ from ai_pipeline_core.documents import Document, DocumentList, FlowDocument
11
+ from ai_pipeline_core.logging import get_pipeline_logger
12
+
13
+ logger = get_pipeline_logger(__name__)
14
+
15
+
16
+ class StatusPayload(TypedDict):
17
+ """Webhook payload for Prefect state transitions (sub-flow level)."""
18
+
19
+ type: Literal["status"]
20
+ flow_run_id: str
21
+ project_name: str
22
+ step: int
23
+ total_steps: int
24
+ flow_name: str
25
+ state: str # RUNNING, COMPLETED, FAILED, CRASHED, CANCELLED
26
+ state_name: str
27
+ timestamp: str
28
+
29
+
30
+ def class_name_to_deployment_name(class_name: str) -> str:
31
+ """Convert PascalCase to kebab-case: ResearchPipeline → research-pipeline."""
32
+ name = re.sub(r"(?<!^)(?=[A-Z])", "-", class_name)
33
+ return name.lower()
34
+
35
+
36
+ def extract_generic_params(cls: type) -> tuple[type | None, type | None]:
37
+ """Extract TOptions and TResult from PipelineDeployment generic args."""
38
+ from ai_pipeline_core.deployment.base import PipelineDeployment # noqa: PLC0415
39
+
40
+ for base in getattr(cls, "__orig_bases__", []):
41
+ origin = getattr(base, "__origin__", None)
42
+ if origin is PipelineDeployment:
43
+ args = getattr(base, "__args__", ())
44
+ if len(args) == 2:
45
+ return args[0], args[1]
46
+
47
+ return None, None
48
+
49
+
50
+ async def download_documents(
51
+ urls: list[str],
52
+ document_type: type[FlowDocument],
53
+ ) -> DocumentList:
54
+ """Download documents from URLs and return as DocumentList."""
55
+ documents: list[Document] = []
56
+ async with httpx.AsyncClient(timeout=60, follow_redirects=True) as client:
57
+ for url in urls:
58
+ response = await client.get(url)
59
+ response.raise_for_status()
60
+ filename = url.split("/")[-1].split("?")[0] or "document"
61
+ documents.append(document_type(name=filename, content=response.content))
62
+ return DocumentList(documents)
63
+
64
+
65
+ async def upload_documents(documents: DocumentList, url_mapping: dict[str, str]) -> None:
66
+ """Upload documents to their mapped URLs."""
67
+ async with httpx.AsyncClient(timeout=60, follow_redirects=True) as client:
68
+ for doc in documents:
69
+ if doc.name in url_mapping:
70
+ response = await client.put(
71
+ url_mapping[doc.name],
72
+ content=doc.content,
73
+ headers={"Content-Type": doc.mime_type},
74
+ )
75
+ response.raise_for_status()
76
+
77
+
78
+ async def send_webhook(
79
+ url: str,
80
+ payload: ProgressRun | CompletedRun | FailedRun,
81
+ max_retries: int = 3,
82
+ retry_delay: float = 10.0,
83
+ ) -> None:
84
+ """Send webhook with retries."""
85
+ data: dict[str, Any] = payload.model_dump(mode="json")
86
+ for attempt in range(max_retries):
87
+ try:
88
+ async with httpx.AsyncClient(timeout=30) as client:
89
+ response = await client.post(url, json=data, follow_redirects=True)
90
+ response.raise_for_status()
91
+ return
92
+ except Exception as e:
93
+ if attempt < max_retries - 1:
94
+ logger.warning(f"Webhook retry {attempt + 1}/{max_retries}: {e}")
95
+ await asyncio.sleep(retry_delay)
96
+ else:
97
+ logger.error(f"Webhook failed after {max_retries} attempts: {e}")
98
+ raise
@@ -18,7 +18,7 @@ class FlowDocument(Document):
18
18
 
19
19
  FlowDocument is used for data that needs to be saved between pipeline
20
20
  steps and across multiple flow executions. These documents are typically
21
- written to the file system using the simple_runner utilities.
21
+ written to the file system using the deployment utilities.
22
22
 
23
23
  Key characteristics:
24
24
  - Persisted to file system between pipeline steps
@@ -40,7 +40,7 @@ class TaskDocument(Document):
40
40
 
41
41
  Note:
42
42
  - Cannot instantiate TaskDocument directly - must subclass
43
- - Not saved by simple_runner utilities
43
+ - Not saved by deployment utilities
44
44
  - Reduces I/O overhead for temporary data
45
45
  - No additional abstract methods to implement
46
46
  """
@@ -23,7 +23,7 @@ class TemporaryDocument(Document):
23
23
  - Can be instantiated directly (not abstract)
24
24
  - Cannot be subclassed (annotated with Python's @final decorator in code)
25
25
  - Useful for transient data like API responses or intermediate calculations
26
- - Ignored by simple_runner save operations
26
+ - Ignored by deployment save operations
27
27
  - Useful for tests and debugging
28
28
 
29
29
  Creating TemporaryDocuments:
@@ -39,11 +39,13 @@ class FlowConfig(ABC):
39
39
  Class Variables:
40
40
  INPUT_DOCUMENT_TYPES: List of FlowDocument types this flow accepts
41
41
  OUTPUT_DOCUMENT_TYPE: Single FlowDocument type this flow produces
42
+ WEIGHT: Weight for progress calculation (default 1.0, based on avg duration)
42
43
 
43
44
  Validation Rules:
44
45
  - INPUT_DOCUMENT_TYPES and OUTPUT_DOCUMENT_TYPE must be defined
45
46
  - OUTPUT_DOCUMENT_TYPE cannot be in INPUT_DOCUMENT_TYPES (prevents cycles)
46
47
  - Field names must be exact (common typos are detected)
48
+ - WEIGHT must be a positive number
47
49
 
48
50
  Why this matters:
49
51
  Flows connect in pipelines where one flow's output becomes another's input.
@@ -54,6 +56,7 @@ class FlowConfig(ABC):
54
56
  >>> class ProcessingFlowConfig(FlowConfig):
55
57
  ... INPUT_DOCUMENT_TYPES = [RawDataDocument]
56
58
  ... OUTPUT_DOCUMENT_TYPE = ProcessedDocument # Different type!
59
+ ... WEIGHT = 45.0 # Average ~45 minutes
57
60
  >>>
58
61
  >>> # Use in @pipeline_flow - RECOMMENDED PATTERN
59
62
  >>> @pipeline_flow(config=ProcessingFlowConfig, name="processing")
@@ -72,11 +75,12 @@ class FlowConfig(ABC):
72
75
  Note:
73
76
  - Validation happens at class definition time
74
77
  - Helps catch configuration errors early
75
- - Used by simple_runner to manage document flow
78
+ - Used by PipelineDeployment to manage document flow
76
79
  """
77
80
 
78
81
  INPUT_DOCUMENT_TYPES: ClassVar[list[type[FlowDocument]]]
79
82
  OUTPUT_DOCUMENT_TYPE: ClassVar[type[FlowDocument]]
83
+ WEIGHT: ClassVar[float] = 1.0
80
84
 
81
85
  def __init_subclass__(cls, **kwargs: Any):
82
86
  """Validate flow configuration at subclass definition time.
@@ -106,7 +110,7 @@ class FlowConfig(ABC):
106
110
  return
107
111
 
108
112
  # Check for invalid field names (common mistakes)
109
- allowed_fields = {"INPUT_DOCUMENT_TYPES", "OUTPUT_DOCUMENT_TYPE"}
113
+ allowed_fields = {"INPUT_DOCUMENT_TYPES", "OUTPUT_DOCUMENT_TYPE", "WEIGHT"}
110
114
  class_attrs = {name for name in dir(cls) if not name.startswith("_") and name.isupper()}
111
115
 
112
116
  # Find fields that look like they might be mistakes
@@ -145,6 +149,13 @@ class FlowConfig(ABC):
145
149
  f"({cls.OUTPUT_DOCUMENT_TYPE.__name__}) cannot be in INPUT_DOCUMENT_TYPES"
146
150
  )
147
151
 
152
+ # Validate WEIGHT
153
+ weight = getattr(cls, "WEIGHT", 1.0)
154
+ if not isinstance(weight, (int, float)) or weight <= 0:
155
+ raise TypeError(
156
+ f"FlowConfig {cls.__name__}: WEIGHT must be a positive number, got {weight}"
157
+ )
158
+
148
159
  @classmethod
149
160
  def get_input_document_types(cls) -> list[type[FlowDocument]]:
150
161
  """Get the list of input document types this flow accepts.
@@ -53,7 +53,7 @@ class FlowOptions(BaseSettings):
53
53
  - Frozen (immutable) after creation
54
54
  - Extra fields ignored (not strict)
55
55
  - Can be populated from environment variables
56
- - Used by simple_runner.cli for command-line parsing
56
+ - Used by PipelineDeployment.run_cli for command-line parsing
57
57
 
58
58
  Note:
59
59
  The base class provides model selection. Subclasses should
@@ -45,31 +45,30 @@ def _process_messages(
45
45
 
46
46
  Internal function that combines context and messages into a single
47
47
  list of API-compatible messages. Applies caching directives to
48
- context messages for efficiency.
48
+ system prompt and context messages for efficiency.
49
49
 
50
50
  Args:
51
51
  context: Messages to be cached (typically expensive/static content).
52
52
  messages: Regular messages without caching (dynamic queries).
53
53
  system_prompt: Optional system instructions for the model.
54
- cache_ttl: Cache TTL for context messages (e.g. "120s", "300s", "1h").
54
+ cache_ttl: Cache TTL for system and context messages (e.g. "120s", "300s", "1h").
55
55
  Set to None or empty string to disable caching.
56
56
 
57
57
  Returns:
58
58
  List of formatted messages ready for API calls, with:
59
- - System prompt at the beginning (if provided)
60
- - Context messages with cache_control on the last one (if cache_ttl)
59
+ - System prompt at the beginning with cache_control (if provided and cache_ttl set)
60
+ - Context messages with cache_control on all messages (if cache_ttl set)
61
61
  - Regular messages without caching
62
62
 
63
63
  System Prompt Location:
64
64
  The system prompt parameter is always injected as the FIRST message
65
- with role="system". It is NOT cached with context, allowing dynamic
66
- system prompts without breaking cache efficiency.
65
+ with role="system". It is cached along with context when cache_ttl is set.
67
66
 
68
67
  Cache behavior:
69
- The last context message gets ephemeral caching with specified TTL
68
+ All system and context messages get ephemeral caching with specified TTL
70
69
  to reduce token usage on repeated calls with same context.
71
70
  If cache_ttl is None or empty string (falsy), no caching is applied.
72
- Only the last context message receives cache_control to maximize efficiency.
71
+ All system and context messages receive cache_control to maximize cache efficiency.
73
72
 
74
73
  Note:
75
74
  This is an internal function used by _generate_with_retry().
@@ -79,26 +78,28 @@ def _process_messages(
79
78
 
80
79
  # Add system prompt if provided
81
80
  if system_prompt:
82
- processed_messages.append({"role": "system", "content": system_prompt})
81
+ processed_messages.append({
82
+ "role": "system",
83
+ "content": [{"type": "text", "text": system_prompt}],
84
+ })
83
85
 
84
86
  # Process context messages with caching if provided
85
87
  if context:
86
88
  # Use AIMessages.to_prompt() for context
87
89
  context_messages = context.to_prompt()
90
+ processed_messages.extend(context_messages)
88
91
 
89
- # Apply caching to last context message and last content part if cache_ttl is set
90
- if cache_ttl:
91
- context_messages[-1]["cache_control"] = { # type: ignore
92
- "type": "ephemeral",
93
- "ttl": cache_ttl,
94
- }
95
- assert isinstance(context_messages[-1]["content"], list) # type: ignore
96
- context_messages[-1]["content"][-1]["cache_control"] = { # type: ignore
92
+ if cache_ttl:
93
+ for message in processed_messages:
94
+ message["cache_control"] = { # type: ignore
97
95
  "type": "ephemeral",
98
96
  "ttl": cache_ttl,
99
97
  }
100
-
101
- processed_messages.extend(context_messages)
98
+ if isinstance(message["content"], list): # type: ignore
99
+ message["content"][-1]["cache_control"] = { # type: ignore
100
+ "type": "ephemeral",
101
+ "ttl": cache_ttl,
102
+ }
102
103
 
103
104
  # Process regular messages without caching
104
105
  if messages:
@@ -153,10 +154,8 @@ def _model_name_to_openrouter_model(model: ModelName) -> str:
153
154
  return "openai/gpt-4o-search-preview"
154
155
  if model == "gemini-2.5-flash-search":
155
156
  return "google/gemini-2.5-flash:online"
156
- if model == "grok-4-fast-search":
157
- return "x-ai/grok-4-fast:online"
158
157
  if model == "sonar-pro-search":
159
- return "perplexity/sonar-reasoning-pro"
158
+ return "perplexity/sonar-pro-search"
160
159
  if model.startswith("gemini"):
161
160
  return f"google/{model}"
162
161
  elif model.startswith("gpt"):
@@ -294,7 +293,7 @@ async def _generate_with_retry(
294
293
  model, span_type="LLM", input=processed_messages
295
294
  ) as span:
296
295
  response = await _generate(model, processed_messages, completion_kwargs)
297
- span.set_attributes(response.get_laminar_metadata())
296
+ span.set_attributes(response.get_laminar_metadata()) # pyright: ignore[reportArgumentType]
298
297
  Laminar.set_span_output([
299
298
  r for r in (response.reasoning_content, response.content) if r
300
299
  ])
@@ -88,10 +88,13 @@ class ModelResponse(ChatCompletion):
88
88
  data = chat_completion.model_dump()
89
89
 
90
90
  # fixes issue where the role is "assistantassistant" instead of "assistant"
91
+ valid_finish_reasons = {"stop", "length", "tool_calls", "content_filter", "function_call"}
91
92
  for i in range(len(data["choices"])):
92
- if role := data["choices"][i]["message"].get("role"):
93
- if role.startswith("assistant") and role != "assistant":
94
- data["choices"][i]["message"]["role"] = "assistant"
93
+ data["choices"][i]["message"]["role"] = "assistant"
94
+ # Only update finish_reason if it's not already a valid value
95
+ current_finish_reason = data["choices"][i].get("finish_reason")
96
+ if current_finish_reason not in valid_finish_reasons:
97
+ data["choices"][i]["finish_reason"] = "stop"
95
98
 
96
99
  super().__init__(**data)
97
100
 
@@ -26,7 +26,6 @@ ModelName: TypeAlias = (
26
26
  "gemini-2.5-flash-search",
27
27
  "sonar-pro-search",
28
28
  "gpt-4o-search",
29
- "grok-4-fast-search",
30
29
  ]
31
30
  | str
32
31
  )
@@ -605,7 +605,7 @@ def pipeline_flow(
605
605
  - pipeline_task: For task-level decoration
606
606
  - FlowConfig: Type-safe flow configuration
607
607
  - FlowOptions: Base class for flow options
608
- - simple_runner.run_pipeline: Execute flows locally
608
+ - PipelineDeployment: Execute flows locally or remotely
609
609
  """
610
610
  flow_decorator: Callable[..., Any] = _prefect_flow
611
611
 
@@ -0,0 +1,127 @@
1
+ """@public Intra-flow progress tracking with order-preserving webhook delivery."""
2
+
3
+ import asyncio
4
+ from collections.abc import Generator
5
+ from contextlib import contextmanager
6
+ from contextvars import ContextVar
7
+ from dataclasses import dataclass
8
+ from datetime import datetime, timezone
9
+ from uuid import UUID
10
+
11
+ from ai_pipeline_core.deployment.contract import ProgressRun
12
+ from ai_pipeline_core.logging import get_pipeline_logger
13
+
14
+ logger = get_pipeline_logger(__name__)
15
+
16
+
17
+ @dataclass(frozen=True, slots=True)
18
+ class ProgressContext:
19
+ """Internal context holding state for progress calculation and webhook delivery."""
20
+
21
+ webhook_url: str
22
+ project_name: str
23
+ run_id: str
24
+ flow_run_id: str
25
+ flow_name: str
26
+ step: int
27
+ total_steps: int
28
+ weights: tuple[float, ...]
29
+ completed_weight: float
30
+ current_flow_weight: float
31
+ queue: asyncio.Queue[ProgressRun | None]
32
+
33
+
34
+ _context: ContextVar[ProgressContext | None] = ContextVar("progress_context", default=None)
35
+
36
+
37
+ async def update(fraction: float, message: str = "") -> None:
38
+ """@public Report intra-flow progress (0.0-1.0). No-op without context."""
39
+ ctx = _context.get()
40
+ if ctx is None or not ctx.webhook_url:
41
+ return
42
+
43
+ fraction = max(0.0, min(1.0, fraction))
44
+
45
+ total_weight = sum(ctx.weights)
46
+ if total_weight > 0:
47
+ overall = (ctx.completed_weight + ctx.current_flow_weight * fraction) / total_weight
48
+ else:
49
+ overall = fraction
50
+ overall = round(max(0.0, min(1.0, overall)), 4)
51
+
52
+ payload = ProgressRun(
53
+ flow_run_id=UUID(ctx.flow_run_id) if ctx.flow_run_id else UUID(int=0),
54
+ project_name=ctx.project_name,
55
+ state="RUNNING",
56
+ timestamp=datetime.now(timezone.utc),
57
+ step=ctx.step,
58
+ total_steps=ctx.total_steps,
59
+ flow_name=ctx.flow_name,
60
+ status="progress",
61
+ progress=overall,
62
+ step_progress=round(fraction, 4),
63
+ message=message,
64
+ )
65
+
66
+ ctx.queue.put_nowait(payload)
67
+
68
+
69
+ async def webhook_worker(
70
+ queue: asyncio.Queue[ProgressRun | None],
71
+ webhook_url: str,
72
+ max_retries: int = 3,
73
+ retry_delay: float = 10.0,
74
+ ) -> None:
75
+ """Process webhooks sequentially with retries, preserving order."""
76
+ from ai_pipeline_core.deployment.helpers import send_webhook # noqa: PLC0415
77
+
78
+ while True:
79
+ payload = await queue.get()
80
+ if payload is None:
81
+ queue.task_done()
82
+ break
83
+
84
+ try:
85
+ await send_webhook(webhook_url, payload, max_retries, retry_delay)
86
+ except Exception:
87
+ pass # Already logged in send_webhook
88
+
89
+ queue.task_done()
90
+
91
+
92
+ @contextmanager
93
+ def flow_context(
94
+ webhook_url: str,
95
+ project_name: str,
96
+ run_id: str,
97
+ flow_run_id: str,
98
+ flow_name: str,
99
+ step: int,
100
+ total_steps: int,
101
+ weights: tuple[float, ...],
102
+ completed_weight: float,
103
+ queue: asyncio.Queue[ProgressRun | None],
104
+ ) -> Generator[None, None, None]:
105
+ """Set up progress context for a flow. Framework internal use."""
106
+ current_flow_weight = weights[step - 1] if step <= len(weights) else 1.0
107
+ ctx = ProgressContext(
108
+ webhook_url=webhook_url,
109
+ project_name=project_name,
110
+ run_id=run_id,
111
+ flow_run_id=flow_run_id,
112
+ flow_name=flow_name,
113
+ step=step,
114
+ total_steps=total_steps,
115
+ weights=weights,
116
+ completed_weight=completed_weight,
117
+ current_flow_weight=current_flow_weight,
118
+ queue=queue,
119
+ )
120
+ token = _context.set(ctx)
121
+ try:
122
+ yield
123
+ finally:
124
+ _context.reset(token)
125
+
126
+
127
+ __all__ = ["update", "webhook_worker", "flow_context", "ProgressContext"]
@@ -0,0 +1,5 @@
1
+ """@public Prompt builder for document-aware LLM interactions with caching."""
2
+
3
+ from .prompt_builder import EnvironmentVariable, PromptBuilder
4
+
5
+ __all__ = ["EnvironmentVariable", "PromptBuilder"]
@@ -0,0 +1,23 @@
1
+ You were provided with the following documents:
2
+ - **core documents** - these are already a reviewed documents which are part of official project documentation.
3
+ - **source documents** (called also **sources**) - these are not part of official project documentation and they will be deleted after your task is completed.
4
+
5
+ {% if core_documents %}
6
+ There are the following **core documents** available during this session:
7
+ {% for document in core_documents %}
8
+ - {{ document.id }} - {{ document.name }}
9
+ {% endfor %}
10
+ {% else %}
11
+ There are no **core documents** available during this session.
12
+ {% endif %}
13
+
14
+ {% if new_documents %}
15
+ There are the following **source documents** (called also **sources**) available during this session:
16
+ {% for document in new_documents %}
17
+ - {{ document.id }} - {{ document.name }}
18
+ {% endfor %}
19
+ {% else %}
20
+ There are no **source documents** (called also **sources**) available during this session.
21
+ {% endif %}
22
+
23
+ There won't be more **core documents** and **source documents** provided during this conversation, however **new core documents** may be provided.
@@ -0,0 +1,78 @@
1
+ """Prompt cache coordination for concurrent LLM calls."""
2
+
3
+ import asyncio
4
+ import time
5
+ from asyncio import Lock
6
+
7
+ from ai_pipeline_core.documents import Document
8
+ from ai_pipeline_core.llm import AIMessages, ModelName
9
+ from ai_pipeline_core.llm.model_response import ModelResponse
10
+
11
+ CACHED_PROMPTS: dict[str, Lock | int] = {}
12
+
13
+ _cache_lock = Lock()
14
+ CACHE_TTL = 600
15
+ MIN_SIZE_FOR_CACHE = 32 * 1024
16
+
17
+
18
+ class GlobalCacheLock:
19
+ """Serialize first prompt per cache key so subsequent calls get cache hits.
20
+
21
+ Waits for the first caller to complete before allowing others to execute,
22
+ ensuring the prompt cache is populated.
23
+ """
24
+
25
+ wait_time: float = 0
26
+ use_cache: bool = False
27
+
28
+ def _context_size(self, context: AIMessages) -> int:
29
+ length = 0
30
+ for msg in context:
31
+ if isinstance(msg, Document):
32
+ if msg.is_text:
33
+ length += msg.size
34
+ else:
35
+ length += 1024
36
+ elif isinstance(msg, str):
37
+ length += len(msg)
38
+ elif isinstance(msg, ModelResponse): # type: ignore[arg-type]
39
+ length += len(msg.content)
40
+ return length
41
+
42
+ def __init__(self, model: ModelName, context: AIMessages, cache_lock: bool): # noqa: D107
43
+ self.use_cache = cache_lock and self._context_size(context) > MIN_SIZE_FOR_CACHE
44
+ self.cache_key = f"{model}-{context.get_prompt_cache_key()}"
45
+ self.new_cache = False
46
+
47
+ async def __aenter__(self) -> "GlobalCacheLock":
48
+ wait_start = time.time()
49
+ if not self.use_cache:
50
+ return self
51
+
52
+ async with _cache_lock:
53
+ cache = CACHED_PROMPTS.get(self.cache_key)
54
+ if isinstance(cache, int):
55
+ if time.time() > cache + CACHE_TTL:
56
+ cache = None
57
+ else:
58
+ CACHED_PROMPTS[self.cache_key] = int(time.time())
59
+ self.wait_time = time.time() - wait_start
60
+ return self
61
+ if not cache:
62
+ self.new_cache = True
63
+ CACHED_PROMPTS[self.cache_key] = Lock()
64
+ await CACHED_PROMPTS[self.cache_key].acquire() # type: ignore[union-attr]
65
+
66
+ if not self.new_cache and isinstance(cache, Lock):
67
+ async with cache:
68
+ pass # waiting for lock to be released
69
+
70
+ self.wait_time = time.time() - wait_start
71
+ return self
72
+
73
+ async def __aexit__(self, exc_type: type | None, exc: BaseException | None, tb: object) -> None:
74
+ if self.new_cache:
75
+ await asyncio.sleep(1) # give time for cache to be prepared
76
+ async with _cache_lock:
77
+ CACHED_PROMPTS[self.cache_key].release() # type: ignore[union-attr]
78
+ CACHED_PROMPTS[self.cache_key] = int(time.time())
@@ -0,0 +1,6 @@
1
+ During this session the followiing **new core documents** were created:
2
+ {% for document in new_core_documents %}
3
+ - {{ document.id }} - {{ document.name }}
4
+ {% endfor %}
5
+
6
+ There won't be more documents provided during this session.