letta-nightly 0.4.1.dev20241004104123__py3-none-any.whl → 0.4.1.dev20241005104008__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of letta-nightly might be problematic. Click here for more details.
- letta/cli/cli.py +30 -365
- letta/cli/cli_config.py +70 -27
- letta/client/client.py +103 -11
- letta/config.py +80 -80
- letta/constants.py +6 -0
- letta/credentials.py +10 -1
- letta/errors.py +63 -5
- letta/llm_api/llm_api_tools.py +110 -52
- letta/local_llm/chat_completion_proxy.py +0 -3
- letta/main.py +1 -2
- letta/metadata.py +12 -0
- letta/providers.py +232 -0
- letta/schemas/block.py +1 -1
- letta/schemas/letta_request.py +17 -0
- letta/schemas/letta_response.py +11 -0
- letta/schemas/llm_config.py +18 -2
- letta/schemas/message.py +40 -13
- letta/server/rest_api/app.py +5 -0
- letta/server/rest_api/interface.py +115 -24
- letta/server/rest_api/routers/v1/agents.py +36 -3
- letta/server/rest_api/routers/v1/llms.py +6 -2
- letta/server/server.py +60 -87
- letta/server/static_files/assets/index-3ab03d5b.css +1 -0
- letta/server/static_files/assets/{index-4d08d8a3.js → index-9a9c449b.js} +69 -69
- letta/server/static_files/index.html +2 -2
- letta/settings.py +144 -114
- letta/utils.py +6 -1
- {letta_nightly-0.4.1.dev20241004104123.dist-info → letta_nightly-0.4.1.dev20241005104008.dist-info}/METADATA +1 -1
- {letta_nightly-0.4.1.dev20241004104123.dist-info → letta_nightly-0.4.1.dev20241005104008.dist-info}/RECORD +32 -32
- letta/local_llm/groq/api.py +0 -97
- letta/server/static_files/assets/index-156816da.css +0 -1
- {letta_nightly-0.4.1.dev20241004104123.dist-info → letta_nightly-0.4.1.dev20241005104008.dist-info}/LICENSE +0 -0
- {letta_nightly-0.4.1.dev20241004104123.dist-info → letta_nightly-0.4.1.dev20241005104008.dist-info}/WHEEL +0 -0
- {letta_nightly-0.4.1.dev20241004104123.dist-info → letta_nightly-0.4.1.dev20241005104008.dist-info}/entry_points.txt +0 -0
letta/schemas/message.py
CHANGED
|
@@ -6,11 +6,16 @@ from typing import List, Optional
|
|
|
6
6
|
|
|
7
7
|
from pydantic import Field, field_validator
|
|
8
8
|
|
|
9
|
-
from letta.constants import
|
|
9
|
+
from letta.constants import (
|
|
10
|
+
DEFAULT_MESSAGE_TOOL,
|
|
11
|
+
DEFAULT_MESSAGE_TOOL_KWARG,
|
|
12
|
+
TOOL_CALL_ID_MAX_LEN,
|
|
13
|
+
)
|
|
10
14
|
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
|
11
15
|
from letta.schemas.enums import MessageRole
|
|
12
16
|
from letta.schemas.letta_base import LettaBase
|
|
13
17
|
from letta.schemas.letta_message import (
|
|
18
|
+
AssistantMessage,
|
|
14
19
|
FunctionCall,
|
|
15
20
|
FunctionCallMessage,
|
|
16
21
|
FunctionReturn,
|
|
@@ -122,7 +127,12 @@ class Message(BaseMessage):
|
|
|
122
127
|
json_message["created_at"] = self.created_at.isoformat()
|
|
123
128
|
return json_message
|
|
124
129
|
|
|
125
|
-
def to_letta_message(
|
|
130
|
+
def to_letta_message(
|
|
131
|
+
self,
|
|
132
|
+
assistant_message: bool = False,
|
|
133
|
+
assistant_message_function_name: str = DEFAULT_MESSAGE_TOOL,
|
|
134
|
+
assistant_message_function_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
|
|
135
|
+
) -> List[LettaMessage]:
|
|
126
136
|
"""Convert message object (in DB format) to the style used by the original Letta API"""
|
|
127
137
|
|
|
128
138
|
messages = []
|
|
@@ -140,16 +150,33 @@ class Message(BaseMessage):
|
|
|
140
150
|
if self.tool_calls is not None:
|
|
141
151
|
# This is type FunctionCall
|
|
142
152
|
for tool_call in self.tool_calls:
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
153
|
+
# If we're supporting using assistant message,
|
|
154
|
+
# then we want to treat certain function calls as a special case
|
|
155
|
+
if assistant_message and tool_call.function.name == assistant_message_function_name:
|
|
156
|
+
# We need to unpack the actual message contents from the function call
|
|
157
|
+
try:
|
|
158
|
+
func_args = json.loads(tool_call.function.arguments)
|
|
159
|
+
message_string = func_args[DEFAULT_MESSAGE_TOOL_KWARG]
|
|
160
|
+
except KeyError:
|
|
161
|
+
raise ValueError(f"Function call {tool_call.function.name} missing {DEFAULT_MESSAGE_TOOL_KWARG} argument")
|
|
162
|
+
messages.append(
|
|
163
|
+
AssistantMessage(
|
|
164
|
+
id=self.id,
|
|
165
|
+
date=self.created_at,
|
|
166
|
+
assistant_message=message_string,
|
|
167
|
+
)
|
|
168
|
+
)
|
|
169
|
+
else:
|
|
170
|
+
messages.append(
|
|
171
|
+
FunctionCallMessage(
|
|
172
|
+
id=self.id,
|
|
173
|
+
date=self.created_at,
|
|
174
|
+
function_call=FunctionCall(
|
|
175
|
+
name=tool_call.function.name,
|
|
176
|
+
arguments=tool_call.function.arguments,
|
|
177
|
+
),
|
|
178
|
+
)
|
|
151
179
|
)
|
|
152
|
-
)
|
|
153
180
|
elif self.role == MessageRole.tool:
|
|
154
181
|
# This is type FunctionReturn
|
|
155
182
|
# Try to interpret the function return, recall that this is how we packaged:
|
|
@@ -560,8 +587,8 @@ class Message(BaseMessage):
|
|
|
560
587
|
if self.tool_calls is not None:
|
|
561
588
|
# NOTE: implied support for multiple calls
|
|
562
589
|
for tool_call in self.tool_calls:
|
|
563
|
-
function_name = tool_call.function
|
|
564
|
-
function_args = tool_call.function
|
|
590
|
+
function_name = tool_call.function.name
|
|
591
|
+
function_args = tool_call.function.arguments
|
|
565
592
|
try:
|
|
566
593
|
# NOTE: Google AI wants actual JSON objects, not strings
|
|
567
594
|
function_args = json.loads(function_args)
|
letta/server/rest_api/app.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import logging
|
|
3
|
+
import sys
|
|
3
4
|
from pathlib import Path
|
|
4
5
|
from typing import Optional
|
|
5
6
|
|
|
@@ -71,6 +72,10 @@ def create_application() -> "FastAPI":
|
|
|
71
72
|
summary="Create LLM agents with long-term memory and custom tools 📚🦙",
|
|
72
73
|
version="1.0.0", # TODO wire this up to the version in the package
|
|
73
74
|
)
|
|
75
|
+
|
|
76
|
+
if "--ade" in sys.argv:
|
|
77
|
+
settings.cors_origins.append("https://app.letta.com")
|
|
78
|
+
|
|
74
79
|
app.add_middleware(
|
|
75
80
|
CORSMiddleware,
|
|
76
81
|
allow_origins=settings.cors_origins,
|
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
import json
|
|
3
3
|
import queue
|
|
4
|
+
import warnings
|
|
4
5
|
from collections import deque
|
|
5
6
|
from datetime import datetime
|
|
6
7
|
from typing import AsyncGenerator, Literal, Optional, Union
|
|
7
8
|
|
|
9
|
+
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
|
8
10
|
from letta.interface import AgentInterface
|
|
9
11
|
from letta.schemas.enums import MessageStreamStatus
|
|
10
12
|
from letta.schemas.letta_message import (
|
|
@@ -249,7 +251,7 @@ class QueuingInterface(AgentInterface):
|
|
|
249
251
|
class FunctionArgumentsStreamHandler:
|
|
250
252
|
"""State machine that can process a stream of"""
|
|
251
253
|
|
|
252
|
-
def __init__(self, json_key=
|
|
254
|
+
def __init__(self, json_key=DEFAULT_MESSAGE_TOOL_KWARG):
|
|
253
255
|
self.json_key = json_key
|
|
254
256
|
self.reset()
|
|
255
257
|
|
|
@@ -311,7 +313,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
|
|
311
313
|
should maintain multiple generators and index them with the request ID
|
|
312
314
|
"""
|
|
313
315
|
|
|
314
|
-
def __init__(
|
|
316
|
+
def __init__(
|
|
317
|
+
self,
|
|
318
|
+
multi_step=True,
|
|
319
|
+
use_assistant_message=False,
|
|
320
|
+
assistant_message_function_name=DEFAULT_MESSAGE_TOOL,
|
|
321
|
+
assistant_message_function_kwarg=DEFAULT_MESSAGE_TOOL_KWARG,
|
|
322
|
+
):
|
|
315
323
|
# If streaming mode, ignores base interface calls like .assistant_message, etc
|
|
316
324
|
self.streaming_mode = False
|
|
317
325
|
# NOTE: flag for supporting legacy 'stream' flag where send_message is treated specially
|
|
@@ -321,7 +329,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
|
|
321
329
|
self.streaming_chat_completion_mode_function_name = None # NOTE: sadly need to track state during stream
|
|
322
330
|
# If chat completion mode, we need a special stream reader to
|
|
323
331
|
# turn function argument to send_message into a normal text stream
|
|
324
|
-
self.streaming_chat_completion_json_reader = FunctionArgumentsStreamHandler()
|
|
332
|
+
self.streaming_chat_completion_json_reader = FunctionArgumentsStreamHandler(json_key=assistant_message_function_kwarg)
|
|
325
333
|
|
|
326
334
|
self._chunks = deque()
|
|
327
335
|
self._event = asyncio.Event() # Use an event to notify when chunks are available
|
|
@@ -333,6 +341,11 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
|
|
333
341
|
self.multi_step_indicator = MessageStreamStatus.done_step
|
|
334
342
|
self.multi_step_gen_indicator = MessageStreamStatus.done_generation
|
|
335
343
|
|
|
344
|
+
# Support for AssistantMessage
|
|
345
|
+
self.use_assistant_message = use_assistant_message
|
|
346
|
+
self.assistant_message_function_name = assistant_message_function_name
|
|
347
|
+
self.assistant_message_function_kwarg = assistant_message_function_kwarg
|
|
348
|
+
|
|
336
349
|
# extra prints
|
|
337
350
|
self.debug = False
|
|
338
351
|
self.timeout = 30
|
|
@@ -441,7 +454,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
|
|
441
454
|
|
|
442
455
|
def _process_chunk_to_letta_style(
|
|
443
456
|
self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime
|
|
444
|
-
) -> Optional[Union[InternalMonologue, FunctionCallMessage]]:
|
|
457
|
+
) -> Optional[Union[InternalMonologue, FunctionCallMessage, AssistantMessage]]:
|
|
445
458
|
"""
|
|
446
459
|
Example data from non-streaming response looks like:
|
|
447
460
|
|
|
@@ -461,23 +474,83 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
|
|
461
474
|
date=message_date,
|
|
462
475
|
internal_monologue=message_delta.content,
|
|
463
476
|
)
|
|
477
|
+
|
|
478
|
+
# tool calls
|
|
464
479
|
elif message_delta.tool_calls is not None and len(message_delta.tool_calls) > 0:
|
|
465
480
|
tool_call = message_delta.tool_calls[0]
|
|
466
481
|
|
|
467
|
-
|
|
468
|
-
if tool_call.
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
482
|
+
# special case for trapping `send_message`
|
|
483
|
+
if self.use_assistant_message and tool_call.function:
|
|
484
|
+
|
|
485
|
+
# If we just received a chunk with the message in it, we either enter "send_message" mode, or we do standard FunctionCallMessage passthrough mode
|
|
486
|
+
|
|
487
|
+
# Track the function name while streaming
|
|
488
|
+
# If we were previously on a 'send_message', we need to 'toggle' into 'content' mode
|
|
473
489
|
if tool_call.function.name:
|
|
474
|
-
|
|
490
|
+
if self.streaming_chat_completion_mode_function_name is None:
|
|
491
|
+
self.streaming_chat_completion_mode_function_name = tool_call.function.name
|
|
492
|
+
else:
|
|
493
|
+
self.streaming_chat_completion_mode_function_name += tool_call.function.name
|
|
494
|
+
|
|
495
|
+
# If we get a "hit" on the special keyword we're looking for, we want to skip to the next chunk
|
|
496
|
+
# TODO I don't think this handles the function name in multi-pieces problem. Instead, we should probably reset the streaming_chat_completion_mode_function_name when we make this hit?
|
|
497
|
+
# if self.streaming_chat_completion_mode_function_name == self.assistant_message_function_name:
|
|
498
|
+
if tool_call.function.name == self.assistant_message_function_name:
|
|
499
|
+
self.streaming_chat_completion_json_reader.reset()
|
|
500
|
+
# early exit to turn into content mode
|
|
501
|
+
return None
|
|
502
|
+
|
|
503
|
+
# if we're in the middle of parsing a send_message, we'll keep processing the JSON chunks
|
|
504
|
+
if (
|
|
505
|
+
tool_call.function.arguments
|
|
506
|
+
and self.streaming_chat_completion_mode_function_name == self.assistant_message_function_name
|
|
507
|
+
):
|
|
508
|
+
# Strip out any extras tokens
|
|
509
|
+
cleaned_func_args = self.streaming_chat_completion_json_reader.process_json_chunk(tool_call.function.arguments)
|
|
510
|
+
# In the case that we just have the prefix of something, no message yet, then we should early exit to move to the next chunk
|
|
511
|
+
if cleaned_func_args is None:
|
|
512
|
+
return None
|
|
513
|
+
else:
|
|
514
|
+
processed_chunk = AssistantMessage(
|
|
515
|
+
id=message_id,
|
|
516
|
+
date=message_date,
|
|
517
|
+
assistant_message=cleaned_func_args,
|
|
518
|
+
)
|
|
519
|
+
|
|
520
|
+
# otherwise we just do a regular passthrough of a FunctionCallDelta via a FunctionCallMessage
|
|
521
|
+
else:
|
|
522
|
+
tool_call_delta = {}
|
|
523
|
+
if tool_call.id:
|
|
524
|
+
tool_call_delta["id"] = tool_call.id
|
|
525
|
+
if tool_call.function:
|
|
526
|
+
if tool_call.function.arguments:
|
|
527
|
+
tool_call_delta["arguments"] = tool_call.function.arguments
|
|
528
|
+
if tool_call.function.name:
|
|
529
|
+
tool_call_delta["name"] = tool_call.function.name
|
|
530
|
+
|
|
531
|
+
processed_chunk = FunctionCallMessage(
|
|
532
|
+
id=message_id,
|
|
533
|
+
date=message_date,
|
|
534
|
+
function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")),
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
else:
|
|
538
|
+
|
|
539
|
+
tool_call_delta = {}
|
|
540
|
+
if tool_call.id:
|
|
541
|
+
tool_call_delta["id"] = tool_call.id
|
|
542
|
+
if tool_call.function:
|
|
543
|
+
if tool_call.function.arguments:
|
|
544
|
+
tool_call_delta["arguments"] = tool_call.function.arguments
|
|
545
|
+
if tool_call.function.name:
|
|
546
|
+
tool_call_delta["name"] = tool_call.function.name
|
|
547
|
+
|
|
548
|
+
processed_chunk = FunctionCallMessage(
|
|
549
|
+
id=message_id,
|
|
550
|
+
date=message_date,
|
|
551
|
+
function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")),
|
|
552
|
+
)
|
|
475
553
|
|
|
476
|
-
processed_chunk = FunctionCallMessage(
|
|
477
|
-
id=message_id,
|
|
478
|
-
date=message_date,
|
|
479
|
-
function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")),
|
|
480
|
-
)
|
|
481
554
|
elif choice.finish_reason is not None:
|
|
482
555
|
# skip if there's a finish
|
|
483
556
|
return None
|
|
@@ -663,14 +736,32 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
|
|
663
736
|
|
|
664
737
|
else:
|
|
665
738
|
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
function_call
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
|
|
739
|
+
try:
|
|
740
|
+
func_args = json.loads(function_call.function.arguments)
|
|
741
|
+
except:
|
|
742
|
+
warnings.warn(f"Failed to parse function arguments: {function_call.function.arguments}")
|
|
743
|
+
func_args = {}
|
|
744
|
+
|
|
745
|
+
if (
|
|
746
|
+
self.use_assistant_message
|
|
747
|
+
and function_call.function.name == self.assistant_message_function_name
|
|
748
|
+
and self.assistant_message_function_kwarg in func_args
|
|
749
|
+
):
|
|
750
|
+
processed_chunk = AssistantMessage(
|
|
751
|
+
id=msg_obj.id,
|
|
752
|
+
date=msg_obj.created_at,
|
|
753
|
+
assistant_message=func_args[self.assistant_message_function_kwarg],
|
|
754
|
+
)
|
|
755
|
+
else:
|
|
756
|
+
processed_chunk = FunctionCallMessage(
|
|
757
|
+
id=msg_obj.id,
|
|
758
|
+
date=msg_obj.created_at,
|
|
759
|
+
function_call=FunctionCall(
|
|
760
|
+
name=function_call.function.name,
|
|
761
|
+
arguments=function_call.function.arguments,
|
|
762
|
+
),
|
|
763
|
+
)
|
|
764
|
+
|
|
674
765
|
# processed_chunk = {
|
|
675
766
|
# "function_call": {
|
|
676
767
|
# "name": function_call.function.name,
|
|
@@ -6,6 +6,7 @@ from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
|
|
|
6
6
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
7
7
|
from starlette.responses import StreamingResponse
|
|
8
8
|
|
|
9
|
+
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
|
9
10
|
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState
|
|
10
11
|
from letta.schemas.enums import MessageRole, MessageStreamStatus
|
|
11
12
|
from letta.schemas.letta_message import (
|
|
@@ -254,6 +255,19 @@ def get_agent_messages(
|
|
|
254
255
|
before: Optional[str] = Query(None, description="Message before which to retrieve the returned messages."),
|
|
255
256
|
limit: int = Query(10, description="Maximum number of messages to retrieve."),
|
|
256
257
|
msg_object: bool = Query(False, description="If true, returns Message objects. If false, return LettaMessage objects."),
|
|
258
|
+
# Flags to support the use of AssistantMessage message types
|
|
259
|
+
use_assistant_message: bool = Query(
|
|
260
|
+
False,
|
|
261
|
+
description="[Only applicable if msg_object is False] If true, returns AssistantMessage objects when the agent calls a designated message tool. If false, return FunctionCallMessage objects for all tool calls.",
|
|
262
|
+
),
|
|
263
|
+
assistant_message_function_name: str = Query(
|
|
264
|
+
DEFAULT_MESSAGE_TOOL,
|
|
265
|
+
description="[Only applicable if use_assistant_message is True] The name of the designated message tool.",
|
|
266
|
+
),
|
|
267
|
+
assistant_message_function_kwarg: str = Query(
|
|
268
|
+
DEFAULT_MESSAGE_TOOL_KWARG,
|
|
269
|
+
description="[Only applicable if use_assistant_message is True] The name of the message argument in the designated message tool.",
|
|
270
|
+
),
|
|
257
271
|
):
|
|
258
272
|
"""
|
|
259
273
|
Retrieve message history for an agent.
|
|
@@ -267,6 +281,9 @@ def get_agent_messages(
|
|
|
267
281
|
limit=limit,
|
|
268
282
|
reverse=True,
|
|
269
283
|
return_message_object=msg_object,
|
|
284
|
+
use_assistant_message=use_assistant_message,
|
|
285
|
+
assistant_message_function_name=assistant_message_function_name,
|
|
286
|
+
assistant_message_function_kwarg=assistant_message_function_kwarg,
|
|
270
287
|
)
|
|
271
288
|
|
|
272
289
|
|
|
@@ -310,6 +327,10 @@ async def send_message(
|
|
|
310
327
|
stream_steps=request.stream_steps,
|
|
311
328
|
stream_tokens=request.stream_tokens,
|
|
312
329
|
return_message_object=request.return_message_object,
|
|
330
|
+
# Support for AssistantMessage
|
|
331
|
+
use_assistant_message=request.use_assistant_message,
|
|
332
|
+
assistant_message_function_name=request.assistant_message_function_name,
|
|
333
|
+
assistant_message_function_kwarg=request.assistant_message_function_kwarg,
|
|
313
334
|
)
|
|
314
335
|
|
|
315
336
|
|
|
@@ -322,12 +343,17 @@ async def send_message_to_agent(
|
|
|
322
343
|
message: str,
|
|
323
344
|
stream_steps: bool,
|
|
324
345
|
stream_tokens: bool,
|
|
346
|
+
# related to whether or not we return `LettaMessage`s or `Message`s
|
|
325
347
|
return_message_object: bool, # Should be True for Python Client, False for REST API
|
|
326
|
-
chat_completion_mode:
|
|
348
|
+
chat_completion_mode: bool = False,
|
|
327
349
|
timestamp: Optional[datetime] = None,
|
|
328
|
-
#
|
|
350
|
+
# Support for AssistantMessage
|
|
351
|
+
use_assistant_message: bool = False,
|
|
352
|
+
assistant_message_function_name: str = DEFAULT_MESSAGE_TOOL,
|
|
353
|
+
assistant_message_function_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
|
|
329
354
|
) -> Union[StreamingResponse, LettaResponse]:
|
|
330
355
|
"""Split off into a separate function so that it can be imported in the /chat/completion proxy."""
|
|
356
|
+
|
|
331
357
|
# TODO: @charles is this the correct way to handle?
|
|
332
358
|
include_final_message = True
|
|
333
359
|
|
|
@@ -356,7 +382,8 @@ async def send_message_to_agent(
|
|
|
356
382
|
|
|
357
383
|
# Disable token streaming if not OpenAI
|
|
358
384
|
# TODO: cleanup this logic
|
|
359
|
-
|
|
385
|
+
llm_config = letta_agent.agent_state.llm_config
|
|
386
|
+
if llm_config.model_endpoint_type != "openai" or "inference.memgpt.ai" in llm_config.model_endpoint:
|
|
360
387
|
print("Warning: token streaming is only supported for OpenAI models. Setting to False.")
|
|
361
388
|
stream_tokens = False
|
|
362
389
|
|
|
@@ -368,6 +395,11 @@ async def send_message_to_agent(
|
|
|
368
395
|
# streaming_interface.allow_assistant_message = stream
|
|
369
396
|
# streaming_interface.function_call_legacy_mode = stream
|
|
370
397
|
|
|
398
|
+
# Allow AssistantMessage is desired by client
|
|
399
|
+
streaming_interface.use_assistant_message = use_assistant_message
|
|
400
|
+
streaming_interface.assistant_message_function_name = assistant_message_function_name
|
|
401
|
+
streaming_interface.assistant_message_function_kwarg = assistant_message_function_kwarg
|
|
402
|
+
|
|
371
403
|
# Offload the synchronous message_func to a separate thread
|
|
372
404
|
streaming_interface.stream_start()
|
|
373
405
|
task = asyncio.create_task(
|
|
@@ -408,6 +440,7 @@ async def send_message_to_agent(
|
|
|
408
440
|
message_ids = [m.id for m in filtered_stream]
|
|
409
441
|
message_ids = deduplicate(message_ids)
|
|
410
442
|
message_objs = [server.get_agent_message(agent_id=agent_id, message_id=m_id) for m_id in message_ids]
|
|
443
|
+
message_objs = [m for m in message_objs if m is not None]
|
|
411
444
|
return LettaResponse(messages=message_objs, usage=usage)
|
|
412
445
|
else:
|
|
413
446
|
return LettaResponse(messages=filtered_stream, usage=usage)
|
|
@@ -17,7 +17,9 @@ def list_llm_backends(
|
|
|
17
17
|
server: "SyncServer" = Depends(get_letta_server),
|
|
18
18
|
):
|
|
19
19
|
|
|
20
|
-
|
|
20
|
+
models = server.list_llm_models()
|
|
21
|
+
print(models)
|
|
22
|
+
return models
|
|
21
23
|
|
|
22
24
|
|
|
23
25
|
@router.get("/embedding", response_model=List[EmbeddingConfig], operation_id="list_embedding_models")
|
|
@@ -25,4 +27,6 @@ def list_embedding_backends(
|
|
|
25
27
|
server: "SyncServer" = Depends(get_letta_server),
|
|
26
28
|
):
|
|
27
29
|
|
|
28
|
-
|
|
30
|
+
models = server.list_embedding_models()
|
|
31
|
+
print(models)
|
|
32
|
+
return models
|
letta/server/server.py
CHANGED
|
@@ -15,7 +15,6 @@ import letta.server.utils as server_utils
|
|
|
15
15
|
import letta.system as system
|
|
16
16
|
from letta.agent import Agent, save_agent
|
|
17
17
|
from letta.agent_store.storage import StorageConnector, TableType
|
|
18
|
-
from letta.cli.cli_config import get_model_options
|
|
19
18
|
from letta.config import LettaConfig
|
|
20
19
|
from letta.credentials import LettaCredentials
|
|
21
20
|
from letta.data_sources.connectors import DataConnector, load_data
|
|
@@ -44,6 +43,13 @@ from letta.log import get_logger
|
|
|
44
43
|
from letta.memory import get_memory_functions
|
|
45
44
|
from letta.metadata import MetadataStore
|
|
46
45
|
from letta.prompts import gpt_system
|
|
46
|
+
from letta.providers import (
|
|
47
|
+
AnthropicProvider,
|
|
48
|
+
GoogleAIProvider,
|
|
49
|
+
OllamaProvider,
|
|
50
|
+
OpenAIProvider,
|
|
51
|
+
VLLMProvider,
|
|
52
|
+
)
|
|
47
53
|
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState
|
|
48
54
|
from letta.schemas.api_key import APIKey, APIKeyCreate
|
|
49
55
|
from letta.schemas.block import (
|
|
@@ -158,7 +164,7 @@ from letta.metadata import (
|
|
|
158
164
|
ToolModel,
|
|
159
165
|
UserModel,
|
|
160
166
|
)
|
|
161
|
-
from letta.settings import settings
|
|
167
|
+
from letta.settings import model_settings, settings
|
|
162
168
|
|
|
163
169
|
config = LettaConfig.load()
|
|
164
170
|
|
|
@@ -234,51 +240,9 @@ class SyncServer(Server):
|
|
|
234
240
|
|
|
235
241
|
# The default interface that will get assigned to agents ON LOAD
|
|
236
242
|
self.default_interface_factory = default_interface_factory
|
|
237
|
-
# self.default_interface = default_interface
|
|
238
|
-
# self.default_interface = default_interface_cls()
|
|
239
|
-
|
|
240
|
-
# Initialize the connection to the DB
|
|
241
|
-
# try:
|
|
242
|
-
# self.config = LettaConfig.load()
|
|
243
|
-
# assert self.config.default_llm_config is not None, "default_llm_config must be set in the config"
|
|
244
|
-
# assert self.config.default_embedding_config is not None, "default_embedding_config must be set in the config"
|
|
245
|
-
# except Exception as e:
|
|
246
|
-
# # TODO: very hacky - need to improve model config for docker container
|
|
247
|
-
# if os.getenv("OPENAI_API_KEY") is None:
|
|
248
|
-
# logger.error("No OPENAI_API_KEY environment variable set and no ~/.letta/config")
|
|
249
|
-
# raise e
|
|
250
|
-
|
|
251
|
-
# from letta.cli.cli import QuickstartChoice, quickstart
|
|
252
243
|
|
|
253
|
-
# quickstart(backend=QuickstartChoice.openai, debug=False, terminal=False, latest=False)
|
|
254
|
-
# self.config = LettaConfig.load()
|
|
255
|
-
# self.config.save()
|
|
256
|
-
|
|
257
|
-
# TODO figure out how to handle credentials for the server
|
|
258
244
|
self.credentials = LettaCredentials.load()
|
|
259
245
|
|
|
260
|
-
# Generate default LLM/Embedding configs for the server
|
|
261
|
-
# TODO: we may also want to do the same thing with default persona/human/etc.
|
|
262
|
-
self.server_llm_config = settings.llm_config
|
|
263
|
-
self.server_embedding_config = settings.embedding_config
|
|
264
|
-
# self.server_llm_config = LLMConfig(
|
|
265
|
-
# model=self.config.default_llm_config.model,
|
|
266
|
-
# model_endpoint_type=self.config.default_llm_config.model_endpoint_type,
|
|
267
|
-
# model_endpoint=self.config.default_llm_config.model_endpoint,
|
|
268
|
-
# model_wrapper=self.config.default_llm_config.model_wrapper,
|
|
269
|
-
# context_window=self.config.default_llm_config.context_window,
|
|
270
|
-
# )
|
|
271
|
-
# self.server_embedding_config = EmbeddingConfig(
|
|
272
|
-
# embedding_endpoint_type=self.config.default_embedding_config.embedding_endpoint_type,
|
|
273
|
-
# embedding_endpoint=self.config.default_embedding_config.embedding_endpoint,
|
|
274
|
-
# embedding_dim=self.config.default_embedding_config.embedding_dim,
|
|
275
|
-
# embedding_model=self.config.default_embedding_config.embedding_model,
|
|
276
|
-
# embedding_chunk_size=self.config.default_embedding_config.embedding_chunk_size,
|
|
277
|
-
# )
|
|
278
|
-
assert self.server_embedding_config.embedding_model is not None, vars(self.server_embedding_config)
|
|
279
|
-
|
|
280
|
-
# Override config values with settings
|
|
281
|
-
|
|
282
246
|
# Initialize the metadata store
|
|
283
247
|
config = LettaConfig.load()
|
|
284
248
|
if settings.letta_pg_uri_no_default:
|
|
@@ -286,8 +250,6 @@ class SyncServer(Server):
|
|
|
286
250
|
config.recall_storage_uri = settings.letta_pg_uri_no_default
|
|
287
251
|
config.archival_storage_type = "postgres"
|
|
288
252
|
config.archival_storage_uri = settings.letta_pg_uri_no_default
|
|
289
|
-
config.default_llm_config = self.server_llm_config
|
|
290
|
-
config.default_embedding_config = self.server_embedding_config
|
|
291
253
|
config.save()
|
|
292
254
|
self.config = config
|
|
293
255
|
self.ms = MetadataStore(self.config)
|
|
@@ -296,6 +258,19 @@ class SyncServer(Server):
|
|
|
296
258
|
# add global default tools (for admin)
|
|
297
259
|
self.add_default_tools(module_name="base")
|
|
298
260
|
|
|
261
|
+
# collect providers
|
|
262
|
+
self._enabled_providers = []
|
|
263
|
+
if model_settings.openai_api_key:
|
|
264
|
+
self._enabled_providers.append(OpenAIProvider(api_key=model_settings.openai_api_key))
|
|
265
|
+
if model_settings.anthropic_api_key:
|
|
266
|
+
self._enabled_providers.append(AnthropicProvider(api_key=model_settings.anthropic_api_key))
|
|
267
|
+
if model_settings.ollama_base_url:
|
|
268
|
+
self._enabled_providers.append(OllamaProvider(base_url=model_settings.ollama_base_url))
|
|
269
|
+
if model_settings.vllm_base_url:
|
|
270
|
+
self._enabled_providers.append(VLLMProvider(base_url=model_settings.vllm_base_url))
|
|
271
|
+
if model_settings.gemini_api_key:
|
|
272
|
+
self._enabled_providers.append(GoogleAIProvider(api_key=model_settings.gemini_api_key))
|
|
273
|
+
|
|
299
274
|
def save_agents(self):
|
|
300
275
|
"""Saves all the agents that are in the in-memory object store"""
|
|
301
276
|
for agent_d in self.active_agents:
|
|
@@ -456,7 +431,7 @@ class SyncServer(Server):
|
|
|
456
431
|
logger.debug("Calling step_yield()")
|
|
457
432
|
letta_agent.interface.step_yield()
|
|
458
433
|
|
|
459
|
-
return LettaUsageStatistics(**total_usage.
|
|
434
|
+
return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count)
|
|
460
435
|
|
|
461
436
|
def _command(self, user_id: str, agent_id: str, command: str) -> LettaUsageStatistics:
|
|
462
437
|
"""Process a CLI command"""
|
|
@@ -766,8 +741,8 @@ class SyncServer(Server):
|
|
|
766
741
|
|
|
767
742
|
try:
|
|
768
743
|
# model configuration
|
|
769
|
-
llm_config = request.llm_config
|
|
770
|
-
embedding_config = request.embedding_config
|
|
744
|
+
llm_config = request.llm_config
|
|
745
|
+
embedding_config = request.embedding_config
|
|
771
746
|
|
|
772
747
|
# get tools + make sure they exist
|
|
773
748
|
tool_objs = []
|
|
@@ -1262,6 +1237,9 @@ class SyncServer(Server):
|
|
|
1262
1237
|
order: Optional[str] = "asc",
|
|
1263
1238
|
reverse: Optional[bool] = False,
|
|
1264
1239
|
return_message_object: bool = True,
|
|
1240
|
+
use_assistant_message: bool = False,
|
|
1241
|
+
assistant_message_function_name: str = constants.DEFAULT_MESSAGE_TOOL,
|
|
1242
|
+
assistant_message_function_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG,
|
|
1265
1243
|
) -> Union[List[Message], List[LettaMessage]]:
|
|
1266
1244
|
if self.ms.get_user(user_id=user_id) is None:
|
|
1267
1245
|
raise ValueError(f"User user_id={user_id} does not exist")
|
|
@@ -1281,9 +1259,25 @@ class SyncServer(Server):
|
|
|
1281
1259
|
if not return_message_object:
|
|
1282
1260
|
# If we're GETing messages in reverse, we need to reverse the inner list (generated by to_letta_message)
|
|
1283
1261
|
if reverse:
|
|
1284
|
-
records = [
|
|
1262
|
+
records = [
|
|
1263
|
+
msg
|
|
1264
|
+
for m in records
|
|
1265
|
+
for msg in m.to_letta_message(
|
|
1266
|
+
assistant_message=use_assistant_message,
|
|
1267
|
+
assistant_message_function_name=assistant_message_function_name,
|
|
1268
|
+
assistant_message_function_kwarg=assistant_message_function_kwarg,
|
|
1269
|
+
)[::-1]
|
|
1270
|
+
]
|
|
1285
1271
|
else:
|
|
1286
|
-
records = [
|
|
1272
|
+
records = [
|
|
1273
|
+
msg
|
|
1274
|
+
for m in records
|
|
1275
|
+
for msg in m.to_letta_message(
|
|
1276
|
+
assistant_message=use_assistant_message,
|
|
1277
|
+
assistant_message_function_name=assistant_message_function_name,
|
|
1278
|
+
assistant_message_function_kwarg=assistant_message_function_kwarg,
|
|
1279
|
+
)
|
|
1280
|
+
]
|
|
1287
1281
|
|
|
1288
1282
|
return records
|
|
1289
1283
|
|
|
@@ -1320,39 +1314,15 @@ class SyncServer(Server):
|
|
|
1320
1314
|
base_config = vars(self.config)
|
|
1321
1315
|
clean_base_config = clean_keys(base_config)
|
|
1322
1316
|
|
|
1323
|
-
clean_base_config_default_llm_config_dict = vars(clean_base_config["default_llm_config"])
|
|
1324
|
-
clean_base_config_default_embedding_config_dict = vars(clean_base_config["default_embedding_config"])
|
|
1325
|
-
|
|
1326
|
-
clean_base_config["default_llm_config"] = clean_base_config_default_llm_config_dict
|
|
1327
|
-
clean_base_config["default_embedding_config"] = clean_base_config_default_embedding_config_dict
|
|
1328
1317
|
response = {"config": clean_base_config}
|
|
1329
1318
|
|
|
1330
1319
|
if include_defaults:
|
|
1331
1320
|
default_config = vars(LettaConfig())
|
|
1332
1321
|
clean_default_config = clean_keys(default_config)
|
|
1333
|
-
clean_default_config["default_llm_config"] = clean_base_config_default_llm_config_dict
|
|
1334
|
-
clean_default_config["default_embedding_config"] = clean_base_config_default_embedding_config_dict
|
|
1335
1322
|
response["defaults"] = clean_default_config
|
|
1336
1323
|
|
|
1337
1324
|
return response
|
|
1338
1325
|
|
|
1339
|
-
def get_available_models(self) -> List[LLMConfig]:
|
|
1340
|
-
"""Poll the LLM endpoint for a list of available models"""
|
|
1341
|
-
|
|
1342
|
-
credentials = LettaCredentials().load()
|
|
1343
|
-
|
|
1344
|
-
try:
|
|
1345
|
-
model_options = get_model_options(
|
|
1346
|
-
credentials=credentials,
|
|
1347
|
-
model_endpoint_type=self.config.default_llm_config.model_endpoint_type,
|
|
1348
|
-
model_endpoint=self.config.default_llm_config.model_endpoint,
|
|
1349
|
-
)
|
|
1350
|
-
return model_options
|
|
1351
|
-
|
|
1352
|
-
except Exception as e:
|
|
1353
|
-
logger.exception(f"Failed to get list of available models from LLM endpoint:\n{str(e)}")
|
|
1354
|
-
raise
|
|
1355
|
-
|
|
1356
1326
|
def update_agent_core_memory(self, user_id: str, agent_id: str, new_memory_contents: dict) -> Memory:
|
|
1357
1327
|
"""Update the agents core memory block, return the new state"""
|
|
1358
1328
|
if self.ms.get_user(user_id=user_id) is None:
|
|
@@ -1472,7 +1442,7 @@ class SyncServer(Server):
|
|
|
1472
1442
|
source = Source(
|
|
1473
1443
|
name=request.name,
|
|
1474
1444
|
user_id=user_id,
|
|
1475
|
-
embedding_config=self.
|
|
1445
|
+
embedding_config=self.list_embedding_models()[0], # TODO: require providing this
|
|
1476
1446
|
)
|
|
1477
1447
|
self.ms.create_source(source)
|
|
1478
1448
|
assert self.ms.get_source(source_name=request.name, user_id=user_id) is not None, f"Failed to create source {request.name}"
|
|
@@ -1970,20 +1940,23 @@ class SyncServer(Server):
|
|
|
1970
1940
|
|
|
1971
1941
|
return self.get_default_user()
|
|
1972
1942
|
|
|
1973
|
-
def
|
|
1943
|
+
def list_llm_models(self) -> List[LLMConfig]:
|
|
1974
1944
|
"""List available models"""
|
|
1975
1945
|
|
|
1976
|
-
|
|
1977
|
-
|
|
1978
|
-
|
|
1979
|
-
|
|
1980
|
-
# model_endpoint=settings.llm_endpoint_type
|
|
1981
|
-
# )
|
|
1982
|
-
|
|
1983
|
-
return [settings.llm_config]
|
|
1946
|
+
llm_models = []
|
|
1947
|
+
for provider in self._enabled_providers:
|
|
1948
|
+
llm_models.extend(provider.list_llm_models())
|
|
1949
|
+
return llm_models
|
|
1984
1950
|
|
|
1985
1951
|
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
|
1986
1952
|
"""List available embedding models"""
|
|
1953
|
+
embedding_models = []
|
|
1954
|
+
for provider in self._enabled_providers:
|
|
1955
|
+
embedding_models.extend(provider.list_embedding_models())
|
|
1956
|
+
return embedding_models
|
|
1957
|
+
|
|
1958
|
+
def add_llm_model(self, request: LLMConfig) -> LLMConfig:
|
|
1959
|
+
"""Add a new LLM model"""
|
|
1987
1960
|
|
|
1988
|
-
|
|
1989
|
-
|
|
1961
|
+
def add_embedding_model(self, request: EmbeddingConfig) -> EmbeddingConfig:
|
|
1962
|
+
"""Add a new embedding model"""
|