letta-nightly 0.4.1.dev20241013104006__py3-none-any.whl → 0.5.0.dev20241015014828__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of letta-nightly might be problematic. Click here for more details.
- letta/__init__.py +2 -2
- letta/agent.py +51 -65
- letta/agent_store/db.py +18 -7
- letta/agent_store/lancedb.py +2 -2
- letta/agent_store/milvus.py +1 -1
- letta/agent_store/qdrant.py +1 -1
- letta/agent_store/storage.py +12 -10
- letta/cli/cli_load.py +1 -1
- letta/client/client.py +51 -0
- letta/data_sources/connectors.py +124 -124
- letta/data_sources/connectors_helper.py +97 -0
- letta/llm_api/mistral.py +47 -0
- letta/main.py +19 -9
- letta/metadata.py +58 -0
- letta/providers.py +44 -0
- letta/schemas/file.py +31 -0
- letta/schemas/job.py +1 -1
- letta/schemas/letta_request.py +3 -3
- letta/schemas/llm_config.py +1 -0
- letta/schemas/message.py +6 -2
- letta/schemas/passage.py +3 -3
- letta/schemas/source.py +2 -2
- letta/server/rest_api/routers/v1/agents.py +10 -16
- letta/server/rest_api/routers/v1/jobs.py +17 -1
- letta/server/rest_api/routers/v1/sources.py +7 -9
- letta/server/server.py +137 -24
- letta/server/static_files/assets/{index-9a9c449b.js → index-dc228d4a.js} +4 -4
- letta/server/static_files/index.html +1 -1
- {letta_nightly-0.4.1.dev20241013104006.dist-info → letta_nightly-0.5.0.dev20241015014828.dist-info}/METADATA +1 -1
- {letta_nightly-0.4.1.dev20241013104006.dist-info → letta_nightly-0.5.0.dev20241015014828.dist-info}/RECORD +33 -31
- letta/schemas/document.py +0 -21
- {letta_nightly-0.4.1.dev20241013104006.dist-info → letta_nightly-0.5.0.dev20241015014828.dist-info}/LICENSE +0 -0
- {letta_nightly-0.4.1.dev20241013104006.dist-info → letta_nightly-0.5.0.dev20241015014828.dist-info}/WHEEL +0 -0
- {letta_nightly-0.4.1.dev20241013104006.dist-info → letta_nightly-0.5.0.dev20241015014828.dist-info}/entry_points.txt +0 -0
letta/schemas/file.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from pydantic import Field
|
|
5
|
+
|
|
6
|
+
from letta.schemas.letta_base import LettaBase
|
|
7
|
+
from letta.utils import get_utc_time
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class FileMetadataBase(LettaBase):
|
|
11
|
+
"""Base class for FileMetadata schemas"""
|
|
12
|
+
|
|
13
|
+
__id_prefix__ = "file"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class FileMetadata(FileMetadataBase):
|
|
17
|
+
"""Representation of a single FileMetadata"""
|
|
18
|
+
|
|
19
|
+
id: str = FileMetadataBase.generate_id_field()
|
|
20
|
+
user_id: str = Field(description="The unique identifier of the user associated with the document.")
|
|
21
|
+
source_id: str = Field(..., description="The unique identifier of the source associated with the document.")
|
|
22
|
+
file_name: Optional[str] = Field(None, description="The name of the file.")
|
|
23
|
+
file_path: Optional[str] = Field(None, description="The path to the file.")
|
|
24
|
+
file_type: Optional[str] = Field(None, description="The type of the file (MIME type).")
|
|
25
|
+
file_size: Optional[int] = Field(None, description="The size of the file in bytes.")
|
|
26
|
+
file_creation_date: Optional[str] = Field(None, description="The creation date of the file.")
|
|
27
|
+
file_last_modified_date: Optional[str] = Field(None, description="The last modified date of the file.")
|
|
28
|
+
created_at: datetime = Field(default_factory=get_utc_time, description="The creation date of this file metadata object.")
|
|
29
|
+
|
|
30
|
+
class Config:
|
|
31
|
+
extra = "allow"
|
letta/schemas/job.py
CHANGED
|
@@ -15,7 +15,7 @@ class JobBase(LettaBase):
|
|
|
15
15
|
|
|
16
16
|
class Job(JobBase):
|
|
17
17
|
"""
|
|
18
|
-
Representation of offline jobs, used for tracking status of data loading tasks (involving parsing and embedding
|
|
18
|
+
Representation of offline jobs, used for tracking status of data loading tasks (involving parsing and embedding files).
|
|
19
19
|
|
|
20
20
|
Parameters:
|
|
21
21
|
id (str): The unique identifier of the job.
|
letta/schemas/letta_request.py
CHANGED
|
@@ -1,13 +1,13 @@
|
|
|
1
|
-
from typing import List
|
|
1
|
+
from typing import List, Union
|
|
2
2
|
|
|
3
3
|
from pydantic import BaseModel, Field
|
|
4
4
|
|
|
5
5
|
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
|
6
|
-
from letta.schemas.message import MessageCreate
|
|
6
|
+
from letta.schemas.message import Message, MessageCreate
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class LettaRequest(BaseModel):
|
|
10
|
-
messages: List[MessageCreate] = Field(..., description="The messages to be sent to the agent.")
|
|
10
|
+
messages: Union[List[MessageCreate], List[Message]] = Field(..., description="The messages to be sent to the agent.")
|
|
11
11
|
run_async: bool = Field(default=False, description="Whether to asynchronously send the messages to the agent.") # TODO: implement
|
|
12
12
|
|
|
13
13
|
stream_steps: bool = Field(
|
letta/schemas/llm_config.py
CHANGED
|
@@ -33,6 +33,7 @@ class LLMConfig(BaseModel):
|
|
|
33
33
|
"koboldcpp",
|
|
34
34
|
"vllm",
|
|
35
35
|
"hugging-face",
|
|
36
|
+
"mistral",
|
|
36
37
|
] = Field(..., description="The endpoint type for the model.")
|
|
37
38
|
model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.")
|
|
38
39
|
model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.")
|
letta/schemas/message.py
CHANGED
|
@@ -2,7 +2,7 @@ import copy
|
|
|
2
2
|
import json
|
|
3
3
|
import warnings
|
|
4
4
|
from datetime import datetime, timezone
|
|
5
|
-
from typing import List, Optional
|
|
5
|
+
from typing import List, Literal, Optional
|
|
6
6
|
|
|
7
7
|
from pydantic import Field, field_validator
|
|
8
8
|
|
|
@@ -57,7 +57,11 @@ class BaseMessage(LettaBase):
|
|
|
57
57
|
class MessageCreate(BaseMessage):
|
|
58
58
|
"""Request to create a message"""
|
|
59
59
|
|
|
60
|
-
|
|
60
|
+
# In the simplified format, only allow simple roles
|
|
61
|
+
role: Literal[
|
|
62
|
+
MessageRole.user,
|
|
63
|
+
MessageRole.system,
|
|
64
|
+
] = Field(..., description="The role of the participant.")
|
|
61
65
|
text: str = Field(..., description="The text of the message.")
|
|
62
66
|
name: Optional[str] = Field(None, description="The name of the participant.")
|
|
63
67
|
|
letta/schemas/passage.py
CHANGED
|
@@ -19,8 +19,8 @@ class PassageBase(LettaBase):
|
|
|
19
19
|
# origin data source
|
|
20
20
|
source_id: Optional[str] = Field(None, description="The data source of the passage.")
|
|
21
21
|
|
|
22
|
-
#
|
|
23
|
-
|
|
22
|
+
# file association
|
|
23
|
+
file_id: Optional[str] = Field(None, description="The unique identifier of the file associated with the passage.")
|
|
24
24
|
metadata_: Optional[Dict] = Field({}, description="The metadata of the passage.")
|
|
25
25
|
|
|
26
26
|
|
|
@@ -36,7 +36,7 @@ class Passage(PassageBase):
|
|
|
36
36
|
user_id (str): The unique identifier of the user associated with the passage.
|
|
37
37
|
agent_id (str): The unique identifier of the agent associated with the passage.
|
|
38
38
|
source_id (str): The data source of the passage.
|
|
39
|
-
|
|
39
|
+
file_id (str): The unique identifier of the file associated with the passage.
|
|
40
40
|
"""
|
|
41
41
|
|
|
42
42
|
id: str = PassageBase.generate_id_field()
|
letta/schemas/source.py
CHANGED
|
@@ -28,7 +28,7 @@ class SourceCreate(BaseSource):
|
|
|
28
28
|
|
|
29
29
|
class Source(BaseSource):
|
|
30
30
|
"""
|
|
31
|
-
Representation of a source, which is a collection of
|
|
31
|
+
Representation of a source, which is a collection of files and passages.
|
|
32
32
|
|
|
33
33
|
Parameters:
|
|
34
34
|
id (str): The ID of the source
|
|
@@ -59,4 +59,4 @@ class UploadFileToSourceRequest(BaseModel):
|
|
|
59
59
|
class UploadFileToSourceResponse(BaseModel):
|
|
60
60
|
source: Source = Field(..., description="The source the file was uploaded to.")
|
|
61
61
|
added_passages: int = Field(..., description="The number of passages added to the source.")
|
|
62
|
-
added_documents: int = Field(..., description="The number of
|
|
62
|
+
added_documents: int = Field(..., description="The number of files added to the source.")
|
|
@@ -8,7 +8,7 @@ from starlette.responses import StreamingResponse
|
|
|
8
8
|
|
|
9
9
|
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
|
10
10
|
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState
|
|
11
|
-
from letta.schemas.enums import
|
|
11
|
+
from letta.schemas.enums import MessageStreamStatus
|
|
12
12
|
from letta.schemas.letta_message import (
|
|
13
13
|
LegacyLettaMessage,
|
|
14
14
|
LettaMessage,
|
|
@@ -23,7 +23,7 @@ from letta.schemas.memory import (
|
|
|
23
23
|
Memory,
|
|
24
24
|
RecallMemorySummary,
|
|
25
25
|
)
|
|
26
|
-
from letta.schemas.message import Message, UpdateMessage
|
|
26
|
+
from letta.schemas.message import Message, MessageCreate, UpdateMessage
|
|
27
27
|
from letta.schemas.passage import Passage
|
|
28
28
|
from letta.schemas.source import Source
|
|
29
29
|
from letta.server.rest_api.interface import StreamingServerInterface
|
|
@@ -326,14 +326,15 @@ async def send_message(
|
|
|
326
326
|
|
|
327
327
|
# TODO(charles): support sending multiple messages
|
|
328
328
|
assert len(request.messages) == 1, f"Multiple messages not supported: {request.messages}"
|
|
329
|
-
|
|
329
|
+
request.messages[0]
|
|
330
330
|
|
|
331
331
|
return await send_message_to_agent(
|
|
332
332
|
server=server,
|
|
333
333
|
agent_id=agent_id,
|
|
334
334
|
user_id=actor.id,
|
|
335
|
-
role=message.role,
|
|
336
|
-
message=message.text,
|
|
335
|
+
# role=message.role,
|
|
336
|
+
# message=message.text,
|
|
337
|
+
messages=request.messages,
|
|
337
338
|
stream_steps=request.stream_steps,
|
|
338
339
|
stream_tokens=request.stream_tokens,
|
|
339
340
|
return_message_object=request.return_message_object,
|
|
@@ -349,8 +350,8 @@ async def send_message_to_agent(
|
|
|
349
350
|
server: SyncServer,
|
|
350
351
|
agent_id: str,
|
|
351
352
|
user_id: str,
|
|
352
|
-
role: MessageRole,
|
|
353
|
-
|
|
353
|
+
# role: MessageRole,
|
|
354
|
+
messages: Union[List[Message], List[MessageCreate]],
|
|
354
355
|
stream_steps: bool,
|
|
355
356
|
stream_tokens: bool,
|
|
356
357
|
# related to whether or not we return `LettaMessage`s or `Message`s
|
|
@@ -367,14 +368,6 @@ async def send_message_to_agent(
|
|
|
367
368
|
# TODO: @charles is this the correct way to handle?
|
|
368
369
|
include_final_message = True
|
|
369
370
|
|
|
370
|
-
# determine role
|
|
371
|
-
if role == MessageRole.user:
|
|
372
|
-
message_func = server.user_message
|
|
373
|
-
elif role == MessageRole.system:
|
|
374
|
-
message_func = server.system_message
|
|
375
|
-
else:
|
|
376
|
-
raise HTTPException(status_code=500, detail=f"Bad role {role}")
|
|
377
|
-
|
|
378
371
|
if not stream_steps and stream_tokens:
|
|
379
372
|
raise HTTPException(status_code=400, detail="stream_steps must be 'true' if stream_tokens is 'true'")
|
|
380
373
|
|
|
@@ -413,7 +406,8 @@ async def send_message_to_agent(
|
|
|
413
406
|
# Offload the synchronous message_func to a separate thread
|
|
414
407
|
streaming_interface.stream_start()
|
|
415
408
|
task = asyncio.create_task(
|
|
416
|
-
asyncio.to_thread(message_func, user_id=user_id, agent_id=agent_id, message=message, timestamp=timestamp)
|
|
409
|
+
# asyncio.to_thread(message_func, user_id=user_id, agent_id=agent_id, message=message, timestamp=timestamp)
|
|
410
|
+
asyncio.to_thread(server.send_messages, user_id=user_id, agent_id=agent_id, messages=messages)
|
|
417
411
|
)
|
|
418
412
|
|
|
419
413
|
if stream_steps:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
from typing import List, Optional
|
|
2
2
|
|
|
3
|
-
from fastapi import APIRouter, Depends, Header, Query
|
|
3
|
+
from fastapi import APIRouter, Depends, Header, HTTPException, Query
|
|
4
4
|
|
|
5
5
|
from letta.schemas.job import Job
|
|
6
6
|
from letta.server.rest_api.utils import get_letta_server
|
|
@@ -54,3 +54,19 @@ def get_job(
|
|
|
54
54
|
"""
|
|
55
55
|
|
|
56
56
|
return server.get_job(job_id=job_id)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@router.delete("/{job_id}", response_model=Job, operation_id="delete_job")
|
|
60
|
+
def delete_job(
|
|
61
|
+
job_id: str,
|
|
62
|
+
server: "SyncServer" = Depends(get_letta_server),
|
|
63
|
+
):
|
|
64
|
+
"""
|
|
65
|
+
Delete a job by its job_id.
|
|
66
|
+
"""
|
|
67
|
+
job = server.get_job(job_id=job_id)
|
|
68
|
+
if not job:
|
|
69
|
+
raise HTTPException(status_code=404, detail="Job not found")
|
|
70
|
+
|
|
71
|
+
server.delete_job(job_id=job_id)
|
|
72
|
+
return job
|
|
@@ -4,7 +4,7 @@ from typing import List, Optional
|
|
|
4
4
|
|
|
5
5
|
from fastapi import APIRouter, BackgroundTasks, Depends, Header, Query, UploadFile
|
|
6
6
|
|
|
7
|
-
from letta.schemas.
|
|
7
|
+
from letta.schemas.file import FileMetadata
|
|
8
8
|
from letta.schemas.job import Job
|
|
9
9
|
from letta.schemas.passage import Passage
|
|
10
10
|
from letta.schemas.source import Source, SourceCreate, SourceUpdate
|
|
@@ -186,19 +186,17 @@ def list_passages(
|
|
|
186
186
|
return passages
|
|
187
187
|
|
|
188
188
|
|
|
189
|
-
@router.get("/{source_id}/
|
|
190
|
-
def
|
|
189
|
+
@router.get("/{source_id}/files", response_model=List[FileMetadata], operation_id="list_files_from_source")
|
|
190
|
+
def list_files_from_source(
|
|
191
191
|
source_id: str,
|
|
192
|
+
limit: int = Query(1000, description="Number of files to return"),
|
|
193
|
+
cursor: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"),
|
|
192
194
|
server: "SyncServer" = Depends(get_letta_server),
|
|
193
|
-
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
|
194
195
|
):
|
|
195
196
|
"""
|
|
196
|
-
List
|
|
197
|
+
List paginated files associated with a data source.
|
|
197
198
|
"""
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
documents = server.list_data_source_documents(user_id=actor.id, source_id=source_id)
|
|
201
|
-
return documents
|
|
199
|
+
return server.list_files_from_source(source_id=source_id, limit=limit, cursor=cursor)
|
|
202
200
|
|
|
203
201
|
|
|
204
202
|
def load_file_to_source_async(server: SyncServer, source_id: str, job_id: str, file: UploadFile, bytes: bytes):
|
letta/server/server.py
CHANGED
|
@@ -47,6 +47,7 @@ from letta.providers import (
|
|
|
47
47
|
AnthropicProvider,
|
|
48
48
|
AzureProvider,
|
|
49
49
|
GoogleAIProvider,
|
|
50
|
+
GroqProvider,
|
|
50
51
|
LettaProvider,
|
|
51
52
|
OllamaProvider,
|
|
52
53
|
OpenAIProvider,
|
|
@@ -63,16 +64,16 @@ from letta.schemas.block import (
|
|
|
63
64
|
CreatePersona,
|
|
64
65
|
UpdateBlock,
|
|
65
66
|
)
|
|
66
|
-
from letta.schemas.document import Document
|
|
67
67
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
68
68
|
|
|
69
69
|
# openai schemas
|
|
70
70
|
from letta.schemas.enums import JobStatus
|
|
71
|
+
from letta.schemas.file import FileMetadata
|
|
71
72
|
from letta.schemas.job import Job
|
|
72
73
|
from letta.schemas.letta_message import LettaMessage
|
|
73
74
|
from letta.schemas.llm_config import LLMConfig
|
|
74
75
|
from letta.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary
|
|
75
|
-
from letta.schemas.message import Message, UpdateMessage
|
|
76
|
+
from letta.schemas.message import Message, MessageCreate, MessageRole, UpdateMessage
|
|
76
77
|
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
|
77
78
|
from letta.schemas.organization import Organization, OrganizationCreate
|
|
78
79
|
from letta.schemas.passage import Passage
|
|
@@ -141,6 +142,11 @@ class Server(object):
|
|
|
141
142
|
"""Process a message from the system, internally calls step"""
|
|
142
143
|
raise NotImplementedError
|
|
143
144
|
|
|
145
|
+
@abstractmethod
|
|
146
|
+
def send_messages(self, user_id: str, agent_id: str, messages: Union[MessageCreate, List[Message]]) -> None:
|
|
147
|
+
"""Send a list of messages to the agent"""
|
|
148
|
+
raise NotImplementedError
|
|
149
|
+
|
|
144
150
|
@abstractmethod
|
|
145
151
|
def run_command(self, user_id: str, agent_id: str, command: str) -> Union[str, None]:
|
|
146
152
|
"""Run a command on the agent, e.g. /memory
|
|
@@ -292,6 +298,12 @@ class SyncServer(Server):
|
|
|
292
298
|
base_url=model_settings.vllm_api_base,
|
|
293
299
|
)
|
|
294
300
|
)
|
|
301
|
+
if model_settings.groq_api_key:
|
|
302
|
+
self._enabled_providers.append(
|
|
303
|
+
GroqProvider(
|
|
304
|
+
api_key=model_settings.groq_api_key,
|
|
305
|
+
)
|
|
306
|
+
)
|
|
295
307
|
|
|
296
308
|
def save_agents(self):
|
|
297
309
|
"""Saves all the agents that are in the in-memory object store"""
|
|
@@ -383,9 +395,22 @@ class SyncServer(Server):
|
|
|
383
395
|
letta_agent = self._load_agent(user_id=user_id, agent_id=agent_id)
|
|
384
396
|
return letta_agent
|
|
385
397
|
|
|
386
|
-
def _step(
|
|
398
|
+
def _step(
|
|
399
|
+
self,
|
|
400
|
+
user_id: str,
|
|
401
|
+
agent_id: str,
|
|
402
|
+
input_messages: Union[Message, List[Message]],
|
|
403
|
+
# timestamp: Optional[datetime],
|
|
404
|
+
) -> LettaUsageStatistics:
|
|
387
405
|
"""Send the input message through the agent"""
|
|
388
|
-
|
|
406
|
+
|
|
407
|
+
# Input validation
|
|
408
|
+
if isinstance(input_messages, Message):
|
|
409
|
+
input_messages = [input_messages]
|
|
410
|
+
if not all(isinstance(m, Message) for m in input_messages):
|
|
411
|
+
raise ValueError(f"messages should be a Message or a list of Message, got {type(input_messages)}")
|
|
412
|
+
|
|
413
|
+
logger.debug(f"Got input messages: {input_messages}")
|
|
389
414
|
try:
|
|
390
415
|
|
|
391
416
|
# Get the agent object (loaded in memory)
|
|
@@ -398,18 +423,18 @@ class SyncServer(Server):
|
|
|
398
423
|
|
|
399
424
|
logger.debug(f"Starting agent step")
|
|
400
425
|
no_verify = True
|
|
401
|
-
next_input_message =
|
|
426
|
+
next_input_message = input_messages
|
|
402
427
|
counter = 0
|
|
403
428
|
total_usage = UsageStatistics()
|
|
404
429
|
step_count = 0
|
|
405
430
|
while True:
|
|
406
431
|
step_response = letta_agent.step(
|
|
407
|
-
next_input_message,
|
|
432
|
+
messages=next_input_message,
|
|
408
433
|
first_message=False,
|
|
409
434
|
skip_verify=no_verify,
|
|
410
435
|
return_dicts=False,
|
|
411
436
|
stream=token_streaming,
|
|
412
|
-
timestamp=timestamp,
|
|
437
|
+
# timestamp=timestamp,
|
|
413
438
|
ms=self.ms,
|
|
414
439
|
)
|
|
415
440
|
step_response.messages
|
|
@@ -436,13 +461,40 @@ class SyncServer(Server):
|
|
|
436
461
|
break
|
|
437
462
|
# Chain handlers
|
|
438
463
|
elif token_warning:
|
|
439
|
-
|
|
464
|
+
assert letta_agent.agent_state.user_id is not None
|
|
465
|
+
next_input_message = Message.dict_to_message(
|
|
466
|
+
agent_id=letta_agent.agent_state.id,
|
|
467
|
+
user_id=letta_agent.agent_state.user_id,
|
|
468
|
+
model=letta_agent.model,
|
|
469
|
+
openai_message_dict={
|
|
470
|
+
"role": "user", # TODO: change to system?
|
|
471
|
+
"content": system.get_token_limit_warning(),
|
|
472
|
+
},
|
|
473
|
+
)
|
|
440
474
|
continue # always chain
|
|
441
475
|
elif function_failed:
|
|
442
|
-
|
|
476
|
+
assert letta_agent.agent_state.user_id is not None
|
|
477
|
+
next_input_message = Message.dict_to_message(
|
|
478
|
+
agent_id=letta_agent.agent_state.id,
|
|
479
|
+
user_id=letta_agent.agent_state.user_id,
|
|
480
|
+
model=letta_agent.model,
|
|
481
|
+
openai_message_dict={
|
|
482
|
+
"role": "user", # TODO: change to system?
|
|
483
|
+
"content": system.get_heartbeat(constants.FUNC_FAILED_HEARTBEAT_MESSAGE),
|
|
484
|
+
},
|
|
485
|
+
)
|
|
443
486
|
continue # always chain
|
|
444
487
|
elif heartbeat_request:
|
|
445
|
-
|
|
488
|
+
assert letta_agent.agent_state.user_id is not None
|
|
489
|
+
next_input_message = Message.dict_to_message(
|
|
490
|
+
agent_id=letta_agent.agent_state.id,
|
|
491
|
+
user_id=letta_agent.agent_state.user_id,
|
|
492
|
+
model=letta_agent.model,
|
|
493
|
+
openai_message_dict={
|
|
494
|
+
"role": "user", # TODO: change to system?
|
|
495
|
+
"content": system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE),
|
|
496
|
+
},
|
|
497
|
+
)
|
|
446
498
|
continue # always chain
|
|
447
499
|
# Letta no-op / yield
|
|
448
500
|
else:
|
|
@@ -621,7 +673,7 @@ class SyncServer(Server):
|
|
|
621
673
|
)
|
|
622
674
|
|
|
623
675
|
# Run the agent state forward
|
|
624
|
-
usage = self._step(user_id=user_id, agent_id=agent_id,
|
|
676
|
+
usage = self._step(user_id=user_id, agent_id=agent_id, input_messages=message)
|
|
625
677
|
return usage
|
|
626
678
|
|
|
627
679
|
def system_message(
|
|
@@ -669,7 +721,7 @@ class SyncServer(Server):
|
|
|
669
721
|
|
|
670
722
|
if isinstance(message, Message):
|
|
671
723
|
# Can't have a null text field
|
|
672
|
-
if
|
|
724
|
+
if message.text is None or len(message.text) == 0:
|
|
673
725
|
raise ValueError(f"Invalid input: '{message.text}'")
|
|
674
726
|
# If the input begins with a command prefix, reject
|
|
675
727
|
elif message.text.startswith("/"):
|
|
@@ -683,7 +735,69 @@ class SyncServer(Server):
|
|
|
683
735
|
message.created_at = timestamp
|
|
684
736
|
|
|
685
737
|
# Run the agent state forward
|
|
686
|
-
return self._step(user_id=user_id, agent_id=agent_id,
|
|
738
|
+
return self._step(user_id=user_id, agent_id=agent_id, input_messages=message)
|
|
739
|
+
|
|
740
|
+
def send_messages(
|
|
741
|
+
self,
|
|
742
|
+
user_id: str,
|
|
743
|
+
agent_id: str,
|
|
744
|
+
messages: Union[List[MessageCreate], List[Message]],
|
|
745
|
+
# whether or not to wrap user and system message as MemGPT-style stringified JSON
|
|
746
|
+
wrap_user_message: bool = True,
|
|
747
|
+
wrap_system_message: bool = True,
|
|
748
|
+
) -> LettaUsageStatistics:
|
|
749
|
+
"""Send a list of messages to the agent
|
|
750
|
+
|
|
751
|
+
If the messages are of type MessageCreate, we need to turn them into
|
|
752
|
+
Message objects first before sending them through step.
|
|
753
|
+
|
|
754
|
+
Otherwise, we can pass them in directly.
|
|
755
|
+
"""
|
|
756
|
+
if self.ms.get_user(user_id=user_id) is None:
|
|
757
|
+
raise ValueError(f"User user_id={user_id} does not exist")
|
|
758
|
+
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
|
|
759
|
+
raise ValueError(f"Agent agent_id={agent_id} does not exist")
|
|
760
|
+
|
|
761
|
+
message_objects: List[Message] = []
|
|
762
|
+
|
|
763
|
+
if all(isinstance(m, MessageCreate) for m in messages):
|
|
764
|
+
for message in messages:
|
|
765
|
+
assert isinstance(message, MessageCreate)
|
|
766
|
+
|
|
767
|
+
# If wrapping is eanbled, wrap with metadata before placing content inside the Message object
|
|
768
|
+
if message.role == MessageRole.user and wrap_user_message:
|
|
769
|
+
message.text = system.package_user_message(user_message=message.text)
|
|
770
|
+
elif message.role == MessageRole.system and wrap_system_message:
|
|
771
|
+
message.text = system.package_system_message(system_message=message.text)
|
|
772
|
+
else:
|
|
773
|
+
raise ValueError(f"Invalid message role: {message.role}")
|
|
774
|
+
|
|
775
|
+
# Create the Message object
|
|
776
|
+
message_objects.append(
|
|
777
|
+
Message(
|
|
778
|
+
user_id=user_id,
|
|
779
|
+
agent_id=agent_id,
|
|
780
|
+
role=message.role,
|
|
781
|
+
text=message.text,
|
|
782
|
+
name=message.name,
|
|
783
|
+
# assigned later?
|
|
784
|
+
model=None,
|
|
785
|
+
# irrelevant
|
|
786
|
+
tool_calls=None,
|
|
787
|
+
tool_call_id=None,
|
|
788
|
+
)
|
|
789
|
+
)
|
|
790
|
+
|
|
791
|
+
elif all(isinstance(m, Message) for m in messages):
|
|
792
|
+
for message in messages:
|
|
793
|
+
assert isinstance(message, Message)
|
|
794
|
+
message_objects.append(message)
|
|
795
|
+
|
|
796
|
+
else:
|
|
797
|
+
raise ValueError(f"All messages must be of type Message or MessageCreate, got {type(messages)}")
|
|
798
|
+
|
|
799
|
+
# Run the agent state forward
|
|
800
|
+
return self._step(user_id=user_id, agent_id=agent_id, input_messages=message_objects)
|
|
687
801
|
|
|
688
802
|
# @LockingServer.agent_lock_decorator
|
|
689
803
|
def run_command(self, user_id: str, agent_id: str, command: str) -> LettaUsageStatistics:
|
|
@@ -1556,7 +1670,7 @@ class SyncServer(Server):
|
|
|
1556
1670
|
# job.status = JobStatus.failed
|
|
1557
1671
|
# job.metadata_["error"] = error
|
|
1558
1672
|
# self.ms.update_job(job)
|
|
1559
|
-
# # TODO: delete any associated passages/
|
|
1673
|
+
# # TODO: delete any associated passages/files?
|
|
1560
1674
|
|
|
1561
1675
|
# # return failed job
|
|
1562
1676
|
# return job
|
|
@@ -1585,11 +1699,10 @@ class SyncServer(Server):
|
|
|
1585
1699
|
|
|
1586
1700
|
# get the data connectors
|
|
1587
1701
|
passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
|
|
1588
|
-
|
|
1589
|
-
document_store = None # StorageConnector.get_storage_connector(TableType.DOCUMENTS, self.config, user_id=user_id)
|
|
1702
|
+
file_store = StorageConnector.get_storage_connector(TableType.FILES, self.config, user_id=user_id)
|
|
1590
1703
|
|
|
1591
1704
|
# load data into the document store
|
|
1592
|
-
passage_count, document_count = load_data(connector, source, passage_store,
|
|
1705
|
+
passage_count, document_count = load_data(connector, source, passage_store, file_store)
|
|
1593
1706
|
return passage_count, document_count
|
|
1594
1707
|
|
|
1595
1708
|
def attach_source_to_agent(
|
|
@@ -1646,14 +1759,14 @@ class SyncServer(Server):
|
|
|
1646
1759
|
# list all attached sources to an agent
|
|
1647
1760
|
return self.ms.list_attached_sources(agent_id)
|
|
1648
1761
|
|
|
1762
|
+
def list_files_from_source(self, source_id: str, limit: int = 1000, cursor: Optional[str] = None) -> List[FileMetadata]:
|
|
1763
|
+
# list all attached sources to an agent
|
|
1764
|
+
return self.ms.list_files_from_source(source_id=source_id, limit=limit, cursor=cursor)
|
|
1765
|
+
|
|
1649
1766
|
def list_data_source_passages(self, user_id: str, source_id: str) -> List[Passage]:
|
|
1650
1767
|
warnings.warn("list_data_source_passages is not yet implemented, returning empty list.", category=UserWarning)
|
|
1651
1768
|
return []
|
|
1652
1769
|
|
|
1653
|
-
def list_data_source_documents(self, user_id: str, source_id: str) -> List[Document]:
|
|
1654
|
-
warnings.warn("list_data_source_documents is not yet implemented, returning empty list.", category=UserWarning)
|
|
1655
|
-
return []
|
|
1656
|
-
|
|
1657
1770
|
def list_all_sources(self, user_id: str) -> List[Source]:
|
|
1658
1771
|
"""List all sources (w/ extra metadata) belonging to a user"""
|
|
1659
1772
|
|
|
@@ -1667,9 +1780,9 @@ class SyncServer(Server):
|
|
|
1667
1780
|
passage_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
|
|
1668
1781
|
num_passages = passage_conn.size({"source_id": source.id})
|
|
1669
1782
|
|
|
1670
|
-
# TODO: add when
|
|
1671
|
-
## count number of
|
|
1672
|
-
# document_conn = StorageConnector.get_storage_connector(TableType.
|
|
1783
|
+
# TODO: add when files table implemented
|
|
1784
|
+
## count number of files
|
|
1785
|
+
# document_conn = StorageConnector.get_storage_connector(TableType.FILES, self.config, user_id=user_id)
|
|
1673
1786
|
# num_documents = document_conn.size({"data_source": source.name})
|
|
1674
1787
|
num_documents = 0
|
|
1675
1788
|
|