letta-nightly 0.8.15.dev20250719104256__py3-none-any.whl → 0.8.16.dev20250721070720__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- letta/__init__.py +1 -1
- letta/agent.py +27 -11
- letta/agents/helpers.py +1 -1
- letta/agents/letta_agent.py +518 -322
- letta/agents/letta_agent_batch.py +1 -2
- letta/agents/voice_agent.py +15 -17
- letta/client/client.py +3 -3
- letta/constants.py +5 -0
- letta/embeddings.py +0 -2
- letta/errors.py +8 -0
- letta/functions/function_sets/base.py +3 -3
- letta/functions/helpers.py +2 -3
- letta/groups/sleeptime_multi_agent.py +0 -1
- letta/helpers/composio_helpers.py +2 -2
- letta/helpers/converters.py +1 -1
- letta/helpers/pinecone_utils.py +8 -0
- letta/helpers/tool_rule_solver.py +13 -18
- letta/llm_api/aws_bedrock.py +16 -2
- letta/llm_api/cohere.py +1 -1
- letta/llm_api/openai_client.py +1 -1
- letta/local_llm/grammars/gbnf_grammar_generator.py +1 -1
- letta/local_llm/llm_chat_completion_wrappers/zephyr.py +14 -14
- letta/local_llm/utils.py +1 -2
- letta/orm/agent.py +3 -3
- letta/orm/block.py +4 -4
- letta/orm/files_agents.py +0 -1
- letta/orm/identity.py +2 -0
- letta/orm/mcp_server.py +0 -2
- letta/orm/message.py +140 -14
- letta/orm/organization.py +5 -5
- letta/orm/passage.py +4 -4
- letta/orm/source.py +1 -1
- letta/orm/sqlalchemy_base.py +61 -39
- letta/orm/step.py +2 -0
- letta/otel/db_pool_monitoring.py +308 -0
- letta/otel/metric_registry.py +94 -1
- letta/otel/sqlalchemy_instrumentation.py +548 -0
- letta/otel/sqlalchemy_instrumentation_integration.py +124 -0
- letta/otel/tracing.py +37 -1
- letta/schemas/agent.py +0 -3
- letta/schemas/agent_file.py +283 -0
- letta/schemas/block.py +0 -3
- letta/schemas/file.py +28 -26
- letta/schemas/letta_message.py +15 -4
- letta/schemas/memory.py +1 -1
- letta/schemas/message.py +31 -26
- letta/schemas/openai/chat_completion_response.py +0 -1
- letta/schemas/providers.py +20 -0
- letta/schemas/source.py +11 -13
- letta/schemas/step.py +12 -0
- letta/schemas/tool.py +0 -4
- letta/serialize_schemas/marshmallow_agent.py +14 -1
- letta/serialize_schemas/marshmallow_block.py +23 -1
- letta/serialize_schemas/marshmallow_message.py +1 -3
- letta/serialize_schemas/marshmallow_tool.py +23 -1
- letta/server/db.py +110 -6
- letta/server/rest_api/app.py +85 -73
- letta/server/rest_api/routers/v1/agents.py +68 -53
- letta/server/rest_api/routers/v1/blocks.py +2 -2
- letta/server/rest_api/routers/v1/jobs.py +3 -0
- letta/server/rest_api/routers/v1/organizations.py +2 -2
- letta/server/rest_api/routers/v1/sources.py +18 -2
- letta/server/rest_api/routers/v1/tools.py +11 -12
- letta/server/rest_api/routers/v1/users.py +1 -1
- letta/server/rest_api/streaming_response.py +13 -5
- letta/server/rest_api/utils.py +8 -25
- letta/server/server.py +11 -4
- letta/server/ws_api/server.py +2 -2
- letta/services/agent_file_manager.py +616 -0
- letta/services/agent_manager.py +133 -46
- letta/services/block_manager.py +38 -17
- letta/services/file_manager.py +106 -21
- letta/services/file_processor/file_processor.py +93 -0
- letta/services/files_agents_manager.py +28 -0
- letta/services/group_manager.py +4 -5
- letta/services/helpers/agent_manager_helper.py +57 -9
- letta/services/identity_manager.py +22 -0
- letta/services/job_manager.py +210 -91
- letta/services/llm_batch_manager.py +9 -6
- letta/services/mcp/stdio_client.py +1 -2
- letta/services/mcp_manager.py +0 -1
- letta/services/message_manager.py +49 -26
- letta/services/passage_manager.py +0 -1
- letta/services/provider_manager.py +1 -1
- letta/services/source_manager.py +114 -5
- letta/services/step_manager.py +36 -4
- letta/services/telemetry_manager.py +9 -2
- letta/services/tool_executor/builtin_tool_executor.py +5 -1
- letta/services/tool_executor/core_tool_executor.py +3 -3
- letta/services/tool_manager.py +95 -20
- letta/services/user_manager.py +4 -12
- letta/settings.py +23 -6
- letta/system.py +1 -1
- letta/utils.py +26 -2
- {letta_nightly-0.8.15.dev20250719104256.dist-info → letta_nightly-0.8.16.dev20250721070720.dist-info}/METADATA +3 -2
- {letta_nightly-0.8.15.dev20250719104256.dist-info → letta_nightly-0.8.16.dev20250721070720.dist-info}/RECORD +99 -94
- {letta_nightly-0.8.15.dev20250719104256.dist-info → letta_nightly-0.8.16.dev20250721070720.dist-info}/LICENSE +0 -0
- {letta_nightly-0.8.15.dev20250719104256.dist-info → letta_nightly-0.8.16.dev20250721070720.dist-info}/WHEEL +0 -0
- {letta_nightly-0.8.15.dev20250719104256.dist-info → letta_nightly-0.8.16.dev20250721070720.dist-info}/entry_points.txt +0 -0
letta/schemas/source.py
CHANGED
@@ -9,11 +9,17 @@ from letta.schemas.letta_base import LettaBase
|
|
9
9
|
|
10
10
|
class BaseSource(LettaBase):
|
11
11
|
"""
|
12
|
-
Shared attributes
|
12
|
+
Shared attributes across all source schemas.
|
13
13
|
"""
|
14
14
|
|
15
15
|
__id_prefix__ = "source"
|
16
16
|
|
17
|
+
# Core source fields
|
18
|
+
name: str = Field(..., description="The name of the source.")
|
19
|
+
description: Optional[str] = Field(None, description="The description of the source.")
|
20
|
+
instructions: Optional[str] = Field(None, description="Instructions for how to use the source.")
|
21
|
+
metadata: Optional[dict] = Field(None, description="Metadata associated with the source.")
|
22
|
+
|
17
23
|
|
18
24
|
class Source(BaseSource):
|
19
25
|
"""
|
@@ -29,9 +35,6 @@ class Source(BaseSource):
|
|
29
35
|
"""
|
30
36
|
|
31
37
|
id: str = BaseSource.generate_id_field()
|
32
|
-
name: str = Field(..., description="The name of the source.")
|
33
|
-
description: Optional[str] = Field(None, description="The description of the source.")
|
34
|
-
instructions: Optional[str] = Field(None, description="Instructions for how to use the source.")
|
35
38
|
embedding_config: EmbeddingConfig = Field(..., description="The embedding configuration used by the source.")
|
36
39
|
organization_id: Optional[str] = Field(None, description="The ID of the organization that created the source.")
|
37
40
|
metadata: Optional[dict] = Field(None, validation_alias="metadata_", description="Metadata associated with the source.")
|
@@ -48,29 +51,24 @@ class SourceCreate(BaseSource):
|
|
48
51
|
Schema for creating a new Source.
|
49
52
|
"""
|
50
53
|
|
51
|
-
# required
|
52
|
-
name: str = Field(..., description="The name of the source.")
|
53
54
|
# TODO: @matt, make this required after shub makes the FE changes
|
54
|
-
|
55
|
-
embedding: Optional[str] = Field(None, description="The hande for the embedding config used by the source.")
|
55
|
+
embedding: Optional[str] = Field(None, description="The handle for the embedding config used by the source.")
|
56
56
|
embedding_chunk_size: Optional[int] = Field(None, description="The chunk size of the embedding.")
|
57
57
|
|
58
58
|
# TODO: remove (legacy config)
|
59
59
|
embedding_config: Optional[EmbeddingConfig] = Field(None, description="(Legacy) The embedding configuration used by the source.")
|
60
60
|
|
61
|
-
# optional
|
62
|
-
description: Optional[str] = Field(None, description="The description of the source.")
|
63
|
-
instructions: Optional[str] = Field(None, description="Instructions for how to use the source.")
|
64
|
-
metadata: Optional[dict] = Field(None, description="Metadata associated with the source.")
|
65
|
-
|
66
61
|
|
67
62
|
class SourceUpdate(BaseSource):
|
68
63
|
"""
|
69
64
|
Schema for updating an existing Source.
|
70
65
|
"""
|
71
66
|
|
67
|
+
# Override base fields to make them optional for updates
|
72
68
|
name: Optional[str] = Field(None, description="The name of the source.")
|
73
69
|
description: Optional[str] = Field(None, description="The description of the source.")
|
74
70
|
instructions: Optional[str] = Field(None, description="Instructions for how to use the source.")
|
75
71
|
metadata: Optional[dict] = Field(None, description="Metadata associated with the source.")
|
72
|
+
|
73
|
+
# Additional update-specific fields
|
76
74
|
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the source.")
|
letta/schemas/step.py
CHANGED
@@ -1,8 +1,10 @@
|
|
1
|
+
from enum import Enum, auto
|
1
2
|
from typing import Dict, List, Literal, Optional
|
2
3
|
|
3
4
|
from pydantic import Field
|
4
5
|
|
5
6
|
from letta.schemas.letta_base import LettaBase
|
7
|
+
from letta.schemas.letta_stop_reason import StopReasonType
|
6
8
|
from letta.schemas.message import Message
|
7
9
|
|
8
10
|
|
@@ -28,6 +30,7 @@ class Step(StepBase):
|
|
28
30
|
prompt_tokens: Optional[int] = Field(None, description="The number of tokens in the prompt during this step.")
|
29
31
|
total_tokens: Optional[int] = Field(None, description="The total number of tokens processed by the agent during this step.")
|
30
32
|
completion_tokens_details: Optional[Dict] = Field(None, description="Metadata for the agent.")
|
33
|
+
stop_reason: Optional[StopReasonType] = Field(None, description="The stop reason associated with the step.")
|
31
34
|
tags: List[str] = Field([], description="Metadata tags.")
|
32
35
|
tid: Optional[str] = Field(None, description="The unique identifier of the transaction that processed this step.")
|
33
36
|
trace_id: Optional[str] = Field(None, description="The trace id of the agent step.")
|
@@ -36,3 +39,12 @@ class Step(StepBase):
|
|
36
39
|
None, description="The feedback for this step. Must be either 'positive' or 'negative'."
|
37
40
|
)
|
38
41
|
project_id: Optional[str] = Field(None, description="The project that the agent that executed this step belongs to (cloud only).")
|
42
|
+
|
43
|
+
|
44
|
+
class StepProgression(int, Enum):
|
45
|
+
START = auto()
|
46
|
+
STREAM_RECEIVED = auto()
|
47
|
+
RESPONSE_RECEIVED = auto()
|
48
|
+
STEP_LOGGED = auto()
|
49
|
+
LOGGED_TRACE = auto()
|
50
|
+
FINISHED = auto()
|
letta/schemas/tool.py
CHANGED
@@ -50,7 +50,6 @@ class Tool(BaseTool):
|
|
50
50
|
tool_type: ToolType = Field(ToolType.CUSTOM, description="The type of the tool.")
|
51
51
|
description: Optional[str] = Field(None, description="The description of the tool.")
|
52
52
|
source_type: Optional[str] = Field(None, description="The type of the source code.")
|
53
|
-
organization_id: Optional[str] = Field(None, description="The unique identifier of the organization associated with the tool.")
|
54
53
|
name: Optional[str] = Field(None, description="The name of the function.")
|
55
54
|
tags: List[str] = Field([], description="Metadata tags.")
|
56
55
|
|
@@ -147,9 +146,6 @@ class ToolCreate(LettaBase):
|
|
147
146
|
return_char_limit: int = Field(FUNCTION_RETURN_CHAR_LIMIT, description="The maximum number of characters in the response.")
|
148
147
|
pip_requirements: Optional[List[PipRequirement]] = Field(None, description="Optional list of pip packages required by this tool.")
|
149
148
|
|
150
|
-
# TODO should we put the HTTP / API fetch inside from_mcp?
|
151
|
-
# async def from_mcp(cls, mcp_server: str, mcp_tool_name: str) -> "ToolCreate":
|
152
|
-
|
153
149
|
@classmethod
|
154
150
|
def from_mcp(cls, mcp_server_name: str, mcp_tool: MCPTool) -> "ToolCreate":
|
155
151
|
from letta.functions.helpers import generate_mcp_tool_wrapper
|
@@ -86,7 +86,9 @@ class MarshmallowAgentSchema(BaseSchema):
|
|
86
86
|
- Marks messages as in-context, preserving the order of the original `message_ids`
|
87
87
|
- Removes individual message `id` fields
|
88
88
|
"""
|
89
|
-
|
89
|
+
del data["id"]
|
90
|
+
del data["_created_by_id"]
|
91
|
+
del data["_last_updated_by_id"]
|
90
92
|
data[self.FIELD_VERSION] = letta.__version__
|
91
93
|
|
92
94
|
original_message_ids = data.pop(self.FIELD_MESSAGE_IDS, [])
|
@@ -107,6 +109,15 @@ class MarshmallowAgentSchema(BaseSchema):
|
|
107
109
|
|
108
110
|
return data
|
109
111
|
|
112
|
+
@pre_load
|
113
|
+
def regenerate_ids(self, data: Dict, **kwargs) -> Dict:
|
114
|
+
if self.Meta.model:
|
115
|
+
data["id"] = self.generate_id()
|
116
|
+
data["_created_by_id"] = self.actor.id
|
117
|
+
data["_last_updated_by_id"] = self.actor.id
|
118
|
+
|
119
|
+
return data
|
120
|
+
|
110
121
|
@post_dump
|
111
122
|
def hide_tool_exec_environment_variables(self, data: Dict, **kwargs):
|
112
123
|
"""Hide the value of tool_exec_environment_variables"""
|
@@ -135,4 +146,6 @@ class MarshmallowAgentSchema(BaseSchema):
|
|
135
146
|
"identities",
|
136
147
|
"is_deleted",
|
137
148
|
"groups",
|
149
|
+
"batch_items",
|
150
|
+
"organization",
|
138
151
|
)
|
@@ -1,3 +1,7 @@
|
|
1
|
+
from typing import Dict
|
2
|
+
|
3
|
+
from marshmallow import post_dump, pre_load
|
4
|
+
|
1
5
|
from letta.orm.block import Block
|
2
6
|
from letta.schemas.block import Block as PydanticBlock
|
3
7
|
from letta.serialize_schemas.marshmallow_base import BaseSchema
|
@@ -10,6 +14,24 @@ class SerializedBlockSchema(BaseSchema):
|
|
10
14
|
|
11
15
|
__pydantic_model__ = PydanticBlock
|
12
16
|
|
17
|
+
@post_dump
|
18
|
+
def sanitize_ids(self, data: Dict, **kwargs) -> Dict:
|
19
|
+
# delete id
|
20
|
+
del data["id"]
|
21
|
+
del data["_created_by_id"]
|
22
|
+
del data["_last_updated_by_id"]
|
23
|
+
|
24
|
+
return data
|
25
|
+
|
26
|
+
@pre_load
|
27
|
+
def regenerate_ids(self, data: Dict, **kwargs) -> Dict:
|
28
|
+
if self.Meta.model:
|
29
|
+
data["id"] = self.generate_id()
|
30
|
+
data["_created_by_id"] = self.actor.id
|
31
|
+
data["_last_updated_by_id"] = self.actor.id
|
32
|
+
|
33
|
+
return data
|
34
|
+
|
13
35
|
class Meta(BaseSchema.Meta):
|
14
36
|
model = Block
|
15
|
-
exclude = BaseSchema.Meta.exclude + ("agents", "identities", "is_deleted")
|
37
|
+
exclude = BaseSchema.Meta.exclude + ("agents", "identities", "is_deleted", "groups", "organization")
|
@@ -23,7 +23,6 @@ class SerializedMessageSchema(BaseSchema):
|
|
23
23
|
# agent dump will then get rid of message ids
|
24
24
|
del data["_created_by_id"]
|
25
25
|
del data["_last_updated_by_id"]
|
26
|
-
del data["organization"]
|
27
26
|
|
28
27
|
return data
|
29
28
|
|
@@ -33,10 +32,9 @@ class SerializedMessageSchema(BaseSchema):
|
|
33
32
|
# Skip regenerating ID, as agent dump will do it
|
34
33
|
data["_created_by_id"] = self.actor.id
|
35
34
|
data["_last_updated_by_id"] = self.actor.id
|
36
|
-
data["organization"] = self.actor.organization_id
|
37
35
|
|
38
36
|
return data
|
39
37
|
|
40
38
|
class Meta(BaseSchema.Meta):
|
41
39
|
model = Message
|
42
|
-
exclude = BaseSchema.Meta.exclude + ("step", "job_message", "otid", "is_deleted")
|
40
|
+
exclude = BaseSchema.Meta.exclude + ("step", "job_message", "otid", "is_deleted", "organization")
|
@@ -1,3 +1,7 @@
|
|
1
|
+
from typing import Dict
|
2
|
+
|
3
|
+
from marshmallow import post_dump, pre_load
|
4
|
+
|
1
5
|
from letta.orm import Tool
|
2
6
|
from letta.schemas.tool import Tool as PydanticTool
|
3
7
|
from letta.serialize_schemas.marshmallow_base import BaseSchema
|
@@ -10,6 +14,24 @@ class SerializedToolSchema(BaseSchema):
|
|
10
14
|
|
11
15
|
__pydantic_model__ = PydanticTool
|
12
16
|
|
17
|
+
@post_dump
|
18
|
+
def sanitize_ids(self, data: Dict, **kwargs) -> Dict:
|
19
|
+
# delete id
|
20
|
+
del data["id"]
|
21
|
+
del data["_created_by_id"]
|
22
|
+
del data["_last_updated_by_id"]
|
23
|
+
|
24
|
+
return data
|
25
|
+
|
26
|
+
@pre_load
|
27
|
+
def regenerate_ids(self, data: Dict, **kwargs) -> Dict:
|
28
|
+
if self.Meta.model:
|
29
|
+
data["id"] = self.generate_id()
|
30
|
+
data["_created_by_id"] = self.actor.id
|
31
|
+
data["_last_updated_by_id"] = self.actor.id
|
32
|
+
|
33
|
+
return data
|
34
|
+
|
13
35
|
class Meta(BaseSchema.Meta):
|
14
36
|
model = Tool
|
15
|
-
exclude = BaseSchema.Meta.exclude + ("is_deleted",)
|
37
|
+
exclude = BaseSchema.Meta.exclude + ("is_deleted", "organization")
|
letta/server/db.py
CHANGED
@@ -1,20 +1,22 @@
|
|
1
1
|
import os
|
2
2
|
import threading
|
3
|
+
import time
|
3
4
|
import uuid
|
4
5
|
from contextlib import asynccontextmanager, contextmanager
|
5
6
|
from typing import Any, AsyncGenerator, Generator
|
6
7
|
|
8
|
+
from opentelemetry import trace
|
7
9
|
from rich.console import Console
|
8
10
|
from rich.panel import Panel
|
9
11
|
from rich.text import Text
|
10
|
-
from sqlalchemy import Engine, NullPool, QueuePool, create_engine
|
12
|
+
from sqlalchemy import Engine, NullPool, QueuePool, create_engine, event
|
11
13
|
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
|
12
14
|
from sqlalchemy.orm import sessionmaker
|
13
15
|
|
14
16
|
from letta.config import LettaConfig
|
15
17
|
from letta.log import get_logger
|
16
18
|
from letta.otel.tracing import trace_method
|
17
|
-
from letta.settings import settings
|
19
|
+
from letta.settings import DatabaseChoice, settings
|
18
20
|
|
19
21
|
logger = get_logger(__name__)
|
20
22
|
|
@@ -36,6 +38,46 @@ def print_sqlite_schema_error():
|
|
36
38
|
console.print(Panel(error_text, border_style="red"))
|
37
39
|
|
38
40
|
|
41
|
+
@event.listens_for(Engine, "connect")
|
42
|
+
def enable_sqlite_foreign_keys(dbapi_connection, connection_record):
|
43
|
+
"""Enable foreign key constraints for SQLite connections."""
|
44
|
+
if "sqlite" in str(dbapi_connection):
|
45
|
+
cursor = dbapi_connection.cursor()
|
46
|
+
cursor.execute("PRAGMA foreign_keys=ON")
|
47
|
+
cursor.close()
|
48
|
+
|
49
|
+
|
50
|
+
def on_connect(dbapi_connection, connection_record):
|
51
|
+
cursor = dbapi_connection.cursor()
|
52
|
+
cursor.execute("SELECT pg_backend_pid()")
|
53
|
+
pid = cursor.fetchone()[0]
|
54
|
+
connection_record.info["pid"] = pid
|
55
|
+
connection_record.info["connect_spawn_time_ms"] = time.perf_counter() * 1000
|
56
|
+
cursor.close()
|
57
|
+
|
58
|
+
|
59
|
+
def on_close(dbapi_connection, connection_record):
|
60
|
+
connection_record.info.get("pid")
|
61
|
+
(time.perf_counter() * 1000) - connection_record.info.get("connect_spawn_time_ms")
|
62
|
+
# print(f"Connection closed: {pid}, duration: {duration:.6f}s")
|
63
|
+
|
64
|
+
|
65
|
+
def on_checkout(dbapi_connection, connection_record, connection_proxy):
|
66
|
+
connection_record.info.get("pid")
|
67
|
+
connection_record.info["connect_checkout_time_ms"] = time.perf_counter() * 1000
|
68
|
+
|
69
|
+
|
70
|
+
def on_checkin(dbapi_connection, connection_record):
|
71
|
+
pid = connection_record.info.get("pid")
|
72
|
+
duration = (time.perf_counter() * 1000) - connection_record.info.get("connect_checkout_time_ms")
|
73
|
+
|
74
|
+
tracer = trace.get_tracer("letta.db.connection")
|
75
|
+
with tracer.start_as_current_span("connect_release") as span:
|
76
|
+
span.set_attribute("db.connection.pid", pid)
|
77
|
+
span.set_attribute("db.connection.duration_ms", duration)
|
78
|
+
span.set_attribute("db.connection.operation", "checkin")
|
79
|
+
|
80
|
+
|
39
81
|
@contextmanager
|
40
82
|
def db_error_handler():
|
41
83
|
"""Context manager for handling database errors"""
|
@@ -43,6 +85,14 @@ def db_error_handler():
|
|
43
85
|
yield
|
44
86
|
except Exception as e:
|
45
87
|
# Handle other SQLAlchemy errors
|
88
|
+
error_str = str(e)
|
89
|
+
|
90
|
+
# Don't exit for expected constraint violations that should be handled by the application
|
91
|
+
if "UNIQUE constraint failed" in error_str or "FOREIGN KEY constraint failed" in error_str:
|
92
|
+
# These are application-level errors that should be handled by the ORM
|
93
|
+
raise
|
94
|
+
|
95
|
+
# For other database errors, print error and exit
|
46
96
|
print(e)
|
47
97
|
print_sqlite_schema_error()
|
48
98
|
# raise ValueError(f"SQLite DB error: {str(e)}")
|
@@ -73,7 +123,7 @@ class DatabaseRegistry:
|
|
73
123
|
return
|
74
124
|
|
75
125
|
# Postgres engine
|
76
|
-
if settings.
|
126
|
+
if settings.database_engine is DatabaseChoice.POSTGRES:
|
77
127
|
self.logger.info("Creating postgres engine")
|
78
128
|
self.config.recall_storage_type = "postgres"
|
79
129
|
self.config.recall_storage_uri = settings.letta_pg_uri_no_default
|
@@ -99,6 +149,15 @@ class DatabaseRegistry:
|
|
99
149
|
Base.metadata.create_all(bind=engine)
|
100
150
|
self._engines["default"] = engine
|
101
151
|
|
152
|
+
# Set up connection monitoring
|
153
|
+
if settings.sqlalchemy_tracing and settings.database_engine is DatabaseChoice.POSTGRES:
|
154
|
+
event.listen(engine, "connect", on_connect)
|
155
|
+
event.listen(engine, "close", on_close)
|
156
|
+
event.listen(engine, "checkout", on_checkout)
|
157
|
+
event.listen(engine, "checkin", on_checkin)
|
158
|
+
|
159
|
+
self._setup_pool_monitoring(engine, "default")
|
160
|
+
|
102
161
|
# Create session factory
|
103
162
|
self._session_factories["default"] = sessionmaker(autocommit=False, autoflush=False, bind=self._engines["default"])
|
104
163
|
self._initialized["sync"] = True
|
@@ -109,7 +168,7 @@ class DatabaseRegistry:
|
|
109
168
|
if self._initialized.get("async") and not force:
|
110
169
|
return
|
111
170
|
|
112
|
-
if settings.
|
171
|
+
if settings.database_engine is DatabaseChoice.POSTGRES:
|
113
172
|
self.logger.info("Creating async postgres engine")
|
114
173
|
|
115
174
|
# Create async engine - convert URI to async format
|
@@ -128,10 +187,27 @@ class DatabaseRegistry:
|
|
128
187
|
self.logger.info("Creating sqlite engine " + engine_path)
|
129
188
|
async_engine = create_async_engine(engine_path, **self._build_sqlalchemy_engine_args(is_async=True))
|
130
189
|
|
190
|
+
# Enable foreign keys for SQLite async connections
|
191
|
+
@event.listens_for(async_engine.sync_engine, "connect")
|
192
|
+
def enable_sqlite_foreign_keys_async(dbapi_connection, connection_record):
|
193
|
+
cursor = dbapi_connection.cursor()
|
194
|
+
cursor.execute("PRAGMA foreign_keys=ON")
|
195
|
+
cursor.close()
|
196
|
+
|
131
197
|
# Create async session factory
|
132
198
|
self._async_engines["default"] = async_engine
|
199
|
+
|
200
|
+
# Set up connection monitoring for async engine
|
201
|
+
if settings.sqlalchemy_tracing and settings.database_engine is DatabaseChoice.POSTGRES:
|
202
|
+
event.listen(async_engine.sync_engine, "connect", on_connect)
|
203
|
+
event.listen(async_engine.sync_engine, "close", on_close)
|
204
|
+
event.listen(async_engine.sync_engine, "checkout", on_checkout)
|
205
|
+
event.listen(async_engine.sync_engine, "checkin", on_checkin)
|
206
|
+
|
207
|
+
self._setup_pool_monitoring(async_engine, "default_async")
|
208
|
+
|
133
209
|
self._async_session_factories["default"] = async_sessionmaker(
|
134
|
-
expire_on_commit=
|
210
|
+
expire_on_commit=False,
|
135
211
|
close_resets_only=False,
|
136
212
|
autocommit=False,
|
137
213
|
autoflush=False,
|
@@ -149,7 +225,10 @@ class DatabaseRegistry:
|
|
149
225
|
pool_cls = NullPool
|
150
226
|
else:
|
151
227
|
logger.info("Enabling pooling on SqlAlchemy")
|
152
|
-
|
228
|
+
# AsyncAdaptedQueuePool will be the default if none is provided for async but setting this explicitly.
|
229
|
+
from sqlalchemy import AsyncAdaptedQueuePool
|
230
|
+
|
231
|
+
pool_cls = QueuePool if not is_async else AsyncAdaptedQueuePool
|
153
232
|
|
154
233
|
base_args = {
|
155
234
|
"echo": settings.pg_echo,
|
@@ -207,11 +286,31 @@ class DatabaseRegistry:
|
|
207
286
|
|
208
287
|
engine.connect = wrapped_connect
|
209
288
|
|
289
|
+
def _setup_pool_monitoring(self, engine: Engine | AsyncEngine, engine_name: str) -> None:
|
290
|
+
"""Set up database pool monitoring for the given engine."""
|
291
|
+
if not settings.enable_db_pool_monitoring:
|
292
|
+
return
|
293
|
+
|
294
|
+
try:
|
295
|
+
from letta.otel.db_pool_monitoring import setup_pool_monitoring
|
296
|
+
|
297
|
+
setup_pool_monitoring(engine, engine_name)
|
298
|
+
self.logger.info(f"Database pool monitoring enabled for {engine_name}")
|
299
|
+
except ImportError:
|
300
|
+
self.logger.warning("Database pool monitoring not available - missing dependencies")
|
301
|
+
except Exception as e:
|
302
|
+
self.logger.warning(f"Failed to setup pool monitoring for {engine_name}: {e}")
|
303
|
+
|
210
304
|
def get_engine(self, name: str = "default") -> Engine:
|
211
305
|
"""Get a database engine by name."""
|
212
306
|
self.initialize_sync()
|
213
307
|
return self._engines.get(name)
|
214
308
|
|
309
|
+
def get_async_engine(self, name: str = "default") -> Engine:
|
310
|
+
"""Get a database engine by name."""
|
311
|
+
self.initialize_async()
|
312
|
+
return self._async_engines.get(name)
|
313
|
+
|
215
314
|
def get_session_factory(self, name: str = "default") -> sessionmaker:
|
216
315
|
"""Get a session factory by name."""
|
217
316
|
self.initialize_sync()
|
@@ -286,6 +385,11 @@ class DatabaseRegistry:
|
|
286
385
|
db_registry = DatabaseRegistry()
|
287
386
|
|
288
387
|
|
388
|
+
def get_db_registry() -> DatabaseRegistry:
|
389
|
+
"""Get the global database registry instance."""
|
390
|
+
return db_registry
|
391
|
+
|
392
|
+
|
289
393
|
def get_db():
|
290
394
|
"""Get a database session."""
|
291
395
|
with db_registry.session() as session:
|