letta-nightly 0.6.9.dev20250115104021__py3-none-any.whl → 0.6.9.dev20250116195713__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of letta-nightly might be problematic. Click here for more details.
- letta/__init__.py +1 -0
- letta/agent.py +24 -0
- letta/client/client.py +274 -11
- letta/constants.py +5 -0
- letta/functions/function_sets/multi_agent.py +96 -0
- letta/functions/helpers.py +105 -1
- letta/functions/schema_generator.py +8 -0
- letta/llm_api/openai.py +18 -2
- letta/local_llm/utils.py +4 -0
- letta/orm/__init__.py +1 -0
- letta/orm/enums.py +6 -0
- letta/orm/job.py +24 -2
- letta/orm/job_messages.py +33 -0
- letta/orm/job_usage_statistics.py +30 -0
- letta/orm/message.py +10 -0
- letta/orm/sqlalchemy_base.py +28 -4
- letta/orm/tool.py +0 -3
- letta/schemas/agent.py +10 -4
- letta/schemas/job.py +2 -0
- letta/schemas/letta_base.py +6 -1
- letta/schemas/letta_request.py +6 -4
- letta/schemas/llm_config.py +1 -1
- letta/schemas/message.py +2 -4
- letta/schemas/providers.py +1 -1
- letta/schemas/run.py +61 -0
- letta/schemas/tool.py +9 -17
- letta/server/rest_api/interface.py +3 -0
- letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +6 -12
- letta/server/rest_api/routers/v1/__init__.py +4 -0
- letta/server/rest_api/routers/v1/agents.py +47 -151
- letta/server/rest_api/routers/v1/runs.py +137 -0
- letta/server/rest_api/routers/v1/tags.py +27 -0
- letta/server/rest_api/utils.py +5 -3
- letta/server/server.py +139 -2
- letta/services/agent_manager.py +101 -6
- letta/services/job_manager.py +274 -9
- letta/services/tool_execution_sandbox.py +1 -1
- letta/services/tool_manager.py +30 -25
- letta/utils.py +3 -4
- {letta_nightly-0.6.9.dev20250115104021.dist-info → letta_nightly-0.6.9.dev20250116195713.dist-info}/METADATA +4 -3
- {letta_nightly-0.6.9.dev20250115104021.dist-info → letta_nightly-0.6.9.dev20250116195713.dist-info}/RECORD +44 -38
- {letta_nightly-0.6.9.dev20250115104021.dist-info → letta_nightly-0.6.9.dev20250116195713.dist-info}/LICENSE +0 -0
- {letta_nightly-0.6.9.dev20250115104021.dist-info → letta_nightly-0.6.9.dev20250116195713.dist-info}/WHEEL +0 -0
- {letta_nightly-0.6.9.dev20250115104021.dist-info → letta_nightly-0.6.9.dev20250116195713.dist-info}/entry_points.txt +0 -0
letta/functions/helpers.py
CHANGED
|
@@ -1,10 +1,15 @@
|
|
|
1
|
+
import json
|
|
1
2
|
from typing import Any, Optional, Union
|
|
2
3
|
|
|
3
4
|
import humps
|
|
4
5
|
from composio.constants import DEFAULT_ENTITY_ID
|
|
5
6
|
from pydantic import BaseModel
|
|
6
7
|
|
|
7
|
-
from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY
|
|
8
|
+
from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
|
9
|
+
from letta.schemas.enums import MessageRole
|
|
10
|
+
from letta.schemas.letta_message import AssistantMessage, ReasoningMessage, ToolCallMessage
|
|
11
|
+
from letta.schemas.letta_response import LettaResponse
|
|
12
|
+
from letta.schemas.message import MessageCreate
|
|
8
13
|
|
|
9
14
|
|
|
10
15
|
def generate_composio_tool_wrapper(action_name: str) -> tuple[str, str]:
|
|
@@ -206,3 +211,102 @@ def generate_import_code(module_attr_map: Optional[dict]):
|
|
|
206
211
|
code_lines.append(f" # Access the {attr} from the module")
|
|
207
212
|
code_lines.append(f" {attr} = getattr({module_name}, '{attr}')")
|
|
208
213
|
return "\n".join(code_lines)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def parse_letta_response_for_assistant_message(
|
|
217
|
+
letta_response: LettaResponse,
|
|
218
|
+
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
|
|
219
|
+
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
|
|
220
|
+
) -> Optional[str]:
|
|
221
|
+
reasoning_message = ""
|
|
222
|
+
for m in letta_response.messages:
|
|
223
|
+
if isinstance(m, AssistantMessage):
|
|
224
|
+
return m.assistant_message
|
|
225
|
+
elif isinstance(m, ToolCallMessage) and m.tool_call.name == assistant_message_tool_name:
|
|
226
|
+
try:
|
|
227
|
+
return json.loads(m.tool_call.arguments)[assistant_message_tool_kwarg]
|
|
228
|
+
except Exception: # TODO: Make this more specific
|
|
229
|
+
continue
|
|
230
|
+
elif isinstance(m, ReasoningMessage):
|
|
231
|
+
# This is not ideal, but we would like to return something rather than nothing
|
|
232
|
+
reasoning_message += f"{m.reasoning}\n"
|
|
233
|
+
|
|
234
|
+
return None
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
import asyncio
|
|
238
|
+
from random import uniform
|
|
239
|
+
from typing import Optional
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
async def async_send_message_with_retries(
|
|
243
|
+
server,
|
|
244
|
+
sender_agent: "Agent",
|
|
245
|
+
target_agent_id: str,
|
|
246
|
+
message_text: str,
|
|
247
|
+
max_retries: int,
|
|
248
|
+
timeout: int,
|
|
249
|
+
logging_prefix: Optional[str] = None,
|
|
250
|
+
) -> str:
|
|
251
|
+
"""
|
|
252
|
+
Shared helper coroutine to send a message to an agent with retries and a timeout.
|
|
253
|
+
|
|
254
|
+
Args:
|
|
255
|
+
server: The Letta server instance (from get_letta_server()).
|
|
256
|
+
sender_agent (Agent): The agent initiating the send action.
|
|
257
|
+
target_agent_id (str): The ID of the agent to send the message to.
|
|
258
|
+
message_text (str): The text to send as the user message.
|
|
259
|
+
max_retries (int): Maximum number of retries for the request.
|
|
260
|
+
timeout (int): Maximum time to wait for a response (in seconds).
|
|
261
|
+
logging_prefix (str): A prefix to append to logging
|
|
262
|
+
Returns:
|
|
263
|
+
str: The response or an error message.
|
|
264
|
+
"""
|
|
265
|
+
logging_prefix = logging_prefix or "[async_send_message_with_retries]"
|
|
266
|
+
for attempt in range(1, max_retries + 1):
|
|
267
|
+
try:
|
|
268
|
+
messages = [MessageCreate(role=MessageRole.user, text=message_text, name=sender_agent.agent_state.name)]
|
|
269
|
+
# Wrap in a timeout
|
|
270
|
+
response = await asyncio.wait_for(
|
|
271
|
+
server.send_message_to_agent(
|
|
272
|
+
agent_id=target_agent_id,
|
|
273
|
+
actor=sender_agent.user,
|
|
274
|
+
messages=messages,
|
|
275
|
+
stream_steps=False,
|
|
276
|
+
stream_tokens=False,
|
|
277
|
+
use_assistant_message=True,
|
|
278
|
+
assistant_message_tool_name=DEFAULT_MESSAGE_TOOL,
|
|
279
|
+
assistant_message_tool_kwarg=DEFAULT_MESSAGE_TOOL_KWARG,
|
|
280
|
+
),
|
|
281
|
+
timeout=timeout,
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
# Extract assistant message
|
|
285
|
+
assistant_message = parse_letta_response_for_assistant_message(
|
|
286
|
+
response,
|
|
287
|
+
assistant_message_tool_name=DEFAULT_MESSAGE_TOOL,
|
|
288
|
+
assistant_message_tool_kwarg=DEFAULT_MESSAGE_TOOL_KWARG,
|
|
289
|
+
)
|
|
290
|
+
if assistant_message:
|
|
291
|
+
msg = f"Agent {target_agent_id} said '{assistant_message}'"
|
|
292
|
+
sender_agent.logger.info(f"{logging_prefix} - {msg}")
|
|
293
|
+
return msg
|
|
294
|
+
else:
|
|
295
|
+
msg = f"(No response from agent {target_agent_id})"
|
|
296
|
+
sender_agent.logger.info(f"{logging_prefix} - {msg}")
|
|
297
|
+
return msg
|
|
298
|
+
except asyncio.TimeoutError:
|
|
299
|
+
error_msg = f"(Timeout on attempt {attempt}/{max_retries} for agent {target_agent_id})"
|
|
300
|
+
sender_agent.logger.warning(f"{logging_prefix} - {error_msg}")
|
|
301
|
+
except Exception as e:
|
|
302
|
+
error_msg = f"(Error on attempt {attempt}/{max_retries} for agent {target_agent_id}: {e})"
|
|
303
|
+
sender_agent.logger.warning(f"{logging_prefix} - {error_msg}")
|
|
304
|
+
|
|
305
|
+
# Exponential backoff before retrying
|
|
306
|
+
if attempt < max_retries:
|
|
307
|
+
backoff = uniform(0.5, 2) * (2**attempt)
|
|
308
|
+
sender_agent.logger.warning(f"{logging_prefix} - Retrying the agent to agent send_message...sleeping for {backoff}")
|
|
309
|
+
await asyncio.sleep(backoff)
|
|
310
|
+
else:
|
|
311
|
+
sender_agent.logger.error(f"{logging_prefix} - Fatal error during agent to agent send_message: {error_msg}")
|
|
312
|
+
return error_msg
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import inspect
|
|
2
|
+
import warnings
|
|
2
3
|
from typing import Any, Dict, List, Optional, Type, Union, get_args, get_origin
|
|
3
4
|
|
|
4
5
|
from docstring_parser import parse
|
|
@@ -44,6 +45,13 @@ def type_to_json_schema_type(py_type) -> dict:
|
|
|
44
45
|
origin = get_origin(py_type)
|
|
45
46
|
if py_type == list or origin in (list, List):
|
|
46
47
|
args = get_args(py_type)
|
|
48
|
+
if len(args) == 0:
|
|
49
|
+
# is this correct
|
|
50
|
+
warnings.warn("Defaulting to string type for untyped List")
|
|
51
|
+
return {
|
|
52
|
+
"type": "array",
|
|
53
|
+
"items": {"type": "string"},
|
|
54
|
+
}
|
|
47
55
|
|
|
48
56
|
if args and inspect.isclass(args[0]) and issubclass(args[0], BaseModel):
|
|
49
57
|
# If it's a list of Pydantic models, return an array with the model schema as items
|
letta/llm_api/openai.py
CHANGED
|
@@ -307,15 +307,31 @@ def openai_chat_completions_process_stream(
|
|
|
307
307
|
warnings.warn(
|
|
308
308
|
f"Tool call index out of range ({tool_call_delta.index})\ncurrent tool calls: {accum_message.tool_calls}\ncurrent delta: {tool_call_delta}"
|
|
309
309
|
)
|
|
310
|
+
# force index 0
|
|
311
|
+
# accum_message.tool_calls[0].id = tool_call_delta.id
|
|
310
312
|
else:
|
|
311
313
|
accum_message.tool_calls[tool_call_delta.index].id = tool_call_delta.id
|
|
312
314
|
if tool_call_delta.function is not None:
|
|
313
315
|
if tool_call_delta.function.name is not None:
|
|
314
316
|
# TODO assert that we're not overwriting?
|
|
315
317
|
# TODO += instead of =?
|
|
316
|
-
|
|
318
|
+
if tool_call_delta.index not in range(len(accum_message.tool_calls)):
|
|
319
|
+
warnings.warn(
|
|
320
|
+
f"Tool call index out of range ({tool_call_delta.index})\ncurrent tool calls: {accum_message.tool_calls}\ncurrent delta: {tool_call_delta}"
|
|
321
|
+
)
|
|
322
|
+
# force index 0
|
|
323
|
+
# accum_message.tool_calls[0].function.name = tool_call_delta.function.name
|
|
324
|
+
else:
|
|
325
|
+
accum_message.tool_calls[tool_call_delta.index].function.name = tool_call_delta.function.name
|
|
317
326
|
if tool_call_delta.function.arguments is not None:
|
|
318
|
-
|
|
327
|
+
if tool_call_delta.index not in range(len(accum_message.tool_calls)):
|
|
328
|
+
warnings.warn(
|
|
329
|
+
f"Tool call index out of range ({tool_call_delta.index})\ncurrent tool calls: {accum_message.tool_calls}\ncurrent delta: {tool_call_delta}"
|
|
330
|
+
)
|
|
331
|
+
# force index 0
|
|
332
|
+
# accum_message.tool_calls[0].function.arguments += tool_call_delta.function.arguments
|
|
333
|
+
else:
|
|
334
|
+
accum_message.tool_calls[tool_call_delta.index].function.arguments += tool_call_delta.function.arguments
|
|
319
335
|
|
|
320
336
|
if message_delta.function_call is not None:
|
|
321
337
|
raise NotImplementedError(f"Old function_call style not support with stream=True")
|
letta/local_llm/utils.py
CHANGED
|
@@ -122,6 +122,10 @@ def num_tokens_from_functions(functions: List[dict], model: str = "gpt-4"):
|
|
|
122
122
|
for o in v["enum"]:
|
|
123
123
|
function_tokens += 3
|
|
124
124
|
function_tokens += len(encoding.encode(o))
|
|
125
|
+
elif field == "items":
|
|
126
|
+
function_tokens += 2
|
|
127
|
+
if isinstance(v["items"], dict) and "type" in v["items"]:
|
|
128
|
+
function_tokens += len(encoding.encode(v["items"]["type"]))
|
|
125
129
|
else:
|
|
126
130
|
warnings.warn(f"num_tokens_from_functions: Unsupported field {field} in function {function}")
|
|
127
131
|
function_tokens += 11
|
letta/orm/__init__.py
CHANGED
|
@@ -5,6 +5,7 @@ from letta.orm.block import Block
|
|
|
5
5
|
from letta.orm.blocks_agents import BlocksAgents
|
|
6
6
|
from letta.orm.file import FileMetadata
|
|
7
7
|
from letta.orm.job import Job
|
|
8
|
+
from letta.orm.job_messages import JobMessage
|
|
8
9
|
from letta.orm.message import Message
|
|
9
10
|
from letta.orm.organization import Organization
|
|
10
11
|
from letta.orm.passage import AgentPassage, BasePassage, SourcePassage
|
letta/orm/enums.py
CHANGED
|
@@ -5,6 +5,12 @@ class ToolType(str, Enum):
|
|
|
5
5
|
CUSTOM = "custom"
|
|
6
6
|
LETTA_CORE = "letta_core"
|
|
7
7
|
LETTA_MEMORY_CORE = "letta_memory_core"
|
|
8
|
+
LETTA_MULTI_AGENT_CORE = "letta_multi_agent_core"
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class JobType(str, Enum):
|
|
12
|
+
JOB = "job"
|
|
13
|
+
RUN = "run"
|
|
8
14
|
|
|
9
15
|
|
|
10
16
|
class ToolSourceType(str, Enum):
|
letta/orm/job.py
CHANGED
|
@@ -1,15 +1,20 @@
|
|
|
1
1
|
from datetime import datetime
|
|
2
|
-
from typing import TYPE_CHECKING, Optional
|
|
2
|
+
from typing import TYPE_CHECKING, List, Optional
|
|
3
3
|
|
|
4
4
|
from sqlalchemy import JSON, String
|
|
5
5
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
|
6
6
|
|
|
7
|
+
from letta.orm.enums import JobType
|
|
7
8
|
from letta.orm.mixins import UserMixin
|
|
8
9
|
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
|
9
10
|
from letta.schemas.enums import JobStatus
|
|
10
11
|
from letta.schemas.job import Job as PydanticJob
|
|
12
|
+
from letta.schemas.letta_request import LettaRequestConfig
|
|
11
13
|
|
|
12
14
|
if TYPE_CHECKING:
|
|
15
|
+
from letta.orm.job_messages import JobMessage
|
|
16
|
+
from letta.orm.job_usage_statistics import JobUsageStatistics
|
|
17
|
+
from letta.orm.message import Message
|
|
13
18
|
from letta.orm.user import User
|
|
14
19
|
|
|
15
20
|
|
|
@@ -23,7 +28,24 @@ class Job(SqlalchemyBase, UserMixin):
|
|
|
23
28
|
|
|
24
29
|
status: Mapped[JobStatus] = mapped_column(String, default=JobStatus.created, doc="The current status of the job.")
|
|
25
30
|
completed_at: Mapped[Optional[datetime]] = mapped_column(nullable=True, doc="The unix timestamp of when the job was completed.")
|
|
26
|
-
metadata_: Mapped[Optional[dict]] = mapped_column(JSON,
|
|
31
|
+
metadata_: Mapped[Optional[dict]] = mapped_column(JSON, doc="The metadata of the job.")
|
|
32
|
+
job_type: Mapped[JobType] = mapped_column(
|
|
33
|
+
String,
|
|
34
|
+
default=JobType.JOB,
|
|
35
|
+
doc="The type of job. This affects whether or not we generate json_schema and source_code on the fly.",
|
|
36
|
+
)
|
|
37
|
+
request_config: Mapped[Optional[LettaRequestConfig]] = mapped_column(
|
|
38
|
+
JSON, nullable=True, doc="The request configuration for the job, stored as JSON."
|
|
39
|
+
)
|
|
27
40
|
|
|
28
41
|
# relationships
|
|
29
42
|
user: Mapped["User"] = relationship("User", back_populates="jobs")
|
|
43
|
+
job_messages: Mapped[List["JobMessage"]] = relationship("JobMessage", back_populates="job", cascade="all, delete-orphan")
|
|
44
|
+
usage_statistics: Mapped[list["JobUsageStatistics"]] = relationship(
|
|
45
|
+
"JobUsageStatistics", back_populates="job", cascade="all, delete-orphan"
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def messages(self) -> List["Message"]:
|
|
50
|
+
"""Get all messages associated with this job."""
|
|
51
|
+
return [jm.message for jm in self.job_messages]
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import ForeignKey, UniqueConstraint
|
|
4
|
+
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
|
5
|
+
|
|
6
|
+
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from letta.orm.job import Job
|
|
10
|
+
from letta.orm.message import Message
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class JobMessage(SqlalchemyBase):
|
|
14
|
+
"""Tracks messages that were created during job execution."""
|
|
15
|
+
|
|
16
|
+
__tablename__ = "job_messages"
|
|
17
|
+
__table_args__ = (UniqueConstraint("job_id", "message_id", name="unique_job_message"),)
|
|
18
|
+
|
|
19
|
+
id: Mapped[int] = mapped_column(primary_key=True, doc="Unique identifier for the job message")
|
|
20
|
+
job_id: Mapped[str] = mapped_column(
|
|
21
|
+
ForeignKey("jobs.id", ondelete="CASCADE"),
|
|
22
|
+
nullable=False, # A job message must belong to a job
|
|
23
|
+
doc="ID of the job that created the message",
|
|
24
|
+
)
|
|
25
|
+
message_id: Mapped[str] = mapped_column(
|
|
26
|
+
ForeignKey("messages.id", ondelete="CASCADE"),
|
|
27
|
+
nullable=False, # A job message must have a message
|
|
28
|
+
doc="ID of the message created by the job",
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
# Relationships
|
|
32
|
+
job: Mapped["Job"] = relationship("Job", back_populates="job_messages")
|
|
33
|
+
message: Mapped["Message"] = relationship("Message", back_populates="job_message")
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Optional
|
|
2
|
+
|
|
3
|
+
from sqlalchemy import ForeignKey
|
|
4
|
+
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
|
5
|
+
|
|
6
|
+
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from letta.orm.job import Job
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class JobUsageStatistics(SqlalchemyBase):
|
|
13
|
+
"""Tracks usage statistics for jobs, with future support for per-step tracking."""
|
|
14
|
+
|
|
15
|
+
__tablename__ = "job_usage_statistics"
|
|
16
|
+
|
|
17
|
+
id: Mapped[int] = mapped_column(primary_key=True, doc="Unique identifier for the usage statistics entry")
|
|
18
|
+
job_id: Mapped[str] = mapped_column(
|
|
19
|
+
ForeignKey("jobs.id", ondelete="CASCADE"), nullable=False, doc="ID of the job these statistics belong to"
|
|
20
|
+
)
|
|
21
|
+
step_id: Mapped[Optional[str]] = mapped_column(
|
|
22
|
+
nullable=True, doc="ID of the specific step within the job (for future per-step tracking)"
|
|
23
|
+
)
|
|
24
|
+
completion_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens generated by the agent")
|
|
25
|
+
prompt_tokens: Mapped[int] = mapped_column(default=0, doc="Number of tokens in the prompt")
|
|
26
|
+
total_tokens: Mapped[int] = mapped_column(default=0, doc="Total number of tokens processed by the agent")
|
|
27
|
+
step_count: Mapped[int] = mapped_column(default=0, doc="Number of steps taken by the agent")
|
|
28
|
+
|
|
29
|
+
# Relationship back to the job
|
|
30
|
+
job: Mapped["Job"] = relationship("Job", back_populates="usage_statistics")
|
letta/orm/message.py
CHANGED
|
@@ -28,3 +28,13 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin):
|
|
|
28
28
|
# Relationships
|
|
29
29
|
agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin")
|
|
30
30
|
organization: Mapped["Organization"] = relationship("Organization", back_populates="messages", lazy="selectin")
|
|
31
|
+
|
|
32
|
+
# Job relationship
|
|
33
|
+
job_message: Mapped[Optional["JobMessage"]] = relationship(
|
|
34
|
+
"JobMessage", back_populates="message", uselist=False, cascade="all, delete-orphan", single_parent=True
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def job(self) -> Optional["Job"]:
|
|
39
|
+
"""Get the job associated with this message, if any."""
|
|
40
|
+
return self.job_message.job if self.job_message else None
|
letta/orm/sqlalchemy_base.py
CHANGED
|
@@ -1,9 +1,9 @@
|
|
|
1
1
|
from datetime import datetime
|
|
2
2
|
from enum import Enum
|
|
3
3
|
from functools import wraps
|
|
4
|
-
from typing import TYPE_CHECKING, List, Literal, Optional
|
|
4
|
+
from typing import TYPE_CHECKING, List, Literal, Optional, Tuple, Union
|
|
5
5
|
|
|
6
|
-
from sqlalchemy import String, desc, func, or_, select
|
|
6
|
+
from sqlalchemy import String, and_, desc, func, or_, select
|
|
7
7
|
from sqlalchemy.exc import DBAPIError, IntegrityError, TimeoutError
|
|
8
8
|
from sqlalchemy.orm import Mapped, Session, mapped_column
|
|
9
9
|
|
|
@@ -61,6 +61,11 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
61
61
|
ascending: bool = True,
|
|
62
62
|
tags: Optional[List[str]] = None,
|
|
63
63
|
match_all_tags: bool = False,
|
|
64
|
+
actor: Optional["User"] = None,
|
|
65
|
+
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
|
66
|
+
access_type: AccessType = AccessType.ORGANIZATION,
|
|
67
|
+
join_model: Optional[Base] = None,
|
|
68
|
+
join_conditions: Optional[Union[Tuple, List]] = None,
|
|
64
69
|
**kwargs,
|
|
65
70
|
) -> List["SqlalchemyBase"]:
|
|
66
71
|
"""
|
|
@@ -94,6 +99,13 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
94
99
|
|
|
95
100
|
query = select(cls)
|
|
96
101
|
|
|
102
|
+
if join_model and join_conditions:
|
|
103
|
+
query = query.join(join_model, and_(*join_conditions))
|
|
104
|
+
|
|
105
|
+
# Apply access predicate if actor is provided
|
|
106
|
+
if actor:
|
|
107
|
+
query = cls.apply_access_predicate(query, actor, access, access_type)
|
|
108
|
+
|
|
97
109
|
# Handle tag filtering if the model has tags
|
|
98
110
|
if tags and hasattr(cls, "tags"):
|
|
99
111
|
query = select(cls)
|
|
@@ -118,7 +130,15 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
118
130
|
|
|
119
131
|
# Apply filtering logic from kwargs
|
|
120
132
|
for key, value in kwargs.items():
|
|
121
|
-
|
|
133
|
+
if "." in key:
|
|
134
|
+
# Handle joined table columns
|
|
135
|
+
table_name, column_name = key.split(".")
|
|
136
|
+
joined_table = locals().get(table_name) or globals().get(table_name)
|
|
137
|
+
column = getattr(joined_table, column_name)
|
|
138
|
+
else:
|
|
139
|
+
# Handle columns from main table
|
|
140
|
+
column = getattr(cls, key)
|
|
141
|
+
|
|
122
142
|
if isinstance(value, (list, tuple, set)):
|
|
123
143
|
query = query.where(column.in_(value))
|
|
124
144
|
else:
|
|
@@ -143,7 +163,11 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
|
143
163
|
|
|
144
164
|
# Text search
|
|
145
165
|
if query_text:
|
|
146
|
-
|
|
166
|
+
if hasattr(cls, "text"):
|
|
167
|
+
query = query.filter(func.lower(cls.text).contains(func.lower(query_text)))
|
|
168
|
+
elif hasattr(cls, "name"):
|
|
169
|
+
# Special case for Agent model - search across name
|
|
170
|
+
query = query.filter(func.lower(cls.name).contains(func.lower(query_text)))
|
|
147
171
|
|
|
148
172
|
# Embedding search (for Passages)
|
|
149
173
|
is_ordered = False
|
letta/orm/tool.py
CHANGED
|
@@ -40,9 +40,6 @@ class Tool(SqlalchemyBase, OrganizationMixin):
|
|
|
40
40
|
source_type: Mapped[ToolSourceType] = mapped_column(String, doc="The type of the source code.", default=ToolSourceType.json)
|
|
41
41
|
source_code: Mapped[Optional[str]] = mapped_column(String, doc="The source code of the function.")
|
|
42
42
|
json_schema: Mapped[Optional[dict]] = mapped_column(JSON, default=lambda: {}, doc="The OAI compatable JSON schema of the function.")
|
|
43
|
-
module: Mapped[Optional[str]] = mapped_column(
|
|
44
|
-
String, nullable=True, doc="the module path from which this tool was derived in the codebase."
|
|
45
|
-
)
|
|
46
43
|
|
|
47
44
|
# relationships
|
|
48
45
|
organization: Mapped["Organization"] = relationship("Organization", back_populates="tools", lazy="selectin")
|
letta/schemas/agent.py
CHANGED
|
@@ -95,8 +95,8 @@ class CreateAgent(BaseModel, validate_assignment=True): #
|
|
|
95
95
|
name: str = Field(default_factory=lambda: create_random_username(), description="The name of the agent.")
|
|
96
96
|
|
|
97
97
|
# memory creation
|
|
98
|
-
memory_blocks: List[CreateBlock] = Field(
|
|
99
|
-
|
|
98
|
+
memory_blocks: Optional[List[CreateBlock]] = Field(
|
|
99
|
+
None,
|
|
100
100
|
description="The blocks to create in the agent's in-context memory.",
|
|
101
101
|
)
|
|
102
102
|
# TODO: This is a legacy field and should be removed ASAP to force `tool_ids` usage
|
|
@@ -115,7 +115,12 @@ class CreateAgent(BaseModel, validate_assignment=True): #
|
|
|
115
115
|
initial_message_sequence: Optional[List[MessageCreate]] = Field(
|
|
116
116
|
None, description="The initial set of messages to put in the agent's in-context memory."
|
|
117
117
|
)
|
|
118
|
-
include_base_tools: bool = Field(
|
|
118
|
+
include_base_tools: bool = Field(
|
|
119
|
+
True, description="If true, attaches the Letta core tools (e.g. archival_memory and core_memory related functions)."
|
|
120
|
+
)
|
|
121
|
+
include_multi_agent_tools: bool = Field(
|
|
122
|
+
False, description="If true, attaches the Letta multi-agent tools (e.g. sending a message to another agent)."
|
|
123
|
+
)
|
|
119
124
|
description: Optional[str] = Field(None, description="The description of the agent.")
|
|
120
125
|
metadata_: Optional[Dict] = Field(None, description="The metadata of the agent.", alias="metadata_")
|
|
121
126
|
llm: Optional[str] = Field(
|
|
@@ -129,7 +134,8 @@ class CreateAgent(BaseModel, validate_assignment=True): #
|
|
|
129
134
|
context_window_limit: Optional[int] = Field(None, description="The context window limit used by the agent.")
|
|
130
135
|
embedding_chunk_size: Optional[int] = Field(DEFAULT_EMBEDDING_CHUNK_SIZE, description="The embedding chunk size used by the agent.")
|
|
131
136
|
from_template: Optional[str] = Field(None, description="The template id used to configure the agent")
|
|
132
|
-
|
|
137
|
+
template: bool = Field(False, description="Whether the agent is a template")
|
|
138
|
+
project: Optional[str] = Field(None, description="The project slug that the agent will be associated with.")
|
|
133
139
|
tool_exec_environment_variables: Optional[Dict[str, str]] = Field(
|
|
134
140
|
None, description="The environment variables for tool execution specific to this agent."
|
|
135
141
|
)
|
letta/schemas/job.py
CHANGED
|
@@ -3,6 +3,7 @@ from typing import Optional
|
|
|
3
3
|
|
|
4
4
|
from pydantic import Field
|
|
5
5
|
|
|
6
|
+
from letta.orm.enums import JobType
|
|
6
7
|
from letta.schemas.enums import JobStatus
|
|
7
8
|
from letta.schemas.letta_base import OrmMetadataBase
|
|
8
9
|
|
|
@@ -12,6 +13,7 @@ class JobBase(OrmMetadataBase):
|
|
|
12
13
|
status: JobStatus = Field(default=JobStatus.created, description="The status of the job.")
|
|
13
14
|
completed_at: Optional[datetime] = Field(None, description="The unix timestamp of when the job was completed.")
|
|
14
15
|
metadata_: Optional[dict] = Field(None, description="The metadata of the job.")
|
|
16
|
+
job_type: JobType = Field(default=JobType.JOB, description="The type of the job.")
|
|
15
17
|
|
|
16
18
|
|
|
17
19
|
class Job(JobBase):
|
letta/schemas/letta_base.py
CHANGED
|
@@ -52,8 +52,13 @@ class LettaBase(BaseModel):
|
|
|
52
52
|
@classmethod
|
|
53
53
|
def _id_regex_pattern(cls, prefix: str):
|
|
54
54
|
"""generates the regex pattern for a given id"""
|
|
55
|
+
if cls.__name__ in ("JobBase", "Job", "Run", "RunBase"):
|
|
56
|
+
prefix_pattern = "(job|run)"
|
|
57
|
+
else:
|
|
58
|
+
prefix_pattern = prefix
|
|
59
|
+
|
|
55
60
|
return (
|
|
56
|
-
r"^" +
|
|
61
|
+
r"^" + prefix_pattern + r"-" # prefix string
|
|
57
62
|
r"[a-fA-F0-9]{8}" # 8 hexadecimal characters
|
|
58
63
|
# r"[a-fA-F0-9]{4}-" # 4 hexadecimal characters
|
|
59
64
|
# r"[a-fA-F0-9]{4}-" # 4 hexadecimal characters
|
letta/schemas/letta_request.py
CHANGED
|
@@ -6,11 +6,8 @@ from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
|
|
6
6
|
from letta.schemas.message import MessageCreate
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
class
|
|
10
|
-
messages: List[MessageCreate] = Field(..., description="The messages to be sent to the agent.")
|
|
11
|
-
|
|
9
|
+
class LettaRequestConfig(BaseModel):
|
|
12
10
|
# Flags to support the use of AssistantMessage message types
|
|
13
|
-
|
|
14
11
|
use_assistant_message: bool = Field(
|
|
15
12
|
default=True,
|
|
16
13
|
description="Whether the server should parse specific tool call arguments (default `send_message`) as `AssistantMessage` objects.",
|
|
@@ -25,6 +22,11 @@ class LettaRequest(BaseModel):
|
|
|
25
22
|
)
|
|
26
23
|
|
|
27
24
|
|
|
25
|
+
class LettaRequest(BaseModel):
|
|
26
|
+
messages: List[MessageCreate] = Field(..., description="The messages to be sent to the agent.")
|
|
27
|
+
config: LettaRequestConfig = Field(default=LettaRequestConfig(), description="Configuration options for the LettaRequest.")
|
|
28
|
+
|
|
29
|
+
|
|
28
30
|
class LettaStreamingRequest(LettaRequest):
|
|
29
31
|
stream_tokens: bool = Field(
|
|
30
32
|
default=False,
|
letta/schemas/llm_config.py
CHANGED
|
@@ -96,7 +96,7 @@ class LLMConfig(BaseModel):
|
|
|
96
96
|
model="memgpt-openai",
|
|
97
97
|
model_endpoint_type="openai",
|
|
98
98
|
model_endpoint="https://inference.memgpt.ai",
|
|
99
|
-
context_window=
|
|
99
|
+
context_window=8192,
|
|
100
100
|
)
|
|
101
101
|
else:
|
|
102
102
|
raise ValueError(f"Model {model_name} not supported.")
|
letta/schemas/message.py
CHANGED
|
@@ -149,9 +149,9 @@ class Message(BaseMessage):
|
|
|
149
149
|
# We need to unpack the actual message contents from the function call
|
|
150
150
|
try:
|
|
151
151
|
func_args = json.loads(tool_call.function.arguments)
|
|
152
|
-
message_string = func_args[
|
|
152
|
+
message_string = func_args[assistant_message_tool_kwarg]
|
|
153
153
|
except KeyError:
|
|
154
|
-
raise ValueError(f"Function call {tool_call.function.name} missing {
|
|
154
|
+
raise ValueError(f"Function call {tool_call.function.name} missing {assistant_message_tool_kwarg} argument")
|
|
155
155
|
messages.append(
|
|
156
156
|
AssistantMessage(
|
|
157
157
|
id=self.id,
|
|
@@ -708,8 +708,6 @@ class Message(BaseMessage):
|
|
|
708
708
|
},
|
|
709
709
|
]
|
|
710
710
|
for tc in self.tool_calls:
|
|
711
|
-
# TODO better way to pack?
|
|
712
|
-
# function_call_text = json.dumps(tc.to_dict())
|
|
713
711
|
function_name = tc.function["name"]
|
|
714
712
|
function_args = json.loads(tc.function["arguments"])
|
|
715
713
|
function_args_str = ",".join([f"{k}={v}" for k, v in function_args.items()])
|
letta/schemas/providers.py
CHANGED
letta/schemas/run.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from typing import Optional
|
|
2
|
+
|
|
3
|
+
from pydantic import Field
|
|
4
|
+
|
|
5
|
+
from letta.orm.enums import JobType
|
|
6
|
+
from letta.schemas.job import Job, JobBase
|
|
7
|
+
from letta.schemas.letta_request import LettaRequestConfig
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class RunBase(JobBase):
|
|
11
|
+
"""Base class for Run schemas that inherits from JobBase but uses 'run' prefix for IDs"""
|
|
12
|
+
|
|
13
|
+
__id_prefix__ = "run"
|
|
14
|
+
job_type: JobType = JobType.RUN
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Run(RunBase):
|
|
18
|
+
"""
|
|
19
|
+
Representation of a run, which is a job with a 'run' prefix in its ID.
|
|
20
|
+
Inherits all fields and behavior from Job except for the ID prefix.
|
|
21
|
+
|
|
22
|
+
Parameters:
|
|
23
|
+
id (str): The unique identifier of the run (prefixed with 'run-').
|
|
24
|
+
status (JobStatus): The status of the run.
|
|
25
|
+
created_at (datetime): The unix timestamp of when the run was created.
|
|
26
|
+
completed_at (datetime): The unix timestamp of when the run was completed.
|
|
27
|
+
user_id (str): The unique identifier of the user associated with the run.
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
id: str = RunBase.generate_id_field()
|
|
31
|
+
user_id: Optional[str] = Field(None, description="The unique identifier of the user associated with the run.")
|
|
32
|
+
request_config: Optional[LettaRequestConfig] = Field(None, description="The request configuration for the run.")
|
|
33
|
+
|
|
34
|
+
@classmethod
|
|
35
|
+
def from_job(cls, job: Job) -> "Run":
|
|
36
|
+
"""
|
|
37
|
+
Convert a Job instance to a Run instance by replacing the ID prefix.
|
|
38
|
+
All other fields are copied as-is.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
job: The Job instance to convert
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
A new Run instance with the same data but 'run-' prefix in ID
|
|
45
|
+
"""
|
|
46
|
+
# Convert job dict to exclude None values
|
|
47
|
+
job_data = job.model_dump(exclude_none=True)
|
|
48
|
+
|
|
49
|
+
# Create new Run instance with converted data
|
|
50
|
+
return cls(**job_data)
|
|
51
|
+
|
|
52
|
+
def to_job(self) -> Job:
|
|
53
|
+
"""
|
|
54
|
+
Convert this Run instance to a Job instance by replacing the ID prefix.
|
|
55
|
+
All other fields are copied as-is.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
A new Job instance with the same data but 'job-' prefix in ID
|
|
59
|
+
"""
|
|
60
|
+
run_data = self.model_dump(exclude_none=True)
|
|
61
|
+
return Job(**run_data)
|