ai-pipeline-core 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,66 @@
1
+ """Flow configuration base class."""
2
+
3
+ from abc import ABC
4
+ from typing import ClassVar
5
+
6
+ from ai_pipeline_core.documents import DocumentList, FlowDocument
7
+
8
+
9
+ class FlowConfig(ABC):
10
+ """
11
+ Configuration for a flow. It makes flow easier to implement and test.
12
+ """
13
+
14
+ INPUT_DOCUMENT_TYPES: ClassVar[list[type[FlowDocument]]]
15
+ OUTPUT_DOCUMENT_TYPE: ClassVar[type[FlowDocument]]
16
+
17
+ @classmethod
18
+ def get_input_document_types(cls) -> list[type[FlowDocument]]:
19
+ """
20
+ Get the input document types for the flow.
21
+ """
22
+ return cls.INPUT_DOCUMENT_TYPES
23
+
24
+ @classmethod
25
+ def get_output_document_type(cls) -> type[FlowDocument]:
26
+ """
27
+ Get the output document type for the flow.
28
+ """
29
+ return cls.OUTPUT_DOCUMENT_TYPE
30
+
31
+ @classmethod
32
+ def has_input_documents(cls, documents: DocumentList) -> bool:
33
+ """
34
+ Check if the flow has all required input documents.
35
+ """
36
+ for doc_cls in cls.INPUT_DOCUMENT_TYPES:
37
+ if not any(isinstance(doc, doc_cls) for doc in documents):
38
+ return False
39
+ return True
40
+
41
+ @classmethod
42
+ def get_input_documents(cls, documents: DocumentList) -> DocumentList:
43
+ """
44
+ Get the input documents for the flow.
45
+ """
46
+ input_documents = DocumentList()
47
+ for doc_cls in cls.INPUT_DOCUMENT_TYPES:
48
+ filtered_documents = [doc for doc in documents if isinstance(doc, doc_cls)]
49
+ if not filtered_documents:
50
+ raise ValueError(f"No input document found for class {doc_cls.__name__}")
51
+ input_documents.extend(filtered_documents)
52
+ return input_documents
53
+
54
+ @classmethod
55
+ def validate_output_documents(cls, documents: DocumentList) -> None:
56
+ """
57
+ Validate the output documents of the flow.
58
+ """
59
+ assert isinstance(documents, DocumentList), "Documents must be a DocumentList"
60
+ output_document_class = cls.get_output_document_type()
61
+
62
+ invalid = [type(d).__name__ for d in documents if not isinstance(d, output_document_class)]
63
+ assert not invalid, (
64
+ "Documents must be of the correct type. "
65
+ f"Expected: {output_document_class.__name__}, Got invalid: {invalid}"
66
+ )
@@ -0,0 +1,19 @@
1
+ from .ai_messages import AIMessages, AIMessageType
2
+ from .client import (
3
+ generate,
4
+ generate_structured,
5
+ )
6
+ from .model_options import ModelOptions
7
+ from .model_response import ModelResponse, StructuredModelResponse
8
+ from .model_types import ModelName
9
+
10
+ __all__ = [
11
+ "AIMessages",
12
+ "AIMessageType",
13
+ "ModelName",
14
+ "ModelOptions",
15
+ "ModelResponse",
16
+ "StructuredModelResponse",
17
+ "generate",
18
+ "generate_structured",
19
+ ]
@@ -0,0 +1,129 @@
1
+ import base64
2
+ import json
3
+
4
+ from openai.types.chat import (
5
+ ChatCompletionContentPartParam,
6
+ ChatCompletionMessageParam,
7
+ )
8
+ from prefect.logging import get_logger
9
+
10
+ from ai_pipeline_core.documents import Document
11
+
12
+ from .model_response import ModelResponse
13
+
14
+ AIMessageType = str | Document | ModelResponse
15
+
16
+
17
+ class AIMessages(list[AIMessageType]):
18
+ def get_last_message(self) -> AIMessageType:
19
+ return self[-1]
20
+
21
+ def get_last_message_as_str(self) -> str:
22
+ last_message = self.get_last_message()
23
+ if isinstance(last_message, str):
24
+ return last_message
25
+ raise ValueError(f"Wrong message type: {type(last_message)}")
26
+
27
+ def to_prompt(self) -> list[ChatCompletionMessageParam]:
28
+ """Convert AIMessages to OpenAI-compatible format.
29
+
30
+ Returns:
31
+ List of ChatCompletionMessageParam for OpenAI API
32
+ """
33
+ messages: list[ChatCompletionMessageParam] = []
34
+
35
+ for message in self:
36
+ if isinstance(message, str):
37
+ messages.append({"role": "user", "content": message})
38
+ elif isinstance(message, Document):
39
+ messages.append({"role": "user", "content": AIMessages.document_to_prompt(message)})
40
+ elif isinstance(message, ModelResponse): # type: ignore
41
+ messages.append({"role": "assistant", "content": message.content})
42
+ else:
43
+ raise ValueError(f"Unsupported message type: {type(message)}")
44
+
45
+ return messages
46
+
47
+ def to_tracing_log(self) -> list[str]:
48
+ """Convert AIMessages to a list of strings for tracing."""
49
+ messages: list[str] = []
50
+ for message in self:
51
+ if isinstance(message, Document):
52
+ serialized_document = message.serialize_model()
53
+ del serialized_document["content"]
54
+ messages.append(json.dumps(serialized_document, indent=2))
55
+ elif isinstance(message, ModelResponse):
56
+ messages.append(message.content)
57
+ else:
58
+ assert isinstance(message, str)
59
+ messages.append(message)
60
+ return messages
61
+
62
+ @staticmethod
63
+ def document_to_prompt(document: Document) -> list[ChatCompletionContentPartParam]:
64
+ """
65
+ Convert a document to prompt format for LLM consumption.
66
+
67
+ Args:
68
+ document: The document to convert
69
+
70
+ Returns:
71
+ List of chat completion content parts for the prompt
72
+ """
73
+ prompt: list[ChatCompletionContentPartParam] = []
74
+
75
+ # Build the text header
76
+ description = (
77
+ f"<description>{document.description}</description>\n" if document.description else ""
78
+ )
79
+ header_text = (
80
+ f"<document>\n<id>{document.id}</id>\n<name>{document.name}</name>\n{description}"
81
+ )
82
+
83
+ # Handle text documents
84
+ if document.is_text:
85
+ content_text = (
86
+ f"{header_text}<content>\n{document.as_text()}\n</content>\n</document>\n"
87
+ )
88
+ prompt.append({"type": "text", "text": content_text})
89
+ return prompt
90
+
91
+ # Handle non-text documents
92
+ if not document.is_image and not document.is_pdf:
93
+ get_logger(__name__).error(
94
+ f"Document is not a text, image or PDF: {document.name} - {document.mime_type}"
95
+ )
96
+ return []
97
+
98
+ # Add header for binary content
99
+ prompt.append(
100
+ {
101
+ "type": "text",
102
+ "text": f"{header_text}<content>\n",
103
+ }
104
+ )
105
+
106
+ # Encode binary content
107
+ base64_content = base64.b64encode(document.content).decode("utf-8")
108
+ data_uri = f"data:{document.mime_type};base64,{base64_content}"
109
+
110
+ # Add appropriate content type
111
+ if document.is_pdf:
112
+ prompt.append(
113
+ {
114
+ "type": "file",
115
+ "file": {"file_data": data_uri},
116
+ }
117
+ )
118
+ else: # is_image
119
+ prompt.append(
120
+ {
121
+ "type": "image_url",
122
+ "image_url": {"url": data_uri, "detail": "high"},
123
+ }
124
+ )
125
+
126
+ # Close the document tag
127
+ prompt.append({"type": "text", "text": "</content>\n</document>\n"})
128
+
129
+ return prompt
@@ -0,0 +1,218 @@
1
+ import asyncio
2
+ from typing import Any, TypeVar
3
+
4
+ from lmnr import Laminar
5
+ from openai import AsyncOpenAI
6
+ from openai.types.chat import (
7
+ ChatCompletionMessageParam,
8
+ )
9
+ from prefect.logging import get_logger
10
+ from pydantic import BaseModel
11
+
12
+ from ai_pipeline_core.exceptions import LLMError
13
+ from ai_pipeline_core.settings import settings
14
+ from ai_pipeline_core.tracing import trace
15
+
16
+ from .ai_messages import AIMessages
17
+ from .model_options import ModelOptions
18
+ from .model_response import ModelResponse, StructuredModelResponse
19
+ from .model_types import ModelName
20
+
21
+ logger = get_logger()
22
+
23
+
24
+ def _process_messages(
25
+ context: AIMessages,
26
+ messages: AIMessages,
27
+ system_prompt: str | None = None,
28
+ ) -> list[ChatCompletionMessageParam]:
29
+ """Convert context and messages to OpenAI-compatible format.
30
+
31
+ Args:
32
+ context: Messages to be cached (optional)
33
+ messages: Regular messages that won't be cached
34
+ system_prompt: Optional system prompt
35
+
36
+ Returns:
37
+ List of formatted messages for OpenAI API
38
+ """
39
+
40
+ processed_messages: list[ChatCompletionMessageParam] = []
41
+
42
+ # Add system prompt if provided
43
+ if system_prompt:
44
+ processed_messages.append({"role": "system", "content": system_prompt})
45
+
46
+ # Process context messages with caching if provided
47
+ if context:
48
+ # Use AIMessages.to_prompt() for context
49
+ context_messages = context.to_prompt()
50
+
51
+ # Apply caching to context messages
52
+ for msg in context_messages:
53
+ if msg.get("role") == "user":
54
+ # Add cache control to user messages in context
55
+ msg["cache_control"] = { # type: ignore
56
+ "type": "ephemeral",
57
+ "ttl": "120s", # Cache for 2m
58
+ }
59
+ processed_messages.append(msg)
60
+
61
+ # Process regular messages without caching
62
+ if messages:
63
+ regular_messages = messages.to_prompt()
64
+ processed_messages.extend(regular_messages)
65
+
66
+ return processed_messages
67
+
68
+
69
+ async def _generate(
70
+ model: str, messages: list[ChatCompletionMessageParam], completion_kwargs: dict[str, Any]
71
+ ) -> ModelResponse:
72
+ async with AsyncOpenAI(
73
+ api_key=settings.openai_api_key,
74
+ base_url=settings.openai_base_url,
75
+ ) as client:
76
+ # Use parse for structured output, create for regular
77
+ if completion_kwargs.get("response_format"):
78
+ raw_response = await client.chat.completions.with_raw_response.parse( # type: ignore[var-annotated]
79
+ **completion_kwargs,
80
+ )
81
+ else:
82
+ raw_response = await client.chat.completions.with_raw_response.create( # type: ignore[var-annotated]
83
+ **completion_kwargs
84
+ )
85
+
86
+ response = ModelResponse(raw_response.parse()) # type: ignore[arg-type]
87
+ response.set_model_options(completion_kwargs)
88
+ response.set_headers(dict(raw_response.headers.items())) # type: ignore[arg-type]
89
+ return response
90
+
91
+
92
+ async def _generate_with_retry(
93
+ model: str,
94
+ context: AIMessages,
95
+ messages: AIMessages,
96
+ options: ModelOptions,
97
+ ) -> ModelResponse:
98
+ """Core generation logic with exponential backoff retry."""
99
+ if not model:
100
+ raise ValueError("Model must be provided")
101
+ if not context and not messages:
102
+ raise ValueError("Either context or messages must be provided")
103
+
104
+ processed_messages = _process_messages(context, messages, options.system_prompt)
105
+ completion_kwargs: dict[str, Any] = {
106
+ "model": model,
107
+ "messages": processed_messages,
108
+ **options.to_openai_completion_kwargs(),
109
+ }
110
+
111
+ for attempt in range(options.retries):
112
+ try:
113
+ with Laminar.start_as_current_span(model, span_type="LLM", input=messages) as span:
114
+ response = await _generate(model, processed_messages, completion_kwargs)
115
+ span.set_attributes(response.get_laminar_metadata())
116
+ Laminar.set_span_output(response.content)
117
+ if not response.content:
118
+ # disable cache in case of empty response
119
+ completion_kwargs["extra_body"]["cache"] = {"no-cache": True}
120
+ raise ValueError(f"Model {model} returned an empty response.")
121
+ return response
122
+ except (asyncio.TimeoutError, ValueError, Exception) as e:
123
+ logger.warning(
124
+ "LLM generation failed (attempt %d/%d): %s",
125
+ attempt + 1,
126
+ options.retries,
127
+ e,
128
+ )
129
+ if attempt == options.retries - 1:
130
+ raise LLMError("Exhausted all retry attempts for LLM generation.") from e
131
+
132
+ await asyncio.sleep(options.retry_delay_seconds)
133
+
134
+ raise LLMError("Unknown error occurred during LLM generation.")
135
+
136
+
137
+ @trace(ignore_inputs=["context"])
138
+ async def generate(
139
+ model: ModelName | str,
140
+ *,
141
+ context: AIMessages = AIMessages(),
142
+ messages: AIMessages | str,
143
+ options: ModelOptions = ModelOptions(),
144
+ ) -> ModelResponse:
145
+ """Generate response using a large or small model.
146
+
147
+ Args:
148
+ model: The model to use for generation
149
+ context: Messages to be cached (optional) - keyword only
150
+ messages: Regular messages that won't be cached - keyword only
151
+ options: Model options - keyword only
152
+
153
+ Returns:
154
+ Model response
155
+ """
156
+ if isinstance(messages, str):
157
+ messages = AIMessages([messages])
158
+
159
+ return await _generate_with_retry(model, context, messages, options)
160
+
161
+
162
+ T = TypeVar("T", bound=BaseModel)
163
+
164
+
165
+ @trace
166
+ async def generate_structured(
167
+ model: ModelName,
168
+ response_format: type[T],
169
+ *,
170
+ context: AIMessages = AIMessages(),
171
+ messages: AIMessages | str,
172
+ options: ModelOptions = ModelOptions(),
173
+ ) -> StructuredModelResponse[T]:
174
+ """Generate structured response using Pydantic models.
175
+
176
+ Args:
177
+ model: The model to use for generation
178
+ response_format: A Pydantic model class
179
+ context: Messages to be cached (optional) - keyword only
180
+ messages: Regular messages that won't be cached - keyword only
181
+ options: Model options - keyword only
182
+
183
+ Returns:
184
+ A StructuredModelResponse containing the parsed Pydantic model instance
185
+ """
186
+ options.response_format = response_format
187
+
188
+ if isinstance(messages, str):
189
+ messages = AIMessages([messages])
190
+
191
+ # Call the internal generate function with structured output enabled
192
+ response = await _generate_with_retry(model, context, messages, options)
193
+
194
+ # Extract the parsed value from the response
195
+ parsed_value: T | None = None
196
+
197
+ # Check if response has choices and parsed content
198
+ if response.choices and hasattr(response.choices[0].message, "parsed"):
199
+ parsed: Any = response.choices[0].message.parsed # type: ignore[attr-defined]
200
+
201
+ # If parsed is a dict, instantiate it as the response format class
202
+ if isinstance(parsed, dict):
203
+ parsed_value = response_format(**parsed)
204
+ # If it's already the right type, use it
205
+ elif isinstance(parsed, response_format):
206
+ parsed_value = parsed
207
+ else:
208
+ # Otherwise try to convert it
209
+ raise TypeError(
210
+ f"Unable to convert parsed response to {response_format.__name__}: "
211
+ f"got type {type(parsed).__name__}" # type: ignore[reportUnknownArgumentType]
212
+ )
213
+
214
+ if parsed_value is None:
215
+ raise ValueError("No parsed content available from the model response")
216
+
217
+ # Create a StructuredModelResponse with the parsed value
218
+ return StructuredModelResponse[T](chat_completion=response, parsed_value=parsed_value)
@@ -0,0 +1,39 @@
1
+ from typing import Any, Literal
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class ModelOptions(BaseModel):
7
+ system_prompt: str | None = None
8
+ search_context_size: Literal["low", "medium", "high"] | None = None
9
+ reasoning_effort: Literal["low", "medium", "high"] | None = None
10
+ retries: int = 3
11
+ retry_delay_seconds: int = 10
12
+ timeout: int = 300
13
+ service_tier: Literal["auto", "default", "flex", "scale", "priority"] | None = None
14
+ max_completion_tokens: int | None = None
15
+ response_format: type[BaseModel] | None = None
16
+
17
+ def to_openai_completion_kwargs(self) -> dict[str, Any]:
18
+ """Convert ModelOptions to OpenAI completion kwargs."""
19
+ kwargs: dict[str, Any] = {
20
+ "timeout": self.timeout,
21
+ "extra_body": {},
22
+ }
23
+
24
+ if self.max_completion_tokens:
25
+ kwargs["max_completion_tokens"] = self.max_completion_tokens
26
+
27
+ if self.reasoning_effort:
28
+ kwargs["reasoning_effort"] = self.reasoning_effort
29
+
30
+ if self.search_context_size:
31
+ kwargs["web_search_options"] = {"search_context_size": self.search_context_size}
32
+
33
+ if self.response_format:
34
+ kwargs["response_format"] = self.response_format
35
+
36
+ if self.service_tier:
37
+ kwargs["service_tier"] = self.service_tier
38
+
39
+ return kwargs
@@ -0,0 +1,149 @@
1
+ import copy
2
+ from typing import Any, Generic, TypeVar
3
+
4
+ from openai.types.chat import ChatCompletion, ParsedChatCompletion
5
+ from pydantic import BaseModel, Field
6
+
7
+ T = TypeVar("T", bound=BaseModel)
8
+
9
+
10
+ class ModelResponse(ChatCompletion):
11
+ """Response from an LLM without structured output."""
12
+
13
+ headers: dict[str, str] = Field(default_factory=dict)
14
+ model_options: dict[str, Any] = Field(default_factory=dict)
15
+
16
+ def __init__(self, chat_completion: ChatCompletion | None = None, **kwargs: Any) -> None:
17
+ """Initialize ModelResponse from a ChatCompletion."""
18
+ if chat_completion:
19
+ # Copy all attributes from the ChatCompletion instance
20
+ data = chat_completion.model_dump()
21
+ data["headers"] = {} # Add default headers
22
+ super().__init__(**data)
23
+ else:
24
+ # Initialize from kwargs
25
+ if "headers" not in kwargs:
26
+ kwargs["headers"] = {}
27
+ super().__init__(**kwargs)
28
+
29
+ @property
30
+ def content(self) -> str:
31
+ """Get the text content of the response."""
32
+ return self.choices[0].message.content or ""
33
+
34
+ def set_model_options(self, options: dict[str, Any]) -> None:
35
+ """Set the model options."""
36
+ self.model_options = copy.deepcopy(options)
37
+ if "messages" in self.model_options:
38
+ del self.model_options["messages"]
39
+
40
+ def set_headers(self, headers: dict[str, str]) -> None:
41
+ """Set the response headers."""
42
+ self.headers = copy.deepcopy(headers)
43
+
44
+ def get_laminar_metadata(self) -> dict[str, str | int | float]:
45
+ """Extract metadata for Laminar observability logging."""
46
+ metadata: dict[str, str | int | float] = {}
47
+
48
+ litellm_id = self.headers.get("x-litellm-call-id")
49
+ cost = float(self.headers.get("x-litellm-response-cost") or 0)
50
+
51
+ # Add all x-litellm-* headers
52
+ for header, value in self.headers.items():
53
+ if header.startswith("x-litellm-"):
54
+ header_name = header.replace("x-litellm-", "").lower()
55
+ metadata[f"litellm.{header_name}"] = value
56
+
57
+ # Add base metadata
58
+ metadata.update(
59
+ {
60
+ "gen_ai.response.id": litellm_id or self.id,
61
+ "gen_ai.response.model": self.model,
62
+ "get_ai.system": "litellm",
63
+ }
64
+ )
65
+
66
+ # Add usage metadata if available
67
+ if self.usage:
68
+ metadata.update(
69
+ {
70
+ "gen_ai.usage.prompt_tokens": self.usage.prompt_tokens,
71
+ "gen_ai.usage.completion_tokens": self.usage.completion_tokens,
72
+ "gen_ai.usage.total_tokens": self.usage.total_tokens,
73
+ }
74
+ )
75
+
76
+ # Check for cost in usage object
77
+ if hasattr(self.usage, "cost"):
78
+ # The 'cost' attribute is added by LiteLLM but not in OpenAI types
79
+ cost = float(self.usage.cost) # type: ignore[attr-defined]
80
+
81
+ # Add reasoning tokens if available
82
+ if completion_details := self.usage.completion_tokens_details:
83
+ if reasoning_tokens := completion_details.reasoning_tokens:
84
+ metadata["gen_ai.usage.reasoning_tokens"] = reasoning_tokens
85
+
86
+ # Add cached tokens if available
87
+ if prompt_details := self.usage.prompt_tokens_details:
88
+ if cached_tokens := prompt_details.cached_tokens:
89
+ metadata["gen_ai.usage.cached_tokens"] = cached_tokens
90
+
91
+ # Add cost metadata if available
92
+ if cost and cost > 0:
93
+ metadata.update(
94
+ {
95
+ "gen_ai.usage.output_cost": cost,
96
+ "gen_ai.usage.cost": cost,
97
+ "get_ai.cost": cost,
98
+ }
99
+ )
100
+
101
+ if self.model_options:
102
+ for key, value in self.model_options.items():
103
+ metadata[f"model_options.{key}"] = str(value)
104
+
105
+ return metadata
106
+
107
+
108
+ class StructuredModelResponse(ModelResponse, Generic[T]):
109
+ """Response from an LLM with structured output of type T."""
110
+
111
+ def __init__(
112
+ self,
113
+ chat_completion: ChatCompletion | None = None,
114
+ parsed_value: T | None = None,
115
+ **kwargs: Any,
116
+ ) -> None:
117
+ """Initialize StructuredModelResponse with a parsed value.
118
+
119
+ Args:
120
+ chat_completion: The base chat completion
121
+ parsed_value: The parsed structured output
122
+ **kwargs: Additional arguments for ChatCompletion
123
+ """
124
+ super().__init__(chat_completion, **kwargs)
125
+ self._parsed_value: T | None = parsed_value
126
+
127
+ # Extract parsed value from ParsedChatCompletion if available
128
+ if chat_completion and isinstance(chat_completion, ParsedChatCompletion):
129
+ if chat_completion.choices: # type: ignore[attr-defined]
130
+ message = chat_completion.choices[0].message # type: ignore[attr-defined]
131
+ if hasattr(message, "parsed"): # type: ignore
132
+ self._parsed_value = message.parsed # type: ignore[attr-defined]
133
+
134
+ @property
135
+ def parsed(self) -> T:
136
+ """Get the parsed structured output.
137
+
138
+ Returns:
139
+ The parsed value of type T.
140
+
141
+ Raises:
142
+ ValueError: If no parsed content is available.
143
+ """
144
+ if self._parsed_value is not None:
145
+ return self._parsed_value
146
+
147
+ raise ValueError(
148
+ "No parsed content available. This should not happen for StructuredModelResponse."
149
+ )
@@ -0,0 +1,17 @@
1
+ from typing import Literal, TypeAlias
2
+
3
+ ModelName: TypeAlias = Literal[
4
+ # Core models
5
+ "gemini-2.5-pro",
6
+ "gpt-5",
7
+ "grok-4",
8
+ # Small models
9
+ "gemini-2.5-flash",
10
+ "gpt-5-mini",
11
+ "grok-3-mini",
12
+ # Search models
13
+ "gemini-2.5-flash-search",
14
+ "sonar-pro-search",
15
+ "gpt-4o-search",
16
+ "grok-3-mini-search",
17
+ ]
@@ -0,0 +1,10 @@
1
+ from .logging_config import LoggingConfig, get_pipeline_logger, setup_logging
2
+ from .logging_mixin import LoggerMixin, StructuredLoggerMixin
3
+
4
+ __all__ = [
5
+ "LoggerMixin",
6
+ "StructuredLoggerMixin",
7
+ "LoggingConfig",
8
+ "setup_logging",
9
+ "get_pipeline_logger",
10
+ ]