openai-agents 0.2.11__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of openai-agents might be problematic. Click here for more details.
- agents/_debug.py +15 -4
- agents/_run_impl.py +34 -37
- agents/agent.py +18 -2
- agents/extensions/handoff_filters.py +2 -0
- agents/extensions/memory/__init__.py +42 -15
- agents/extensions/memory/encrypt_session.py +185 -0
- agents/extensions/models/litellm_model.py +62 -10
- agents/function_schema.py +45 -3
- agents/memory/__init__.py +2 -0
- agents/memory/openai_conversations_session.py +0 -3
- agents/memory/util.py +20 -0
- agents/models/chatcmpl_converter.py +74 -15
- agents/models/chatcmpl_helpers.py +6 -0
- agents/models/chatcmpl_stream_handler.py +29 -1
- agents/models/openai_chatcompletions.py +26 -4
- agents/models/openai_responses.py +30 -4
- agents/realtime/__init__.py +2 -0
- agents/realtime/_util.py +1 -1
- agents/realtime/agent.py +7 -0
- agents/realtime/audio_formats.py +29 -0
- agents/realtime/config.py +32 -4
- agents/realtime/items.py +17 -1
- agents/realtime/model_events.py +2 -0
- agents/realtime/model_inputs.py +15 -1
- agents/realtime/openai_realtime.py +421 -130
- agents/realtime/session.py +167 -14
- agents/result.py +47 -20
- agents/run.py +191 -106
- agents/tool.py +1 -1
- agents/tracing/processor_interface.py +84 -11
- agents/tracing/spans.py +88 -0
- agents/tracing/traces.py +99 -16
- agents/util/_json.py +19 -1
- agents/util/_transforms.py +12 -2
- agents/voice/input.py +5 -4
- agents/voice/models/openai_stt.py +15 -8
- {openai_agents-0.2.11.dist-info → openai_agents-0.3.1.dist-info}/METADATA +4 -2
- {openai_agents-0.2.11.dist-info → openai_agents-0.3.1.dist-info}/RECORD +40 -37
- {openai_agents-0.2.11.dist-info → openai_agents-0.3.1.dist-info}/WHEEL +0 -0
- {openai_agents-0.2.11.dist-info → openai_agents-0.3.1.dist-info}/licenses/LICENSE +0 -0
agents/_debug.py
CHANGED
|
@@ -1,17 +1,28 @@
|
|
|
1
1
|
import os
|
|
2
2
|
|
|
3
3
|
|
|
4
|
-
def _debug_flag_enabled(flag: str) -> bool:
|
|
4
|
+
def _debug_flag_enabled(flag: str, default: bool = False) -> bool:
|
|
5
5
|
flag_value = os.getenv(flag)
|
|
6
|
-
|
|
6
|
+
if flag_value is None:
|
|
7
|
+
return default
|
|
8
|
+
else:
|
|
9
|
+
return flag_value == "1" or flag_value.lower() == "true"
|
|
7
10
|
|
|
8
11
|
|
|
9
|
-
|
|
12
|
+
def _load_dont_log_model_data() -> bool:
|
|
13
|
+
return _debug_flag_enabled("OPENAI_AGENTS_DONT_LOG_MODEL_DATA", default=True)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _load_dont_log_tool_data() -> bool:
|
|
17
|
+
return _debug_flag_enabled("OPENAI_AGENTS_DONT_LOG_TOOL_DATA", default=True)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
DONT_LOG_MODEL_DATA = _load_dont_log_model_data()
|
|
10
21
|
"""By default we don't log LLM inputs/outputs, to prevent exposing sensitive information. Set this
|
|
11
22
|
flag to enable logging them.
|
|
12
23
|
"""
|
|
13
24
|
|
|
14
|
-
DONT_LOG_TOOL_DATA =
|
|
25
|
+
DONT_LOG_TOOL_DATA = _load_dont_log_tool_data()
|
|
15
26
|
"""By default we don't log tool call inputs/outputs, to prevent exposing sensitive information. Set
|
|
16
27
|
this flag to enable logging them.
|
|
17
28
|
"""
|
agents/_run_impl.py
CHANGED
|
@@ -330,43 +330,40 @@ class RunImpl:
|
|
|
330
330
|
ItemHelpers.extract_last_text(message_items[-1].raw_item) if message_items else None
|
|
331
331
|
)
|
|
332
332
|
|
|
333
|
-
#
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
)
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
new_step_items=new_step_items,
|
|
368
|
-
next_step=NextStepRunAgain(),
|
|
369
|
-
)
|
|
333
|
+
# Generate final output only when there are no pending tool calls or approval requests.
|
|
334
|
+
if not processed_response.has_tools_or_approvals_to_run():
|
|
335
|
+
if output_schema and not output_schema.is_plain_text() and potential_final_output_text:
|
|
336
|
+
final_output = output_schema.validate_json(potential_final_output_text)
|
|
337
|
+
return await cls.execute_final_output(
|
|
338
|
+
agent=agent,
|
|
339
|
+
original_input=original_input,
|
|
340
|
+
new_response=new_response,
|
|
341
|
+
pre_step_items=pre_step_items,
|
|
342
|
+
new_step_items=new_step_items,
|
|
343
|
+
final_output=final_output,
|
|
344
|
+
hooks=hooks,
|
|
345
|
+
context_wrapper=context_wrapper,
|
|
346
|
+
)
|
|
347
|
+
elif not output_schema or output_schema.is_plain_text():
|
|
348
|
+
return await cls.execute_final_output(
|
|
349
|
+
agent=agent,
|
|
350
|
+
original_input=original_input,
|
|
351
|
+
new_response=new_response,
|
|
352
|
+
pre_step_items=pre_step_items,
|
|
353
|
+
new_step_items=new_step_items,
|
|
354
|
+
final_output=potential_final_output_text or "",
|
|
355
|
+
hooks=hooks,
|
|
356
|
+
context_wrapper=context_wrapper,
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
# If there's no final output, we can just run again
|
|
360
|
+
return SingleStepResult(
|
|
361
|
+
original_input=original_input,
|
|
362
|
+
model_response=new_response,
|
|
363
|
+
pre_step_items=pre_step_items,
|
|
364
|
+
new_step_items=new_step_items,
|
|
365
|
+
next_step=NextStepRunAgain(),
|
|
366
|
+
)
|
|
370
367
|
|
|
371
368
|
@classmethod
|
|
372
369
|
def maybe_reset_tool_choice(
|
agents/agent.py
CHANGED
|
@@ -30,9 +30,11 @@ from .util import _transforms
|
|
|
30
30
|
from .util._types import MaybeAwaitable
|
|
31
31
|
|
|
32
32
|
if TYPE_CHECKING:
|
|
33
|
-
from .lifecycle import AgentHooks
|
|
33
|
+
from .lifecycle import AgentHooks, RunHooks
|
|
34
34
|
from .mcp import MCPServer
|
|
35
|
+
from .memory.session import Session
|
|
35
36
|
from .result import RunResult
|
|
37
|
+
from .run import RunConfig
|
|
36
38
|
|
|
37
39
|
|
|
38
40
|
@dataclass
|
|
@@ -384,6 +386,12 @@ class Agent(AgentBase, Generic[TContext]):
|
|
|
384
386
|
custom_output_extractor: Callable[[RunResult], Awaitable[str]] | None = None,
|
|
385
387
|
is_enabled: bool
|
|
386
388
|
| Callable[[RunContextWrapper[Any], AgentBase[Any]], MaybeAwaitable[bool]] = True,
|
|
389
|
+
run_config: RunConfig | None = None,
|
|
390
|
+
max_turns: int | None = None,
|
|
391
|
+
hooks: RunHooks[TContext] | None = None,
|
|
392
|
+
previous_response_id: str | None = None,
|
|
393
|
+
conversation_id: str | None = None,
|
|
394
|
+
session: Session | None = None,
|
|
387
395
|
) -> Tool:
|
|
388
396
|
"""Transform this agent into a tool, callable by other agents.
|
|
389
397
|
|
|
@@ -410,12 +418,20 @@ class Agent(AgentBase, Generic[TContext]):
|
|
|
410
418
|
is_enabled=is_enabled,
|
|
411
419
|
)
|
|
412
420
|
async def run_agent(context: RunContextWrapper, input: str) -> str:
|
|
413
|
-
from .run import Runner
|
|
421
|
+
from .run import DEFAULT_MAX_TURNS, Runner
|
|
422
|
+
|
|
423
|
+
resolved_max_turns = max_turns if max_turns is not None else DEFAULT_MAX_TURNS
|
|
414
424
|
|
|
415
425
|
output = await Runner.run(
|
|
416
426
|
starting_agent=self,
|
|
417
427
|
input=input,
|
|
418
428
|
context=context.context,
|
|
429
|
+
run_config=run_config,
|
|
430
|
+
max_turns=resolved_max_turns,
|
|
431
|
+
hooks=hooks,
|
|
432
|
+
previous_response_id=previous_response_id,
|
|
433
|
+
conversation_id=conversation_id,
|
|
434
|
+
session=session,
|
|
419
435
|
)
|
|
420
436
|
if custom_output_extractor:
|
|
421
437
|
return await custom_output_extractor(output)
|
|
@@ -4,6 +4,7 @@ from ..handoffs import HandoffInputData
|
|
|
4
4
|
from ..items import (
|
|
5
5
|
HandoffCallItem,
|
|
6
6
|
HandoffOutputItem,
|
|
7
|
+
ReasoningItem,
|
|
7
8
|
RunItem,
|
|
8
9
|
ToolCallItem,
|
|
9
10
|
ToolCallOutputItem,
|
|
@@ -41,6 +42,7 @@ def _remove_tools_from_items(items: tuple[RunItem, ...]) -> tuple[RunItem, ...]:
|
|
|
41
42
|
or isinstance(item, HandoffOutputItem)
|
|
42
43
|
or isinstance(item, ToolCallItem)
|
|
43
44
|
or isinstance(item, ToolCallOutputItem)
|
|
45
|
+
or isinstance(item, ReasoningItem)
|
|
44
46
|
):
|
|
45
47
|
continue
|
|
46
48
|
filtered_items.append(item)
|
|
@@ -1,15 +1,42 @@
|
|
|
1
|
-
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
from __future__ import annotations
|
|
10
|
-
|
|
11
|
-
from
|
|
12
|
-
|
|
13
|
-
__all__: list[str] = [
|
|
14
|
-
"
|
|
15
|
-
|
|
1
|
+
"""Session memory backends living in the extensions namespace.
|
|
2
|
+
|
|
3
|
+
This package contains optional, production-grade session implementations that
|
|
4
|
+
introduce extra third-party dependencies (database drivers, ORMs, etc.). They
|
|
5
|
+
conform to the :class:`agents.memory.session.Session` protocol so they can be
|
|
6
|
+
used as a drop-in replacement for :class:`agents.memory.session.SQLiteSession`.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
__all__: list[str] = [
|
|
14
|
+
"EncryptedSession",
|
|
15
|
+
"SQLAlchemySession",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def __getattr__(name: str) -> Any:
|
|
20
|
+
if name == "EncryptedSession":
|
|
21
|
+
try:
|
|
22
|
+
from .encrypt_session import EncryptedSession # noqa: F401
|
|
23
|
+
|
|
24
|
+
return EncryptedSession
|
|
25
|
+
except ModuleNotFoundError as e:
|
|
26
|
+
raise ImportError(
|
|
27
|
+
"EncryptedSession requires the 'cryptography' extra. "
|
|
28
|
+
"Install it with: pip install openai-agents[encrypt]"
|
|
29
|
+
) from e
|
|
30
|
+
|
|
31
|
+
if name == "SQLAlchemySession":
|
|
32
|
+
try:
|
|
33
|
+
from .sqlalchemy_session import SQLAlchemySession # noqa: F401
|
|
34
|
+
|
|
35
|
+
return SQLAlchemySession
|
|
36
|
+
except ModuleNotFoundError as e:
|
|
37
|
+
raise ImportError(
|
|
38
|
+
"SQLAlchemySession requires the 'sqlalchemy' extra. "
|
|
39
|
+
"Install it with: pip install openai-agents[sqlalchemy]"
|
|
40
|
+
) from e
|
|
41
|
+
|
|
42
|
+
raise AttributeError(f"module {__name__} has no attribute {name}")
|
|
@@ -0,0 +1,185 @@
|
|
|
1
|
+
"""Encrypted Session wrapper for secure conversation storage.
|
|
2
|
+
|
|
3
|
+
This module provides transparent encryption for session storage with automatic
|
|
4
|
+
expiration of old data. When TTL expires, expired items are silently skipped.
|
|
5
|
+
|
|
6
|
+
Usage::
|
|
7
|
+
|
|
8
|
+
from agents.extensions.memory import EncryptedSession, SQLAlchemySession
|
|
9
|
+
|
|
10
|
+
# Create underlying session (e.g. SQLAlchemySession)
|
|
11
|
+
underlying_session = SQLAlchemySession.from_url(
|
|
12
|
+
session_id="user-123",
|
|
13
|
+
url="postgresql+asyncpg://app:secret@db.example.com/agents",
|
|
14
|
+
create_tables=True,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
# Wrap with encryption and TTL-based expiration
|
|
18
|
+
session = EncryptedSession(
|
|
19
|
+
session_id="user-123",
|
|
20
|
+
underlying_session=underlying_session,
|
|
21
|
+
encryption_key="your-encryption-key",
|
|
22
|
+
ttl=600, # 10 minutes
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
await Runner.run(agent, "Hello", session=session)
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
from __future__ import annotations
|
|
29
|
+
|
|
30
|
+
import base64
|
|
31
|
+
import json
|
|
32
|
+
from typing import Any, cast
|
|
33
|
+
|
|
34
|
+
from cryptography.fernet import Fernet, InvalidToken
|
|
35
|
+
from cryptography.hazmat.primitives import hashes
|
|
36
|
+
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
|
|
37
|
+
from typing_extensions import Literal, TypedDict, TypeGuard
|
|
38
|
+
|
|
39
|
+
from ...items import TResponseInputItem
|
|
40
|
+
from ...memory.session import SessionABC
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class EncryptedEnvelope(TypedDict):
|
|
44
|
+
"""TypedDict for encrypted message envelopes stored in the underlying session."""
|
|
45
|
+
|
|
46
|
+
__enc__: Literal[1]
|
|
47
|
+
v: int
|
|
48
|
+
kid: str
|
|
49
|
+
payload: str
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _ensure_fernet_key_bytes(master_key: str) -> bytes:
|
|
53
|
+
"""
|
|
54
|
+
Accept either a Fernet key (urlsafe-b64, 32 bytes after decode) or a raw string.
|
|
55
|
+
Returns raw bytes suitable for HKDF input.
|
|
56
|
+
"""
|
|
57
|
+
if not master_key:
|
|
58
|
+
raise ValueError("encryption_key not set; required for EncryptedSession.")
|
|
59
|
+
try:
|
|
60
|
+
key_bytes = base64.urlsafe_b64decode(master_key)
|
|
61
|
+
if len(key_bytes) == 32:
|
|
62
|
+
return key_bytes
|
|
63
|
+
except Exception:
|
|
64
|
+
pass
|
|
65
|
+
return master_key.encode("utf-8")
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _derive_session_fernet_key(master_key_bytes: bytes, session_id: str) -> Fernet:
|
|
69
|
+
hkdf = HKDF(
|
|
70
|
+
algorithm=hashes.SHA256(),
|
|
71
|
+
length=32,
|
|
72
|
+
salt=session_id.encode("utf-8"),
|
|
73
|
+
info=b"agents.session-store.hkdf.v1",
|
|
74
|
+
)
|
|
75
|
+
derived = hkdf.derive(master_key_bytes)
|
|
76
|
+
return Fernet(base64.urlsafe_b64encode(derived))
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _to_json_bytes(obj: Any) -> bytes:
|
|
80
|
+
return json.dumps(obj, ensure_ascii=False, separators=(",", ":"), default=str).encode("utf-8")
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def _from_json_bytes(data: bytes) -> Any:
|
|
84
|
+
return json.loads(data.decode("utf-8"))
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _is_encrypted_envelope(item: object) -> TypeGuard[EncryptedEnvelope]:
|
|
88
|
+
"""Type guard to check if an item is an encrypted envelope."""
|
|
89
|
+
return (
|
|
90
|
+
isinstance(item, dict)
|
|
91
|
+
and item.get("__enc__") == 1
|
|
92
|
+
and "payload" in item
|
|
93
|
+
and "kid" in item
|
|
94
|
+
and "v" in item
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class EncryptedSession(SessionABC):
|
|
99
|
+
"""Encrypted wrapper for Session implementations with TTL-based expiration.
|
|
100
|
+
|
|
101
|
+
This class wraps any SessionABC implementation to provide transparent
|
|
102
|
+
encryption/decryption of stored items using Fernet encryption with
|
|
103
|
+
per-session key derivation and automatic expiration of old data.
|
|
104
|
+
|
|
105
|
+
When items expire (exceed TTL), they are silently skipped during retrieval.
|
|
106
|
+
|
|
107
|
+
Note: Expired tokens are rejected based on the system clock of the application server.
|
|
108
|
+
To avoid valid tokens being rejected due to clock drift, ensure all servers in
|
|
109
|
+
your environment are synchronized using NTP.
|
|
110
|
+
"""
|
|
111
|
+
|
|
112
|
+
def __init__(
|
|
113
|
+
self,
|
|
114
|
+
session_id: str,
|
|
115
|
+
underlying_session: SessionABC,
|
|
116
|
+
encryption_key: str,
|
|
117
|
+
ttl: int = 600,
|
|
118
|
+
):
|
|
119
|
+
"""
|
|
120
|
+
Args:
|
|
121
|
+
session_id: ID for this session
|
|
122
|
+
underlying_session: The real session store (e.g. SQLiteSession, SQLAlchemySession)
|
|
123
|
+
encryption_key: Master key (Fernet key or raw secret)
|
|
124
|
+
ttl: Token time-to-live in seconds (default 10 min)
|
|
125
|
+
"""
|
|
126
|
+
self.session_id = session_id
|
|
127
|
+
self.underlying_session = underlying_session
|
|
128
|
+
self.ttl = ttl
|
|
129
|
+
|
|
130
|
+
master = _ensure_fernet_key_bytes(encryption_key)
|
|
131
|
+
self.cipher = _derive_session_fernet_key(master, session_id)
|
|
132
|
+
self._kid = "hkdf-v1"
|
|
133
|
+
self._ver = 1
|
|
134
|
+
|
|
135
|
+
def __getattr__(self, name):
|
|
136
|
+
return getattr(self.underlying_session, name)
|
|
137
|
+
|
|
138
|
+
def _wrap(self, item: TResponseInputItem) -> EncryptedEnvelope:
|
|
139
|
+
if isinstance(item, dict):
|
|
140
|
+
payload = item
|
|
141
|
+
elif hasattr(item, "model_dump"):
|
|
142
|
+
payload = item.model_dump()
|
|
143
|
+
elif hasattr(item, "__dict__"):
|
|
144
|
+
payload = item.__dict__
|
|
145
|
+
else:
|
|
146
|
+
payload = dict(item)
|
|
147
|
+
|
|
148
|
+
token = self.cipher.encrypt(_to_json_bytes(payload)).decode("utf-8")
|
|
149
|
+
return {"__enc__": 1, "v": self._ver, "kid": self._kid, "payload": token}
|
|
150
|
+
|
|
151
|
+
def _unwrap(self, item: TResponseInputItem | EncryptedEnvelope) -> TResponseInputItem | None:
|
|
152
|
+
if not _is_encrypted_envelope(item):
|
|
153
|
+
return cast(TResponseInputItem, item)
|
|
154
|
+
|
|
155
|
+
try:
|
|
156
|
+
token = item["payload"].encode("utf-8")
|
|
157
|
+
plaintext = self.cipher.decrypt(token, ttl=self.ttl)
|
|
158
|
+
return cast(TResponseInputItem, _from_json_bytes(plaintext))
|
|
159
|
+
except (InvalidToken, KeyError):
|
|
160
|
+
return None
|
|
161
|
+
|
|
162
|
+
async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
|
|
163
|
+
encrypted_items = await self.underlying_session.get_items(limit)
|
|
164
|
+
valid_items: list[TResponseInputItem] = []
|
|
165
|
+
for enc in encrypted_items:
|
|
166
|
+
item = self._unwrap(enc)
|
|
167
|
+
if item is not None:
|
|
168
|
+
valid_items.append(item)
|
|
169
|
+
return valid_items
|
|
170
|
+
|
|
171
|
+
async def add_items(self, items: list[TResponseInputItem]) -> None:
|
|
172
|
+
wrapped: list[EncryptedEnvelope] = [self._wrap(it) for it in items]
|
|
173
|
+
await self.underlying_session.add_items(cast(list[TResponseInputItem], wrapped))
|
|
174
|
+
|
|
175
|
+
async def pop_item(self) -> TResponseInputItem | None:
|
|
176
|
+
while True:
|
|
177
|
+
enc = await self.underlying_session.pop_item()
|
|
178
|
+
if not enc:
|
|
179
|
+
return None
|
|
180
|
+
item = self._unwrap(enc)
|
|
181
|
+
if item is not None:
|
|
182
|
+
return item
|
|
183
|
+
|
|
184
|
+
async def clear_session(self) -> None:
|
|
185
|
+
await self.underlying_session.clear_session()
|
|
@@ -39,7 +39,7 @@ from ...items import ModelResponse, TResponseInputItem, TResponseStreamEvent
|
|
|
39
39
|
from ...logger import logger
|
|
40
40
|
from ...model_settings import ModelSettings
|
|
41
41
|
from ...models.chatcmpl_converter import Converter
|
|
42
|
-
from ...models.chatcmpl_helpers import HEADERS
|
|
42
|
+
from ...models.chatcmpl_helpers import HEADERS, USER_AGENT_OVERRIDE
|
|
43
43
|
from ...models.chatcmpl_stream_handler import ChatCmplStreamHandler
|
|
44
44
|
from ...models.fake_id import FAKE_RESPONSES_ID
|
|
45
45
|
from ...models.interface import Model, ModelTracing
|
|
@@ -48,14 +48,16 @@ from ...tracing import generation_span
|
|
|
48
48
|
from ...tracing.span_data import GenerationSpanData
|
|
49
49
|
from ...tracing.spans import Span
|
|
50
50
|
from ...usage import Usage
|
|
51
|
+
from ...util._json import _to_dump_compatible
|
|
51
52
|
|
|
52
53
|
|
|
53
54
|
class InternalChatCompletionMessage(ChatCompletionMessage):
|
|
54
55
|
"""
|
|
55
|
-
An internal subclass to carry reasoning_content without modifying the original model.
|
|
56
|
-
"""
|
|
56
|
+
An internal subclass to carry reasoning_content and thinking_blocks without modifying the original model.
|
|
57
|
+
""" # noqa: E501
|
|
57
58
|
|
|
58
59
|
reasoning_content: str
|
|
60
|
+
thinking_blocks: list[dict[str, Any]] | None = None
|
|
59
61
|
|
|
60
62
|
|
|
61
63
|
class LitellmModel(Model):
|
|
@@ -255,7 +257,15 @@ class LitellmModel(Model):
|
|
|
255
257
|
stream: bool = False,
|
|
256
258
|
prompt: Any | None = None,
|
|
257
259
|
) -> litellm.types.utils.ModelResponse | tuple[Response, AsyncStream[ChatCompletionChunk]]:
|
|
258
|
-
|
|
260
|
+
# Preserve reasoning messages for tool calls when reasoning is on
|
|
261
|
+
# This is needed for models like Claude 4 Sonnet/Opus which support interleaved thinking
|
|
262
|
+
preserve_thinking_blocks = (
|
|
263
|
+
model_settings.reasoning is not None and model_settings.reasoning.effort is not None
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
converted_messages = Converter.items_to_messages(
|
|
267
|
+
input, preserve_thinking_blocks=preserve_thinking_blocks
|
|
268
|
+
)
|
|
259
269
|
|
|
260
270
|
if system_instructions:
|
|
261
271
|
converted_messages.insert(
|
|
@@ -265,6 +275,8 @@ class LitellmModel(Model):
|
|
|
265
275
|
"role": "system",
|
|
266
276
|
},
|
|
267
277
|
)
|
|
278
|
+
converted_messages = _to_dump_compatible(converted_messages)
|
|
279
|
+
|
|
268
280
|
if tracing.include_data():
|
|
269
281
|
span.span_data.input = converted_messages
|
|
270
282
|
|
|
@@ -283,13 +295,25 @@ class LitellmModel(Model):
|
|
|
283
295
|
for handoff in handoffs:
|
|
284
296
|
converted_tools.append(Converter.convert_handoff_tool(handoff))
|
|
285
297
|
|
|
298
|
+
converted_tools = _to_dump_compatible(converted_tools)
|
|
299
|
+
|
|
286
300
|
if _debug.DONT_LOG_MODEL_DATA:
|
|
287
301
|
logger.debug("Calling LLM")
|
|
288
302
|
else:
|
|
303
|
+
messages_json = json.dumps(
|
|
304
|
+
converted_messages,
|
|
305
|
+
indent=2,
|
|
306
|
+
ensure_ascii=False,
|
|
307
|
+
)
|
|
308
|
+
tools_json = json.dumps(
|
|
309
|
+
converted_tools,
|
|
310
|
+
indent=2,
|
|
311
|
+
ensure_ascii=False,
|
|
312
|
+
)
|
|
289
313
|
logger.debug(
|
|
290
314
|
f"Calling Litellm model: {self.model}\n"
|
|
291
|
-
f"{
|
|
292
|
-
f"Tools:\n{
|
|
315
|
+
f"{messages_json}\n"
|
|
316
|
+
f"Tools:\n{tools_json}\n"
|
|
293
317
|
f"Stream: {stream}\n"
|
|
294
318
|
f"Tool choice: {tool_choice}\n"
|
|
295
319
|
f"Response format: {response_format}\n"
|
|
@@ -329,7 +353,7 @@ class LitellmModel(Model):
|
|
|
329
353
|
stream_options=stream_options,
|
|
330
354
|
reasoning_effort=reasoning_effort,
|
|
331
355
|
top_logprobs=model_settings.top_logprobs,
|
|
332
|
-
extra_headers=
|
|
356
|
+
extra_headers=self._merge_headers(model_settings),
|
|
333
357
|
api_key=self.api_key,
|
|
334
358
|
base_url=self.base_url,
|
|
335
359
|
**extra_kwargs,
|
|
@@ -360,6 +384,13 @@ class LitellmModel(Model):
|
|
|
360
384
|
return None
|
|
361
385
|
return value
|
|
362
386
|
|
|
387
|
+
def _merge_headers(self, model_settings: ModelSettings):
|
|
388
|
+
merged = {**HEADERS, **(model_settings.extra_headers or {})}
|
|
389
|
+
ua_ctx = USER_AGENT_OVERRIDE.get()
|
|
390
|
+
if ua_ctx is not None:
|
|
391
|
+
merged["User-Agent"] = ua_ctx
|
|
392
|
+
return merged
|
|
393
|
+
|
|
363
394
|
|
|
364
395
|
class LitellmConverter:
|
|
365
396
|
@classmethod
|
|
@@ -369,9 +400,9 @@ class LitellmConverter:
|
|
|
369
400
|
if message.role != "assistant":
|
|
370
401
|
raise ModelBehaviorError(f"Unsupported role: {message.role}")
|
|
371
402
|
|
|
372
|
-
tool_calls:
|
|
373
|
-
ChatCompletionMessageFunctionToolCall | ChatCompletionMessageCustomToolCall
|
|
374
|
-
|
|
403
|
+
tool_calls: (
|
|
404
|
+
list[ChatCompletionMessageFunctionToolCall | ChatCompletionMessageCustomToolCall] | None
|
|
405
|
+
) = (
|
|
375
406
|
[LitellmConverter.convert_tool_call_to_openai(tool) for tool in message.tool_calls]
|
|
376
407
|
if message.tool_calls
|
|
377
408
|
else None
|
|
@@ -386,6 +417,26 @@ class LitellmConverter:
|
|
|
386
417
|
if hasattr(message, "reasoning_content") and message.reasoning_content:
|
|
387
418
|
reasoning_content = message.reasoning_content
|
|
388
419
|
|
|
420
|
+
# Extract full thinking blocks including signatures (for Anthropic)
|
|
421
|
+
thinking_blocks: list[dict[str, Any]] | None = None
|
|
422
|
+
if hasattr(message, "thinking_blocks") and message.thinking_blocks:
|
|
423
|
+
# Convert thinking blocks to dict format for compatibility
|
|
424
|
+
thinking_blocks = []
|
|
425
|
+
for block in message.thinking_blocks:
|
|
426
|
+
if isinstance(block, dict):
|
|
427
|
+
thinking_blocks.append(cast(dict[str, Any], block))
|
|
428
|
+
else:
|
|
429
|
+
# Convert object to dict by accessing its attributes
|
|
430
|
+
block_dict: dict[str, Any] = {}
|
|
431
|
+
if hasattr(block, "__dict__"):
|
|
432
|
+
block_dict = dict(block.__dict__.items())
|
|
433
|
+
elif hasattr(block, "model_dump"):
|
|
434
|
+
block_dict = block.model_dump()
|
|
435
|
+
else:
|
|
436
|
+
# Last resort: convert to string representation
|
|
437
|
+
block_dict = {"thinking": str(block)}
|
|
438
|
+
thinking_blocks.append(block_dict)
|
|
439
|
+
|
|
389
440
|
return InternalChatCompletionMessage(
|
|
390
441
|
content=message.content,
|
|
391
442
|
refusal=refusal,
|
|
@@ -394,6 +445,7 @@ class LitellmConverter:
|
|
|
394
445
|
audio=message.get("audio", None), # litellm deletes audio if not present
|
|
395
446
|
tool_calls=tool_calls,
|
|
396
447
|
reasoning_content=reasoning_content,
|
|
448
|
+
thinking_blocks=thinking_blocks,
|
|
397
449
|
)
|
|
398
450
|
|
|
399
451
|
@classmethod
|
agents/function_schema.py
CHANGED
|
@@ -5,7 +5,7 @@ import inspect
|
|
|
5
5
|
import logging
|
|
6
6
|
import re
|
|
7
7
|
from dataclasses import dataclass
|
|
8
|
-
from typing import Any, Callable, Literal, get_args, get_origin, get_type_hints
|
|
8
|
+
from typing import Annotated, Any, Callable, Literal, get_args, get_origin, get_type_hints
|
|
9
9
|
|
|
10
10
|
from griffe import Docstring, DocstringSectionKind
|
|
11
11
|
from pydantic import BaseModel, Field, create_model
|
|
@@ -185,6 +185,31 @@ def generate_func_documentation(
|
|
|
185
185
|
)
|
|
186
186
|
|
|
187
187
|
|
|
188
|
+
def _strip_annotated(annotation: Any) -> tuple[Any, tuple[Any, ...]]:
|
|
189
|
+
"""Returns the underlying annotation and any metadata from typing.Annotated."""
|
|
190
|
+
|
|
191
|
+
metadata: tuple[Any, ...] = ()
|
|
192
|
+
ann = annotation
|
|
193
|
+
|
|
194
|
+
while get_origin(ann) is Annotated:
|
|
195
|
+
args = get_args(ann)
|
|
196
|
+
if not args:
|
|
197
|
+
break
|
|
198
|
+
ann = args[0]
|
|
199
|
+
metadata = (*metadata, *args[1:])
|
|
200
|
+
|
|
201
|
+
return ann, metadata
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _extract_description_from_metadata(metadata: tuple[Any, ...]) -> str | None:
|
|
205
|
+
"""Extracts a human readable description from Annotated metadata if present."""
|
|
206
|
+
|
|
207
|
+
for item in metadata:
|
|
208
|
+
if isinstance(item, str):
|
|
209
|
+
return item
|
|
210
|
+
return None
|
|
211
|
+
|
|
212
|
+
|
|
188
213
|
def function_schema(
|
|
189
214
|
func: Callable[..., Any],
|
|
190
215
|
docstring_style: DocstringStyle | None = None,
|
|
@@ -219,17 +244,34 @@ def function_schema(
|
|
|
219
244
|
# 1. Grab docstring info
|
|
220
245
|
if use_docstring_info:
|
|
221
246
|
doc_info = generate_func_documentation(func, docstring_style)
|
|
222
|
-
param_descs = doc_info.param_descriptions or {}
|
|
247
|
+
param_descs = dict(doc_info.param_descriptions or {})
|
|
223
248
|
else:
|
|
224
249
|
doc_info = None
|
|
225
250
|
param_descs = {}
|
|
226
251
|
|
|
252
|
+
type_hints_with_extras = get_type_hints(func, include_extras=True)
|
|
253
|
+
type_hints: dict[str, Any] = {}
|
|
254
|
+
annotated_param_descs: dict[str, str] = {}
|
|
255
|
+
|
|
256
|
+
for name, annotation in type_hints_with_extras.items():
|
|
257
|
+
if name == "return":
|
|
258
|
+
continue
|
|
259
|
+
|
|
260
|
+
stripped_ann, metadata = _strip_annotated(annotation)
|
|
261
|
+
type_hints[name] = stripped_ann
|
|
262
|
+
|
|
263
|
+
description = _extract_description_from_metadata(metadata)
|
|
264
|
+
if description is not None:
|
|
265
|
+
annotated_param_descs[name] = description
|
|
266
|
+
|
|
267
|
+
for name, description in annotated_param_descs.items():
|
|
268
|
+
param_descs.setdefault(name, description)
|
|
269
|
+
|
|
227
270
|
# Ensure name_override takes precedence even if docstring info is disabled.
|
|
228
271
|
func_name = name_override or (doc_info.name if doc_info else func.__name__)
|
|
229
272
|
|
|
230
273
|
# 2. Inspect function signature and get type hints
|
|
231
274
|
sig = inspect.signature(func)
|
|
232
|
-
type_hints = get_type_hints(func)
|
|
233
275
|
params = list(sig.parameters.items())
|
|
234
276
|
takes_context = False
|
|
235
277
|
filtered_params = []
|
agents/memory/__init__.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
from .openai_conversations_session import OpenAIConversationsSession
|
|
2
2
|
from .session import Session, SessionABC
|
|
3
3
|
from .sqlite_session import SQLiteSession
|
|
4
|
+
from .util import SessionInputCallback
|
|
4
5
|
|
|
5
6
|
__all__ = [
|
|
6
7
|
"Session",
|
|
7
8
|
"SessionABC",
|
|
9
|
+
"SessionInputCallback",
|
|
8
10
|
"SQLiteSession",
|
|
9
11
|
"OpenAIConversationsSession",
|
|
10
12
|
]
|