letta-nightly 0.6.49.dev20250408104230__py3-none-any.whl → 0.6.50.dev20250409043626__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of letta-nightly might be problematic. Click here for more details.
- letta/__init__.py +1 -1
- letta/agent.py +8 -1
- letta/functions/function_sets/base.py +4 -1
- letta/functions/helpers.py +16 -2
- letta/jobs/__init__.py +0 -0
- letta/jobs/helpers.py +25 -0
- letta/jobs/llm_batch_job_polling.py +204 -0
- letta/jobs/scheduler.py +28 -0
- letta/jobs/types.py +10 -0
- letta/llm_api/anthropic.py +8 -3
- letta/llm_api/anthropic_client.py +5 -4
- letta/llm_api/llm_api_tools.py +2 -0
- letta/llm_api/openai_client.py +3 -1
- letta/memory.py +20 -4
- letta/orm/message.py +21 -5
- letta/schemas/enums.py +1 -0
- letta/schemas/llm_config.py +8 -4
- letta/schemas/message.py +8 -7
- letta/server/rest_api/app.py +11 -0
- letta/server/rest_api/chat_completions_interface.py +1 -0
- letta/server/rest_api/routers/v1/agents.py +16 -3
- letta/server/server.py +5 -1
- letta/services/agent_manager.py +34 -28
- letta/services/helpers/agent_manager_helper.py +3 -1
- letta/services/llm_batch_manager.py +97 -6
- letta/services/tool_sandbox/local_sandbox.py +2 -1
- letta/settings.py +4 -0
- letta/streaming_interface.py +2 -0
- {letta_nightly-0.6.49.dev20250408104230.dist-info → letta_nightly-0.6.50.dev20250409043626.dist-info}/METADATA +5 -4
- {letta_nightly-0.6.49.dev20250408104230.dist-info → letta_nightly-0.6.50.dev20250409043626.dist-info}/RECORD +33 -28
- {letta_nightly-0.6.49.dev20250408104230.dist-info → letta_nightly-0.6.50.dev20250409043626.dist-info}/LICENSE +0 -0
- {letta_nightly-0.6.49.dev20250408104230.dist-info → letta_nightly-0.6.50.dev20250409043626.dist-info}/WHEEL +0 -0
- {letta_nightly-0.6.49.dev20250408104230.dist-info → letta_nightly-0.6.50.dev20250409043626.dist-info}/entry_points.txt +0 -0
letta/__init__.py
CHANGED
letta/agent.py
CHANGED
|
@@ -376,7 +376,6 @@ class Agent(BaseAgent):
|
|
|
376
376
|
else:
|
|
377
377
|
raise ValueError(f"Bad finish reason from API: {response.choices[0].finish_reason}")
|
|
378
378
|
log_telemetry(self.logger, "_handle_ai_response finish")
|
|
379
|
-
return response
|
|
380
379
|
|
|
381
380
|
except ValueError as ve:
|
|
382
381
|
if attempt >= empty_response_retry_limit:
|
|
@@ -393,6 +392,14 @@ class Agent(BaseAgent):
|
|
|
393
392
|
log_telemetry(self.logger, "_handle_ai_response finish generic Exception")
|
|
394
393
|
raise e
|
|
395
394
|
|
|
395
|
+
# check if we are going over the context window: this allows for articifial constraints
|
|
396
|
+
if response.usage.total_tokens > self.agent_state.llm_config.context_window:
|
|
397
|
+
# trigger summarization
|
|
398
|
+
log_telemetry(self.logger, "_get_ai_reply summarize_messages_inplace")
|
|
399
|
+
self.summarize_messages_inplace()
|
|
400
|
+
# return the response
|
|
401
|
+
return response
|
|
402
|
+
|
|
396
403
|
log_telemetry(self.logger, "_handle_ai_response finish catch-all exception")
|
|
397
404
|
raise Exception("Retries exhausted and no valid response received.")
|
|
398
405
|
|
|
@@ -225,7 +225,10 @@ def core_memory_insert(agent_state: "AgentState", target_block_label: str, new_m
|
|
|
225
225
|
current_value_list = current_value.split("\n")
|
|
226
226
|
if line_number is None:
|
|
227
227
|
line_number = len(current_value_list)
|
|
228
|
-
|
|
228
|
+
if replace:
|
|
229
|
+
current_value_list[line_number] = new_memory
|
|
230
|
+
else:
|
|
231
|
+
current_value_list.insert(line_number, new_memory)
|
|
229
232
|
new_value = "\n".join(current_value_list)
|
|
230
233
|
agent_state.memory.update_block_value(label=target_block_label, value=new_value)
|
|
231
234
|
return None
|
letta/functions/helpers.py
CHANGED
|
@@ -629,8 +629,22 @@ def _get_field_type(field_schema: Dict[str, Any], nested_models: Dict[str, Type[
|
|
|
629
629
|
if nested_models and ref_type in nested_models:
|
|
630
630
|
return nested_models[ref_type]
|
|
631
631
|
elif "additionalProperties" in field_schema:
|
|
632
|
-
|
|
633
|
-
|
|
632
|
+
# TODO: This is totally GPT generated and I'm not sure it works
|
|
633
|
+
# TODO: This is done to quickly patch some tests, we should nuke this whole pathway asap
|
|
634
|
+
ap = field_schema["additionalProperties"]
|
|
635
|
+
|
|
636
|
+
if ap is True:
|
|
637
|
+
return dict
|
|
638
|
+
elif ap is False:
|
|
639
|
+
raise ValueError("additionalProperties=false is not supported.")
|
|
640
|
+
else:
|
|
641
|
+
# Try resolving nested type
|
|
642
|
+
nested_type = _get_field_type(ap, nested_models)
|
|
643
|
+
# If nested_type is Any, fall back to `dict`, or raise, depending on how strict you want to be
|
|
644
|
+
if nested_type == Any:
|
|
645
|
+
return dict
|
|
646
|
+
return Dict[str, nested_type]
|
|
647
|
+
|
|
634
648
|
return dict
|
|
635
649
|
elif field_schema.get("$ref") is not None:
|
|
636
650
|
ref_type = field_schema["$ref"].split("/")[-1]
|
letta/jobs/__init__.py
ADDED
|
File without changes
|
letta/jobs/helpers.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from anthropic.types.beta.messages import (
|
|
2
|
+
BetaMessageBatchCanceledResult,
|
|
3
|
+
BetaMessageBatchIndividualResponse,
|
|
4
|
+
BetaMessageBatchSucceededResult,
|
|
5
|
+
)
|
|
6
|
+
|
|
7
|
+
from letta.schemas.enums import JobStatus
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def map_anthropic_batch_job_status_to_job_status(anthropic_status: str) -> JobStatus:
|
|
11
|
+
mapping = {
|
|
12
|
+
"in_progress": JobStatus.running,
|
|
13
|
+
"canceling": JobStatus.cancelled,
|
|
14
|
+
"ended": JobStatus.completed,
|
|
15
|
+
}
|
|
16
|
+
return mapping.get(anthropic_status, JobStatus.pending) # fallback just in case
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def map_anthropic_individual_batch_item_status_to_job_status(individual_item: BetaMessageBatchIndividualResponse) -> JobStatus:
|
|
20
|
+
if isinstance(individual_item.result, BetaMessageBatchSucceededResult):
|
|
21
|
+
return JobStatus.completed
|
|
22
|
+
elif isinstance(individual_item.result, BetaMessageBatchCanceledResult):
|
|
23
|
+
return JobStatus.cancelled
|
|
24
|
+
else:
|
|
25
|
+
return JobStatus.failed
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import datetime
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
5
|
+
from letta.jobs.helpers import map_anthropic_batch_job_status_to_job_status, map_anthropic_individual_batch_item_status_to_job_status
|
|
6
|
+
from letta.jobs.types import BatchId, BatchPollingResult, ItemUpdateInfo
|
|
7
|
+
from letta.log import get_logger
|
|
8
|
+
from letta.schemas.enums import JobStatus, ProviderType
|
|
9
|
+
from letta.schemas.llm_batch_job import LLMBatchJob
|
|
10
|
+
from letta.server.server import SyncServer
|
|
11
|
+
|
|
12
|
+
logger = get_logger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class BatchPollingMetrics:
|
|
16
|
+
"""Class to track metrics for batch polling operations."""
|
|
17
|
+
|
|
18
|
+
def __init__(self):
|
|
19
|
+
self.start_time = datetime.datetime.now()
|
|
20
|
+
self.total_batches = 0
|
|
21
|
+
self.anthropic_batches = 0
|
|
22
|
+
self.running_count = 0
|
|
23
|
+
self.completed_count = 0
|
|
24
|
+
self.updated_items_count = 0
|
|
25
|
+
|
|
26
|
+
def log_summary(self):
|
|
27
|
+
"""Log a summary of the metrics collected during polling."""
|
|
28
|
+
elapsed = (datetime.datetime.now() - self.start_time).total_seconds()
|
|
29
|
+
logger.info(f"[Poll BatchJob] Finished poll_running_llm_batches job in {elapsed:.2f}s")
|
|
30
|
+
logger.info(f"[Poll BatchJob] Found {self.total_batches} running batches total.")
|
|
31
|
+
logger.info(f"[Poll BatchJob] Found {self.anthropic_batches} Anthropic batch(es) to poll.")
|
|
32
|
+
logger.info(f"[Poll BatchJob] Final results: {self.completed_count} completed, {self.running_count} still running.")
|
|
33
|
+
logger.info(f"[Poll BatchJob] Updated {self.updated_items_count} items for newly completed batch(es).")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
async def fetch_batch_status(server: SyncServer, batch_job: LLMBatchJob) -> BatchPollingResult:
|
|
37
|
+
"""
|
|
38
|
+
Fetch the current status of a single batch job from the provider.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
server: The SyncServer instance
|
|
42
|
+
batch_job: The batch job to check status for
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
A tuple containing (batch_id, new_status, polling_response)
|
|
46
|
+
"""
|
|
47
|
+
batch_id_str = batch_job.create_batch_response.id
|
|
48
|
+
try:
|
|
49
|
+
response = await server.anthropic_async_client.beta.messages.batches.retrieve(batch_id_str)
|
|
50
|
+
new_status = map_anthropic_batch_job_status_to_job_status(response.processing_status)
|
|
51
|
+
logger.debug(f"[Poll BatchJob] Batch {batch_job.id}: provider={response.processing_status} → internal={new_status}")
|
|
52
|
+
return (batch_job.id, new_status, response)
|
|
53
|
+
except Exception as e:
|
|
54
|
+
logger.warning(f"[Poll BatchJob] Batch {batch_job.id}: failed to retrieve {batch_id_str}: {e}")
|
|
55
|
+
# We treat a retrieval error as still running to try again next cycle
|
|
56
|
+
return (batch_job.id, JobStatus.running, None)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
async def fetch_batch_items(server: SyncServer, batch_id: BatchId, batch_resp_id: str) -> List[ItemUpdateInfo]:
|
|
60
|
+
"""
|
|
61
|
+
Fetch individual item results for a completed batch.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
server: The SyncServer instance
|
|
65
|
+
batch_id: The internal batch ID
|
|
66
|
+
batch_resp_id: The provider's batch response ID
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
A list of item update information tuples
|
|
70
|
+
"""
|
|
71
|
+
updates = []
|
|
72
|
+
try:
|
|
73
|
+
async for item_result in server.anthropic_async_client.beta.messages.batches.results(batch_resp_id):
|
|
74
|
+
# Here, custom_id should be the agent_id
|
|
75
|
+
item_status = map_anthropic_individual_batch_item_status_to_job_status(item_result)
|
|
76
|
+
updates.append((batch_id, item_result.custom_id, item_status, item_result))
|
|
77
|
+
logger.info(f"[Poll BatchJob] Fetched {len(updates)} item updates for batch {batch_id}.")
|
|
78
|
+
except Exception as e:
|
|
79
|
+
logger.error(f"[Poll BatchJob] Error fetching item updates for batch {batch_id}: {e}")
|
|
80
|
+
|
|
81
|
+
return updates
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
async def poll_batch_updates(server: SyncServer, batch_jobs: List[LLMBatchJob], metrics: BatchPollingMetrics) -> List[BatchPollingResult]:
|
|
85
|
+
"""
|
|
86
|
+
Poll for updates to multiple batch jobs concurrently.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
server: The SyncServer instance
|
|
90
|
+
batch_jobs: List of batch jobs to poll
|
|
91
|
+
metrics: Metrics collection object
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
List of batch polling results
|
|
95
|
+
"""
|
|
96
|
+
if not batch_jobs:
|
|
97
|
+
logger.info("[Poll BatchJob] No Anthropic batches to update; job complete.")
|
|
98
|
+
return []
|
|
99
|
+
|
|
100
|
+
# Create polling tasks for all batch jobs
|
|
101
|
+
coros = [fetch_batch_status(server, b) for b in batch_jobs]
|
|
102
|
+
results: List[BatchPollingResult] = await asyncio.gather(*coros)
|
|
103
|
+
|
|
104
|
+
# Update the server with batch status changes
|
|
105
|
+
server.batch_manager.bulk_update_batch_statuses(updates=results)
|
|
106
|
+
logger.info(f"[Poll BatchJob] Bulk-updated {len(results)} LLM batch(es) in the DB at job level.")
|
|
107
|
+
|
|
108
|
+
return results
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
async def process_completed_batches(
|
|
112
|
+
server: SyncServer, batch_results: List[BatchPollingResult], metrics: BatchPollingMetrics
|
|
113
|
+
) -> List[ItemUpdateInfo]:
|
|
114
|
+
"""
|
|
115
|
+
Process batches that have completed and fetch their item results.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
server: The SyncServer instance
|
|
119
|
+
batch_results: Results from polling batch statuses
|
|
120
|
+
metrics: Metrics collection object
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
List of item updates to apply
|
|
124
|
+
"""
|
|
125
|
+
item_update_tasks = []
|
|
126
|
+
|
|
127
|
+
# Process each top-level polling result
|
|
128
|
+
for batch_id, new_status, maybe_batch_resp in batch_results:
|
|
129
|
+
if not maybe_batch_resp:
|
|
130
|
+
if new_status == JobStatus.running:
|
|
131
|
+
metrics.running_count += 1
|
|
132
|
+
logger.warning(f"[Poll BatchJob] Batch {batch_id}: JobStatus was {new_status} and no batch response was found.")
|
|
133
|
+
continue
|
|
134
|
+
|
|
135
|
+
if new_status == JobStatus.completed:
|
|
136
|
+
metrics.completed_count += 1
|
|
137
|
+
batch_resp_id = maybe_batch_resp.id # The Anthropic-assigned batch ID
|
|
138
|
+
# Queue an async call to fetch item results for this batch
|
|
139
|
+
item_update_tasks.append(fetch_batch_items(server, batch_id, batch_resp_id))
|
|
140
|
+
elif new_status == JobStatus.running:
|
|
141
|
+
metrics.running_count += 1
|
|
142
|
+
|
|
143
|
+
# Launch all item update tasks concurrently
|
|
144
|
+
concurrent_results = await asyncio.gather(*item_update_tasks, return_exceptions=True)
|
|
145
|
+
|
|
146
|
+
# Flatten and filter the results
|
|
147
|
+
item_updates = []
|
|
148
|
+
for result in concurrent_results:
|
|
149
|
+
if isinstance(result, Exception):
|
|
150
|
+
logger.error(f"[Poll BatchJob] A fetch_batch_items task failed with: {result}")
|
|
151
|
+
elif isinstance(result, list):
|
|
152
|
+
item_updates.extend(result)
|
|
153
|
+
|
|
154
|
+
logger.info(f"[Poll BatchJob] Collected a total of {len(item_updates)} item update(s) from completed batches.")
|
|
155
|
+
|
|
156
|
+
return item_updates
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
async def poll_running_llm_batches(server: "SyncServer") -> None:
|
|
160
|
+
"""
|
|
161
|
+
Cron job to poll all running LLM batch jobs and update their polling responses in bulk.
|
|
162
|
+
|
|
163
|
+
Steps:
|
|
164
|
+
1. Fetch currently running batch jobs
|
|
165
|
+
2. Filter Anthropic only
|
|
166
|
+
3. Retrieve updated top-level polling info concurrently
|
|
167
|
+
4. Bulk update LLMBatchJob statuses
|
|
168
|
+
5. For each completed batch, call .results(...) to get item-level results
|
|
169
|
+
6. Bulk update all matching LLMBatchItem records by (batch_id, agent_id)
|
|
170
|
+
7. Log telemetry about success/fail
|
|
171
|
+
"""
|
|
172
|
+
# Initialize metrics tracking
|
|
173
|
+
metrics = BatchPollingMetrics()
|
|
174
|
+
|
|
175
|
+
logger.info("[Poll BatchJob] Starting poll_running_llm_batches job")
|
|
176
|
+
|
|
177
|
+
try:
|
|
178
|
+
# 1. Retrieve running batch jobs
|
|
179
|
+
batches = server.batch_manager.list_running_batches()
|
|
180
|
+
metrics.total_batches = len(batches)
|
|
181
|
+
|
|
182
|
+
# TODO: Expand to more providers
|
|
183
|
+
# 2. Filter for Anthropic jobs only
|
|
184
|
+
anthropic_batch_jobs = [b for b in batches if b.llm_provider == ProviderType.anthropic]
|
|
185
|
+
metrics.anthropic_batches = len(anthropic_batch_jobs)
|
|
186
|
+
|
|
187
|
+
# 3-4. Poll for batch updates and bulk update statuses
|
|
188
|
+
batch_results = await poll_batch_updates(server, anthropic_batch_jobs, metrics)
|
|
189
|
+
|
|
190
|
+
# 5. Process completed batches and fetch item results
|
|
191
|
+
item_updates = await process_completed_batches(server, batch_results, metrics)
|
|
192
|
+
|
|
193
|
+
# 6. Bulk update all items for newly completed batch(es)
|
|
194
|
+
if item_updates:
|
|
195
|
+
metrics.updated_items_count = len(item_updates)
|
|
196
|
+
server.batch_manager.bulk_update_batch_items_by_agent(item_updates)
|
|
197
|
+
else:
|
|
198
|
+
logger.info("[Poll BatchJob] No item-level updates needed.")
|
|
199
|
+
|
|
200
|
+
except Exception as e:
|
|
201
|
+
logger.exception("[Poll BatchJob] Unhandled error in poll_running_llm_batches", exc_info=e)
|
|
202
|
+
finally:
|
|
203
|
+
# 7. Log metrics summary
|
|
204
|
+
metrics.log_summary()
|
letta/jobs/scheduler.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
|
|
3
|
+
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
|
4
|
+
from apscheduler.triggers.interval import IntervalTrigger
|
|
5
|
+
|
|
6
|
+
from letta.jobs.llm_batch_job_polling import poll_running_llm_batches
|
|
7
|
+
from letta.server.server import SyncServer
|
|
8
|
+
from letta.settings import settings
|
|
9
|
+
|
|
10
|
+
scheduler = AsyncIOScheduler()
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def start_cron_jobs(server: SyncServer):
|
|
14
|
+
"""Initialize cron jobs"""
|
|
15
|
+
scheduler.add_job(
|
|
16
|
+
poll_running_llm_batches,
|
|
17
|
+
args=[server],
|
|
18
|
+
trigger=IntervalTrigger(seconds=settings.poll_running_llm_batches_interval_seconds),
|
|
19
|
+
next_run_time=datetime.datetime.now(datetime.UTC),
|
|
20
|
+
id="poll_llm_batches",
|
|
21
|
+
name="Poll LLM API batch jobs and update status",
|
|
22
|
+
replace_existing=True,
|
|
23
|
+
)
|
|
24
|
+
scheduler.start()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def shutdown_cron_scheduler():
|
|
28
|
+
scheduler.shutdown()
|
letta/jobs/types.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from typing import Optional, Tuple
|
|
2
|
+
|
|
3
|
+
from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchIndividualResponse
|
|
4
|
+
|
|
5
|
+
from letta.schemas.enums import JobStatus
|
|
6
|
+
|
|
7
|
+
BatchId = str
|
|
8
|
+
AgentId = str
|
|
9
|
+
BatchPollingResult = Tuple[BatchId, JobStatus, Optional[BetaMessageBatch]]
|
|
10
|
+
ItemUpdateInfo = Tuple[BatchId, AgentId, JobStatus, BetaMessageBatchIndividualResponse]
|
letta/llm_api/anthropic.py
CHANGED
|
@@ -25,6 +25,7 @@ from letta.llm_api.aws_bedrock import get_bedrock_client
|
|
|
25
25
|
from letta.llm_api.helpers import add_inner_thoughts_to_functions
|
|
26
26
|
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION
|
|
27
27
|
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
|
28
|
+
from letta.log import get_logger
|
|
28
29
|
from letta.schemas.message import Message as _Message
|
|
29
30
|
from letta.schemas.message import MessageRole as _MessageRole
|
|
30
31
|
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, Tool
|
|
@@ -44,6 +45,8 @@ from letta.settings import model_settings
|
|
|
44
45
|
from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface
|
|
45
46
|
from letta.tracing import log_event
|
|
46
47
|
|
|
48
|
+
logger = get_logger(__name__)
|
|
49
|
+
|
|
47
50
|
BASE_URL = "https://api.anthropic.com/v1"
|
|
48
51
|
|
|
49
52
|
|
|
@@ -620,9 +623,9 @@ def _prepare_anthropic_request(
|
|
|
620
623
|
data: ChatCompletionRequest,
|
|
621
624
|
inner_thoughts_xml_tag: Optional[str] = "thinking",
|
|
622
625
|
# if true, prefix fill the generation with the thinking tag
|
|
623
|
-
prefix_fill: bool =
|
|
626
|
+
prefix_fill: bool = False,
|
|
624
627
|
# if true, put COT inside the tool calls instead of inside the content
|
|
625
|
-
put_inner_thoughts_in_kwargs: bool =
|
|
628
|
+
put_inner_thoughts_in_kwargs: bool = True,
|
|
626
629
|
bedrock: bool = False,
|
|
627
630
|
# extended thinking related fields
|
|
628
631
|
# https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking
|
|
@@ -634,7 +637,9 @@ def _prepare_anthropic_request(
|
|
|
634
637
|
assert (
|
|
635
638
|
max_reasoning_tokens is not None and max_reasoning_tokens < data.max_tokens
|
|
636
639
|
), "max tokens must be greater than thinking budget"
|
|
637
|
-
|
|
640
|
+
if put_inner_thoughts_in_kwargs:
|
|
641
|
+
logger.warning("Extended thinking not compatible with put_inner_thoughts_in_kwargs")
|
|
642
|
+
put_inner_thoughts_in_kwargs = False
|
|
638
643
|
# assert not prefix_fill, "extended thinking not compatible with prefix_fill"
|
|
639
644
|
# Silently disable prefix_fill for now
|
|
640
645
|
prefix_fill = False
|
|
@@ -90,7 +90,7 @@ class AnthropicClient(LLMClientBase):
|
|
|
90
90
|
def build_request_data(
|
|
91
91
|
self,
|
|
92
92
|
messages: List[PydanticMessage],
|
|
93
|
-
tools: List[dict],
|
|
93
|
+
tools: Optional[List[dict]] = None,
|
|
94
94
|
force_tool_call: Optional[str] = None,
|
|
95
95
|
) -> dict:
|
|
96
96
|
# TODO: This needs to get cleaned up. The logic here is pretty confusing.
|
|
@@ -146,11 +146,12 @@ class AnthropicClient(LLMClientBase):
|
|
|
146
146
|
tools_for_request = [Tool(function=f) for f in tools] if tools is not None else None
|
|
147
147
|
|
|
148
148
|
# Add tool choice
|
|
149
|
-
|
|
149
|
+
if tool_choice:
|
|
150
|
+
data["tool_choice"] = tool_choice
|
|
150
151
|
|
|
151
152
|
# Add inner thoughts kwarg
|
|
152
153
|
# TODO: Can probably make this more efficient
|
|
153
|
-
if len(tools_for_request) > 0 and self.llm_config.put_inner_thoughts_in_kwargs:
|
|
154
|
+
if tools_for_request and len(tools_for_request) > 0 and self.llm_config.put_inner_thoughts_in_kwargs:
|
|
154
155
|
tools_with_inner_thoughts = add_inner_thoughts_to_functions(
|
|
155
156
|
functions=[t.function.model_dump() for t in tools_for_request],
|
|
156
157
|
inner_thoughts_key=INNER_THOUGHTS_KWARG,
|
|
@@ -158,7 +159,7 @@ class AnthropicClient(LLMClientBase):
|
|
|
158
159
|
)
|
|
159
160
|
tools_for_request = [Tool(function=f) for f in tools_with_inner_thoughts]
|
|
160
161
|
|
|
161
|
-
if len(tools_for_request) > 0:
|
|
162
|
+
if tools_for_request and len(tools_for_request) > 0:
|
|
162
163
|
# TODO eventually enable parallel tool use
|
|
163
164
|
data["tools"] = convert_tools_to_anthropic_format(tools_for_request)
|
|
164
165
|
|
letta/llm_api/llm_api_tools.py
CHANGED
|
@@ -322,6 +322,7 @@ def create(
|
|
|
322
322
|
|
|
323
323
|
# Force tool calling
|
|
324
324
|
tool_call = None
|
|
325
|
+
llm_config.put_inner_thoughts_in_kwargs = True
|
|
325
326
|
if functions is None:
|
|
326
327
|
# Special case for summarization path
|
|
327
328
|
tools = None
|
|
@@ -356,6 +357,7 @@ def create(
|
|
|
356
357
|
if stream: # Client requested token streaming
|
|
357
358
|
assert isinstance(stream_interface, (AgentChunkStreamingInterface, AgentRefreshStreamingInterface)), type(stream_interface)
|
|
358
359
|
|
|
360
|
+
stream_interface.inner_thoughts_in_kwargs = True
|
|
359
361
|
response = anthropic_chat_completions_process_stream(
|
|
360
362
|
chat_completion_request=chat_completion_request,
|
|
361
363
|
put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs,
|
letta/llm_api/openai_client.py
CHANGED
|
@@ -78,9 +78,11 @@ class OpenAIClient(LLMClientBase):
|
|
|
78
78
|
# force function calling for reliability, see https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
|
|
79
79
|
# TODO(matt) move into LLMConfig
|
|
80
80
|
# TODO: This vllm checking is very brittle and is a patch at most
|
|
81
|
+
tool_choice = None
|
|
81
82
|
if self.llm_config.model_endpoint == "https://inference.memgpt.ai" or (self.llm_config.handle and "vllm" in self.llm_config.handle):
|
|
82
83
|
tool_choice = "auto" # TODO change to "required" once proxy supports it
|
|
83
|
-
|
|
84
|
+
elif tools:
|
|
85
|
+
# only set if tools is non-Null
|
|
84
86
|
tool_choice = "required"
|
|
85
87
|
|
|
86
88
|
if force_tool_call is not None:
|
letta/memory.py
CHANGED
|
@@ -2,6 +2,7 @@ from typing import Callable, Dict, List
|
|
|
2
2
|
|
|
3
3
|
from letta.constants import MESSAGE_SUMMARY_REQUEST_ACK
|
|
4
4
|
from letta.llm_api.llm_api_tools import create
|
|
5
|
+
from letta.llm_api.llm_client import LLMClient
|
|
5
6
|
from letta.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
|
|
6
7
|
from letta.schemas.agent import AgentState
|
|
7
8
|
from letta.schemas.enums import MessageRole
|
|
@@ -9,6 +10,7 @@ from letta.schemas.letta_message_content import TextContent
|
|
|
9
10
|
from letta.schemas.memory import Memory
|
|
10
11
|
from letta.schemas.message import Message
|
|
11
12
|
from letta.settings import summarizer_settings
|
|
13
|
+
from letta.tracing import trace_method
|
|
12
14
|
from letta.utils import count_tokens, printd
|
|
13
15
|
|
|
14
16
|
|
|
@@ -45,6 +47,7 @@ def _format_summary_history(message_history: List[Message]):
|
|
|
45
47
|
return "\n".join([f"{m.role}: {get_message_text(m.content)}" for m in message_history])
|
|
46
48
|
|
|
47
49
|
|
|
50
|
+
@trace_method
|
|
48
51
|
def summarize_messages(
|
|
49
52
|
agent_state: AgentState,
|
|
50
53
|
message_sequence_to_summarize: List[Message],
|
|
@@ -74,12 +77,25 @@ def summarize_messages(
|
|
|
74
77
|
# TODO: We need to eventually have a separate LLM config for the summarizer LLM
|
|
75
78
|
llm_config_no_inner_thoughts = agent_state.llm_config.model_copy(deep=True)
|
|
76
79
|
llm_config_no_inner_thoughts.put_inner_thoughts_in_kwargs = False
|
|
77
|
-
|
|
80
|
+
|
|
81
|
+
llm_client = LLMClient.create(
|
|
78
82
|
llm_config=llm_config_no_inner_thoughts,
|
|
79
|
-
|
|
80
|
-
messages=message_sequence,
|
|
81
|
-
stream=False,
|
|
83
|
+
put_inner_thoughts_first=False,
|
|
82
84
|
)
|
|
85
|
+
# try to use new client, otherwise fallback to old flow
|
|
86
|
+
# TODO: we can just directly call the LLM here?
|
|
87
|
+
if llm_client:
|
|
88
|
+
response = llm_client.send_llm_request(
|
|
89
|
+
messages=message_sequence,
|
|
90
|
+
stream=False,
|
|
91
|
+
)
|
|
92
|
+
else:
|
|
93
|
+
response = create(
|
|
94
|
+
llm_config=llm_config_no_inner_thoughts,
|
|
95
|
+
user_id=agent_state.created_by_id,
|
|
96
|
+
messages=message_sequence,
|
|
97
|
+
stream=False,
|
|
98
|
+
)
|
|
83
99
|
|
|
84
100
|
printd(f"summarize_messages gpt reply: {response.choices[0]}")
|
|
85
101
|
reply = response.choices[0].message.content
|
letta/orm/message.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
from typing import List, Optional
|
|
2
2
|
|
|
3
3
|
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
|
4
|
-
from sqlalchemy import BigInteger, ForeignKey, Index, Sequence
|
|
5
|
-
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
|
4
|
+
from sqlalchemy import BigInteger, ForeignKey, Index, Sequence, event, text
|
|
5
|
+
from sqlalchemy.orm import Mapped, Session, mapped_column, relationship
|
|
6
6
|
|
|
7
7
|
from letta.orm.custom_columns import MessageContentColumn, ToolCallColumn, ToolReturnColumn
|
|
8
8
|
from letta.orm.mixins import AgentMixin, OrganizationMixin
|
|
@@ -11,6 +11,7 @@ from letta.schemas.letta_message_content import MessageContent
|
|
|
11
11
|
from letta.schemas.letta_message_content import TextContent as PydanticTextContent
|
|
12
12
|
from letta.schemas.message import Message as PydanticMessage
|
|
13
13
|
from letta.schemas.message import ToolReturn
|
|
14
|
+
from letta.settings import settings
|
|
14
15
|
|
|
15
16
|
|
|
16
17
|
class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
|
@@ -42,9 +43,7 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
|
|
42
43
|
group_id: Mapped[Optional[str]] = mapped_column(nullable=True, doc="The multi-agent group that the message was sent in")
|
|
43
44
|
|
|
44
45
|
# Monotonically increasing sequence for efficient/correct listing
|
|
45
|
-
sequence_id
|
|
46
|
-
BigInteger, Sequence("message_seq_id"), unique=True, nullable=False, doc="Global monotonically increasing ID"
|
|
47
|
-
)
|
|
46
|
+
sequence_id = mapped_column(BigInteger, Sequence("message_seq_id"), unique=True, nullable=False)
|
|
48
47
|
|
|
49
48
|
# Relationships
|
|
50
49
|
agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin")
|
|
@@ -67,3 +66,20 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
|
|
67
66
|
if self.text and not model.content:
|
|
68
67
|
model.content = [PydanticTextContent(text=self.text)]
|
|
69
68
|
return model
|
|
69
|
+
|
|
70
|
+
# listener
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@event.listens_for(Message, "before_insert")
|
|
74
|
+
def set_sequence_id_for_sqlite(mapper, connection, target):
|
|
75
|
+
# TODO: Kind of hacky, used to detect if we are using sqlite or not
|
|
76
|
+
if not settings.pg_uri:
|
|
77
|
+
session = Session.object_session(target)
|
|
78
|
+
|
|
79
|
+
if not hasattr(session, "_sequence_id_counter"):
|
|
80
|
+
# Initialize counter for this flush
|
|
81
|
+
max_seq = connection.scalar(text("SELECT MAX(sequence_id) FROM messages"))
|
|
82
|
+
session._sequence_id_counter = max_seq or 0
|
|
83
|
+
|
|
84
|
+
session._sequence_id_counter += 1
|
|
85
|
+
target.sequence_id = session._sequence_id_counter
|
letta/schemas/enums.py
CHANGED
letta/schemas/llm_config.py
CHANGED
|
@@ -2,6 +2,10 @@ from typing import Literal, Optional
|
|
|
2
2
|
|
|
3
3
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|
4
4
|
|
|
5
|
+
from letta.log import get_logger
|
|
6
|
+
|
|
7
|
+
logger = get_logger(__name__)
|
|
8
|
+
|
|
5
9
|
|
|
6
10
|
class LLMConfig(BaseModel):
|
|
7
11
|
"""
|
|
@@ -88,14 +92,14 @@ class LLMConfig(BaseModel):
|
|
|
88
92
|
return values
|
|
89
93
|
|
|
90
94
|
@model_validator(mode="after")
|
|
91
|
-
def
|
|
95
|
+
def issue_warning_for_reasoning_constraints(self) -> "LLMConfig":
|
|
92
96
|
if self.enable_reasoner:
|
|
93
97
|
if self.max_reasoning_tokens is None:
|
|
94
|
-
|
|
98
|
+
logger.warning("max_reasoning_tokens must be set when enable_reasoner is True")
|
|
95
99
|
if self.max_tokens is not None and self.max_reasoning_tokens >= self.max_tokens:
|
|
96
|
-
|
|
100
|
+
logger.warning("max_tokens must be greater than max_reasoning_tokens (thinking budget)")
|
|
97
101
|
if self.put_inner_thoughts_in_kwargs:
|
|
98
|
-
|
|
102
|
+
logger.warning("Extended thinking is not compatible with put_inner_thoughts_in_kwargs")
|
|
99
103
|
return self
|
|
100
104
|
|
|
101
105
|
@classmethod
|
letta/schemas/message.py
CHANGED
|
@@ -37,6 +37,7 @@ from letta.schemas.letta_message_content import (
|
|
|
37
37
|
get_letta_message_content_union_str_json_schema,
|
|
38
38
|
)
|
|
39
39
|
from letta.system import unpack_message
|
|
40
|
+
from letta.utils import parse_json
|
|
40
41
|
|
|
41
42
|
|
|
42
43
|
def add_inner_thoughts_to_tool_call(
|
|
@@ -47,7 +48,7 @@ def add_inner_thoughts_to_tool_call(
|
|
|
47
48
|
"""Add inner thoughts (arg + value) to a tool call"""
|
|
48
49
|
try:
|
|
49
50
|
# load the args list
|
|
50
|
-
func_args =
|
|
51
|
+
func_args = parse_json(tool_call.function.arguments)
|
|
51
52
|
# create new ordered dict with inner thoughts first
|
|
52
53
|
ordered_args = OrderedDict({inner_thoughts_key: inner_thoughts})
|
|
53
54
|
# update with remaining args
|
|
@@ -293,7 +294,7 @@ class Message(BaseMessage):
|
|
|
293
294
|
if use_assistant_message and tool_call.function.name == assistant_message_tool_name:
|
|
294
295
|
# We need to unpack the actual message contents from the function call
|
|
295
296
|
try:
|
|
296
|
-
func_args =
|
|
297
|
+
func_args = parse_json(tool_call.function.arguments)
|
|
297
298
|
message_string = func_args[assistant_message_tool_kwarg]
|
|
298
299
|
except KeyError:
|
|
299
300
|
raise ValueError(f"Function call {tool_call.function.name} missing {assistant_message_tool_kwarg} argument")
|
|
@@ -336,7 +337,7 @@ class Message(BaseMessage):
|
|
|
336
337
|
raise ValueError(f"Invalid tool return (no text object on message): {self.content}")
|
|
337
338
|
|
|
338
339
|
try:
|
|
339
|
-
function_return =
|
|
340
|
+
function_return = parse_json(text_content)
|
|
340
341
|
status = function_return["status"]
|
|
341
342
|
if status == "OK":
|
|
342
343
|
status_enum = "success"
|
|
@@ -760,7 +761,7 @@ class Message(BaseMessage):
|
|
|
760
761
|
inner_thoughts_key=INNER_THOUGHTS_KWARG,
|
|
761
762
|
).model_dump()
|
|
762
763
|
else:
|
|
763
|
-
tool_call_input =
|
|
764
|
+
tool_call_input = parse_json(tool_call.function.arguments)
|
|
764
765
|
|
|
765
766
|
content.append(
|
|
766
767
|
{
|
|
@@ -846,7 +847,7 @@ class Message(BaseMessage):
|
|
|
846
847
|
function_args = tool_call.function.arguments
|
|
847
848
|
try:
|
|
848
849
|
# NOTE: Google AI wants actual JSON objects, not strings
|
|
849
|
-
function_args =
|
|
850
|
+
function_args = parse_json(function_args)
|
|
850
851
|
except:
|
|
851
852
|
raise UserWarning(f"Failed to parse JSON function args: {function_args}")
|
|
852
853
|
function_args = {"args": function_args}
|
|
@@ -881,7 +882,7 @@ class Message(BaseMessage):
|
|
|
881
882
|
|
|
882
883
|
# NOTE: Google AI API wants the function response as JSON only, no string
|
|
883
884
|
try:
|
|
884
|
-
function_response =
|
|
885
|
+
function_response = parse_json(text_content)
|
|
885
886
|
except:
|
|
886
887
|
function_response = {"function_response": text_content}
|
|
887
888
|
|
|
@@ -970,7 +971,7 @@ class Message(BaseMessage):
|
|
|
970
971
|
]
|
|
971
972
|
for tc in self.tool_calls:
|
|
972
973
|
function_name = tc.function["name"]
|
|
973
|
-
function_args =
|
|
974
|
+
function_args = parse_json(tc.function["arguments"])
|
|
974
975
|
function_args_str = ",".join([f"{k}={v}" for k, v in function_args.items()])
|
|
975
976
|
function_call_text = f"{function_name}({function_args_str})"
|
|
976
977
|
cohere_message.append(
|