polos-sdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- polos/__init__.py +105 -0
- polos/agents/__init__.py +7 -0
- polos/agents/agent.py +746 -0
- polos/agents/conversation_history.py +121 -0
- polos/agents/stop_conditions.py +280 -0
- polos/agents/stream.py +635 -0
- polos/core/__init__.py +0 -0
- polos/core/context.py +143 -0
- polos/core/state.py +26 -0
- polos/core/step.py +1380 -0
- polos/core/workflow.py +1192 -0
- polos/features/__init__.py +0 -0
- polos/features/events.py +456 -0
- polos/features/schedules.py +110 -0
- polos/features/tracing.py +605 -0
- polos/features/wait.py +82 -0
- polos/llm/__init__.py +9 -0
- polos/llm/generate.py +152 -0
- polos/llm/providers/__init__.py +5 -0
- polos/llm/providers/anthropic.py +615 -0
- polos/llm/providers/azure.py +42 -0
- polos/llm/providers/base.py +196 -0
- polos/llm/providers/fireworks.py +41 -0
- polos/llm/providers/gemini.py +40 -0
- polos/llm/providers/groq.py +40 -0
- polos/llm/providers/openai.py +1021 -0
- polos/llm/providers/together.py +40 -0
- polos/llm/stream.py +183 -0
- polos/middleware/__init__.py +0 -0
- polos/middleware/guardrail.py +148 -0
- polos/middleware/guardrail_executor.py +253 -0
- polos/middleware/hook.py +164 -0
- polos/middleware/hook_executor.py +104 -0
- polos/runtime/__init__.py +0 -0
- polos/runtime/batch.py +87 -0
- polos/runtime/client.py +841 -0
- polos/runtime/queue.py +42 -0
- polos/runtime/worker.py +1365 -0
- polos/runtime/worker_server.py +249 -0
- polos/tools/__init__.py +0 -0
- polos/tools/tool.py +587 -0
- polos/types/__init__.py +23 -0
- polos/types/types.py +116 -0
- polos/utils/__init__.py +27 -0
- polos/utils/agent.py +27 -0
- polos/utils/client_context.py +41 -0
- polos/utils/config.py +12 -0
- polos/utils/output_schema.py +311 -0
- polos/utils/retry.py +47 -0
- polos/utils/serializer.py +167 -0
- polos/utils/tracing.py +27 -0
- polos/utils/worker_singleton.py +40 -0
- polos_sdk-0.1.0.dist-info/METADATA +650 -0
- polos_sdk-0.1.0.dist-info/RECORD +55 -0
- polos_sdk-0.1.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,615 @@
|
|
|
1
|
+
"""Anthropic provider implementation."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from .base import LLMProvider, LLMResponse, register_provider
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@register_provider("anthropic")
|
|
11
|
+
class AnthropicProvider(LLMProvider):
|
|
12
|
+
"""Anthropic provider for LLM calls using the Messages API."""
|
|
13
|
+
|
|
14
|
+
def __init__(self, api_key: str | None = None):
|
|
15
|
+
"""
|
|
16
|
+
Initialize Anthropic provider.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
api_key: Anthropic API key. If not provided, uses ANTHROPIC_API_KEY env var.
|
|
20
|
+
"""
|
|
21
|
+
# Import Anthropic SDK only when this provider is used (lazy loading)
|
|
22
|
+
try:
|
|
23
|
+
from anthropic import AsyncAnthropic
|
|
24
|
+
except ImportError:
|
|
25
|
+
raise ImportError(
|
|
26
|
+
"Anthropic SDK not installed. Install it with: pip install 'polos[anthropic]'"
|
|
27
|
+
) from None
|
|
28
|
+
|
|
29
|
+
self.api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
|
|
30
|
+
if not self.api_key:
|
|
31
|
+
raise ValueError(
|
|
32
|
+
"Anthropic API key not provided. Set ANTHROPIC_API_KEY "
|
|
33
|
+
"environment variable or pass api_key parameter."
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
# Initialize Anthropic async client
|
|
37
|
+
self.client = AsyncAnthropic(api_key=self.api_key)
|
|
38
|
+
|
|
39
|
+
async def generate(
|
|
40
|
+
self,
|
|
41
|
+
messages: list[dict[str, Any]],
|
|
42
|
+
model: str,
|
|
43
|
+
tools: list[dict[str, Any]] | None = None,
|
|
44
|
+
temperature: float | None = None,
|
|
45
|
+
max_tokens: int | None = None,
|
|
46
|
+
top_p: float | None = None,
|
|
47
|
+
agent_config: dict[str, Any] | None = None,
|
|
48
|
+
tool_results: list[dict[str, Any]] | None = None,
|
|
49
|
+
output_schema: dict[str, Any] | None = None,
|
|
50
|
+
output_schema_name: str | None = None,
|
|
51
|
+
**kwargs,
|
|
52
|
+
) -> LLMResponse:
|
|
53
|
+
"""
|
|
54
|
+
Make a request to Anthropic using the Messages API.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
messages: List of message dicts with 'role' and 'content' keys
|
|
58
|
+
model: Model identifier (e.g., "claude-sonnet-4-5-20250929", "claude-3-opus-20240229")
|
|
59
|
+
tools: Optional list of tool schemas for function calling
|
|
60
|
+
temperature: Optional temperature parameter (0-1)
|
|
61
|
+
max_tokens: Required max tokens parameter (or from kwargs)
|
|
62
|
+
agent_config: Optional AgentConfig dict containing system_prompt and other config
|
|
63
|
+
tool_results: Optional list of tool results in OpenAI format to add to messages
|
|
64
|
+
output_schema: Optional structured output schema (not yet supported by Anthropic)
|
|
65
|
+
output_schema_name: Optional schema name (not yet supported by Anthropic)
|
|
66
|
+
**kwargs: Additional Anthropic-specific parameters
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
LLMResponse with content, usage, and tool_calls
|
|
70
|
+
"""
|
|
71
|
+
# Prepare messages - copy to avoid mutating input
|
|
72
|
+
processed_messages = messages.copy() if messages else []
|
|
73
|
+
|
|
74
|
+
# Add system prompt from agent_config if present
|
|
75
|
+
if agent_config and agent_config.get("system_prompt"):
|
|
76
|
+
# Anthropic uses "system" parameter, not in messages
|
|
77
|
+
system_prompt = agent_config.get("system_prompt")
|
|
78
|
+
else:
|
|
79
|
+
system_prompt = None
|
|
80
|
+
|
|
81
|
+
# Convert tool_results from OpenAI format to Anthropic format and add to messages
|
|
82
|
+
if tool_results:
|
|
83
|
+
# Find the last assistant message or create a new one
|
|
84
|
+
# Tool results should be added as a new message with role "user"
|
|
85
|
+
# containing tool_result blocks
|
|
86
|
+
tool_result_blocks = []
|
|
87
|
+
for tool_result in tool_results:
|
|
88
|
+
if tool_result.get("type") == "function_call_output":
|
|
89
|
+
call_id = tool_result.get("call_id")
|
|
90
|
+
output = tool_result.get("output")
|
|
91
|
+
# Convert to Anthropic format
|
|
92
|
+
# content can be string or array, we'll use string for now
|
|
93
|
+
tool_result_block = {
|
|
94
|
+
"type": "tool_result",
|
|
95
|
+
"tool_use_id": call_id,
|
|
96
|
+
"content": output if isinstance(output, str) else json.dumps(output),
|
|
97
|
+
"is_error": False,
|
|
98
|
+
}
|
|
99
|
+
tool_result_blocks.append(tool_result_block)
|
|
100
|
+
|
|
101
|
+
if tool_result_blocks:
|
|
102
|
+
# Add tool results as a new user message
|
|
103
|
+
processed_messages.append({"role": "user", "content": tool_result_blocks})
|
|
104
|
+
|
|
105
|
+
# Add structured output instructions to system prompt if output_schema is provided
|
|
106
|
+
if output_schema:
|
|
107
|
+
schema_json = json.dumps(output_schema, indent=2)
|
|
108
|
+
structured_output_instruction = (
|
|
109
|
+
f"\n\nIMPORTANT: You must structure your response as valid JSON "
|
|
110
|
+
f"that strictly conforms to this schema:\n\n{schema_json}\n\n"
|
|
111
|
+
f"Return ONLY valid JSON that matches this schema. Do not include "
|
|
112
|
+
f"any text, explanation, markdown code fences (```json or ```), or "
|
|
113
|
+
f"formatting outside of the JSON structure. Return only the raw JSON "
|
|
114
|
+
f"without any markdown formatting."
|
|
115
|
+
)
|
|
116
|
+
if system_prompt:
|
|
117
|
+
system_prompt = system_prompt + structured_output_instruction
|
|
118
|
+
else:
|
|
119
|
+
system_prompt = structured_output_instruction
|
|
120
|
+
|
|
121
|
+
# Prepare request parameters for Messages API
|
|
122
|
+
request_params: dict[str, Any] = {
|
|
123
|
+
"model": model,
|
|
124
|
+
"messages": processed_messages,
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
# Add system prompt if present
|
|
128
|
+
if system_prompt:
|
|
129
|
+
request_params["system"] = system_prompt
|
|
130
|
+
|
|
131
|
+
# max_tokens is required for Anthropic
|
|
132
|
+
if max_tokens is not None:
|
|
133
|
+
request_params["max_tokens"] = max_tokens
|
|
134
|
+
elif "max_tokens" not in kwargs:
|
|
135
|
+
# Default to 4K if not provided
|
|
136
|
+
request_params["max_tokens"] = 4096
|
|
137
|
+
|
|
138
|
+
if temperature is not None:
|
|
139
|
+
request_params["temperature"] = temperature
|
|
140
|
+
if top_p is not None:
|
|
141
|
+
request_params["top_p"] = top_p
|
|
142
|
+
|
|
143
|
+
if tools:
|
|
144
|
+
# Anthropic expects tools in this format:
|
|
145
|
+
# [{"name": "...", "description": "...", "input_schema": {...}}]
|
|
146
|
+
validated_tools = _validate_tools(tools)
|
|
147
|
+
if validated_tools:
|
|
148
|
+
request_params["tools"] = validated_tools
|
|
149
|
+
|
|
150
|
+
# Add any additional kwargs
|
|
151
|
+
request_params.update(kwargs)
|
|
152
|
+
try:
|
|
153
|
+
# Use the SDK's Messages API
|
|
154
|
+
response = await self.client.messages.create(**request_params)
|
|
155
|
+
if not response:
|
|
156
|
+
raise RuntimeError("Anthropic API returned no response")
|
|
157
|
+
|
|
158
|
+
# Extract content from response
|
|
159
|
+
# Response.content is a list of content blocks
|
|
160
|
+
content_parts = []
|
|
161
|
+
raw_output = []
|
|
162
|
+
for content_block in response.content:
|
|
163
|
+
raw_output.append(
|
|
164
|
+
content_block.model_dump(exclude_none=True, mode="json")
|
|
165
|
+
if hasattr(content_block, "model_dump")
|
|
166
|
+
else json.dumps(content_block)
|
|
167
|
+
)
|
|
168
|
+
if content_block.type == "text":
|
|
169
|
+
content_parts.append(content_block.text)
|
|
170
|
+
|
|
171
|
+
content = "".join(content_parts)
|
|
172
|
+
|
|
173
|
+
# Extract tool calls from response
|
|
174
|
+
tool_calls = []
|
|
175
|
+
for content_block in response.content:
|
|
176
|
+
if content_block.type == "tool_use":
|
|
177
|
+
# Anthropic returns tool_use blocks with input as a dict
|
|
178
|
+
input_data = content_block.input
|
|
179
|
+
if hasattr(input_data, "model_dump"):
|
|
180
|
+
# Pydantic model, convert to dict then JSON
|
|
181
|
+
arguments = json.dumps(input_data.model_dump(mode="json"))
|
|
182
|
+
elif isinstance(input_data, dict):
|
|
183
|
+
# Already a dict, convert to JSON string
|
|
184
|
+
arguments = json.dumps(input_data)
|
|
185
|
+
else:
|
|
186
|
+
# Fallback to string representation
|
|
187
|
+
arguments = str(input_data)
|
|
188
|
+
|
|
189
|
+
tool_call_data = {
|
|
190
|
+
"call_id": content_block.id,
|
|
191
|
+
"id": "",
|
|
192
|
+
"type": "function",
|
|
193
|
+
"function": {"name": content_block.name, "arguments": arguments},
|
|
194
|
+
}
|
|
195
|
+
tool_calls.append(tool_call_data)
|
|
196
|
+
|
|
197
|
+
# Extract usage information
|
|
198
|
+
usage_data = response.usage
|
|
199
|
+
usage = {
|
|
200
|
+
"input_tokens": usage_data.input_tokens if usage_data else 0,
|
|
201
|
+
"output_tokens": usage_data.output_tokens if usage_data else 0,
|
|
202
|
+
"total_tokens": (usage_data.input_tokens + usage_data.output_tokens)
|
|
203
|
+
if usage_data
|
|
204
|
+
else 0,
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
# Extract model and stop_reason from response
|
|
208
|
+
response_model = getattr(response, "model", None) or model
|
|
209
|
+
response_stop_reason = getattr(response, "stop_reason", None)
|
|
210
|
+
|
|
211
|
+
processed_messages.append(
|
|
212
|
+
{
|
|
213
|
+
"role": "assistant",
|
|
214
|
+
"content": raw_output,
|
|
215
|
+
}
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
return LLMResponse(
|
|
219
|
+
content=content,
|
|
220
|
+
usage=usage,
|
|
221
|
+
tool_calls=tool_calls,
|
|
222
|
+
raw_output=processed_messages,
|
|
223
|
+
model=response_model,
|
|
224
|
+
stop_reason=response_stop_reason,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
except Exception as e:
|
|
228
|
+
# Re-raise with more context
|
|
229
|
+
raise RuntimeError(f"Anthropic Messages API call failed: {str(e)}") from e
|
|
230
|
+
|
|
231
|
+
async def stream(
|
|
232
|
+
self,
|
|
233
|
+
messages: list[dict[str, Any]],
|
|
234
|
+
model: str,
|
|
235
|
+
tools: list[dict[str, Any]] | None = None,
|
|
236
|
+
temperature: float | None = None,
|
|
237
|
+
max_tokens: int | None = None,
|
|
238
|
+
top_p: float | None = None,
|
|
239
|
+
agent_config: dict[str, Any] | None = None,
|
|
240
|
+
tool_results: list[dict[str, Any]] | None = None,
|
|
241
|
+
output_schema: dict[str, Any] | None = None,
|
|
242
|
+
output_schema_name: str | None = None,
|
|
243
|
+
**kwargs,
|
|
244
|
+
):
|
|
245
|
+
"""
|
|
246
|
+
Stream responses from Anthropic using the Messages API.
|
|
247
|
+
|
|
248
|
+
Args:
|
|
249
|
+
messages: List of message dicts with 'role' and 'content' keys
|
|
250
|
+
model: Model identifier (e.g., "claude-sonnet-4-5-20250929")
|
|
251
|
+
tools: Optional list of tool schemas for function calling
|
|
252
|
+
temperature: Optional temperature parameter (0-1)
|
|
253
|
+
max_tokens: Required max tokens parameter (or from kwargs)
|
|
254
|
+
agent_config: Optional AgentConfig dict containing system_prompt and other config
|
|
255
|
+
tool_results: Optional list of tool results in OpenAI format to add to messages
|
|
256
|
+
output_schema: Optional structured output schema (not yet supported)
|
|
257
|
+
output_schema_name: Optional schema name (not yet supported)
|
|
258
|
+
**kwargs: Additional Anthropic-specific parameters
|
|
259
|
+
|
|
260
|
+
Yields:
|
|
261
|
+
Dict with event information:
|
|
262
|
+
- type: "text_delta", "tool_call", "done", "error"
|
|
263
|
+
- data: Event-specific data
|
|
264
|
+
"""
|
|
265
|
+
# Prepare messages - copy to avoid mutating input
|
|
266
|
+
processed_messages = messages.copy() if messages else []
|
|
267
|
+
|
|
268
|
+
# Add system prompt from agent_config if present
|
|
269
|
+
if agent_config and agent_config.get("system_prompt"):
|
|
270
|
+
# Anthropic uses "system" parameter, not in messages
|
|
271
|
+
system_prompt = agent_config.get("system_prompt")
|
|
272
|
+
else:
|
|
273
|
+
system_prompt = None
|
|
274
|
+
|
|
275
|
+
# Convert tool_results from OpenAI format to Anthropic format and add to messages
|
|
276
|
+
if tool_results:
|
|
277
|
+
# Find the last assistant message or create a new one
|
|
278
|
+
# Tool results should be added as a new message with role "user"
|
|
279
|
+
# containing tool_result blocks
|
|
280
|
+
tool_result_blocks = []
|
|
281
|
+
for tool_result in tool_results:
|
|
282
|
+
if tool_result.get("type") == "function_call_output":
|
|
283
|
+
call_id = tool_result.get("call_id")
|
|
284
|
+
output = tool_result.get("output")
|
|
285
|
+
# Convert to Anthropic format
|
|
286
|
+
# content can be string or array, we'll use string for now
|
|
287
|
+
tool_result_block = {
|
|
288
|
+
"type": "tool_result",
|
|
289
|
+
"tool_use_id": call_id,
|
|
290
|
+
"content": output if isinstance(output, str) else json.dumps(output),
|
|
291
|
+
"is_error": False,
|
|
292
|
+
}
|
|
293
|
+
tool_result_blocks.append(tool_result_block)
|
|
294
|
+
|
|
295
|
+
if tool_result_blocks:
|
|
296
|
+
# Add tool results as a new user message
|
|
297
|
+
processed_messages.append({"role": "user", "content": tool_result_blocks})
|
|
298
|
+
|
|
299
|
+
# Add structured output instructions to system prompt if output_schema is provided
|
|
300
|
+
if output_schema:
|
|
301
|
+
schema_json = json.dumps(output_schema, indent=2)
|
|
302
|
+
structured_output_instruction = (
|
|
303
|
+
f"\n\nIMPORTANT: You must structure your response as valid JSON "
|
|
304
|
+
f"that strictly conforms to this schema:\n\n{schema_json}\n\n"
|
|
305
|
+
f"Return ONLY valid JSON that matches this schema. Do not include "
|
|
306
|
+
f"any text, explanation, markdown code fences (```json or ```), or "
|
|
307
|
+
f"formatting outside of the JSON structure. Return only the raw JSON "
|
|
308
|
+
f"without any markdown formatting."
|
|
309
|
+
)
|
|
310
|
+
if system_prompt:
|
|
311
|
+
system_prompt = system_prompt + structured_output_instruction
|
|
312
|
+
else:
|
|
313
|
+
system_prompt = structured_output_instruction
|
|
314
|
+
|
|
315
|
+
# Prepare request parameters for Messages API
|
|
316
|
+
request_params: dict[str, Any] = {
|
|
317
|
+
"model": model,
|
|
318
|
+
"messages": processed_messages,
|
|
319
|
+
"stream": True, # Enable streaming
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
# Add system prompt if present
|
|
323
|
+
if system_prompt:
|
|
324
|
+
request_params["system"] = system_prompt
|
|
325
|
+
|
|
326
|
+
# max_tokens is required for Anthropic
|
|
327
|
+
if max_tokens is not None:
|
|
328
|
+
request_params["max_tokens"] = max_tokens
|
|
329
|
+
elif "max_tokens" not in kwargs:
|
|
330
|
+
# Default to 64K if not provided for streaming
|
|
331
|
+
request_params["max_tokens"] = 64000
|
|
332
|
+
|
|
333
|
+
if temperature is not None:
|
|
334
|
+
request_params["temperature"] = temperature
|
|
335
|
+
if top_p is not None:
|
|
336
|
+
request_params["top_p"] = top_p
|
|
337
|
+
|
|
338
|
+
if tools:
|
|
339
|
+
# Anthropic expects tools in this format:
|
|
340
|
+
# [{"name": "...", "description": "...", "input_schema": {...}}]
|
|
341
|
+
validated_tools = _validate_tools(tools)
|
|
342
|
+
if validated_tools:
|
|
343
|
+
request_params["tools"] = validated_tools
|
|
344
|
+
|
|
345
|
+
# Add any additional kwargs
|
|
346
|
+
request_params.update(kwargs)
|
|
347
|
+
try:
|
|
348
|
+
# Use the SDK's Messages API with streaming
|
|
349
|
+
stream = await self.client.messages.create(**request_params)
|
|
350
|
+
|
|
351
|
+
tool_calls = []
|
|
352
|
+
usage = {
|
|
353
|
+
"input_tokens": 0,
|
|
354
|
+
"output_tokens": 0,
|
|
355
|
+
"total_tokens": 0,
|
|
356
|
+
}
|
|
357
|
+
stop_reason = None
|
|
358
|
+
response_model = model
|
|
359
|
+
|
|
360
|
+
# Track current tool_use block state
|
|
361
|
+
current_tool_use = None # {id, name, partial_json}
|
|
362
|
+
accumulated_partial_json = ""
|
|
363
|
+
accumulated_content_blocks = []
|
|
364
|
+
accumulated_text = ""
|
|
365
|
+
accumulated_thinking = ""
|
|
366
|
+
accumulated_signature = ""
|
|
367
|
+
|
|
368
|
+
async for event in stream:
|
|
369
|
+
# Event types: message_start, content_block_start, content_block_delta,
|
|
370
|
+
# content_block_stop, message_delta, message_stop
|
|
371
|
+
event_type = event.type
|
|
372
|
+
event = (
|
|
373
|
+
event.model_dump(mode="json")
|
|
374
|
+
if hasattr(event, "model_dump")
|
|
375
|
+
else json.dumps(event)
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
if event_type == "content_block_start":
|
|
379
|
+
# Content block starting - could be text or tool_use
|
|
380
|
+
if event.get("content_block"):
|
|
381
|
+
content_block = event.get("content_block")
|
|
382
|
+
if content_block.get("type") == "text":
|
|
383
|
+
content_text = content_block.get("text")
|
|
384
|
+
if content_text:
|
|
385
|
+
accumulated_text += content_text
|
|
386
|
+
yield {
|
|
387
|
+
"type": "text_delta",
|
|
388
|
+
"data": {
|
|
389
|
+
"content": content_text,
|
|
390
|
+
},
|
|
391
|
+
}
|
|
392
|
+
|
|
393
|
+
elif content_block.get("type") == "tool_use":
|
|
394
|
+
# Start tracking a tool_use block
|
|
395
|
+
current_tool_use = {
|
|
396
|
+
"id": content_block.get("id"),
|
|
397
|
+
"name": content_block.get("name"),
|
|
398
|
+
"partial_json": "",
|
|
399
|
+
}
|
|
400
|
+
accumulated_partial_json = ""
|
|
401
|
+
|
|
402
|
+
elif content_block.get("type") == "thinking":
|
|
403
|
+
accumulated_thinking = content_block.get("content", "")
|
|
404
|
+
accumulated_signature = content_block.get("signature", "")
|
|
405
|
+
|
|
406
|
+
elif event_type == "content_block_delta":
|
|
407
|
+
# Content delta - could be text_delta or input_json_delta
|
|
408
|
+
if event.get("delta"):
|
|
409
|
+
delta = event.get("delta")
|
|
410
|
+
if delta.get("type") == "text_delta":
|
|
411
|
+
# Text delta - incremental text chunk
|
|
412
|
+
delta_text = delta.get("text")
|
|
413
|
+
if delta_text:
|
|
414
|
+
accumulated_text += delta_text
|
|
415
|
+
yield {
|
|
416
|
+
"type": "text_delta",
|
|
417
|
+
"data": {
|
|
418
|
+
"content": delta_text,
|
|
419
|
+
},
|
|
420
|
+
}
|
|
421
|
+
elif delta.get("type") == "input_json_delta" and delta.get("partial_json"):
|
|
422
|
+
# Tool use input JSON delta - accumulate partial_json
|
|
423
|
+
if delta.get("partial_json"):
|
|
424
|
+
accumulated_partial_json += delta.get("partial_json")
|
|
425
|
+
if current_tool_use:
|
|
426
|
+
current_tool_use["partial_json"] = accumulated_partial_json
|
|
427
|
+
|
|
428
|
+
elif delta.get("type") == "thinking_delta":
|
|
429
|
+
accumulated_thinking += delta.get("thinking")
|
|
430
|
+
|
|
431
|
+
elif delta.get("type") == "signature_delta":
|
|
432
|
+
accumulated_signature += delta.get("signature")
|
|
433
|
+
|
|
434
|
+
elif event_type == "content_block_stop":
|
|
435
|
+
# Content block complete
|
|
436
|
+
if current_tool_use and accumulated_partial_json:
|
|
437
|
+
# Parse the accumulated JSON
|
|
438
|
+
try:
|
|
439
|
+
arguments_json = json.loads(accumulated_partial_json)
|
|
440
|
+
arguments = json.dumps(arguments_json)
|
|
441
|
+
except json.JSONDecodeError:
|
|
442
|
+
raise RuntimeError(
|
|
443
|
+
f"Failed to parse tool use input JSON: {accumulated_partial_json}"
|
|
444
|
+
) from None
|
|
445
|
+
|
|
446
|
+
tool_call_data = {
|
|
447
|
+
"call_id": current_tool_use.get("id"),
|
|
448
|
+
"id": "",
|
|
449
|
+
"type": "function",
|
|
450
|
+
"function": {
|
|
451
|
+
"name": current_tool_use.get("name"),
|
|
452
|
+
"arguments": arguments,
|
|
453
|
+
},
|
|
454
|
+
}
|
|
455
|
+
tool_calls.append(tool_call_data)
|
|
456
|
+
yield {
|
|
457
|
+
"type": "tool_call",
|
|
458
|
+
"data": {
|
|
459
|
+
"tool_call": tool_call_data,
|
|
460
|
+
},
|
|
461
|
+
}
|
|
462
|
+
|
|
463
|
+
accumulated_content_blocks.append(
|
|
464
|
+
{
|
|
465
|
+
"type": "tool_use",
|
|
466
|
+
"id": current_tool_use.get("id"),
|
|
467
|
+
"name": current_tool_use.get("name"),
|
|
468
|
+
"input": json.loads(accumulated_partial_json),
|
|
469
|
+
}
|
|
470
|
+
)
|
|
471
|
+
|
|
472
|
+
# Reset tool_use tracking
|
|
473
|
+
current_tool_use = None
|
|
474
|
+
accumulated_partial_json = ""
|
|
475
|
+
|
|
476
|
+
elif accumulated_text:
|
|
477
|
+
accumulated_content_blocks.append(
|
|
478
|
+
{
|
|
479
|
+
"type": "text",
|
|
480
|
+
"text": accumulated_text,
|
|
481
|
+
}
|
|
482
|
+
)
|
|
483
|
+
accumulated_text = ""
|
|
484
|
+
|
|
485
|
+
elif accumulated_thinking:
|
|
486
|
+
accumulated_content_blocks.append(
|
|
487
|
+
{
|
|
488
|
+
"type": "thinking",
|
|
489
|
+
"thinking": accumulated_thinking,
|
|
490
|
+
"signature": accumulated_signature,
|
|
491
|
+
}
|
|
492
|
+
)
|
|
493
|
+
accumulated_thinking = ""
|
|
494
|
+
accumulated_signature = ""
|
|
495
|
+
|
|
496
|
+
elif event_type in ["message_start", "message_delta"]:
|
|
497
|
+
# Message delta - contains stop_reason and usage
|
|
498
|
+
message = None
|
|
499
|
+
if event_type == "message_start":
|
|
500
|
+
message = event.get("message")
|
|
501
|
+
else:
|
|
502
|
+
message = event.get("delta")
|
|
503
|
+
|
|
504
|
+
if message:
|
|
505
|
+
response_model = message.get("model") or response_model # Update if present
|
|
506
|
+
stop_reason = message.get("stop_reason") or stop_reason # Update if present
|
|
507
|
+
|
|
508
|
+
if message.get("usage"):
|
|
509
|
+
usage_data = message.get("usage")
|
|
510
|
+
if usage_data:
|
|
511
|
+
if usage_data.get("input_tokens"):
|
|
512
|
+
usage["input_tokens"] = usage_data.get("input_tokens")
|
|
513
|
+
if usage_data.get("output_tokens"):
|
|
514
|
+
usage["output_tokens"] = usage_data.get("output_tokens")
|
|
515
|
+
|
|
516
|
+
elif event_type == "message_stop":
|
|
517
|
+
# Stream complete - final event
|
|
518
|
+
usage["total_tokens"] = usage.get("input_tokens", 0) + usage.get(
|
|
519
|
+
"output_tokens", 0
|
|
520
|
+
)
|
|
521
|
+
processed_messages.append(
|
|
522
|
+
{
|
|
523
|
+
"role": "assistant",
|
|
524
|
+
"content": accumulated_content_blocks,
|
|
525
|
+
}
|
|
526
|
+
)
|
|
527
|
+
yield {
|
|
528
|
+
"type": "done",
|
|
529
|
+
"data": {
|
|
530
|
+
"usage": usage,
|
|
531
|
+
"model": response_model,
|
|
532
|
+
"stop_reason": stop_reason,
|
|
533
|
+
"raw_output": processed_messages,
|
|
534
|
+
},
|
|
535
|
+
}
|
|
536
|
+
|
|
537
|
+
elif event_type == "error":
|
|
538
|
+
# Stream error
|
|
539
|
+
error_msg = "Stream failed"
|
|
540
|
+
if event.get("error"):
|
|
541
|
+
error_obj = event.get("error")
|
|
542
|
+
if error_obj.get("message"):
|
|
543
|
+
error_msg = error_obj.get("message")
|
|
544
|
+
if error_obj.get("type"):
|
|
545
|
+
error_msg = f"{error_obj.get('type')}: {error_msg}"
|
|
546
|
+
|
|
547
|
+
yield {
|
|
548
|
+
"type": "error",
|
|
549
|
+
"data": {
|
|
550
|
+
"error": error_msg,
|
|
551
|
+
},
|
|
552
|
+
}
|
|
553
|
+
break
|
|
554
|
+
|
|
555
|
+
except Exception as e:
|
|
556
|
+
# Yield error event and re-raise
|
|
557
|
+
yield {
|
|
558
|
+
"type": "error",
|
|
559
|
+
"data": {
|
|
560
|
+
"error": str(e),
|
|
561
|
+
},
|
|
562
|
+
}
|
|
563
|
+
raise RuntimeError(f"Anthropic Messages API streaming failed: {str(e)}") from e
|
|
564
|
+
|
|
565
|
+
|
|
566
|
+
def _validate_tools(tools: list[dict[str, Any]] | None) -> list[dict[str, Any]]:
|
|
567
|
+
"""
|
|
568
|
+
Validate and convert tools to Anthropic format.
|
|
569
|
+
|
|
570
|
+
Anthropic expects:
|
|
571
|
+
[{
|
|
572
|
+
"name": "...",
|
|
573
|
+
"description": "...",
|
|
574
|
+
"input_schema": {
|
|
575
|
+
"type": "object",
|
|
576
|
+
"properties": {...},
|
|
577
|
+
"required": [...]
|
|
578
|
+
}
|
|
579
|
+
}]
|
|
580
|
+
"""
|
|
581
|
+
validated_tools = []
|
|
582
|
+
for tool in tools:
|
|
583
|
+
# Convert OpenAI-style tool format to Anthropic format
|
|
584
|
+
if tool.get("type") == "function" or "type" not in tool:
|
|
585
|
+
# Extract function name, description, and parameters
|
|
586
|
+
function_data = tool.get("function", tool)
|
|
587
|
+
name = function_data.get("name") or tool.get("name")
|
|
588
|
+
description = function_data.get("description") or tool.get("description", "")
|
|
589
|
+
parameters = (
|
|
590
|
+
function_data.get("parameters")
|
|
591
|
+
or tool.get("parameters")
|
|
592
|
+
or tool.get("input_schema")
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
if name and parameters:
|
|
596
|
+
# Convert parameters to input_schema (Anthropic format)
|
|
597
|
+
anthropic_tool = {
|
|
598
|
+
"name": name,
|
|
599
|
+
"description": description,
|
|
600
|
+
"input_schema": parameters, # Anthropic uses input_schema instead of parameters
|
|
601
|
+
}
|
|
602
|
+
validated_tools.append(anthropic_tool)
|
|
603
|
+
else:
|
|
604
|
+
# Missing name or parameters, skip
|
|
605
|
+
import warnings
|
|
606
|
+
|
|
607
|
+
warnings.warn(
|
|
608
|
+
f"Skipping invalid tool (missing name or parameters): {tool}",
|
|
609
|
+
stacklevel=2,
|
|
610
|
+
)
|
|
611
|
+
continue
|
|
612
|
+
else:
|
|
613
|
+
validated_tools.append(anthropic_tool)
|
|
614
|
+
|
|
615
|
+
return validated_tools
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""Azure OpenAI provider - routes to OpenAI provider with custom base_url."""
|
|
2
|
+
|
|
3
|
+
from .base import register_provider
|
|
4
|
+
from .openai import OpenAIProvider
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@register_provider("azure")
|
|
8
|
+
class AzureProvider(OpenAIProvider):
|
|
9
|
+
"""Azure OpenAI provider using OpenAI provider with Azure base URL."""
|
|
10
|
+
|
|
11
|
+
def __init__(self, api_key=None, base_url=None):
|
|
12
|
+
"""
|
|
13
|
+
Initialize Azure OpenAI provider.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
api_key: Azure OpenAI API key. If not provided, uses AZURE_OPENAI_API_KEY env var.
|
|
17
|
+
base_url: Azure OpenAI endpoint base URL (required).
|
|
18
|
+
Format: https://<resource-name>.openai.azure.com/
|
|
19
|
+
"""
|
|
20
|
+
import os
|
|
21
|
+
|
|
22
|
+
azure_api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY")
|
|
23
|
+
if not azure_api_key:
|
|
24
|
+
raise ValueError(
|
|
25
|
+
"Azure OpenAI API key not provided. Set AZURE_OPENAI_API_KEY "
|
|
26
|
+
"environment variable or pass api_key parameter."
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
if not base_url:
|
|
30
|
+
raise ValueError(
|
|
31
|
+
"base_url is required for Azure OpenAI provider. Provide the Azure endpoint URL."
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
from openai import AsyncOpenAI # noqa: F401
|
|
36
|
+
except ImportError:
|
|
37
|
+
raise ImportError(
|
|
38
|
+
"OpenAI SDK not installed. Install it with: pip install 'polos[openai]'"
|
|
39
|
+
) from None
|
|
40
|
+
|
|
41
|
+
# Initialize with Azure's base URL
|
|
42
|
+
super().__init__(api_key=azure_api_key, base_url=base_url)
|