dao-ai 0.0.20__py3-none-any.whl → 0.0.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/config.py +114 -17
- dao_ai/graph.py +3 -0
- dao_ai/memory/core.py +10 -6
- dao_ai/memory/postgres.py +102 -33
- dao_ai/models.py +137 -2
- dao_ai/providers/databricks.py +282 -0
- dao_ai/state.py +3 -0
- dao_ai/tools/genie.py +346 -34
- dao_ai/utils.py +4 -0
- {dao_ai-0.0.20.dist-info → dao_ai-0.0.22.dist-info}/METADATA +3 -3
- {dao_ai-0.0.20.dist-info → dao_ai-0.0.22.dist-info}/RECORD +14 -14
- {dao_ai-0.0.20.dist-info → dao_ai-0.0.22.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.20.dist-info → dao_ai-0.0.22.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.20.dist-info → dao_ai-0.0.22.dist-info}/licenses/LICENSE +0 -0
dao_ai/config.py
CHANGED
|
@@ -23,6 +23,7 @@ from databricks.sdk.credentials_provider import (
|
|
|
23
23
|
ModelServingUserCredentials,
|
|
24
24
|
)
|
|
25
25
|
from databricks.sdk.service.catalog import FunctionInfo, TableInfo
|
|
26
|
+
from databricks.sdk.service.database import DatabaseInstance
|
|
26
27
|
from databricks.vector_search.client import VectorSearchClient
|
|
27
28
|
from databricks.vector_search.index import VectorSearchIndex
|
|
28
29
|
from databricks_langchain import (
|
|
@@ -427,7 +428,8 @@ class GenieRoomModel(BaseModel, IsDatabricksResource):
|
|
|
427
428
|
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
428
429
|
return [
|
|
429
430
|
DatabricksGenieSpace(
|
|
430
|
-
genie_space_id=self.space_id,
|
|
431
|
+
genie_space_id=value_of(self.space_id),
|
|
432
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
431
433
|
)
|
|
432
434
|
]
|
|
433
435
|
|
|
@@ -437,7 +439,7 @@ class GenieRoomModel(BaseModel, IsDatabricksResource):
|
|
|
437
439
|
return self
|
|
438
440
|
|
|
439
441
|
|
|
440
|
-
class VolumeModel(BaseModel, HasFullName):
|
|
442
|
+
class VolumeModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
441
443
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
442
444
|
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
443
445
|
name: str
|
|
@@ -455,6 +457,13 @@ class VolumeModel(BaseModel, HasFullName):
|
|
|
455
457
|
provider: ServiceProvider = DatabricksProvider(w=w)
|
|
456
458
|
provider.create_volume(self)
|
|
457
459
|
|
|
460
|
+
@property
|
|
461
|
+
def api_scopes(self) -> Sequence[str]:
|
|
462
|
+
return ["files.files", "catalog.volumes"]
|
|
463
|
+
|
|
464
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
465
|
+
return []
|
|
466
|
+
|
|
458
467
|
|
|
459
468
|
class VolumePathModel(BaseModel, HasFullName):
|
|
460
469
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
@@ -683,7 +692,8 @@ class WarehouseModel(BaseModel, IsDatabricksResource):
|
|
|
683
692
|
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
684
693
|
return [
|
|
685
694
|
DatabricksSQLWarehouse(
|
|
686
|
-
warehouse_id=self.warehouse_id,
|
|
695
|
+
warehouse_id=value_of(self.warehouse_id),
|
|
696
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
687
697
|
)
|
|
688
698
|
]
|
|
689
699
|
|
|
@@ -694,15 +704,18 @@ class WarehouseModel(BaseModel, IsDatabricksResource):
|
|
|
694
704
|
|
|
695
705
|
|
|
696
706
|
class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
697
|
-
model_config = ConfigDict(
|
|
707
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
698
708
|
name: str
|
|
709
|
+
instance_name: Optional[str] = None
|
|
699
710
|
description: Optional[str] = None
|
|
700
|
-
host: Optional[AnyVariable]
|
|
711
|
+
host: Optional[AnyVariable] = None
|
|
701
712
|
database: Optional[AnyVariable] = "databricks_postgres"
|
|
702
713
|
port: Optional[AnyVariable] = 5432
|
|
703
714
|
connection_kwargs: Optional[dict[str, Any]] = Field(default_factory=dict)
|
|
704
715
|
max_pool_size: Optional[int] = 10
|
|
705
|
-
timeout_seconds: Optional[int] =
|
|
716
|
+
timeout_seconds: Optional[int] = 10
|
|
717
|
+
capacity: Optional[Literal["CU_1", "CU_2"]] = "CU_2"
|
|
718
|
+
node_count: Optional[int] = None
|
|
706
719
|
user: Optional[AnyVariable] = None
|
|
707
720
|
password: Optional[AnyVariable] = None
|
|
708
721
|
client_id: Optional[AnyVariable] = None
|
|
@@ -716,11 +729,44 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
716
729
|
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
717
730
|
return [
|
|
718
731
|
DatabricksLakebase(
|
|
719
|
-
database_instance_name=self.
|
|
732
|
+
database_instance_name=self.instance_name,
|
|
720
733
|
on_behalf_of_user=self.on_behalf_of_user,
|
|
721
734
|
)
|
|
722
735
|
]
|
|
723
736
|
|
|
737
|
+
@model_validator(mode="after")
|
|
738
|
+
def update_instance_name(self):
|
|
739
|
+
if self.instance_name is None:
|
|
740
|
+
self.instance_name = self.name
|
|
741
|
+
|
|
742
|
+
return self
|
|
743
|
+
|
|
744
|
+
@model_validator(mode="after")
|
|
745
|
+
def update_user(self):
|
|
746
|
+
if self.client_id or self.user:
|
|
747
|
+
return self
|
|
748
|
+
|
|
749
|
+
self.user = self.workspace_client.current_user.me().user_name
|
|
750
|
+
if not self.user:
|
|
751
|
+
raise ValueError(
|
|
752
|
+
"Unable to determine current user. Please provide a user name or OAuth credentials."
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
return self
|
|
756
|
+
|
|
757
|
+
@model_validator(mode="after")
|
|
758
|
+
def update_host(self):
|
|
759
|
+
if self.host is not None:
|
|
760
|
+
return self
|
|
761
|
+
|
|
762
|
+
existing_instance: DatabaseInstance = (
|
|
763
|
+
self.workspace_client.database.get_database_instance(
|
|
764
|
+
name=self.instance_name
|
|
765
|
+
)
|
|
766
|
+
)
|
|
767
|
+
self.host = existing_instance.read_write_dns
|
|
768
|
+
return self
|
|
769
|
+
|
|
724
770
|
@model_validator(mode="after")
|
|
725
771
|
def validate_auth_methods(self):
|
|
726
772
|
oauth_fields: Sequence[Any] = [
|
|
@@ -730,7 +776,7 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
730
776
|
]
|
|
731
777
|
has_oauth: bool = all(field is not None for field in oauth_fields)
|
|
732
778
|
|
|
733
|
-
pat_fields: Sequence[Any] = [self.user
|
|
779
|
+
pat_fields: Sequence[Any] = [self.user]
|
|
734
780
|
has_user_auth: bool = all(field is not None for field in pat_fields)
|
|
735
781
|
|
|
736
782
|
if has_oauth and has_user_auth:
|
|
@@ -749,7 +795,14 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
749
795
|
return self
|
|
750
796
|
|
|
751
797
|
@property
|
|
752
|
-
def
|
|
798
|
+
def connection_params(self) -> dict[str, Any]:
|
|
799
|
+
"""
|
|
800
|
+
Get database connection parameters as a dictionary.
|
|
801
|
+
|
|
802
|
+
Returns a dict with connection parameters suitable for psycopg ConnectionPool.
|
|
803
|
+
If username is configured, it will be included; otherwise it will be omitted
|
|
804
|
+
to allow Lakebase to authenticate using the token's identity.
|
|
805
|
+
"""
|
|
753
806
|
from dao_ai.providers.base import ServiceProvider
|
|
754
807
|
from dao_ai.providers.databricks import DatabricksProvider
|
|
755
808
|
|
|
@@ -757,7 +810,7 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
757
810
|
|
|
758
811
|
if self.client_id and self.client_secret and self.workspace_host:
|
|
759
812
|
username = value_of(self.client_id)
|
|
760
|
-
|
|
813
|
+
elif self.user:
|
|
761
814
|
username = value_of(self.user)
|
|
762
815
|
|
|
763
816
|
host: str = value_of(self.host)
|
|
@@ -770,11 +823,48 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
770
823
|
workspace_host=value_of(self.workspace_host),
|
|
771
824
|
pat=value_of(self.password),
|
|
772
825
|
)
|
|
773
|
-
token: str = provider.create_token()
|
|
774
826
|
|
|
775
|
-
|
|
776
|
-
|
|
777
|
-
|
|
827
|
+
token: str = provider.lakebase_password_provider(self.instance_name)
|
|
828
|
+
|
|
829
|
+
# Build connection parameters dictionary
|
|
830
|
+
params: dict[str, Any] = {
|
|
831
|
+
"dbname": database,
|
|
832
|
+
"host": host,
|
|
833
|
+
"port": port,
|
|
834
|
+
"password": token,
|
|
835
|
+
"sslmode": "require",
|
|
836
|
+
}
|
|
837
|
+
|
|
838
|
+
# Only include user if explicitly configured
|
|
839
|
+
if username:
|
|
840
|
+
params["user"] = username
|
|
841
|
+
logger.debug(
|
|
842
|
+
f"Connection params: dbname={database} user={username} host={host} port={port} password=******** sslmode=require"
|
|
843
|
+
)
|
|
844
|
+
else:
|
|
845
|
+
logger.debug(
|
|
846
|
+
f"Connection params: dbname={database} host={host} port={port} password=******** sslmode=require (using token identity)"
|
|
847
|
+
)
|
|
848
|
+
|
|
849
|
+
return params
|
|
850
|
+
|
|
851
|
+
@property
|
|
852
|
+
def connection_url(self) -> str:
|
|
853
|
+
"""
|
|
854
|
+
Get database connection URL as a string (for backwards compatibility).
|
|
855
|
+
|
|
856
|
+
Note: It's recommended to use connection_params instead for better flexibility.
|
|
857
|
+
"""
|
|
858
|
+
params = self.connection_params
|
|
859
|
+
parts = [f"{k}={v}" for k, v in params.items()]
|
|
860
|
+
return " ".join(parts)
|
|
861
|
+
|
|
862
|
+
def create(self, w: WorkspaceClient | None = None) -> None:
|
|
863
|
+
from dao_ai.providers.databricks import DatabricksProvider
|
|
864
|
+
|
|
865
|
+
provider: DatabricksProvider = DatabricksProvider()
|
|
866
|
+
provider.create_lakebase(self)
|
|
867
|
+
provider.create_lakebase_instance_role(self)
|
|
778
868
|
|
|
779
869
|
|
|
780
870
|
class SearchParametersModel(BaseModel):
|
|
@@ -867,7 +957,7 @@ class PythonFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
867
957
|
|
|
868
958
|
class FactoryFunctionModel(BaseFunctionModel, HasFullName):
|
|
869
959
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
870
|
-
args: Optional[dict[str,
|
|
960
|
+
args: Optional[dict[str, Any]] = Field(default_factory=dict)
|
|
871
961
|
type: Literal[FunctionType.FACTORY] = FunctionType.FACTORY
|
|
872
962
|
|
|
873
963
|
@property
|
|
@@ -879,6 +969,12 @@ class FactoryFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
879
969
|
|
|
880
970
|
return [create_factory_tool(self, **kwargs)]
|
|
881
971
|
|
|
972
|
+
@model_validator(mode="after")
|
|
973
|
+
def update_args(self):
|
|
974
|
+
for key, value in self.args.items():
|
|
975
|
+
self.args[key] = value_of(value)
|
|
976
|
+
return self
|
|
977
|
+
|
|
882
978
|
|
|
883
979
|
class TransportType(str, Enum):
|
|
884
980
|
STREAMABLE_HTTP = "streamable_http"
|
|
@@ -1078,6 +1174,7 @@ class AgentModel(BaseModel):
|
|
|
1078
1174
|
class SupervisorModel(BaseModel):
|
|
1079
1175
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1080
1176
|
model: LLMModel
|
|
1177
|
+
tools: list[ToolModel] = Field(default_factory=list)
|
|
1081
1178
|
prompt: Optional[str] = None
|
|
1082
1179
|
|
|
1083
1180
|
|
|
@@ -1262,12 +1359,12 @@ class AppModel(BaseModel):
|
|
|
1262
1359
|
if len(self.agents) > 1:
|
|
1263
1360
|
default_agent: AgentModel = self.agents[0]
|
|
1264
1361
|
self.orchestration = OrchestrationModel(
|
|
1265
|
-
|
|
1362
|
+
supervisor=SupervisorModel(model=default_agent.model)
|
|
1266
1363
|
)
|
|
1267
1364
|
elif len(self.agents) == 1:
|
|
1268
1365
|
default_agent: AgentModel = self.agents[0]
|
|
1269
1366
|
self.orchestration = OrchestrationModel(
|
|
1270
|
-
|
|
1367
|
+
swarm=SwarmModel(
|
|
1271
1368
|
model=default_agent.model, default_agent=default_agent
|
|
1272
1369
|
)
|
|
1273
1370
|
)
|
dao_ai/graph.py
CHANGED
|
@@ -27,6 +27,7 @@ from dao_ai.nodes import (
|
|
|
27
27
|
)
|
|
28
28
|
from dao_ai.prompts import make_prompt
|
|
29
29
|
from dao_ai.state import Context, IncomingState, OutgoingState, SharedState
|
|
30
|
+
from dao_ai.tools import create_tools
|
|
30
31
|
|
|
31
32
|
|
|
32
33
|
def route_message(state: SharedState) -> str:
|
|
@@ -91,6 +92,8 @@ def _create_supervisor_graph(config: AppConfig) -> CompiledStateGraph:
|
|
|
91
92
|
orchestration: OrchestrationModel = config.app.orchestration
|
|
92
93
|
supervisor: SupervisorModel = orchestration.supervisor
|
|
93
94
|
|
|
95
|
+
tools += create_tools(orchestration.supervisor.tools)
|
|
96
|
+
|
|
94
97
|
store: BaseStore = None
|
|
95
98
|
if orchestration.memory and orchestration.memory.store:
|
|
96
99
|
store = orchestration.memory.store.as_store()
|
dao_ai/memory/core.py
CHANGED
|
@@ -70,10 +70,14 @@ class StoreManager:
|
|
|
70
70
|
case StorageType.POSTGRES:
|
|
71
71
|
from dao_ai.memory.postgres import PostgresStoreManager
|
|
72
72
|
|
|
73
|
-
store_manager = cls.store_managers.get(
|
|
73
|
+
store_manager = cls.store_managers.get(
|
|
74
|
+
store_model.database.instance_name
|
|
75
|
+
)
|
|
74
76
|
if store_manager is None:
|
|
75
77
|
store_manager = PostgresStoreManager(store_model)
|
|
76
|
-
cls.store_managers[store_model.database.
|
|
78
|
+
cls.store_managers[store_model.database.instance_name] = (
|
|
79
|
+
store_manager
|
|
80
|
+
)
|
|
77
81
|
case _:
|
|
78
82
|
raise ValueError(f"Unknown store type: {store_model.type}")
|
|
79
83
|
|
|
@@ -102,15 +106,15 @@ class CheckpointManager:
|
|
|
102
106
|
from dao_ai.memory.postgres import AsyncPostgresCheckpointerManager
|
|
103
107
|
|
|
104
108
|
checkpointer_manager = cls.checkpoint_managers.get(
|
|
105
|
-
checkpointer_model.database.
|
|
109
|
+
checkpointer_model.database.instance_name
|
|
106
110
|
)
|
|
107
111
|
if checkpointer_manager is None:
|
|
108
112
|
checkpointer_manager = AsyncPostgresCheckpointerManager(
|
|
109
113
|
checkpointer_model
|
|
110
114
|
)
|
|
111
|
-
cls.checkpoint_managers[
|
|
112
|
-
|
|
113
|
-
|
|
115
|
+
cls.checkpoint_managers[
|
|
116
|
+
checkpointer_model.database.instance_name
|
|
117
|
+
] = checkpointer_manager
|
|
114
118
|
case _:
|
|
115
119
|
raise ValueError(f"Unknown store type: {checkpointer_model.type}")
|
|
116
120
|
|
dao_ai/memory/postgres.py
CHANGED
|
@@ -20,6 +20,59 @@ from dao_ai.memory.base import (
|
|
|
20
20
|
)
|
|
21
21
|
|
|
22
22
|
|
|
23
|
+
def _create_pool(
|
|
24
|
+
connection_params: dict[str, Any],
|
|
25
|
+
database_name: str,
|
|
26
|
+
max_pool_size: int,
|
|
27
|
+
timeout_seconds: int,
|
|
28
|
+
kwargs: dict,
|
|
29
|
+
) -> ConnectionPool:
|
|
30
|
+
"""Create a connection pool using the provided connection parameters."""
|
|
31
|
+
logger.debug(
|
|
32
|
+
f"Connection params for {database_name}: {', '.join(k + '=' + (str(v) if k != 'password' else '***') for k, v in connection_params.items())}"
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
# Merge connection_params into kwargs for psycopg
|
|
36
|
+
connection_kwargs = kwargs | connection_params
|
|
37
|
+
pool = ConnectionPool(
|
|
38
|
+
conninfo="", # Empty conninfo, params come from kwargs
|
|
39
|
+
min_size=1,
|
|
40
|
+
max_size=max_pool_size,
|
|
41
|
+
open=False,
|
|
42
|
+
timeout=timeout_seconds,
|
|
43
|
+
kwargs=connection_kwargs,
|
|
44
|
+
)
|
|
45
|
+
pool.open(wait=True, timeout=timeout_seconds)
|
|
46
|
+
logger.info(f"Successfully connected to {database_name}")
|
|
47
|
+
return pool
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
async def _create_async_pool(
|
|
51
|
+
connection_params: dict[str, Any],
|
|
52
|
+
database_name: str,
|
|
53
|
+
max_pool_size: int,
|
|
54
|
+
timeout_seconds: int,
|
|
55
|
+
kwargs: dict,
|
|
56
|
+
) -> AsyncConnectionPool:
|
|
57
|
+
"""Create an async connection pool using the provided connection parameters."""
|
|
58
|
+
logger.debug(
|
|
59
|
+
f"Connection params for {database_name}: {', '.join(k + '=' + (str(v) if k != 'password' else '***') for k, v in connection_params.items())}"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# Merge connection_params into kwargs for psycopg
|
|
63
|
+
connection_kwargs = kwargs | connection_params
|
|
64
|
+
pool = AsyncConnectionPool(
|
|
65
|
+
conninfo="", # Empty conninfo, params come from kwargs
|
|
66
|
+
max_size=max_pool_size,
|
|
67
|
+
open=False,
|
|
68
|
+
timeout=timeout_seconds,
|
|
69
|
+
kwargs=connection_kwargs,
|
|
70
|
+
)
|
|
71
|
+
await pool.open(wait=True, timeout=timeout_seconds)
|
|
72
|
+
logger.info(f"Successfully connected to {database_name}")
|
|
73
|
+
return pool
|
|
74
|
+
|
|
75
|
+
|
|
23
76
|
class AsyncPostgresPoolManager:
|
|
24
77
|
_pools: dict[str, AsyncConnectionPool] = {}
|
|
25
78
|
_lock: asyncio.Lock = asyncio.Lock()
|
|
@@ -27,7 +80,7 @@ class AsyncPostgresPoolManager:
|
|
|
27
80
|
@classmethod
|
|
28
81
|
async def get_pool(cls, database: DatabaseModel) -> AsyncConnectionPool:
|
|
29
82
|
connection_key: str = database.name
|
|
30
|
-
|
|
83
|
+
connection_params: dict[str, Any] = database.connection_params
|
|
31
84
|
|
|
32
85
|
async with cls._lock:
|
|
33
86
|
if connection_key in cls._pools:
|
|
@@ -41,23 +94,17 @@ class AsyncPostgresPoolManager:
|
|
|
41
94
|
"autocommit": True,
|
|
42
95
|
} | database.connection_kwargs or {}
|
|
43
96
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
97
|
+
# Create connection pool
|
|
98
|
+
pool: AsyncConnectionPool = await _create_async_pool(
|
|
99
|
+
connection_params=connection_params,
|
|
100
|
+
database_name=database.name,
|
|
101
|
+
max_pool_size=database.max_pool_size,
|
|
102
|
+
timeout_seconds=database.timeout_seconds,
|
|
49
103
|
kwargs=kwargs,
|
|
50
104
|
)
|
|
51
105
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
cls._pools[connection_key] = pool
|
|
55
|
-
return pool
|
|
56
|
-
except Exception as e:
|
|
57
|
-
logger.error(
|
|
58
|
-
f"Failed to create PostgreSQL pool for {database.name}: {e}"
|
|
59
|
-
)
|
|
60
|
-
raise e
|
|
106
|
+
cls._pools[connection_key] = pool
|
|
107
|
+
return pool
|
|
61
108
|
|
|
62
109
|
@classmethod
|
|
63
110
|
async def close_pool(cls, database: DatabaseModel):
|
|
@@ -74,8 +121,17 @@ class AsyncPostgresPoolManager:
|
|
|
74
121
|
async with cls._lock:
|
|
75
122
|
for connection_key, pool in cls._pools.items():
|
|
76
123
|
try:
|
|
77
|
-
|
|
124
|
+
# Use a short timeout to avoid blocking on pool closure
|
|
125
|
+
await asyncio.wait_for(pool.close(), timeout=2.0)
|
|
78
126
|
logger.debug(f"Closed PostgreSQL pool: {connection_key}")
|
|
127
|
+
except asyncio.TimeoutError:
|
|
128
|
+
logger.warning(
|
|
129
|
+
f"Timeout closing pool {connection_key}, forcing closure"
|
|
130
|
+
)
|
|
131
|
+
except asyncio.CancelledError:
|
|
132
|
+
logger.warning(
|
|
133
|
+
f"Pool closure cancelled for {connection_key} (shutdown in progress)"
|
|
134
|
+
)
|
|
79
135
|
except Exception as e:
|
|
80
136
|
logger.error(f"Error closing pool {connection_key}: {e}")
|
|
81
137
|
cls._pools.clear()
|
|
@@ -209,7 +265,7 @@ class PostgresPoolManager:
|
|
|
209
265
|
@classmethod
|
|
210
266
|
def get_pool(cls, database: DatabaseModel) -> ConnectionPool:
|
|
211
267
|
connection_key: str = str(database.name)
|
|
212
|
-
|
|
268
|
+
connection_params: dict[str, Any] = database.connection_params
|
|
213
269
|
|
|
214
270
|
with cls._lock:
|
|
215
271
|
if connection_key in cls._pools:
|
|
@@ -223,23 +279,17 @@ class PostgresPoolManager:
|
|
|
223
279
|
"autocommit": True,
|
|
224
280
|
} | database.connection_kwargs or {}
|
|
225
281
|
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
282
|
+
# Create connection pool
|
|
283
|
+
pool: ConnectionPool = _create_pool(
|
|
284
|
+
connection_params=connection_params,
|
|
285
|
+
database_name=database.name,
|
|
286
|
+
max_pool_size=database.max_pool_size,
|
|
287
|
+
timeout_seconds=database.timeout_seconds,
|
|
231
288
|
kwargs=kwargs,
|
|
232
289
|
)
|
|
233
290
|
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
cls._pools[connection_key] = pool
|
|
237
|
-
return pool
|
|
238
|
-
except Exception as e:
|
|
239
|
-
logger.error(
|
|
240
|
-
f"Failed to create PostgreSQL pool for {database.name}: {e}"
|
|
241
|
-
)
|
|
242
|
-
raise e
|
|
291
|
+
cls._pools[connection_key] = pool
|
|
292
|
+
return pool
|
|
243
293
|
|
|
244
294
|
@classmethod
|
|
245
295
|
def close_pool(cls, database: DatabaseModel):
|
|
@@ -369,8 +419,27 @@ def _shutdown_pools():
|
|
|
369
419
|
|
|
370
420
|
def _shutdown_async_pools():
|
|
371
421
|
try:
|
|
372
|
-
|
|
373
|
-
|
|
422
|
+
# Try to get the current event loop first
|
|
423
|
+
try:
|
|
424
|
+
loop = asyncio.get_running_loop()
|
|
425
|
+
# If we're already in an event loop, create a task
|
|
426
|
+
loop.create_task(AsyncPostgresPoolManager.close_all_pools())
|
|
427
|
+
logger.debug("Scheduled async pool closure in running event loop")
|
|
428
|
+
except RuntimeError:
|
|
429
|
+
# No running loop, try to get or create one
|
|
430
|
+
try:
|
|
431
|
+
loop = asyncio.get_event_loop()
|
|
432
|
+
if loop.is_closed():
|
|
433
|
+
# Loop is closed, create a new one
|
|
434
|
+
loop = asyncio.new_event_loop()
|
|
435
|
+
asyncio.set_event_loop(loop)
|
|
436
|
+
loop.run_until_complete(AsyncPostgresPoolManager.close_all_pools())
|
|
437
|
+
logger.debug("Successfully closed all asynchronous PostgreSQL pools")
|
|
438
|
+
except Exception as inner_e:
|
|
439
|
+
# If all else fails, just log the error
|
|
440
|
+
logger.warning(
|
|
441
|
+
f"Could not close async pools cleanly during shutdown: {inner_e}"
|
|
442
|
+
)
|
|
374
443
|
except Exception as e:
|
|
375
444
|
logger.error(
|
|
376
445
|
f"Error closing asynchronous PostgreSQL pools during shutdown: {e}"
|
dao_ai/models.py
CHANGED
|
@@ -5,6 +5,7 @@ from typing import Any, Generator, Optional, Sequence, Union
|
|
|
5
5
|
|
|
6
6
|
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
|
7
7
|
from langgraph.graph.state import CompiledStateGraph
|
|
8
|
+
from langgraph.types import StateSnapshot
|
|
8
9
|
from loguru import logger
|
|
9
10
|
from mlflow import MlflowClient
|
|
10
11
|
from mlflow.pyfunc import ChatAgent, ChatModel, ResponsesAgent
|
|
@@ -59,6 +60,113 @@ def get_latest_model_version(model_name: str) -> int:
|
|
|
59
60
|
return latest_version
|
|
60
61
|
|
|
61
62
|
|
|
63
|
+
async def get_state_snapshot_async(
|
|
64
|
+
graph: CompiledStateGraph, thread_id: str
|
|
65
|
+
) -> Optional[StateSnapshot]:
|
|
66
|
+
"""
|
|
67
|
+
Retrieve the state snapshot from the graph for a given thread_id asynchronously.
|
|
68
|
+
|
|
69
|
+
This utility function accesses the graph's checkpointer to retrieve the current
|
|
70
|
+
state snapshot, which contains the full state values and metadata.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
graph: The compiled LangGraph state machine
|
|
74
|
+
thread_id: The thread/conversation ID to retrieve state for
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
StateSnapshot if found, None otherwise
|
|
78
|
+
"""
|
|
79
|
+
logger.debug(f"Retrieving state snapshot for thread_id: {thread_id}")
|
|
80
|
+
try:
|
|
81
|
+
# Check if graph has a checkpointer
|
|
82
|
+
if graph.checkpointer is None:
|
|
83
|
+
logger.debug("No checkpointer available in graph")
|
|
84
|
+
return None
|
|
85
|
+
|
|
86
|
+
# Get the current state from the checkpointer (use async version)
|
|
87
|
+
config: dict[str, Any] = {"configurable": {"thread_id": thread_id}}
|
|
88
|
+
state_snapshot: Optional[StateSnapshot] = await graph.aget_state(config)
|
|
89
|
+
|
|
90
|
+
if state_snapshot is None:
|
|
91
|
+
logger.debug(f"No state found for thread_id: {thread_id}")
|
|
92
|
+
return None
|
|
93
|
+
|
|
94
|
+
return state_snapshot
|
|
95
|
+
|
|
96
|
+
except Exception as e:
|
|
97
|
+
logger.warning(f"Error retrieving state snapshot for thread {thread_id}: {e}")
|
|
98
|
+
return None
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def get_state_snapshot(
|
|
102
|
+
graph: CompiledStateGraph, thread_id: str
|
|
103
|
+
) -> Optional[StateSnapshot]:
|
|
104
|
+
"""
|
|
105
|
+
Retrieve the state snapshot from the graph for a given thread_id.
|
|
106
|
+
|
|
107
|
+
This is a synchronous wrapper around get_state_snapshot_async.
|
|
108
|
+
Use this for backward compatibility in synchronous contexts.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
graph: The compiled LangGraph state machine
|
|
112
|
+
thread_id: The thread/conversation ID to retrieve state for
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
StateSnapshot if found, None otherwise
|
|
116
|
+
"""
|
|
117
|
+
import asyncio
|
|
118
|
+
|
|
119
|
+
try:
|
|
120
|
+
loop = asyncio.get_event_loop()
|
|
121
|
+
except RuntimeError:
|
|
122
|
+
loop = asyncio.new_event_loop()
|
|
123
|
+
asyncio.set_event_loop(loop)
|
|
124
|
+
|
|
125
|
+
try:
|
|
126
|
+
return loop.run_until_complete(get_state_snapshot_async(graph, thread_id))
|
|
127
|
+
except Exception as e:
|
|
128
|
+
logger.warning(f"Error in synchronous state snapshot retrieval: {e}")
|
|
129
|
+
return None
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def get_genie_conversation_ids_from_state(
|
|
133
|
+
state_snapshot: Optional[StateSnapshot],
|
|
134
|
+
) -> dict[str, str]:
|
|
135
|
+
"""
|
|
136
|
+
Extract genie_conversation_ids from a state snapshot.
|
|
137
|
+
|
|
138
|
+
This function extracts the genie_conversation_ids dictionary from the state
|
|
139
|
+
snapshot values if present.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
state_snapshot: The state snapshot to extract conversation IDs from
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
A dictionary mapping genie space_id to conversation_id, or empty dict if not found
|
|
146
|
+
"""
|
|
147
|
+
if state_snapshot is None:
|
|
148
|
+
return {}
|
|
149
|
+
|
|
150
|
+
try:
|
|
151
|
+
# Extract state values - these contain the actual state data
|
|
152
|
+
state_values: dict[str, Any] = state_snapshot.values
|
|
153
|
+
|
|
154
|
+
# Extract genie_conversation_ids from state values
|
|
155
|
+
genie_conversation_ids: dict[str, str] = state_values.get(
|
|
156
|
+
"genie_conversation_ids", {}
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
if genie_conversation_ids:
|
|
160
|
+
logger.debug(f"Retrieved genie_conversation_ids: {genie_conversation_ids}")
|
|
161
|
+
return genie_conversation_ids
|
|
162
|
+
|
|
163
|
+
return {}
|
|
164
|
+
|
|
165
|
+
except Exception as e:
|
|
166
|
+
logger.warning(f"Error extracting genie_conversation_ids from state: {e}")
|
|
167
|
+
return {}
|
|
168
|
+
|
|
169
|
+
|
|
62
170
|
class LanggraphChatModel(ChatModel):
|
|
63
171
|
"""
|
|
64
172
|
ChatModel that delegates requests to a LangGraph CompiledStateGraph.
|
|
@@ -257,7 +365,19 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
257
365
|
text=last_message.content, id=f"msg_{uuid.uuid4().hex[:8]}"
|
|
258
366
|
)
|
|
259
367
|
|
|
260
|
-
|
|
368
|
+
# Retrieve genie_conversation_ids from state if available
|
|
369
|
+
custom_outputs: dict[str, Any] = custom_inputs.copy()
|
|
370
|
+
thread_id: Optional[str] = context.thread_id
|
|
371
|
+
if thread_id:
|
|
372
|
+
state_snapshot: Optional[StateSnapshot] = loop.run_until_complete(
|
|
373
|
+
get_state_snapshot_async(self.graph, thread_id)
|
|
374
|
+
)
|
|
375
|
+
genie_conversation_ids: dict[str, str] = (
|
|
376
|
+
get_genie_conversation_ids_from_state(state_snapshot)
|
|
377
|
+
)
|
|
378
|
+
if genie_conversation_ids:
|
|
379
|
+
custom_outputs["genie_conversation_ids"] = genie_conversation_ids
|
|
380
|
+
|
|
261
381
|
return ResponsesAgentResponse(
|
|
262
382
|
output=[output_item], custom_outputs=custom_outputs
|
|
263
383
|
)
|
|
@@ -318,7 +438,22 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
318
438
|
**self.create_text_delta(delta=content, item_id=item_id)
|
|
319
439
|
)
|
|
320
440
|
|
|
321
|
-
|
|
441
|
+
# Retrieve genie_conversation_ids from state if available
|
|
442
|
+
custom_outputs: dict[str, Any] = custom_inputs.copy()
|
|
443
|
+
thread_id: Optional[str] = context.thread_id
|
|
444
|
+
|
|
445
|
+
if thread_id:
|
|
446
|
+
state_snapshot: Optional[
|
|
447
|
+
StateSnapshot
|
|
448
|
+
] = await get_state_snapshot_async(self.graph, thread_id)
|
|
449
|
+
genie_conversation_ids: dict[str, str] = (
|
|
450
|
+
get_genie_conversation_ids_from_state(state_snapshot)
|
|
451
|
+
)
|
|
452
|
+
if genie_conversation_ids:
|
|
453
|
+
custom_outputs["genie_conversation_ids"] = (
|
|
454
|
+
genie_conversation_ids
|
|
455
|
+
)
|
|
456
|
+
|
|
322
457
|
# Yield final output item
|
|
323
458
|
yield ResponsesAgentStreamEvent(
|
|
324
459
|
type="response.output_item.done",
|