ai-pipeline-core 0.2.9__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_pipeline_core/__init__.py +14 -4
- ai_pipeline_core/deployment/__init__.py +46 -0
- ai_pipeline_core/deployment/base.py +681 -0
- ai_pipeline_core/deployment/contract.py +84 -0
- ai_pipeline_core/deployment/helpers.py +98 -0
- ai_pipeline_core/documents/flow_document.py +1 -1
- ai_pipeline_core/documents/task_document.py +1 -1
- ai_pipeline_core/documents/temporary_document.py +1 -1
- ai_pipeline_core/flow/config.py +13 -2
- ai_pipeline_core/flow/options.py +1 -1
- ai_pipeline_core/llm/client.py +1 -3
- ai_pipeline_core/llm/model_types.py +0 -1
- ai_pipeline_core/pipeline.py +1 -1
- ai_pipeline_core/progress.py +127 -0
- ai_pipeline_core/prompt_builder/__init__.py +5 -0
- ai_pipeline_core/prompt_builder/documents_prompt.jinja2 +23 -0
- ai_pipeline_core/prompt_builder/global_cache.py +78 -0
- ai_pipeline_core/prompt_builder/new_core_documents_prompt.jinja2 +6 -0
- ai_pipeline_core/prompt_builder/prompt_builder.py +253 -0
- ai_pipeline_core/prompt_builder/system_prompt.jinja2 +41 -0
- ai_pipeline_core/tracing.py +1 -1
- ai_pipeline_core/utils/remote_deployment.py +37 -187
- {ai_pipeline_core-0.2.9.dist-info → ai_pipeline_core-0.3.0.dist-info}/METADATA +23 -20
- ai_pipeline_core-0.3.0.dist-info/RECORD +49 -0
- {ai_pipeline_core-0.2.9.dist-info → ai_pipeline_core-0.3.0.dist-info}/WHEEL +1 -1
- ai_pipeline_core/simple_runner/__init__.py +0 -14
- ai_pipeline_core/simple_runner/cli.py +0 -254
- ai_pipeline_core/simple_runner/simple_runner.py +0 -247
- ai_pipeline_core-0.2.9.dist-info/RECORD +0 -41
- {ai_pipeline_core-0.2.9.dist-info → ai_pipeline_core-0.3.0.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
"""Unified pipeline run response contract.
|
|
2
|
+
|
|
3
|
+
@public
|
|
4
|
+
|
|
5
|
+
Single source of truth for the response shape used by both
|
|
6
|
+
webhook push (ai-pipeline-core) and polling pull (unified-middleware).
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from datetime import datetime
|
|
10
|
+
from typing import Annotated, Literal
|
|
11
|
+
from uuid import UUID
|
|
12
|
+
|
|
13
|
+
from pydantic import BaseModel, ConfigDict, Discriminator
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class _RunBase(BaseModel):
|
|
17
|
+
"""Common fields on every run response variant."""
|
|
18
|
+
|
|
19
|
+
type: str
|
|
20
|
+
flow_run_id: UUID
|
|
21
|
+
project_name: str
|
|
22
|
+
state: str # PENDING, RUNNING, COMPLETED, FAILED, CRASHED, CANCELLED
|
|
23
|
+
timestamp: datetime
|
|
24
|
+
storage_uri: str = ""
|
|
25
|
+
|
|
26
|
+
model_config = ConfigDict(frozen=True)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class PendingRun(_RunBase):
|
|
30
|
+
"""Pipeline queued or running but no progress reported yet."""
|
|
31
|
+
|
|
32
|
+
type: Literal["pending"] = "pending" # pyright: ignore[reportIncompatibleVariableOverride]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class ProgressRun(_RunBase):
|
|
36
|
+
"""Pipeline running with step-level progress data."""
|
|
37
|
+
|
|
38
|
+
type: Literal["progress"] = "progress" # pyright: ignore[reportIncompatibleVariableOverride]
|
|
39
|
+
step: int
|
|
40
|
+
total_steps: int
|
|
41
|
+
flow_name: str
|
|
42
|
+
status: str # "started", "completed", "cached"
|
|
43
|
+
progress: float # overall 0.0–1.0
|
|
44
|
+
step_progress: float # within step 0.0–1.0
|
|
45
|
+
message: str
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class DeploymentResultData(BaseModel):
|
|
49
|
+
"""Typed result payload — always has success + optional error."""
|
|
50
|
+
|
|
51
|
+
success: bool
|
|
52
|
+
error: str | None = None
|
|
53
|
+
|
|
54
|
+
model_config = ConfigDict(frozen=True, extra="allow")
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class CompletedRun(_RunBase):
|
|
58
|
+
"""Pipeline finished (Prefect COMPLETED). Check result.success for business outcome."""
|
|
59
|
+
|
|
60
|
+
type: Literal["completed"] = "completed" # pyright: ignore[reportIncompatibleVariableOverride]
|
|
61
|
+
result: DeploymentResultData
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class FailedRun(_RunBase):
|
|
65
|
+
"""Pipeline crashed — execution error, not business logic."""
|
|
66
|
+
|
|
67
|
+
type: Literal["failed"] = "failed" # pyright: ignore[reportIncompatibleVariableOverride]
|
|
68
|
+
error: str
|
|
69
|
+
result: DeploymentResultData | None = None
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
RunResponse = Annotated[
|
|
73
|
+
PendingRun | ProgressRun | CompletedRun | FailedRun,
|
|
74
|
+
Discriminator("type"),
|
|
75
|
+
]
|
|
76
|
+
|
|
77
|
+
__all__ = [
|
|
78
|
+
"CompletedRun",
|
|
79
|
+
"DeploymentResultData",
|
|
80
|
+
"FailedRun",
|
|
81
|
+
"PendingRun",
|
|
82
|
+
"ProgressRun",
|
|
83
|
+
"RunResponse",
|
|
84
|
+
]
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
"""Helper functions for pipeline deployments."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import re
|
|
5
|
+
from typing import Any, Literal, TypedDict
|
|
6
|
+
|
|
7
|
+
import httpx
|
|
8
|
+
|
|
9
|
+
from ai_pipeline_core.deployment.contract import CompletedRun, FailedRun, ProgressRun
|
|
10
|
+
from ai_pipeline_core.documents import Document, DocumentList, FlowDocument
|
|
11
|
+
from ai_pipeline_core.logging import get_pipeline_logger
|
|
12
|
+
|
|
13
|
+
logger = get_pipeline_logger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class StatusPayload(TypedDict):
|
|
17
|
+
"""Webhook payload for Prefect state transitions (sub-flow level)."""
|
|
18
|
+
|
|
19
|
+
type: Literal["status"]
|
|
20
|
+
flow_run_id: str
|
|
21
|
+
project_name: str
|
|
22
|
+
step: int
|
|
23
|
+
total_steps: int
|
|
24
|
+
flow_name: str
|
|
25
|
+
state: str # RUNNING, COMPLETED, FAILED, CRASHED, CANCELLED
|
|
26
|
+
state_name: str
|
|
27
|
+
timestamp: str
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def class_name_to_deployment_name(class_name: str) -> str:
|
|
31
|
+
"""Convert PascalCase to kebab-case: ResearchPipeline → research-pipeline."""
|
|
32
|
+
name = re.sub(r"(?<!^)(?=[A-Z])", "-", class_name)
|
|
33
|
+
return name.lower()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def extract_generic_params(cls: type) -> tuple[type | None, type | None]:
|
|
37
|
+
"""Extract TOptions and TResult from PipelineDeployment generic args."""
|
|
38
|
+
from ai_pipeline_core.deployment.base import PipelineDeployment # noqa: PLC0415
|
|
39
|
+
|
|
40
|
+
for base in getattr(cls, "__orig_bases__", []):
|
|
41
|
+
origin = getattr(base, "__origin__", None)
|
|
42
|
+
if origin is PipelineDeployment:
|
|
43
|
+
args = getattr(base, "__args__", ())
|
|
44
|
+
if len(args) == 2:
|
|
45
|
+
return args[0], args[1]
|
|
46
|
+
|
|
47
|
+
return None, None
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
async def download_documents(
|
|
51
|
+
urls: list[str],
|
|
52
|
+
document_type: type[FlowDocument],
|
|
53
|
+
) -> DocumentList:
|
|
54
|
+
"""Download documents from URLs and return as DocumentList."""
|
|
55
|
+
documents: list[Document] = []
|
|
56
|
+
async with httpx.AsyncClient(timeout=60, follow_redirects=True) as client:
|
|
57
|
+
for url in urls:
|
|
58
|
+
response = await client.get(url)
|
|
59
|
+
response.raise_for_status()
|
|
60
|
+
filename = url.split("/")[-1].split("?")[0] or "document"
|
|
61
|
+
documents.append(document_type(name=filename, content=response.content))
|
|
62
|
+
return DocumentList(documents)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
async def upload_documents(documents: DocumentList, url_mapping: dict[str, str]) -> None:
|
|
66
|
+
"""Upload documents to their mapped URLs."""
|
|
67
|
+
async with httpx.AsyncClient(timeout=60, follow_redirects=True) as client:
|
|
68
|
+
for doc in documents:
|
|
69
|
+
if doc.name in url_mapping:
|
|
70
|
+
response = await client.put(
|
|
71
|
+
url_mapping[doc.name],
|
|
72
|
+
content=doc.content,
|
|
73
|
+
headers={"Content-Type": doc.mime_type},
|
|
74
|
+
)
|
|
75
|
+
response.raise_for_status()
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
async def send_webhook(
|
|
79
|
+
url: str,
|
|
80
|
+
payload: ProgressRun | CompletedRun | FailedRun,
|
|
81
|
+
max_retries: int = 3,
|
|
82
|
+
retry_delay: float = 10.0,
|
|
83
|
+
) -> None:
|
|
84
|
+
"""Send webhook with retries."""
|
|
85
|
+
data: dict[str, Any] = payload.model_dump(mode="json")
|
|
86
|
+
for attempt in range(max_retries):
|
|
87
|
+
try:
|
|
88
|
+
async with httpx.AsyncClient(timeout=30) as client:
|
|
89
|
+
response = await client.post(url, json=data, follow_redirects=True)
|
|
90
|
+
response.raise_for_status()
|
|
91
|
+
return
|
|
92
|
+
except Exception as e:
|
|
93
|
+
if attempt < max_retries - 1:
|
|
94
|
+
logger.warning(f"Webhook retry {attempt + 1}/{max_retries}: {e}")
|
|
95
|
+
await asyncio.sleep(retry_delay)
|
|
96
|
+
else:
|
|
97
|
+
logger.error(f"Webhook failed after {max_retries} attempts: {e}")
|
|
98
|
+
raise
|
|
@@ -18,7 +18,7 @@ class FlowDocument(Document):
|
|
|
18
18
|
|
|
19
19
|
FlowDocument is used for data that needs to be saved between pipeline
|
|
20
20
|
steps and across multiple flow executions. These documents are typically
|
|
21
|
-
written to the file system using the
|
|
21
|
+
written to the file system using the deployment utilities.
|
|
22
22
|
|
|
23
23
|
Key characteristics:
|
|
24
24
|
- Persisted to file system between pipeline steps
|
|
@@ -40,7 +40,7 @@ class TaskDocument(Document):
|
|
|
40
40
|
|
|
41
41
|
Note:
|
|
42
42
|
- Cannot instantiate TaskDocument directly - must subclass
|
|
43
|
-
- Not saved by
|
|
43
|
+
- Not saved by deployment utilities
|
|
44
44
|
- Reduces I/O overhead for temporary data
|
|
45
45
|
- No additional abstract methods to implement
|
|
46
46
|
"""
|
|
@@ -23,7 +23,7 @@ class TemporaryDocument(Document):
|
|
|
23
23
|
- Can be instantiated directly (not abstract)
|
|
24
24
|
- Cannot be subclassed (annotated with Python's @final decorator in code)
|
|
25
25
|
- Useful for transient data like API responses or intermediate calculations
|
|
26
|
-
- Ignored by
|
|
26
|
+
- Ignored by deployment save operations
|
|
27
27
|
- Useful for tests and debugging
|
|
28
28
|
|
|
29
29
|
Creating TemporaryDocuments:
|
ai_pipeline_core/flow/config.py
CHANGED
|
@@ -39,11 +39,13 @@ class FlowConfig(ABC):
|
|
|
39
39
|
Class Variables:
|
|
40
40
|
INPUT_DOCUMENT_TYPES: List of FlowDocument types this flow accepts
|
|
41
41
|
OUTPUT_DOCUMENT_TYPE: Single FlowDocument type this flow produces
|
|
42
|
+
WEIGHT: Weight for progress calculation (default 1.0, based on avg duration)
|
|
42
43
|
|
|
43
44
|
Validation Rules:
|
|
44
45
|
- INPUT_DOCUMENT_TYPES and OUTPUT_DOCUMENT_TYPE must be defined
|
|
45
46
|
- OUTPUT_DOCUMENT_TYPE cannot be in INPUT_DOCUMENT_TYPES (prevents cycles)
|
|
46
47
|
- Field names must be exact (common typos are detected)
|
|
48
|
+
- WEIGHT must be a positive number
|
|
47
49
|
|
|
48
50
|
Why this matters:
|
|
49
51
|
Flows connect in pipelines where one flow's output becomes another's input.
|
|
@@ -54,6 +56,7 @@ class FlowConfig(ABC):
|
|
|
54
56
|
>>> class ProcessingFlowConfig(FlowConfig):
|
|
55
57
|
... INPUT_DOCUMENT_TYPES = [RawDataDocument]
|
|
56
58
|
... OUTPUT_DOCUMENT_TYPE = ProcessedDocument # Different type!
|
|
59
|
+
... WEIGHT = 45.0 # Average ~45 minutes
|
|
57
60
|
>>>
|
|
58
61
|
>>> # Use in @pipeline_flow - RECOMMENDED PATTERN
|
|
59
62
|
>>> @pipeline_flow(config=ProcessingFlowConfig, name="processing")
|
|
@@ -72,11 +75,12 @@ class FlowConfig(ABC):
|
|
|
72
75
|
Note:
|
|
73
76
|
- Validation happens at class definition time
|
|
74
77
|
- Helps catch configuration errors early
|
|
75
|
-
- Used by
|
|
78
|
+
- Used by PipelineDeployment to manage document flow
|
|
76
79
|
"""
|
|
77
80
|
|
|
78
81
|
INPUT_DOCUMENT_TYPES: ClassVar[list[type[FlowDocument]]]
|
|
79
82
|
OUTPUT_DOCUMENT_TYPE: ClassVar[type[FlowDocument]]
|
|
83
|
+
WEIGHT: ClassVar[float] = 1.0
|
|
80
84
|
|
|
81
85
|
def __init_subclass__(cls, **kwargs: Any):
|
|
82
86
|
"""Validate flow configuration at subclass definition time.
|
|
@@ -106,7 +110,7 @@ class FlowConfig(ABC):
|
|
|
106
110
|
return
|
|
107
111
|
|
|
108
112
|
# Check for invalid field names (common mistakes)
|
|
109
|
-
allowed_fields = {"INPUT_DOCUMENT_TYPES", "OUTPUT_DOCUMENT_TYPE"}
|
|
113
|
+
allowed_fields = {"INPUT_DOCUMENT_TYPES", "OUTPUT_DOCUMENT_TYPE", "WEIGHT"}
|
|
110
114
|
class_attrs = {name for name in dir(cls) if not name.startswith("_") and name.isupper()}
|
|
111
115
|
|
|
112
116
|
# Find fields that look like they might be mistakes
|
|
@@ -145,6 +149,13 @@ class FlowConfig(ABC):
|
|
|
145
149
|
f"({cls.OUTPUT_DOCUMENT_TYPE.__name__}) cannot be in INPUT_DOCUMENT_TYPES"
|
|
146
150
|
)
|
|
147
151
|
|
|
152
|
+
# Validate WEIGHT
|
|
153
|
+
weight = getattr(cls, "WEIGHT", 1.0)
|
|
154
|
+
if not isinstance(weight, (int, float)) or weight <= 0:
|
|
155
|
+
raise TypeError(
|
|
156
|
+
f"FlowConfig {cls.__name__}: WEIGHT must be a positive number, got {weight}"
|
|
157
|
+
)
|
|
158
|
+
|
|
148
159
|
@classmethod
|
|
149
160
|
def get_input_document_types(cls) -> list[type[FlowDocument]]:
|
|
150
161
|
"""Get the list of input document types this flow accepts.
|
ai_pipeline_core/flow/options.py
CHANGED
|
@@ -53,7 +53,7 @@ class FlowOptions(BaseSettings):
|
|
|
53
53
|
- Frozen (immutable) after creation
|
|
54
54
|
- Extra fields ignored (not strict)
|
|
55
55
|
- Can be populated from environment variables
|
|
56
|
-
- Used by
|
|
56
|
+
- Used by PipelineDeployment.run_cli for command-line parsing
|
|
57
57
|
|
|
58
58
|
Note:
|
|
59
59
|
The base class provides model selection. Subclasses should
|
ai_pipeline_core/llm/client.py
CHANGED
|
@@ -154,8 +154,6 @@ def _model_name_to_openrouter_model(model: ModelName) -> str:
|
|
|
154
154
|
return "openai/gpt-4o-search-preview"
|
|
155
155
|
if model == "gemini-2.5-flash-search":
|
|
156
156
|
return "google/gemini-2.5-flash:online"
|
|
157
|
-
if model == "grok-4-fast-search":
|
|
158
|
-
return "x-ai/grok-4-fast:online"
|
|
159
157
|
if model == "sonar-pro-search":
|
|
160
158
|
return "perplexity/sonar-pro-search"
|
|
161
159
|
if model.startswith("gemini"):
|
|
@@ -295,7 +293,7 @@ async def _generate_with_retry(
|
|
|
295
293
|
model, span_type="LLM", input=processed_messages
|
|
296
294
|
) as span:
|
|
297
295
|
response = await _generate(model, processed_messages, completion_kwargs)
|
|
298
|
-
span.set_attributes(response.get_laminar_metadata())
|
|
296
|
+
span.set_attributes(response.get_laminar_metadata()) # pyright: ignore[reportArgumentType]
|
|
299
297
|
Laminar.set_span_output([
|
|
300
298
|
r for r in (response.reasoning_content, response.content) if r
|
|
301
299
|
])
|
ai_pipeline_core/pipeline.py
CHANGED
|
@@ -605,7 +605,7 @@ def pipeline_flow(
|
|
|
605
605
|
- pipeline_task: For task-level decoration
|
|
606
606
|
- FlowConfig: Type-safe flow configuration
|
|
607
607
|
- FlowOptions: Base class for flow options
|
|
608
|
-
-
|
|
608
|
+
- PipelineDeployment: Execute flows locally or remotely
|
|
609
609
|
"""
|
|
610
610
|
flow_decorator: Callable[..., Any] = _prefect_flow
|
|
611
611
|
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
"""@public Intra-flow progress tracking with order-preserving webhook delivery."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
from collections.abc import Generator
|
|
5
|
+
from contextlib import contextmanager
|
|
6
|
+
from contextvars import ContextVar
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from datetime import datetime, timezone
|
|
9
|
+
from uuid import UUID
|
|
10
|
+
|
|
11
|
+
from ai_pipeline_core.deployment.contract import ProgressRun
|
|
12
|
+
from ai_pipeline_core.logging import get_pipeline_logger
|
|
13
|
+
|
|
14
|
+
logger = get_pipeline_logger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@dataclass(frozen=True, slots=True)
|
|
18
|
+
class ProgressContext:
|
|
19
|
+
"""Internal context holding state for progress calculation and webhook delivery."""
|
|
20
|
+
|
|
21
|
+
webhook_url: str
|
|
22
|
+
project_name: str
|
|
23
|
+
run_id: str
|
|
24
|
+
flow_run_id: str
|
|
25
|
+
flow_name: str
|
|
26
|
+
step: int
|
|
27
|
+
total_steps: int
|
|
28
|
+
weights: tuple[float, ...]
|
|
29
|
+
completed_weight: float
|
|
30
|
+
current_flow_weight: float
|
|
31
|
+
queue: asyncio.Queue[ProgressRun | None]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
_context: ContextVar[ProgressContext | None] = ContextVar("progress_context", default=None)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
async def update(fraction: float, message: str = "") -> None:
|
|
38
|
+
"""@public Report intra-flow progress (0.0-1.0). No-op without context."""
|
|
39
|
+
ctx = _context.get()
|
|
40
|
+
if ctx is None or not ctx.webhook_url:
|
|
41
|
+
return
|
|
42
|
+
|
|
43
|
+
fraction = max(0.0, min(1.0, fraction))
|
|
44
|
+
|
|
45
|
+
total_weight = sum(ctx.weights)
|
|
46
|
+
if total_weight > 0:
|
|
47
|
+
overall = (ctx.completed_weight + ctx.current_flow_weight * fraction) / total_weight
|
|
48
|
+
else:
|
|
49
|
+
overall = fraction
|
|
50
|
+
overall = round(max(0.0, min(1.0, overall)), 4)
|
|
51
|
+
|
|
52
|
+
payload = ProgressRun(
|
|
53
|
+
flow_run_id=UUID(ctx.flow_run_id) if ctx.flow_run_id else UUID(int=0),
|
|
54
|
+
project_name=ctx.project_name,
|
|
55
|
+
state="RUNNING",
|
|
56
|
+
timestamp=datetime.now(timezone.utc),
|
|
57
|
+
step=ctx.step,
|
|
58
|
+
total_steps=ctx.total_steps,
|
|
59
|
+
flow_name=ctx.flow_name,
|
|
60
|
+
status="progress",
|
|
61
|
+
progress=overall,
|
|
62
|
+
step_progress=round(fraction, 4),
|
|
63
|
+
message=message,
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
ctx.queue.put_nowait(payload)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
async def webhook_worker(
|
|
70
|
+
queue: asyncio.Queue[ProgressRun | None],
|
|
71
|
+
webhook_url: str,
|
|
72
|
+
max_retries: int = 3,
|
|
73
|
+
retry_delay: float = 10.0,
|
|
74
|
+
) -> None:
|
|
75
|
+
"""Process webhooks sequentially with retries, preserving order."""
|
|
76
|
+
from ai_pipeline_core.deployment.helpers import send_webhook # noqa: PLC0415
|
|
77
|
+
|
|
78
|
+
while True:
|
|
79
|
+
payload = await queue.get()
|
|
80
|
+
if payload is None:
|
|
81
|
+
queue.task_done()
|
|
82
|
+
break
|
|
83
|
+
|
|
84
|
+
try:
|
|
85
|
+
await send_webhook(webhook_url, payload, max_retries, retry_delay)
|
|
86
|
+
except Exception:
|
|
87
|
+
pass # Already logged in send_webhook
|
|
88
|
+
|
|
89
|
+
queue.task_done()
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@contextmanager
|
|
93
|
+
def flow_context(
|
|
94
|
+
webhook_url: str,
|
|
95
|
+
project_name: str,
|
|
96
|
+
run_id: str,
|
|
97
|
+
flow_run_id: str,
|
|
98
|
+
flow_name: str,
|
|
99
|
+
step: int,
|
|
100
|
+
total_steps: int,
|
|
101
|
+
weights: tuple[float, ...],
|
|
102
|
+
completed_weight: float,
|
|
103
|
+
queue: asyncio.Queue[ProgressRun | None],
|
|
104
|
+
) -> Generator[None, None, None]:
|
|
105
|
+
"""Set up progress context for a flow. Framework internal use."""
|
|
106
|
+
current_flow_weight = weights[step - 1] if step <= len(weights) else 1.0
|
|
107
|
+
ctx = ProgressContext(
|
|
108
|
+
webhook_url=webhook_url,
|
|
109
|
+
project_name=project_name,
|
|
110
|
+
run_id=run_id,
|
|
111
|
+
flow_run_id=flow_run_id,
|
|
112
|
+
flow_name=flow_name,
|
|
113
|
+
step=step,
|
|
114
|
+
total_steps=total_steps,
|
|
115
|
+
weights=weights,
|
|
116
|
+
completed_weight=completed_weight,
|
|
117
|
+
current_flow_weight=current_flow_weight,
|
|
118
|
+
queue=queue,
|
|
119
|
+
)
|
|
120
|
+
token = _context.set(ctx)
|
|
121
|
+
try:
|
|
122
|
+
yield
|
|
123
|
+
finally:
|
|
124
|
+
_context.reset(token)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
__all__ = ["update", "webhook_worker", "flow_context", "ProgressContext"]
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
You were provided with the following documents:
|
|
2
|
+
- **core documents** - these are already a reviewed documents which are part of official project documentation.
|
|
3
|
+
- **source documents** (called also **sources**) - these are not part of official project documentation and they will be deleted after your task is completed.
|
|
4
|
+
|
|
5
|
+
{% if core_documents %}
|
|
6
|
+
There are the following **core documents** available during this session:
|
|
7
|
+
{% for document in core_documents %}
|
|
8
|
+
- {{ document.id }} - {{ document.name }}
|
|
9
|
+
{% endfor %}
|
|
10
|
+
{% else %}
|
|
11
|
+
There are no **core documents** available during this session.
|
|
12
|
+
{% endif %}
|
|
13
|
+
|
|
14
|
+
{% if new_documents %}
|
|
15
|
+
There are the following **source documents** (called also **sources**) available during this session:
|
|
16
|
+
{% for document in new_documents %}
|
|
17
|
+
- {{ document.id }} - {{ document.name }}
|
|
18
|
+
{% endfor %}
|
|
19
|
+
{% else %}
|
|
20
|
+
There are no **source documents** (called also **sources**) available during this session.
|
|
21
|
+
{% endif %}
|
|
22
|
+
|
|
23
|
+
There won't be more **core documents** and **source documents** provided during this conversation, however **new core documents** may be provided.
|
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
"""Prompt cache coordination for concurrent LLM calls."""
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import time
|
|
5
|
+
from asyncio import Lock
|
|
6
|
+
|
|
7
|
+
from ai_pipeline_core.documents import Document
|
|
8
|
+
from ai_pipeline_core.llm import AIMessages, ModelName
|
|
9
|
+
from ai_pipeline_core.llm.model_response import ModelResponse
|
|
10
|
+
|
|
11
|
+
CACHED_PROMPTS: dict[str, Lock | int] = {}
|
|
12
|
+
|
|
13
|
+
_cache_lock = Lock()
|
|
14
|
+
CACHE_TTL = 600
|
|
15
|
+
MIN_SIZE_FOR_CACHE = 32 * 1024
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class GlobalCacheLock:
|
|
19
|
+
"""Serialize first prompt per cache key so subsequent calls get cache hits.
|
|
20
|
+
|
|
21
|
+
Waits for the first caller to complete before allowing others to execute,
|
|
22
|
+
ensuring the prompt cache is populated.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
wait_time: float = 0
|
|
26
|
+
use_cache: bool = False
|
|
27
|
+
|
|
28
|
+
def _context_size(self, context: AIMessages) -> int:
|
|
29
|
+
length = 0
|
|
30
|
+
for msg in context:
|
|
31
|
+
if isinstance(msg, Document):
|
|
32
|
+
if msg.is_text:
|
|
33
|
+
length += msg.size
|
|
34
|
+
else:
|
|
35
|
+
length += 1024
|
|
36
|
+
elif isinstance(msg, str):
|
|
37
|
+
length += len(msg)
|
|
38
|
+
elif isinstance(msg, ModelResponse): # type: ignore[arg-type]
|
|
39
|
+
length += len(msg.content)
|
|
40
|
+
return length
|
|
41
|
+
|
|
42
|
+
def __init__(self, model: ModelName, context: AIMessages, cache_lock: bool): # noqa: D107
|
|
43
|
+
self.use_cache = cache_lock and self._context_size(context) > MIN_SIZE_FOR_CACHE
|
|
44
|
+
self.cache_key = f"{model}-{context.get_prompt_cache_key()}"
|
|
45
|
+
self.new_cache = False
|
|
46
|
+
|
|
47
|
+
async def __aenter__(self) -> "GlobalCacheLock":
|
|
48
|
+
wait_start = time.time()
|
|
49
|
+
if not self.use_cache:
|
|
50
|
+
return self
|
|
51
|
+
|
|
52
|
+
async with _cache_lock:
|
|
53
|
+
cache = CACHED_PROMPTS.get(self.cache_key)
|
|
54
|
+
if isinstance(cache, int):
|
|
55
|
+
if time.time() > cache + CACHE_TTL:
|
|
56
|
+
cache = None
|
|
57
|
+
else:
|
|
58
|
+
CACHED_PROMPTS[self.cache_key] = int(time.time())
|
|
59
|
+
self.wait_time = time.time() - wait_start
|
|
60
|
+
return self
|
|
61
|
+
if not cache:
|
|
62
|
+
self.new_cache = True
|
|
63
|
+
CACHED_PROMPTS[self.cache_key] = Lock()
|
|
64
|
+
await CACHED_PROMPTS[self.cache_key].acquire() # type: ignore[union-attr]
|
|
65
|
+
|
|
66
|
+
if not self.new_cache and isinstance(cache, Lock):
|
|
67
|
+
async with cache:
|
|
68
|
+
pass # waiting for lock to be released
|
|
69
|
+
|
|
70
|
+
self.wait_time = time.time() - wait_start
|
|
71
|
+
return self
|
|
72
|
+
|
|
73
|
+
async def __aexit__(self, exc_type: type | None, exc: BaseException | None, tb: object) -> None:
|
|
74
|
+
if self.new_cache:
|
|
75
|
+
await asyncio.sleep(1) # give time for cache to be prepared
|
|
76
|
+
async with _cache_lock:
|
|
77
|
+
CACHED_PROMPTS[self.cache_key].release() # type: ignore[union-attr]
|
|
78
|
+
CACHED_PROMPTS[self.cache_key] = int(time.time())
|