openai-agents 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openai-agents might be problematic. Click here for more details.
- agents/_run_impl.py +2 -0
- agents/extensions/memory/__init__.py +37 -1
- agents/extensions/memory/dapr_session.py +423 -0
- agents/extensions/memory/sqlalchemy_session.py +13 -0
- agents/extensions/models/litellm_model.py +54 -15
- agents/items.py +3 -0
- agents/lifecycle.py +4 -4
- agents/models/chatcmpl_converter.py +8 -4
- agents/models/chatcmpl_stream_handler.py +13 -7
- agents/realtime/config.py +3 -0
- agents/realtime/events.py +7 -0
- agents/realtime/model.py +7 -0
- agents/realtime/model_inputs.py +3 -0
- agents/realtime/openai_realtime.py +69 -33
- agents/realtime/session.py +62 -9
- agents/run.py +63 -1
- agents/stream_events.py +1 -0
- agents/tool.py +15 -1
- agents/usage.py +65 -0
- agents/voice/models/openai_stt.py +2 -1
- {openai_agents-0.4.0.dist-info → openai_agents-0.5.0.dist-info}/METADATA +14 -4
- {openai_agents-0.4.0.dist-info → openai_agents-0.5.0.dist-info}/RECORD +24 -23
- {openai_agents-0.4.0.dist-info → openai_agents-0.5.0.dist-info}/WHEEL +0 -0
- {openai_agents-0.4.0.dist-info → openai_agents-0.5.0.dist-info}/licenses/LICENSE +0 -0
agents/_run_impl.py
CHANGED
|
@@ -1172,6 +1172,8 @@ class RunImpl:
|
|
|
1172
1172
|
event = RunItemStreamEvent(item=item, name="reasoning_item_created")
|
|
1173
1173
|
elif isinstance(item, MCPApprovalRequestItem):
|
|
1174
1174
|
event = RunItemStreamEvent(item=item, name="mcp_approval_requested")
|
|
1175
|
+
elif isinstance(item, MCPApprovalResponseItem):
|
|
1176
|
+
event = RunItemStreamEvent(item=item, name="mcp_approval_response")
|
|
1175
1177
|
elif isinstance(item, MCPListToolsItem):
|
|
1176
1178
|
event = RunItemStreamEvent(item=item, name="mcp_list_tools")
|
|
1177
1179
|
|
|
@@ -11,10 +11,13 @@ from __future__ import annotations
|
|
|
11
11
|
from typing import Any
|
|
12
12
|
|
|
13
13
|
__all__: list[str] = [
|
|
14
|
+
"AdvancedSQLiteSession",
|
|
15
|
+
"DAPR_CONSISTENCY_EVENTUAL",
|
|
16
|
+
"DAPR_CONSISTENCY_STRONG",
|
|
17
|
+
"DaprSession",
|
|
14
18
|
"EncryptedSession",
|
|
15
19
|
"RedisSession",
|
|
16
20
|
"SQLAlchemySession",
|
|
17
|
-
"AdvancedSQLiteSession",
|
|
18
21
|
]
|
|
19
22
|
|
|
20
23
|
|
|
@@ -60,4 +63,37 @@ def __getattr__(name: str) -> Any:
|
|
|
60
63
|
except ModuleNotFoundError as e:
|
|
61
64
|
raise ImportError(f"Failed to import AdvancedSQLiteSession: {e}") from e
|
|
62
65
|
|
|
66
|
+
if name == "DaprSession":
|
|
67
|
+
try:
|
|
68
|
+
from .dapr_session import DaprSession # noqa: F401
|
|
69
|
+
|
|
70
|
+
return DaprSession
|
|
71
|
+
except ModuleNotFoundError as e:
|
|
72
|
+
raise ImportError(
|
|
73
|
+
"DaprSession requires the 'dapr' extra. "
|
|
74
|
+
"Install it with: pip install openai-agents[dapr]"
|
|
75
|
+
) from e
|
|
76
|
+
|
|
77
|
+
if name == "DAPR_CONSISTENCY_EVENTUAL":
|
|
78
|
+
try:
|
|
79
|
+
from .dapr_session import DAPR_CONSISTENCY_EVENTUAL # noqa: F401
|
|
80
|
+
|
|
81
|
+
return DAPR_CONSISTENCY_EVENTUAL
|
|
82
|
+
except ModuleNotFoundError as e:
|
|
83
|
+
raise ImportError(
|
|
84
|
+
"DAPR_CONSISTENCY_EVENTUAL requires the 'dapr' extra. "
|
|
85
|
+
"Install it with: pip install openai-agents[dapr]"
|
|
86
|
+
) from e
|
|
87
|
+
|
|
88
|
+
if name == "DAPR_CONSISTENCY_STRONG":
|
|
89
|
+
try:
|
|
90
|
+
from .dapr_session import DAPR_CONSISTENCY_STRONG # noqa: F401
|
|
91
|
+
|
|
92
|
+
return DAPR_CONSISTENCY_STRONG
|
|
93
|
+
except ModuleNotFoundError as e:
|
|
94
|
+
raise ImportError(
|
|
95
|
+
"DAPR_CONSISTENCY_STRONG requires the 'dapr' extra. "
|
|
96
|
+
"Install it with: pip install openai-agents[dapr]"
|
|
97
|
+
) from e
|
|
98
|
+
|
|
63
99
|
raise AttributeError(f"module {__name__} has no attribute {name}")
|
|
@@ -0,0 +1,423 @@
|
|
|
1
|
+
"""Dapr State Store-powered Session backend.
|
|
2
|
+
|
|
3
|
+
Usage::
|
|
4
|
+
|
|
5
|
+
from agents.extensions.memory import DaprSession
|
|
6
|
+
|
|
7
|
+
# Create from Dapr sidecar address
|
|
8
|
+
session = DaprSession.from_address(
|
|
9
|
+
session_id="user-123",
|
|
10
|
+
state_store_name="statestore",
|
|
11
|
+
dapr_address="localhost:50001",
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
# Or pass an existing Dapr client that your application already manages
|
|
15
|
+
session = DaprSession(
|
|
16
|
+
session_id="user-123",
|
|
17
|
+
state_store_name="statestore",
|
|
18
|
+
dapr_client=my_dapr_client,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
await Runner.run(agent, "Hello", session=session)
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
from __future__ import annotations
|
|
25
|
+
|
|
26
|
+
import asyncio
|
|
27
|
+
import json
|
|
28
|
+
import random
|
|
29
|
+
import time
|
|
30
|
+
from typing import Any, Final, Literal
|
|
31
|
+
|
|
32
|
+
try:
|
|
33
|
+
from dapr.aio.clients import DaprClient
|
|
34
|
+
from dapr.clients.grpc._state import Concurrency, Consistency, StateOptions
|
|
35
|
+
except ImportError as e:
|
|
36
|
+
raise ImportError(
|
|
37
|
+
"DaprSession requires the 'dapr' package. Install it with: pip install dapr"
|
|
38
|
+
) from e
|
|
39
|
+
|
|
40
|
+
from ...items import TResponseInputItem
|
|
41
|
+
from ...logger import logger
|
|
42
|
+
from ...memory.session import SessionABC
|
|
43
|
+
|
|
44
|
+
# Type alias for consistency levels
|
|
45
|
+
ConsistencyLevel = Literal["eventual", "strong"]
|
|
46
|
+
|
|
47
|
+
# Consistency level constants
|
|
48
|
+
DAPR_CONSISTENCY_EVENTUAL: ConsistencyLevel = "eventual"
|
|
49
|
+
DAPR_CONSISTENCY_STRONG: ConsistencyLevel = "strong"
|
|
50
|
+
|
|
51
|
+
_MAX_WRITE_ATTEMPTS: Final[int] = 5
|
|
52
|
+
_RETRY_BASE_DELAY_SECONDS: Final[float] = 0.05
|
|
53
|
+
_RETRY_MAX_DELAY_SECONDS: Final[float] = 1.0
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class DaprSession(SessionABC):
|
|
57
|
+
"""Dapr State Store implementation of :pyclass:`agents.memory.session.Session`."""
|
|
58
|
+
|
|
59
|
+
def __init__(
|
|
60
|
+
self,
|
|
61
|
+
session_id: str,
|
|
62
|
+
*,
|
|
63
|
+
state_store_name: str,
|
|
64
|
+
dapr_client: DaprClient,
|
|
65
|
+
ttl: int | None = None,
|
|
66
|
+
consistency: ConsistencyLevel = DAPR_CONSISTENCY_EVENTUAL,
|
|
67
|
+
):
|
|
68
|
+
"""Initializes a new DaprSession.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
session_id (str): Unique identifier for the conversation.
|
|
72
|
+
state_store_name (str): Name of the Dapr state store component.
|
|
73
|
+
dapr_client (DaprClient): A pre-configured Dapr client.
|
|
74
|
+
ttl (int | None, optional): Time-to-live in seconds for session data.
|
|
75
|
+
If None, data persists indefinitely. Note that TTL support depends on
|
|
76
|
+
the underlying state store implementation. Defaults to None.
|
|
77
|
+
consistency (ConsistencyLevel, optional): Consistency level for state operations.
|
|
78
|
+
Use DAPR_CONSISTENCY_EVENTUAL or DAPR_CONSISTENCY_STRONG constants.
|
|
79
|
+
Defaults to DAPR_CONSISTENCY_EVENTUAL.
|
|
80
|
+
"""
|
|
81
|
+
self.session_id = session_id
|
|
82
|
+
self._dapr_client = dapr_client
|
|
83
|
+
self._state_store_name = state_store_name
|
|
84
|
+
self._ttl = ttl
|
|
85
|
+
self._consistency = consistency
|
|
86
|
+
self._lock = asyncio.Lock()
|
|
87
|
+
self._owns_client = False # Track if we own the Dapr client
|
|
88
|
+
|
|
89
|
+
# State keys
|
|
90
|
+
self._messages_key = f"{self.session_id}:messages"
|
|
91
|
+
self._metadata_key = f"{self.session_id}:metadata"
|
|
92
|
+
|
|
93
|
+
@classmethod
|
|
94
|
+
def from_address(
|
|
95
|
+
cls,
|
|
96
|
+
session_id: str,
|
|
97
|
+
*,
|
|
98
|
+
state_store_name: str,
|
|
99
|
+
dapr_address: str = "localhost:50001",
|
|
100
|
+
**kwargs: Any,
|
|
101
|
+
) -> DaprSession:
|
|
102
|
+
"""Create a session from a Dapr sidecar address.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
session_id (str): Conversation ID.
|
|
106
|
+
state_store_name (str): Name of the Dapr state store component.
|
|
107
|
+
dapr_address (str): Dapr sidecar gRPC address. Defaults to "localhost:50001".
|
|
108
|
+
**kwargs: Additional keyword arguments forwarded to the main constructor
|
|
109
|
+
(e.g., ttl, consistency).
|
|
110
|
+
|
|
111
|
+
Returns:
|
|
112
|
+
DaprSession: An instance of DaprSession connected to the specified Dapr sidecar.
|
|
113
|
+
|
|
114
|
+
Note:
|
|
115
|
+
The Dapr Python SDK performs health checks on the HTTP endpoint (default: http://localhost:3500).
|
|
116
|
+
Ensure the Dapr sidecar is started with --dapr-http-port 3500. Alternatively, set one of
|
|
117
|
+
these environment variables: DAPR_HTTP_ENDPOINT (e.g., "http://localhost:3500") or
|
|
118
|
+
DAPR_HTTP_PORT (e.g., "3500") to avoid connection errors.
|
|
119
|
+
"""
|
|
120
|
+
dapr_client = DaprClient(address=dapr_address)
|
|
121
|
+
session = cls(
|
|
122
|
+
session_id, state_store_name=state_store_name, dapr_client=dapr_client, **kwargs
|
|
123
|
+
)
|
|
124
|
+
session._owns_client = True # We created the client, so we own it
|
|
125
|
+
return session
|
|
126
|
+
|
|
127
|
+
def _get_read_metadata(self) -> dict[str, str]:
|
|
128
|
+
"""Get metadata for read operations including consistency.
|
|
129
|
+
|
|
130
|
+
The consistency level is passed through state_metadata as per Dapr's state API.
|
|
131
|
+
"""
|
|
132
|
+
metadata: dict[str, str] = {}
|
|
133
|
+
# Add consistency level to metadata for read operations
|
|
134
|
+
if self._consistency:
|
|
135
|
+
metadata["consistency"] = self._consistency
|
|
136
|
+
return metadata
|
|
137
|
+
|
|
138
|
+
def _get_state_options(self, *, concurrency: Concurrency | None = None) -> StateOptions | None:
|
|
139
|
+
"""Get StateOptions configured with consistency and optional concurrency."""
|
|
140
|
+
options_kwargs: dict[str, Any] = {}
|
|
141
|
+
if self._consistency == DAPR_CONSISTENCY_STRONG:
|
|
142
|
+
options_kwargs["consistency"] = Consistency.strong
|
|
143
|
+
elif self._consistency == DAPR_CONSISTENCY_EVENTUAL:
|
|
144
|
+
options_kwargs["consistency"] = Consistency.eventual
|
|
145
|
+
if concurrency is not None:
|
|
146
|
+
options_kwargs["concurrency"] = concurrency
|
|
147
|
+
if options_kwargs:
|
|
148
|
+
return StateOptions(**options_kwargs)
|
|
149
|
+
return None
|
|
150
|
+
|
|
151
|
+
def _get_metadata(self) -> dict[str, str]:
|
|
152
|
+
"""Get metadata for state operations including TTL if configured."""
|
|
153
|
+
metadata = {}
|
|
154
|
+
if self._ttl is not None:
|
|
155
|
+
metadata["ttlInSeconds"] = str(self._ttl)
|
|
156
|
+
return metadata
|
|
157
|
+
|
|
158
|
+
async def _serialize_item(self, item: TResponseInputItem) -> str:
|
|
159
|
+
"""Serialize an item to JSON string. Can be overridden by subclasses."""
|
|
160
|
+
return json.dumps(item, separators=(",", ":"))
|
|
161
|
+
|
|
162
|
+
async def _deserialize_item(self, item: str) -> TResponseInputItem:
|
|
163
|
+
"""Deserialize a JSON string to an item. Can be overridden by subclasses."""
|
|
164
|
+
return json.loads(item) # type: ignore[no-any-return]
|
|
165
|
+
|
|
166
|
+
def _decode_messages(self, data: bytes | None) -> list[Any]:
|
|
167
|
+
if not data:
|
|
168
|
+
return []
|
|
169
|
+
try:
|
|
170
|
+
messages_json = data.decode("utf-8")
|
|
171
|
+
messages = json.loads(messages_json)
|
|
172
|
+
if isinstance(messages, list):
|
|
173
|
+
return list(messages)
|
|
174
|
+
except (json.JSONDecodeError, UnicodeDecodeError):
|
|
175
|
+
return []
|
|
176
|
+
return []
|
|
177
|
+
|
|
178
|
+
def _calculate_retry_delay(self, attempt: int) -> float:
|
|
179
|
+
base: float = _RETRY_BASE_DELAY_SECONDS * (2 ** max(0, attempt - 1))
|
|
180
|
+
delay: float = min(base, _RETRY_MAX_DELAY_SECONDS)
|
|
181
|
+
# Add jitter (10%) similar to tracing processors to avoid thundering herd.
|
|
182
|
+
return delay + random.uniform(0, 0.1 * delay)
|
|
183
|
+
|
|
184
|
+
def _is_concurrency_conflict(self, error: Exception) -> bool:
|
|
185
|
+
code_attr = getattr(error, "code", None)
|
|
186
|
+
if callable(code_attr):
|
|
187
|
+
try:
|
|
188
|
+
status_code = code_attr()
|
|
189
|
+
except Exception:
|
|
190
|
+
status_code = None
|
|
191
|
+
if status_code is not None:
|
|
192
|
+
status_name = getattr(status_code, "name", str(status_code))
|
|
193
|
+
if status_name in {"ABORTED", "FAILED_PRECONDITION"}:
|
|
194
|
+
return True
|
|
195
|
+
message = str(error).lower()
|
|
196
|
+
conflict_markers = (
|
|
197
|
+
"etag mismatch",
|
|
198
|
+
"etag does not match",
|
|
199
|
+
"precondition failed",
|
|
200
|
+
"concurrency conflict",
|
|
201
|
+
"invalid etag",
|
|
202
|
+
"failed to set key", # Redis state store Lua script error during conditional write
|
|
203
|
+
"user_script", # Redis script failure hint
|
|
204
|
+
)
|
|
205
|
+
return any(marker in message for marker in conflict_markers)
|
|
206
|
+
|
|
207
|
+
async def _handle_concurrency_conflict(self, error: Exception, attempt: int) -> bool:
|
|
208
|
+
if not self._is_concurrency_conflict(error):
|
|
209
|
+
return False
|
|
210
|
+
if attempt >= _MAX_WRITE_ATTEMPTS:
|
|
211
|
+
return False
|
|
212
|
+
delay = self._calculate_retry_delay(attempt)
|
|
213
|
+
if delay > 0:
|
|
214
|
+
await asyncio.sleep(delay)
|
|
215
|
+
return True
|
|
216
|
+
|
|
217
|
+
# ------------------------------------------------------------------
|
|
218
|
+
# Session protocol implementation
|
|
219
|
+
# ------------------------------------------------------------------
|
|
220
|
+
|
|
221
|
+
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
|
|
222
|
+
"""Retrieve the conversation history for this session.
|
|
223
|
+
|
|
224
|
+
Args:
|
|
225
|
+
limit: Maximum number of items to retrieve. If None, retrieves all items.
|
|
226
|
+
When specified, returns the latest N items in chronological order.
|
|
227
|
+
|
|
228
|
+
Returns:
|
|
229
|
+
List of input items representing the conversation history
|
|
230
|
+
"""
|
|
231
|
+
async with self._lock:
|
|
232
|
+
# Get messages from state store with consistency level
|
|
233
|
+
response = await self._dapr_client.get_state(
|
|
234
|
+
store_name=self._state_store_name,
|
|
235
|
+
key=self._messages_key,
|
|
236
|
+
state_metadata=self._get_read_metadata(),
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
messages = self._decode_messages(response.data)
|
|
240
|
+
if not messages:
|
|
241
|
+
return []
|
|
242
|
+
if limit is not None:
|
|
243
|
+
if limit <= 0:
|
|
244
|
+
return []
|
|
245
|
+
messages = messages[-limit:]
|
|
246
|
+
items: list[TResponseInputItem] = []
|
|
247
|
+
for msg in messages:
|
|
248
|
+
try:
|
|
249
|
+
if isinstance(msg, str):
|
|
250
|
+
item = await self._deserialize_item(msg)
|
|
251
|
+
else:
|
|
252
|
+
item = msg
|
|
253
|
+
items.append(item)
|
|
254
|
+
except (json.JSONDecodeError, TypeError):
|
|
255
|
+
continue
|
|
256
|
+
return items
|
|
257
|
+
|
|
258
|
+
async def add_items(self, items: list[TResponseInputItem]) -> None:
|
|
259
|
+
"""Add new items to the conversation history.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
items: List of input items to add to the history
|
|
263
|
+
"""
|
|
264
|
+
if not items:
|
|
265
|
+
return
|
|
266
|
+
|
|
267
|
+
async with self._lock:
|
|
268
|
+
serialized_items: list[str] = [await self._serialize_item(item) for item in items]
|
|
269
|
+
attempt = 0
|
|
270
|
+
while True:
|
|
271
|
+
attempt += 1
|
|
272
|
+
response = await self._dapr_client.get_state(
|
|
273
|
+
store_name=self._state_store_name,
|
|
274
|
+
key=self._messages_key,
|
|
275
|
+
state_metadata=self._get_read_metadata(),
|
|
276
|
+
)
|
|
277
|
+
existing_messages = self._decode_messages(response.data)
|
|
278
|
+
updated_messages = existing_messages + serialized_items
|
|
279
|
+
messages_json = json.dumps(updated_messages, separators=(",", ":"))
|
|
280
|
+
etag = response.etag
|
|
281
|
+
try:
|
|
282
|
+
await self._dapr_client.save_state(
|
|
283
|
+
store_name=self._state_store_name,
|
|
284
|
+
key=self._messages_key,
|
|
285
|
+
value=messages_json,
|
|
286
|
+
etag=etag,
|
|
287
|
+
state_metadata=self._get_metadata(),
|
|
288
|
+
options=self._get_state_options(concurrency=Concurrency.first_write),
|
|
289
|
+
)
|
|
290
|
+
break
|
|
291
|
+
except Exception as error:
|
|
292
|
+
should_retry = await self._handle_concurrency_conflict(error, attempt)
|
|
293
|
+
if should_retry:
|
|
294
|
+
continue
|
|
295
|
+
raise
|
|
296
|
+
|
|
297
|
+
# Update metadata
|
|
298
|
+
metadata = {
|
|
299
|
+
"session_id": self.session_id,
|
|
300
|
+
"created_at": str(int(time.time())),
|
|
301
|
+
"updated_at": str(int(time.time())),
|
|
302
|
+
}
|
|
303
|
+
await self._dapr_client.save_state(
|
|
304
|
+
store_name=self._state_store_name,
|
|
305
|
+
key=self._metadata_key,
|
|
306
|
+
value=json.dumps(metadata),
|
|
307
|
+
state_metadata=self._get_metadata(),
|
|
308
|
+
options=self._get_state_options(),
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
async def pop_item(self) -> TResponseInputItem | None:
|
|
312
|
+
"""Remove and return the most recent item from the session.
|
|
313
|
+
|
|
314
|
+
Returns:
|
|
315
|
+
The most recent item if it exists, None if the session is empty
|
|
316
|
+
"""
|
|
317
|
+
async with self._lock:
|
|
318
|
+
attempt = 0
|
|
319
|
+
while True:
|
|
320
|
+
attempt += 1
|
|
321
|
+
response = await self._dapr_client.get_state(
|
|
322
|
+
store_name=self._state_store_name,
|
|
323
|
+
key=self._messages_key,
|
|
324
|
+
state_metadata=self._get_read_metadata(),
|
|
325
|
+
)
|
|
326
|
+
messages = self._decode_messages(response.data)
|
|
327
|
+
if not messages:
|
|
328
|
+
return None
|
|
329
|
+
last_item = messages.pop()
|
|
330
|
+
messages_json = json.dumps(messages, separators=(",", ":"))
|
|
331
|
+
etag = getattr(response, "etag", None) or None
|
|
332
|
+
etag = getattr(response, "etag", None) or None
|
|
333
|
+
try:
|
|
334
|
+
await self._dapr_client.save_state(
|
|
335
|
+
store_name=self._state_store_name,
|
|
336
|
+
key=self._messages_key,
|
|
337
|
+
value=messages_json,
|
|
338
|
+
etag=etag,
|
|
339
|
+
state_metadata=self._get_metadata(),
|
|
340
|
+
options=self._get_state_options(concurrency=Concurrency.first_write),
|
|
341
|
+
)
|
|
342
|
+
break
|
|
343
|
+
except Exception as error:
|
|
344
|
+
should_retry = await self._handle_concurrency_conflict(error, attempt)
|
|
345
|
+
if should_retry:
|
|
346
|
+
continue
|
|
347
|
+
raise
|
|
348
|
+
try:
|
|
349
|
+
if isinstance(last_item, str):
|
|
350
|
+
return await self._deserialize_item(last_item)
|
|
351
|
+
return last_item # type: ignore[no-any-return]
|
|
352
|
+
except (json.JSONDecodeError, TypeError):
|
|
353
|
+
return None
|
|
354
|
+
|
|
355
|
+
async def clear_session(self) -> None:
|
|
356
|
+
"""Clear all items for this session."""
|
|
357
|
+
async with self._lock:
|
|
358
|
+
# Delete messages and metadata keys
|
|
359
|
+
await self._dapr_client.delete_state(
|
|
360
|
+
store_name=self._state_store_name,
|
|
361
|
+
key=self._messages_key,
|
|
362
|
+
options=self._get_state_options(),
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
await self._dapr_client.delete_state(
|
|
366
|
+
store_name=self._state_store_name,
|
|
367
|
+
key=self._metadata_key,
|
|
368
|
+
options=self._get_state_options(),
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
async def close(self) -> None:
|
|
372
|
+
"""Close the Dapr client connection.
|
|
373
|
+
|
|
374
|
+
Only closes the connection if this session owns the Dapr client
|
|
375
|
+
(i.e., created via from_address). If the client was injected externally,
|
|
376
|
+
the caller is responsible for managing its lifecycle.
|
|
377
|
+
"""
|
|
378
|
+
if self._owns_client:
|
|
379
|
+
await self._dapr_client.close()
|
|
380
|
+
|
|
381
|
+
async def __aenter__(self) -> DaprSession:
|
|
382
|
+
"""Enter async context manager."""
|
|
383
|
+
return self
|
|
384
|
+
|
|
385
|
+
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
386
|
+
"""Exit async context manager and close the connection."""
|
|
387
|
+
await self.close()
|
|
388
|
+
|
|
389
|
+
async def ping(self) -> bool:
|
|
390
|
+
"""Test Dapr connectivity by checking metadata.
|
|
391
|
+
|
|
392
|
+
Returns:
|
|
393
|
+
True if Dapr is reachable, False otherwise.
|
|
394
|
+
"""
|
|
395
|
+
try:
|
|
396
|
+
# First attempt a read; some stores may not be initialized yet.
|
|
397
|
+
await self._dapr_client.get_state(
|
|
398
|
+
store_name=self._state_store_name,
|
|
399
|
+
key="__ping__",
|
|
400
|
+
state_metadata=self._get_read_metadata(),
|
|
401
|
+
)
|
|
402
|
+
return True
|
|
403
|
+
except Exception as initial_error:
|
|
404
|
+
# If relation/table is missing or store isn't initialized,
|
|
405
|
+
# attempt a write to initialize it, then read again.
|
|
406
|
+
try:
|
|
407
|
+
await self._dapr_client.save_state(
|
|
408
|
+
store_name=self._state_store_name,
|
|
409
|
+
key="__ping__",
|
|
410
|
+
value="ok",
|
|
411
|
+
state_metadata=self._get_metadata(),
|
|
412
|
+
options=self._get_state_options(),
|
|
413
|
+
)
|
|
414
|
+
# Read again after write.
|
|
415
|
+
await self._dapr_client.get_state(
|
|
416
|
+
store_name=self._state_store_name,
|
|
417
|
+
key="__ping__",
|
|
418
|
+
state_metadata=self._get_read_metadata(),
|
|
419
|
+
)
|
|
420
|
+
return True
|
|
421
|
+
except Exception:
|
|
422
|
+
logger.error("Dapr connection failed: %s", initial_error)
|
|
423
|
+
return False
|
|
@@ -319,3 +319,16 @@ class SQLAlchemySession(SessionABC):
|
|
|
319
319
|
await sess.execute(
|
|
320
320
|
delete(self._sessions).where(self._sessions.c.session_id == self.session_id)
|
|
321
321
|
)
|
|
322
|
+
|
|
323
|
+
@property
|
|
324
|
+
def engine(self) -> AsyncEngine:
|
|
325
|
+
"""Access the underlying SQLAlchemy AsyncEngine.
|
|
326
|
+
|
|
327
|
+
This property provides direct access to the engine for advanced use cases,
|
|
328
|
+
such as checking connection pool status, configuring engine settings,
|
|
329
|
+
or manually disposing the engine when needed.
|
|
330
|
+
|
|
331
|
+
Returns:
|
|
332
|
+
AsyncEngine: The SQLAlchemy async engine instance.
|
|
333
|
+
"""
|
|
334
|
+
return self._engine
|
|
@@ -44,6 +44,7 @@ from ...models.chatcmpl_helpers import HEADERS, HEADERS_OVERRIDE
|
|
|
44
44
|
from ...models.chatcmpl_stream_handler import ChatCmplStreamHandler
|
|
45
45
|
from ...models.fake_id import FAKE_RESPONSES_ID
|
|
46
46
|
from ...models.interface import Model, ModelTracing
|
|
47
|
+
from ...models.openai_responses import Converter as OpenAIResponsesConverter
|
|
47
48
|
from ...tool import Tool
|
|
48
49
|
from ...tracing import generation_span
|
|
49
50
|
from ...tracing.span_data import GenerationSpanData
|
|
@@ -109,18 +110,26 @@ class LitellmModel(Model):
|
|
|
109
110
|
prompt=prompt,
|
|
110
111
|
)
|
|
111
112
|
|
|
112
|
-
|
|
113
|
+
message: litellm.types.utils.Message | None = None
|
|
114
|
+
first_choice: litellm.types.utils.Choices | None = None
|
|
115
|
+
if response.choices and len(response.choices) > 0:
|
|
116
|
+
choice = response.choices[0]
|
|
117
|
+
if isinstance(choice, litellm.types.utils.Choices):
|
|
118
|
+
first_choice = choice
|
|
119
|
+
message = first_choice.message
|
|
113
120
|
|
|
114
121
|
if _debug.DONT_LOG_MODEL_DATA:
|
|
115
122
|
logger.debug("Received model response")
|
|
116
123
|
else:
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
+
if message is not None:
|
|
125
|
+
logger.debug(
|
|
126
|
+
f"""LLM resp:\n{
|
|
127
|
+
json.dumps(message.model_dump(), indent=2, ensure_ascii=False)
|
|
128
|
+
}\n"""
|
|
129
|
+
)
|
|
130
|
+
else:
|
|
131
|
+
finish_reason = first_choice.finish_reason if first_choice else "-"
|
|
132
|
+
logger.debug(f"LLM resp had no message. finish_reason: {finish_reason}")
|
|
124
133
|
|
|
125
134
|
if hasattr(response, "usage"):
|
|
126
135
|
response_usage = response.usage
|
|
@@ -151,14 +160,20 @@ class LitellmModel(Model):
|
|
|
151
160
|
logger.warning("No usage information returned from Litellm")
|
|
152
161
|
|
|
153
162
|
if tracing.include_data():
|
|
154
|
-
span_generation.span_data.output =
|
|
163
|
+
span_generation.span_data.output = (
|
|
164
|
+
[message.model_dump()] if message is not None else []
|
|
165
|
+
)
|
|
155
166
|
span_generation.span_data.usage = {
|
|
156
167
|
"input_tokens": usage.input_tokens,
|
|
157
168
|
"output_tokens": usage.output_tokens,
|
|
158
169
|
}
|
|
159
170
|
|
|
160
|
-
items =
|
|
161
|
-
|
|
171
|
+
items = (
|
|
172
|
+
Converter.message_to_output_items(
|
|
173
|
+
LitellmConverter.convert_message_to_openai(message)
|
|
174
|
+
)
|
|
175
|
+
if message is not None
|
|
176
|
+
else []
|
|
162
177
|
)
|
|
163
178
|
|
|
164
179
|
return ModelResponse(
|
|
@@ -269,7 +284,7 @@ class LitellmModel(Model):
|
|
|
269
284
|
)
|
|
270
285
|
|
|
271
286
|
# Fix for interleaved thinking bug: reorder messages to ensure tool_use comes before tool_result # noqa: E501
|
|
272
|
-
if
|
|
287
|
+
if "anthropic" in self.model.lower() or "claude" in self.model.lower():
|
|
273
288
|
converted_messages = self._fix_tool_message_ordering(converted_messages)
|
|
274
289
|
|
|
275
290
|
if system_instructions:
|
|
@@ -325,6 +340,23 @@ class LitellmModel(Model):
|
|
|
325
340
|
)
|
|
326
341
|
|
|
327
342
|
reasoning_effort = model_settings.reasoning.effort if model_settings.reasoning else None
|
|
343
|
+
# Enable developers to pass non-OpenAI compatible reasoning_effort data like "none"
|
|
344
|
+
# Priority order:
|
|
345
|
+
# 1. model_settings.reasoning.effort
|
|
346
|
+
# 2. model_settings.extra_body["reasoning_effort"]
|
|
347
|
+
# 3. model_settings.extra_args["reasoning_effort"]
|
|
348
|
+
if (
|
|
349
|
+
reasoning_effort is None # Unset in model_settings
|
|
350
|
+
and isinstance(model_settings.extra_body, dict)
|
|
351
|
+
and "reasoning_effort" in model_settings.extra_body
|
|
352
|
+
):
|
|
353
|
+
reasoning_effort = model_settings.extra_body["reasoning_effort"]
|
|
354
|
+
if (
|
|
355
|
+
reasoning_effort is None # Unset in both model_settings and model_settings.extra_body
|
|
356
|
+
and model_settings.extra_args
|
|
357
|
+
and "reasoning_effort" in model_settings.extra_args
|
|
358
|
+
):
|
|
359
|
+
reasoning_effort = model_settings.extra_args["reasoning_effort"]
|
|
328
360
|
|
|
329
361
|
stream_options = None
|
|
330
362
|
if stream and model_settings.include_usage is not None:
|
|
@@ -342,6 +374,9 @@ class LitellmModel(Model):
|
|
|
342
374
|
if model_settings.extra_args:
|
|
343
375
|
extra_kwargs.update(model_settings.extra_args)
|
|
344
376
|
|
|
377
|
+
# Prevent duplicate reasoning_effort kwargs when it was promoted to a top-level argument.
|
|
378
|
+
extra_kwargs.pop("reasoning_effort", None)
|
|
379
|
+
|
|
345
380
|
ret = await litellm.acompletion(
|
|
346
381
|
model=self.model,
|
|
347
382
|
messages=converted_messages,
|
|
@@ -367,15 +402,19 @@ class LitellmModel(Model):
|
|
|
367
402
|
if isinstance(ret, litellm.types.utils.ModelResponse):
|
|
368
403
|
return ret
|
|
369
404
|
|
|
405
|
+
responses_tool_choice = OpenAIResponsesConverter.convert_tool_choice(
|
|
406
|
+
model_settings.tool_choice
|
|
407
|
+
)
|
|
408
|
+
if responses_tool_choice is None or responses_tool_choice is omit:
|
|
409
|
+
responses_tool_choice = "auto"
|
|
410
|
+
|
|
370
411
|
response = Response(
|
|
371
412
|
id=FAKE_RESPONSES_ID,
|
|
372
413
|
created_at=time.time(),
|
|
373
414
|
model=self.model,
|
|
374
415
|
object="response",
|
|
375
416
|
output=[],
|
|
376
|
-
tool_choice=
|
|
377
|
-
if tool_choice is not omit
|
|
378
|
-
else "auto",
|
|
417
|
+
tool_choice=responses_tool_choice, # type: ignore[arg-type]
|
|
379
418
|
top_p=model_settings.top_p,
|
|
380
419
|
temperature=model_settings.temperature,
|
|
381
420
|
tools=[],
|
agents/items.py
CHANGED
|
@@ -361,6 +361,9 @@ class ItemHelpers:
|
|
|
361
361
|
if isinstance(output, (ToolOutputText, ToolOutputImage, ToolOutputFileContent)):
|
|
362
362
|
return output
|
|
363
363
|
elif isinstance(output, dict):
|
|
364
|
+
# Require explicit 'type' field in dict to be considered a structured output
|
|
365
|
+
if "type" not in output:
|
|
366
|
+
return None
|
|
364
367
|
try:
|
|
365
368
|
return ValidToolOutputPydanticModelsTypeAdapter.validate_python(output)
|
|
366
369
|
except pydantic.ValidationError:
|
agents/lifecycle.py
CHANGED
|
@@ -62,7 +62,7 @@ class RunHooksBase(Generic[TContext, TAgent]):
|
|
|
62
62
|
agent: TAgent,
|
|
63
63
|
tool: Tool,
|
|
64
64
|
) -> None:
|
|
65
|
-
"""Called
|
|
65
|
+
"""Called immediately before a local tool is invoked."""
|
|
66
66
|
pass
|
|
67
67
|
|
|
68
68
|
async def on_tool_end(
|
|
@@ -72,7 +72,7 @@ class RunHooksBase(Generic[TContext, TAgent]):
|
|
|
72
72
|
tool: Tool,
|
|
73
73
|
result: str,
|
|
74
74
|
) -> None:
|
|
75
|
-
"""Called after a tool is invoked."""
|
|
75
|
+
"""Called immediately after a local tool is invoked."""
|
|
76
76
|
pass
|
|
77
77
|
|
|
78
78
|
|
|
@@ -113,7 +113,7 @@ class AgentHooksBase(Generic[TContext, TAgent]):
|
|
|
113
113
|
agent: TAgent,
|
|
114
114
|
tool: Tool,
|
|
115
115
|
) -> None:
|
|
116
|
-
"""Called
|
|
116
|
+
"""Called immediately before a local tool is invoked."""
|
|
117
117
|
pass
|
|
118
118
|
|
|
119
119
|
async def on_tool_end(
|
|
@@ -123,7 +123,7 @@ class AgentHooksBase(Generic[TContext, TAgent]):
|
|
|
123
123
|
tool: Tool,
|
|
124
124
|
result: str,
|
|
125
125
|
) -> None:
|
|
126
|
-
"""Called after a tool is invoked."""
|
|
126
|
+
"""Called immediately after a local tool is invoked."""
|
|
127
127
|
pass
|
|
128
128
|
|
|
129
129
|
async def on_llm_start(
|