dao-ai 0.0.21__py3-none-any.whl → 0.0.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/config.py +98 -11
- dao_ai/graph.py +3 -0
- dao_ai/memory/core.py +10 -6
- dao_ai/memory/postgres.py +71 -30
- dao_ai/providers/databricks.py +280 -0
- dao_ai/tools/__init__.py +2 -0
- dao_ai/tools/genie.py +1 -1
- dao_ai/tools/mcp.py +125 -78
- dao_ai/tools/slack.py +136 -0
- dao_ai/utils.py +8 -0
- {dao_ai-0.0.21.dist-info → dao_ai-0.0.23.dist-info}/METADATA +28 -5
- {dao_ai-0.0.21.dist-info → dao_ai-0.0.23.dist-info}/RECORD +15 -14
- {dao_ai-0.0.21.dist-info → dao_ai-0.0.23.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.21.dist-info → dao_ai-0.0.23.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.21.dist-info → dao_ai-0.0.23.dist-info}/licenses/LICENSE +0 -0
dao_ai/config.py
CHANGED
|
@@ -23,6 +23,7 @@ from databricks.sdk.credentials_provider import (
|
|
|
23
23
|
ModelServingUserCredentials,
|
|
24
24
|
)
|
|
25
25
|
from databricks.sdk.service.catalog import FunctionInfo, TableInfo
|
|
26
|
+
from databricks.sdk.service.database import DatabaseInstance
|
|
26
27
|
from databricks.vector_search.client import VectorSearchClient
|
|
27
28
|
from databricks.vector_search.index import VectorSearchIndex
|
|
28
29
|
from databricks_langchain import (
|
|
@@ -665,6 +666,10 @@ class ConnectionModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
|
665
666
|
return [
|
|
666
667
|
"catalog.connections",
|
|
667
668
|
"serving.serving-endpoints",
|
|
669
|
+
"mcp.genie",
|
|
670
|
+
"mcp.functions",
|
|
671
|
+
"mcp.vectorsearch",
|
|
672
|
+
"mcp.external",
|
|
668
673
|
]
|
|
669
674
|
|
|
670
675
|
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
@@ -703,15 +708,18 @@ class WarehouseModel(BaseModel, IsDatabricksResource):
|
|
|
703
708
|
|
|
704
709
|
|
|
705
710
|
class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
706
|
-
model_config = ConfigDict(
|
|
711
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
707
712
|
name: str
|
|
713
|
+
instance_name: Optional[str] = None
|
|
708
714
|
description: Optional[str] = None
|
|
709
|
-
host: Optional[AnyVariable]
|
|
715
|
+
host: Optional[AnyVariable] = None
|
|
710
716
|
database: Optional[AnyVariable] = "databricks_postgres"
|
|
711
717
|
port: Optional[AnyVariable] = 5432
|
|
712
718
|
connection_kwargs: Optional[dict[str, Any]] = Field(default_factory=dict)
|
|
713
719
|
max_pool_size: Optional[int] = 10
|
|
714
|
-
timeout_seconds: Optional[int] =
|
|
720
|
+
timeout_seconds: Optional[int] = 10
|
|
721
|
+
capacity: Optional[Literal["CU_1", "CU_2"]] = "CU_2"
|
|
722
|
+
node_count: Optional[int] = None
|
|
715
723
|
user: Optional[AnyVariable] = None
|
|
716
724
|
password: Optional[AnyVariable] = None
|
|
717
725
|
client_id: Optional[AnyVariable] = None
|
|
@@ -725,11 +733,44 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
725
733
|
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
726
734
|
return [
|
|
727
735
|
DatabricksLakebase(
|
|
728
|
-
database_instance_name=self.
|
|
736
|
+
database_instance_name=self.instance_name,
|
|
729
737
|
on_behalf_of_user=self.on_behalf_of_user,
|
|
730
738
|
)
|
|
731
739
|
]
|
|
732
740
|
|
|
741
|
+
@model_validator(mode="after")
|
|
742
|
+
def update_instance_name(self):
|
|
743
|
+
if self.instance_name is None:
|
|
744
|
+
self.instance_name = self.name
|
|
745
|
+
|
|
746
|
+
return self
|
|
747
|
+
|
|
748
|
+
@model_validator(mode="after")
|
|
749
|
+
def update_user(self):
|
|
750
|
+
if self.client_id or self.user:
|
|
751
|
+
return self
|
|
752
|
+
|
|
753
|
+
self.user = self.workspace_client.current_user.me().user_name
|
|
754
|
+
if not self.user:
|
|
755
|
+
raise ValueError(
|
|
756
|
+
"Unable to determine current user. Please provide a user name or OAuth credentials."
|
|
757
|
+
)
|
|
758
|
+
|
|
759
|
+
return self
|
|
760
|
+
|
|
761
|
+
@model_validator(mode="after")
|
|
762
|
+
def update_host(self):
|
|
763
|
+
if self.host is not None:
|
|
764
|
+
return self
|
|
765
|
+
|
|
766
|
+
existing_instance: DatabaseInstance = (
|
|
767
|
+
self.workspace_client.database.get_database_instance(
|
|
768
|
+
name=self.instance_name
|
|
769
|
+
)
|
|
770
|
+
)
|
|
771
|
+
self.host = existing_instance.read_write_dns
|
|
772
|
+
return self
|
|
773
|
+
|
|
733
774
|
@model_validator(mode="after")
|
|
734
775
|
def validate_auth_methods(self):
|
|
735
776
|
oauth_fields: Sequence[Any] = [
|
|
@@ -739,7 +780,7 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
739
780
|
]
|
|
740
781
|
has_oauth: bool = all(field is not None for field in oauth_fields)
|
|
741
782
|
|
|
742
|
-
pat_fields: Sequence[Any] = [self.user
|
|
783
|
+
pat_fields: Sequence[Any] = [self.user]
|
|
743
784
|
has_user_auth: bool = all(field is not None for field in pat_fields)
|
|
744
785
|
|
|
745
786
|
if has_oauth and has_user_auth:
|
|
@@ -758,7 +799,14 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
758
799
|
return self
|
|
759
800
|
|
|
760
801
|
@property
|
|
761
|
-
def
|
|
802
|
+
def connection_params(self) -> dict[str, Any]:
|
|
803
|
+
"""
|
|
804
|
+
Get database connection parameters as a dictionary.
|
|
805
|
+
|
|
806
|
+
Returns a dict with connection parameters suitable for psycopg ConnectionPool.
|
|
807
|
+
If username is configured, it will be included; otherwise it will be omitted
|
|
808
|
+
to allow Lakebase to authenticate using the token's identity.
|
|
809
|
+
"""
|
|
762
810
|
from dao_ai.providers.base import ServiceProvider
|
|
763
811
|
from dao_ai.providers.databricks import DatabricksProvider
|
|
764
812
|
|
|
@@ -766,7 +814,7 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
766
814
|
|
|
767
815
|
if self.client_id and self.client_secret and self.workspace_host:
|
|
768
816
|
username = value_of(self.client_id)
|
|
769
|
-
|
|
817
|
+
elif self.user:
|
|
770
818
|
username = value_of(self.user)
|
|
771
819
|
|
|
772
820
|
host: str = value_of(self.host)
|
|
@@ -779,11 +827,48 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
|
779
827
|
workspace_host=value_of(self.workspace_host),
|
|
780
828
|
pat=value_of(self.password),
|
|
781
829
|
)
|
|
782
|
-
token: str = provider.create_token()
|
|
783
830
|
|
|
784
|
-
|
|
785
|
-
|
|
786
|
-
|
|
831
|
+
token: str = provider.lakebase_password_provider(self.instance_name)
|
|
832
|
+
|
|
833
|
+
# Build connection parameters dictionary
|
|
834
|
+
params: dict[str, Any] = {
|
|
835
|
+
"dbname": database,
|
|
836
|
+
"host": host,
|
|
837
|
+
"port": port,
|
|
838
|
+
"password": token,
|
|
839
|
+
"sslmode": "require",
|
|
840
|
+
}
|
|
841
|
+
|
|
842
|
+
# Only include user if explicitly configured
|
|
843
|
+
if username:
|
|
844
|
+
params["user"] = username
|
|
845
|
+
logger.debug(
|
|
846
|
+
f"Connection params: dbname={database} user={username} host={host} port={port} password=******** sslmode=require"
|
|
847
|
+
)
|
|
848
|
+
else:
|
|
849
|
+
logger.debug(
|
|
850
|
+
f"Connection params: dbname={database} host={host} port={port} password=******** sslmode=require (using token identity)"
|
|
851
|
+
)
|
|
852
|
+
|
|
853
|
+
return params
|
|
854
|
+
|
|
855
|
+
@property
|
|
856
|
+
def connection_url(self) -> str:
|
|
857
|
+
"""
|
|
858
|
+
Get database connection URL as a string (for backwards compatibility).
|
|
859
|
+
|
|
860
|
+
Note: It's recommended to use connection_params instead for better flexibility.
|
|
861
|
+
"""
|
|
862
|
+
params = self.connection_params
|
|
863
|
+
parts = [f"{k}={v}" for k, v in params.items()]
|
|
864
|
+
return " ".join(parts)
|
|
865
|
+
|
|
866
|
+
def create(self, w: WorkspaceClient | None = None) -> None:
|
|
867
|
+
from dao_ai.providers.databricks import DatabricksProvider
|
|
868
|
+
|
|
869
|
+
provider: DatabricksProvider = DatabricksProvider()
|
|
870
|
+
provider.create_lakebase(self)
|
|
871
|
+
provider.create_lakebase_instance_role(self)
|
|
787
872
|
|
|
788
873
|
|
|
789
874
|
class SearchParametersModel(BaseModel):
|
|
@@ -907,6 +992,7 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
907
992
|
transport: TransportType = TransportType.STREAMABLE_HTTP
|
|
908
993
|
command: Optional[str] = "python"
|
|
909
994
|
url: Optional[AnyVariable] = None
|
|
995
|
+
connection: Optional[ConnectionModel] = None
|
|
910
996
|
headers: dict[str, AnyVariable] = Field(default_factory=dict)
|
|
911
997
|
args: list[str] = Field(default_factory=list)
|
|
912
998
|
pat: Optional[AnyVariable] = None
|
|
@@ -1093,6 +1179,7 @@ class AgentModel(BaseModel):
|
|
|
1093
1179
|
class SupervisorModel(BaseModel):
|
|
1094
1180
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1095
1181
|
model: LLMModel
|
|
1182
|
+
tools: list[ToolModel] = Field(default_factory=list)
|
|
1096
1183
|
prompt: Optional[str] = None
|
|
1097
1184
|
|
|
1098
1185
|
|
dao_ai/graph.py
CHANGED
|
@@ -27,6 +27,7 @@ from dao_ai.nodes import (
|
|
|
27
27
|
)
|
|
28
28
|
from dao_ai.prompts import make_prompt
|
|
29
29
|
from dao_ai.state import Context, IncomingState, OutgoingState, SharedState
|
|
30
|
+
from dao_ai.tools import create_tools
|
|
30
31
|
|
|
31
32
|
|
|
32
33
|
def route_message(state: SharedState) -> str:
|
|
@@ -91,6 +92,8 @@ def _create_supervisor_graph(config: AppConfig) -> CompiledStateGraph:
|
|
|
91
92
|
orchestration: OrchestrationModel = config.app.orchestration
|
|
92
93
|
supervisor: SupervisorModel = orchestration.supervisor
|
|
93
94
|
|
|
95
|
+
tools += create_tools(orchestration.supervisor.tools)
|
|
96
|
+
|
|
94
97
|
store: BaseStore = None
|
|
95
98
|
if orchestration.memory and orchestration.memory.store:
|
|
96
99
|
store = orchestration.memory.store.as_store()
|
dao_ai/memory/core.py
CHANGED
|
@@ -70,10 +70,14 @@ class StoreManager:
|
|
|
70
70
|
case StorageType.POSTGRES:
|
|
71
71
|
from dao_ai.memory.postgres import PostgresStoreManager
|
|
72
72
|
|
|
73
|
-
store_manager = cls.store_managers.get(
|
|
73
|
+
store_manager = cls.store_managers.get(
|
|
74
|
+
store_model.database.instance_name
|
|
75
|
+
)
|
|
74
76
|
if store_manager is None:
|
|
75
77
|
store_manager = PostgresStoreManager(store_model)
|
|
76
|
-
cls.store_managers[store_model.database.
|
|
78
|
+
cls.store_managers[store_model.database.instance_name] = (
|
|
79
|
+
store_manager
|
|
80
|
+
)
|
|
77
81
|
case _:
|
|
78
82
|
raise ValueError(f"Unknown store type: {store_model.type}")
|
|
79
83
|
|
|
@@ -102,15 +106,15 @@ class CheckpointManager:
|
|
|
102
106
|
from dao_ai.memory.postgres import AsyncPostgresCheckpointerManager
|
|
103
107
|
|
|
104
108
|
checkpointer_manager = cls.checkpoint_managers.get(
|
|
105
|
-
checkpointer_model.database.
|
|
109
|
+
checkpointer_model.database.instance_name
|
|
106
110
|
)
|
|
107
111
|
if checkpointer_manager is None:
|
|
108
112
|
checkpointer_manager = AsyncPostgresCheckpointerManager(
|
|
109
113
|
checkpointer_model
|
|
110
114
|
)
|
|
111
|
-
cls.checkpoint_managers[
|
|
112
|
-
|
|
113
|
-
|
|
115
|
+
cls.checkpoint_managers[
|
|
116
|
+
checkpointer_model.database.instance_name
|
|
117
|
+
] = checkpointer_manager
|
|
114
118
|
case _:
|
|
115
119
|
raise ValueError(f"Unknown store type: {checkpointer_model.type}")
|
|
116
120
|
|
dao_ai/memory/postgres.py
CHANGED
|
@@ -20,6 +20,59 @@ from dao_ai.memory.base import (
|
|
|
20
20
|
)
|
|
21
21
|
|
|
22
22
|
|
|
23
|
+
def _create_pool(
|
|
24
|
+
connection_params: dict[str, Any],
|
|
25
|
+
database_name: str,
|
|
26
|
+
max_pool_size: int,
|
|
27
|
+
timeout_seconds: int,
|
|
28
|
+
kwargs: dict,
|
|
29
|
+
) -> ConnectionPool:
|
|
30
|
+
"""Create a connection pool using the provided connection parameters."""
|
|
31
|
+
logger.debug(
|
|
32
|
+
f"Connection params for {database_name}: {', '.join(k + '=' + (str(v) if k != 'password' else '***') for k, v in connection_params.items())}"
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
# Merge connection_params into kwargs for psycopg
|
|
36
|
+
connection_kwargs = kwargs | connection_params
|
|
37
|
+
pool = ConnectionPool(
|
|
38
|
+
conninfo="", # Empty conninfo, params come from kwargs
|
|
39
|
+
min_size=1,
|
|
40
|
+
max_size=max_pool_size,
|
|
41
|
+
open=False,
|
|
42
|
+
timeout=timeout_seconds,
|
|
43
|
+
kwargs=connection_kwargs,
|
|
44
|
+
)
|
|
45
|
+
pool.open(wait=True, timeout=timeout_seconds)
|
|
46
|
+
logger.info(f"Successfully connected to {database_name}")
|
|
47
|
+
return pool
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
async def _create_async_pool(
|
|
51
|
+
connection_params: dict[str, Any],
|
|
52
|
+
database_name: str,
|
|
53
|
+
max_pool_size: int,
|
|
54
|
+
timeout_seconds: int,
|
|
55
|
+
kwargs: dict,
|
|
56
|
+
) -> AsyncConnectionPool:
|
|
57
|
+
"""Create an async connection pool using the provided connection parameters."""
|
|
58
|
+
logger.debug(
|
|
59
|
+
f"Connection params for {database_name}: {', '.join(k + '=' + (str(v) if k != 'password' else '***') for k, v in connection_params.items())}"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# Merge connection_params into kwargs for psycopg
|
|
63
|
+
connection_kwargs = kwargs | connection_params
|
|
64
|
+
pool = AsyncConnectionPool(
|
|
65
|
+
conninfo="", # Empty conninfo, params come from kwargs
|
|
66
|
+
max_size=max_pool_size,
|
|
67
|
+
open=False,
|
|
68
|
+
timeout=timeout_seconds,
|
|
69
|
+
kwargs=connection_kwargs,
|
|
70
|
+
)
|
|
71
|
+
await pool.open(wait=True, timeout=timeout_seconds)
|
|
72
|
+
logger.info(f"Successfully connected to {database_name}")
|
|
73
|
+
return pool
|
|
74
|
+
|
|
75
|
+
|
|
23
76
|
class AsyncPostgresPoolManager:
|
|
24
77
|
_pools: dict[str, AsyncConnectionPool] = {}
|
|
25
78
|
_lock: asyncio.Lock = asyncio.Lock()
|
|
@@ -27,7 +80,7 @@ class AsyncPostgresPoolManager:
|
|
|
27
80
|
@classmethod
|
|
28
81
|
async def get_pool(cls, database: DatabaseModel) -> AsyncConnectionPool:
|
|
29
82
|
connection_key: str = database.name
|
|
30
|
-
|
|
83
|
+
connection_params: dict[str, Any] = database.connection_params
|
|
31
84
|
|
|
32
85
|
async with cls._lock:
|
|
33
86
|
if connection_key in cls._pools:
|
|
@@ -41,23 +94,17 @@ class AsyncPostgresPoolManager:
|
|
|
41
94
|
"autocommit": True,
|
|
42
95
|
} | database.connection_kwargs or {}
|
|
43
96
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
97
|
+
# Create connection pool
|
|
98
|
+
pool: AsyncConnectionPool = await _create_async_pool(
|
|
99
|
+
connection_params=connection_params,
|
|
100
|
+
database_name=database.name,
|
|
101
|
+
max_pool_size=database.max_pool_size,
|
|
102
|
+
timeout_seconds=database.timeout_seconds,
|
|
49
103
|
kwargs=kwargs,
|
|
50
104
|
)
|
|
51
105
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
cls._pools[connection_key] = pool
|
|
55
|
-
return pool
|
|
56
|
-
except Exception as e:
|
|
57
|
-
logger.error(
|
|
58
|
-
f"Failed to create PostgreSQL pool for {database.name}: {e}"
|
|
59
|
-
)
|
|
60
|
-
raise e
|
|
106
|
+
cls._pools[connection_key] = pool
|
|
107
|
+
return pool
|
|
61
108
|
|
|
62
109
|
@classmethod
|
|
63
110
|
async def close_pool(cls, database: DatabaseModel):
|
|
@@ -218,7 +265,7 @@ class PostgresPoolManager:
|
|
|
218
265
|
@classmethod
|
|
219
266
|
def get_pool(cls, database: DatabaseModel) -> ConnectionPool:
|
|
220
267
|
connection_key: str = str(database.name)
|
|
221
|
-
|
|
268
|
+
connection_params: dict[str, Any] = database.connection_params
|
|
222
269
|
|
|
223
270
|
with cls._lock:
|
|
224
271
|
if connection_key in cls._pools:
|
|
@@ -232,23 +279,17 @@ class PostgresPoolManager:
|
|
|
232
279
|
"autocommit": True,
|
|
233
280
|
} | database.connection_kwargs or {}
|
|
234
281
|
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
282
|
+
# Create connection pool
|
|
283
|
+
pool: ConnectionPool = _create_pool(
|
|
284
|
+
connection_params=connection_params,
|
|
285
|
+
database_name=database.name,
|
|
286
|
+
max_pool_size=database.max_pool_size,
|
|
287
|
+
timeout_seconds=database.timeout_seconds,
|
|
240
288
|
kwargs=kwargs,
|
|
241
289
|
)
|
|
242
290
|
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
cls._pools[connection_key] = pool
|
|
246
|
-
return pool
|
|
247
|
-
except Exception as e:
|
|
248
|
-
logger.error(
|
|
249
|
-
f"Failed to create PostgreSQL pool for {database.name}: {e}"
|
|
250
|
-
)
|
|
251
|
-
raise e
|
|
291
|
+
cls._pools[connection_key] = pool
|
|
292
|
+
return pool
|
|
252
293
|
|
|
253
294
|
@classmethod
|
|
254
295
|
def close_pool(cls, database: DatabaseModel):
|
dao_ai/providers/databricks.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import base64
|
|
2
|
+
import uuid
|
|
2
3
|
from importlib.metadata import version
|
|
3
4
|
from pathlib import Path
|
|
4
5
|
from typing import Any, Callable, Final, Sequence
|
|
@@ -21,6 +22,7 @@ from databricks.sdk.service.catalog import (
|
|
|
21
22
|
VolumeInfo,
|
|
22
23
|
VolumeType,
|
|
23
24
|
)
|
|
25
|
+
from databricks.sdk.service.database import DatabaseCredential
|
|
24
26
|
from databricks.sdk.service.iam import User
|
|
25
27
|
from databricks.sdk.service.workspace import GetSecretResponse
|
|
26
28
|
from databricks.vector_search.client import VectorSearchClient
|
|
@@ -743,3 +745,281 @@ class DatabricksProvider(ServiceProvider):
|
|
|
743
745
|
break
|
|
744
746
|
logger.debug(f"Vector search index found: {found_endpoint_name}")
|
|
745
747
|
return found_endpoint_name
|
|
748
|
+
|
|
749
|
+
def create_lakebase(self, database: DatabaseModel) -> None:
|
|
750
|
+
"""
|
|
751
|
+
Create a Lakebase database instance using the Databricks workspace client.
|
|
752
|
+
|
|
753
|
+
This method handles idempotent database creation, gracefully handling cases where:
|
|
754
|
+
- The database instance already exists
|
|
755
|
+
- The database is in an intermediate state (STARTING, UPDATING, etc.)
|
|
756
|
+
|
|
757
|
+
Args:
|
|
758
|
+
database: DatabaseModel containing the database configuration
|
|
759
|
+
|
|
760
|
+
Returns:
|
|
761
|
+
None
|
|
762
|
+
|
|
763
|
+
Raises:
|
|
764
|
+
Exception: If an unexpected error occurs during database creation
|
|
765
|
+
"""
|
|
766
|
+
import time
|
|
767
|
+
from typing import Any
|
|
768
|
+
|
|
769
|
+
workspace_client: WorkspaceClient = database.workspace_client
|
|
770
|
+
|
|
771
|
+
try:
|
|
772
|
+
# First, check if the database instance already exists
|
|
773
|
+
existing_instance: Any = workspace_client.database.get_database_instance(
|
|
774
|
+
name=database.instance_name
|
|
775
|
+
)
|
|
776
|
+
|
|
777
|
+
if existing_instance:
|
|
778
|
+
logger.debug(
|
|
779
|
+
f"Database instance {database.instance_name} already exists with state: {existing_instance.state}"
|
|
780
|
+
)
|
|
781
|
+
|
|
782
|
+
# Check if database is in an intermediate state
|
|
783
|
+
if existing_instance.state in ["STARTING", "UPDATING"]:
|
|
784
|
+
logger.info(
|
|
785
|
+
f"Database instance {database.instance_name} is in {existing_instance.state} state, waiting for it to become AVAILABLE..."
|
|
786
|
+
)
|
|
787
|
+
|
|
788
|
+
# Wait for database to reach a stable state
|
|
789
|
+
max_wait_time: int = 600 # 10 minutes
|
|
790
|
+
wait_interval: int = 10 # 10 seconds
|
|
791
|
+
elapsed: int = 0
|
|
792
|
+
|
|
793
|
+
while elapsed < max_wait_time:
|
|
794
|
+
try:
|
|
795
|
+
current_instance: Any = (
|
|
796
|
+
workspace_client.database.get_database_instance(
|
|
797
|
+
name=database.instance_name
|
|
798
|
+
)
|
|
799
|
+
)
|
|
800
|
+
current_state: str = current_instance.state
|
|
801
|
+
logger.debug(f"Database instance state: {current_state}")
|
|
802
|
+
|
|
803
|
+
if current_state == "AVAILABLE":
|
|
804
|
+
logger.info(
|
|
805
|
+
f"Database instance {database.instance_name} is now AVAILABLE"
|
|
806
|
+
)
|
|
807
|
+
break
|
|
808
|
+
elif current_state in ["STARTING", "UPDATING"]:
|
|
809
|
+
logger.debug(
|
|
810
|
+
f"Database instance still in {current_state} state, waiting {wait_interval} seconds..."
|
|
811
|
+
)
|
|
812
|
+
time.sleep(wait_interval)
|
|
813
|
+
elapsed += wait_interval
|
|
814
|
+
elif current_state in ["STOPPED", "DELETING"]:
|
|
815
|
+
logger.warning(
|
|
816
|
+
f"Database instance {database.instance_name} is in unexpected state: {current_state}"
|
|
817
|
+
)
|
|
818
|
+
break
|
|
819
|
+
else:
|
|
820
|
+
logger.warning(
|
|
821
|
+
f"Unknown database state: {current_state}, proceeding anyway"
|
|
822
|
+
)
|
|
823
|
+
break
|
|
824
|
+
except NotFound:
|
|
825
|
+
logger.warning(
|
|
826
|
+
f"Database instance {database.instance_name} no longer exists, will attempt to recreate"
|
|
827
|
+
)
|
|
828
|
+
break
|
|
829
|
+
except Exception as state_error:
|
|
830
|
+
logger.warning(
|
|
831
|
+
f"Could not check database state: {state_error}, proceeding anyway"
|
|
832
|
+
)
|
|
833
|
+
break
|
|
834
|
+
|
|
835
|
+
if elapsed >= max_wait_time:
|
|
836
|
+
logger.warning(
|
|
837
|
+
f"Timed out waiting for database instance {database.instance_name} to become AVAILABLE after {max_wait_time} seconds"
|
|
838
|
+
)
|
|
839
|
+
|
|
840
|
+
elif existing_instance.state == "AVAILABLE":
|
|
841
|
+
logger.info(
|
|
842
|
+
f"Database instance {database.instance_name} already exists and is AVAILABLE"
|
|
843
|
+
)
|
|
844
|
+
return
|
|
845
|
+
elif existing_instance.state in ["STOPPED", "DELETING"]:
|
|
846
|
+
logger.warning(
|
|
847
|
+
f"Database instance {database.instance_name} is in {existing_instance.state} state"
|
|
848
|
+
)
|
|
849
|
+
return
|
|
850
|
+
else:
|
|
851
|
+
logger.info(
|
|
852
|
+
f"Database instance {database.instance_name} already exists with state: {existing_instance.state}"
|
|
853
|
+
)
|
|
854
|
+
return
|
|
855
|
+
|
|
856
|
+
except NotFound:
|
|
857
|
+
# Database doesn't exist, proceed with creation
|
|
858
|
+
logger.debug(
|
|
859
|
+
f"Database instance {database.instance_name} not found, creating new instance..."
|
|
860
|
+
)
|
|
861
|
+
|
|
862
|
+
try:
|
|
863
|
+
# Resolve variable values for database parameters
|
|
864
|
+
from databricks.sdk.service.database import DatabaseInstance
|
|
865
|
+
|
|
866
|
+
capacity: str = database.capacity if database.capacity else "CU_2"
|
|
867
|
+
|
|
868
|
+
# Create the database instance object
|
|
869
|
+
database_instance: DatabaseInstance = DatabaseInstance(
|
|
870
|
+
name=database.instance_name,
|
|
871
|
+
capacity=capacity,
|
|
872
|
+
node_count=database.node_count,
|
|
873
|
+
)
|
|
874
|
+
|
|
875
|
+
# Create the database instance via API
|
|
876
|
+
workspace_client.database.create_database_instance(
|
|
877
|
+
database_instance=database_instance
|
|
878
|
+
)
|
|
879
|
+
logger.info(
|
|
880
|
+
f"Successfully created database instance: {database.instance_name}"
|
|
881
|
+
)
|
|
882
|
+
|
|
883
|
+
except Exception as create_error:
|
|
884
|
+
error_msg: str = str(create_error)
|
|
885
|
+
|
|
886
|
+
# Handle case where database was created by another process concurrently
|
|
887
|
+
if (
|
|
888
|
+
"already exists" in error_msg.lower()
|
|
889
|
+
or "RESOURCE_ALREADY_EXISTS" in error_msg
|
|
890
|
+
):
|
|
891
|
+
logger.info(
|
|
892
|
+
f"Database instance {database.instance_name} was created concurrently by another process"
|
|
893
|
+
)
|
|
894
|
+
return
|
|
895
|
+
else:
|
|
896
|
+
# Re-raise unexpected errors
|
|
897
|
+
logger.error(
|
|
898
|
+
f"Error creating database instance {database.instance_name}: {create_error}"
|
|
899
|
+
)
|
|
900
|
+
raise
|
|
901
|
+
|
|
902
|
+
except Exception as e:
|
|
903
|
+
# Handle other unexpected errors
|
|
904
|
+
error_msg: str = str(e)
|
|
905
|
+
|
|
906
|
+
# Check if this is actually a "resource already exists" type error
|
|
907
|
+
if (
|
|
908
|
+
"already exists" in error_msg.lower()
|
|
909
|
+
or "RESOURCE_ALREADY_EXISTS" in error_msg
|
|
910
|
+
):
|
|
911
|
+
logger.info(
|
|
912
|
+
f"Database instance {database.instance_name} already exists (detected via exception)"
|
|
913
|
+
)
|
|
914
|
+
return
|
|
915
|
+
else:
|
|
916
|
+
logger.error(
|
|
917
|
+
f"Unexpected error while handling database {database.instance_name}: {e}"
|
|
918
|
+
)
|
|
919
|
+
raise
|
|
920
|
+
|
|
921
|
+
def lakebase_password_provider(self, instance_name: str) -> str:
|
|
922
|
+
"""
|
|
923
|
+
Ask Databricks to mint a fresh DB credential for this instance.
|
|
924
|
+
"""
|
|
925
|
+
logger.debug(f"Generating password for lakebase instance: {instance_name}")
|
|
926
|
+
w: WorkspaceClient = self.w
|
|
927
|
+
cred: DatabaseCredential = w.database.generate_database_credential(
|
|
928
|
+
request_id=str(uuid.uuid4()),
|
|
929
|
+
instance_names=[instance_name],
|
|
930
|
+
)
|
|
931
|
+
return cred.token
|
|
932
|
+
|
|
933
|
+
def create_lakebase_instance_role(self, database: DatabaseModel) -> None:
|
|
934
|
+
"""
|
|
935
|
+
Create a database instance role for a Lakebase instance.
|
|
936
|
+
|
|
937
|
+
This method creates a role with DATABRICKS_SUPERUSER membership for the
|
|
938
|
+
service principal specified in the database configuration.
|
|
939
|
+
|
|
940
|
+
Args:
|
|
941
|
+
database: DatabaseModel containing the database and service principal configuration
|
|
942
|
+
|
|
943
|
+
Returns:
|
|
944
|
+
None
|
|
945
|
+
|
|
946
|
+
Raises:
|
|
947
|
+
ValueError: If client_id is not provided in the database configuration
|
|
948
|
+
Exception: If an unexpected error occurs during role creation
|
|
949
|
+
"""
|
|
950
|
+
from databricks.sdk.service.database import (
|
|
951
|
+
DatabaseInstanceRole,
|
|
952
|
+
DatabaseInstanceRoleIdentityType,
|
|
953
|
+
DatabaseInstanceRoleMembershipRole,
|
|
954
|
+
)
|
|
955
|
+
|
|
956
|
+
from dao_ai.config import value_of
|
|
957
|
+
|
|
958
|
+
# Validate that client_id is provided
|
|
959
|
+
if not database.client_id:
|
|
960
|
+
logger.warning(
|
|
961
|
+
f"client_id is required to create instance role for database {database.instance_name}"
|
|
962
|
+
)
|
|
963
|
+
return
|
|
964
|
+
|
|
965
|
+
# Resolve the client_id value
|
|
966
|
+
client_id: str = value_of(database.client_id)
|
|
967
|
+
role_name: str = client_id
|
|
968
|
+
instance_name: str = database.instance_name
|
|
969
|
+
|
|
970
|
+
logger.debug(
|
|
971
|
+
f"Creating instance role '{role_name}' for database {instance_name} with principal {client_id}"
|
|
972
|
+
)
|
|
973
|
+
|
|
974
|
+
try:
|
|
975
|
+
# Check if role already exists
|
|
976
|
+
try:
|
|
977
|
+
_ = self.w.database.get_database_instance_role(
|
|
978
|
+
instance_name=instance_name,
|
|
979
|
+
name=role_name,
|
|
980
|
+
)
|
|
981
|
+
logger.info(
|
|
982
|
+
f"Instance role '{role_name}' already exists for database {instance_name}"
|
|
983
|
+
)
|
|
984
|
+
return
|
|
985
|
+
except NotFound:
|
|
986
|
+
# Role doesn't exist, proceed with creation
|
|
987
|
+
logger.debug(
|
|
988
|
+
f"Instance role '{role_name}' not found, creating new role..."
|
|
989
|
+
)
|
|
990
|
+
|
|
991
|
+
# Create the database instance role
|
|
992
|
+
role: DatabaseInstanceRole = DatabaseInstanceRole(
|
|
993
|
+
name=role_name,
|
|
994
|
+
identity_type=DatabaseInstanceRoleIdentityType.SERVICE_PRINCIPAL,
|
|
995
|
+
membership_role=DatabaseInstanceRoleMembershipRole.DATABRICKS_SUPERUSER,
|
|
996
|
+
)
|
|
997
|
+
|
|
998
|
+
# Create the role using the API
|
|
999
|
+
self.w.database.create_database_instance_role(
|
|
1000
|
+
instance_name=instance_name,
|
|
1001
|
+
database_instance_role=role,
|
|
1002
|
+
)
|
|
1003
|
+
|
|
1004
|
+
logger.info(
|
|
1005
|
+
f"Successfully created instance role '{role_name}' for database {instance_name}"
|
|
1006
|
+
)
|
|
1007
|
+
|
|
1008
|
+
except Exception as e:
|
|
1009
|
+
error_msg: str = str(e)
|
|
1010
|
+
|
|
1011
|
+
# Handle case where role was created concurrently
|
|
1012
|
+
if (
|
|
1013
|
+
"already exists" in error_msg.lower()
|
|
1014
|
+
or "RESOURCE_ALREADY_EXISTS" in error_msg
|
|
1015
|
+
):
|
|
1016
|
+
logger.info(
|
|
1017
|
+
f"Instance role '{role_name}' was created concurrently for database {instance_name}"
|
|
1018
|
+
)
|
|
1019
|
+
return
|
|
1020
|
+
|
|
1021
|
+
# Re-raise unexpected errors
|
|
1022
|
+
logger.error(
|
|
1023
|
+
f"Error creating instance role '{role_name}' for database {instance_name}: {e}"
|
|
1024
|
+
)
|
|
1025
|
+
raise
|
dao_ai/tools/__init__.py
CHANGED
|
@@ -7,6 +7,7 @@ from dao_ai.tools.core import (
|
|
|
7
7
|
from dao_ai.tools.genie import create_genie_tool
|
|
8
8
|
from dao_ai.tools.mcp import create_mcp_tools
|
|
9
9
|
from dao_ai.tools.python import create_factory_tool, create_python_tool
|
|
10
|
+
from dao_ai.tools.slack import create_send_slack_message_tool
|
|
10
11
|
from dao_ai.tools.time import (
|
|
11
12
|
add_time_tool,
|
|
12
13
|
current_time_tool,
|
|
@@ -27,6 +28,7 @@ __all__ = [
|
|
|
27
28
|
"create_hooks",
|
|
28
29
|
"create_mcp_tools",
|
|
29
30
|
"create_python_tool",
|
|
31
|
+
"create_send_slack_message_tool",
|
|
30
32
|
"create_tools",
|
|
31
33
|
"create_uc_tools",
|
|
32
34
|
"create_vector_search_tool",
|
dao_ai/tools/genie.py
CHANGED
|
@@ -276,7 +276,7 @@ def create_genie_tool(
|
|
|
276
276
|
genie_room: GenieRoomModel | dict[str, Any],
|
|
277
277
|
name: Optional[str] = None,
|
|
278
278
|
description: Optional[str] = None,
|
|
279
|
-
persist_conversation: bool =
|
|
279
|
+
persist_conversation: bool = False,
|
|
280
280
|
truncate_results: bool = False,
|
|
281
281
|
poll_interval: int = DEFAULT_POLLING_INTERVAL_SECS,
|
|
282
282
|
) -> Callable[[str], GenieResponse]:
|
dao_ai/tools/mcp.py
CHANGED
|
@@ -1,10 +1,14 @@
|
|
|
1
1
|
import asyncio
|
|
2
2
|
from typing import Any, Sequence
|
|
3
3
|
|
|
4
|
+
from databricks_mcp import DatabricksOAuthClientProvider
|
|
4
5
|
from langchain_core.runnables.base import RunnableLike
|
|
5
6
|
from langchain_core.tools import tool as create_tool
|
|
6
7
|
from langchain_mcp_adapters.client import MultiServerMCPClient
|
|
8
|
+
from langchain_mcp_adapters.tools import load_mcp_tools
|
|
7
9
|
from loguru import logger
|
|
10
|
+
from mcp import ClientSession
|
|
11
|
+
from mcp.client.streamable_http import streamablehttp_client
|
|
8
12
|
from mcp.types import ListToolsResult, Tool
|
|
9
13
|
|
|
10
14
|
from dao_ai.config import (
|
|
@@ -20,98 +24,141 @@ def create_mcp_tools(
|
|
|
20
24
|
"""
|
|
21
25
|
Create tools for invoking Databricks MCP functions.
|
|
22
26
|
|
|
27
|
+
Supports both direct MCP connections and UC Connection-based MCP access.
|
|
23
28
|
Uses session-based approach to handle authentication token expiration properly.
|
|
29
|
+
|
|
30
|
+
Based on: https://docs.databricks.com/aws/en/generative-ai/mcp/external-mcp
|
|
24
31
|
"""
|
|
25
32
|
logger.debug(f"create_mcp_tools: {function}")
|
|
26
33
|
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
try:
|
|
47
|
-
provider = DatabricksProvider(
|
|
48
|
-
workspace_host=value_of(function.workspace_host),
|
|
49
|
-
client_id=value_of(function.client_id),
|
|
50
|
-
client_secret=value_of(function.client_secret),
|
|
51
|
-
pat=value_of(function.pat),
|
|
52
|
-
)
|
|
53
|
-
headers["Authorization"] = f"Bearer {provider.create_token()}"
|
|
54
|
-
logger.debug("Generated fresh authentication token")
|
|
55
|
-
except Exception as e:
|
|
56
|
-
logger.error(f"Failed to create fresh token: {e}")
|
|
57
|
-
else:
|
|
58
|
-
logger.debug("Using existing authentication token")
|
|
34
|
+
# Check if using UC Connection or direct MCP connection
|
|
35
|
+
if function.connection:
|
|
36
|
+
# Use UC Connection approach with DatabricksOAuthClientProvider
|
|
37
|
+
logger.debug(f"Using UC Connection for MCP: {function.connection.name}")
|
|
38
|
+
logger.debug(f"MCP URL: {function.url}")
|
|
39
|
+
|
|
40
|
+
async def _get_tools_with_connection():
|
|
41
|
+
"""Get tools using DatabricksOAuthClientProvider."""
|
|
42
|
+
workspace_client = function.connection.workspace_client
|
|
43
|
+
|
|
44
|
+
async with streamablehttp_client(
|
|
45
|
+
function.url, auth=DatabricksOAuthClientProvider(workspace_client)
|
|
46
|
+
) as (read_stream, write_stream, _):
|
|
47
|
+
async with ClientSession(read_stream, write_stream) as session:
|
|
48
|
+
# Initialize and list tools
|
|
49
|
+
await session.initialize()
|
|
50
|
+
tools = await load_mcp_tools(session)
|
|
51
|
+
return tools
|
|
59
52
|
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
return response
|
|
53
|
+
try:
|
|
54
|
+
langchain_tools = asyncio.run(_get_tools_with_connection())
|
|
55
|
+
logger.debug(
|
|
56
|
+
f"Retrieved {len(langchain_tools)} MCP tools via UC Connection"
|
|
57
|
+
)
|
|
67
58
|
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
59
|
+
# Wrap tools with human-in-the-loop if needed
|
|
60
|
+
wrapped_tools = [
|
|
61
|
+
as_human_in_the_loop(tool, function) for tool in langchain_tools
|
|
62
|
+
]
|
|
63
|
+
return wrapped_tools
|
|
72
64
|
|
|
73
|
-
try:
|
|
74
|
-
async with client.session(function.name) as session:
|
|
75
|
-
return await session.list_tools()
|
|
76
65
|
except Exception as e:
|
|
77
|
-
logger.error(f"Failed to
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
66
|
+
logger.error(f"Failed to get tools from MCP server via UC Connection: {e}")
|
|
67
|
+
raise RuntimeError(
|
|
68
|
+
f"Failed to list MCP tools for function '{function.name}' via UC Connection '{function.connection.name}': {e}"
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
else:
|
|
72
|
+
# Use direct MCP connection with MultiServerMCPClient
|
|
73
|
+
logger.debug("Using direct MCP connection with MultiServerMCPClient")
|
|
74
|
+
|
|
75
|
+
def _create_fresh_connection() -> dict[str, Any]:
|
|
76
|
+
"""Create connection config with fresh authentication headers."""
|
|
77
|
+
logger.debug("Creating fresh connection...")
|
|
78
|
+
|
|
79
|
+
if function.transport == TransportType.STDIO:
|
|
80
|
+
return {
|
|
81
|
+
"command": function.command,
|
|
82
|
+
"args": function.args,
|
|
83
|
+
"transport": function.transport,
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
# For HTTP transport, generate fresh headers
|
|
87
|
+
headers = function.headers.copy() if function.headers else {}
|
|
88
|
+
|
|
89
|
+
if "Authorization" not in headers:
|
|
90
|
+
logger.debug("Generating fresh authentication token for MCP function")
|
|
91
|
+
|
|
92
|
+
from dao_ai.config import value_of
|
|
93
|
+
from dao_ai.providers.databricks import DatabricksProvider
|
|
94
|
+
|
|
95
|
+
try:
|
|
96
|
+
provider = DatabricksProvider(
|
|
97
|
+
workspace_host=value_of(function.workspace_host),
|
|
98
|
+
client_id=value_of(function.client_id),
|
|
99
|
+
client_secret=value_of(function.client_secret),
|
|
100
|
+
pat=value_of(function.pat),
|
|
101
|
+
)
|
|
102
|
+
headers["Authorization"] = f"Bearer {provider.create_token()}"
|
|
103
|
+
logger.debug("Generated fresh authentication token")
|
|
104
|
+
except Exception as e:
|
|
105
|
+
logger.error(f"Failed to create fresh token: {e}")
|
|
106
|
+
else:
|
|
107
|
+
logger.debug("Using existing authentication token")
|
|
104
108
|
|
|
109
|
+
return {
|
|
110
|
+
"url": function.url,
|
|
111
|
+
"transport": function.transport,
|
|
112
|
+
"headers": headers,
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
# Get available tools from MCP server
|
|
116
|
+
async def _list_mcp_tools():
|
|
105
117
|
connection = _create_fresh_connection()
|
|
106
118
|
client = MultiServerMCPClient({function.name: connection})
|
|
107
119
|
|
|
108
120
|
try:
|
|
109
121
|
async with client.session(function.name) as session:
|
|
110
|
-
return await session.
|
|
122
|
+
return await session.list_tools()
|
|
111
123
|
except Exception as e:
|
|
112
|
-
logger.error(f"
|
|
113
|
-
|
|
124
|
+
logger.error(f"Failed to list MCP tools: {e}")
|
|
125
|
+
return []
|
|
114
126
|
|
|
115
|
-
|
|
127
|
+
# Note: This still needs to run sync during tool creation/registration
|
|
128
|
+
# The actual tool execution will be async
|
|
129
|
+
try:
|
|
130
|
+
mcp_tools: list[Tool] | ListToolsResult = asyncio.run(_list_mcp_tools())
|
|
131
|
+
if isinstance(mcp_tools, ListToolsResult):
|
|
132
|
+
mcp_tools = mcp_tools.tools
|
|
116
133
|
|
|
117
|
-
|
|
134
|
+
logger.debug(f"Retrieved {len(mcp_tools)} MCP tools")
|
|
135
|
+
except Exception as e:
|
|
136
|
+
logger.error(f"Failed to get tools from MCP server: {e}")
|
|
137
|
+
raise RuntimeError(
|
|
138
|
+
f"Failed to list MCP tools for function '{function.name}' with transport '{function.transport}' and URL '{function.url}': {e}"
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# Create wrapper tools with fresh session per invocation
|
|
142
|
+
def _create_tool_wrapper(mcp_tool: Tool) -> RunnableLike:
|
|
143
|
+
@create_tool(
|
|
144
|
+
mcp_tool.name,
|
|
145
|
+
description=mcp_tool.description or f"MCP tool: {mcp_tool.name}",
|
|
146
|
+
args_schema=mcp_tool.inputSchema,
|
|
147
|
+
)
|
|
148
|
+
async def tool_wrapper(**kwargs):
|
|
149
|
+
"""Execute MCP tool with fresh session and authentication."""
|
|
150
|
+
logger.debug(f"Invoking MCP tool {mcp_tool.name} with fresh session")
|
|
151
|
+
|
|
152
|
+
connection = _create_fresh_connection()
|
|
153
|
+
client = MultiServerMCPClient({function.name: connection})
|
|
154
|
+
|
|
155
|
+
try:
|
|
156
|
+
async with client.session(function.name) as session:
|
|
157
|
+
return await session.call_tool(mcp_tool.name, kwargs)
|
|
158
|
+
except Exception as e:
|
|
159
|
+
logger.error(f"MCP tool {mcp_tool.name} failed: {e}")
|
|
160
|
+
raise
|
|
161
|
+
|
|
162
|
+
return as_human_in_the_loop(tool_wrapper, function)
|
|
163
|
+
|
|
164
|
+
return [_create_tool_wrapper(tool) for tool in mcp_tools]
|
dao_ai/tools/slack.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
from typing import Any, Callable, Optional
|
|
2
|
+
|
|
3
|
+
from databricks.sdk.service.serving import ExternalFunctionRequestHttpMethod
|
|
4
|
+
from langchain_core.tools import tool
|
|
5
|
+
from loguru import logger
|
|
6
|
+
from requests import Response
|
|
7
|
+
|
|
8
|
+
from dao_ai.config import ConnectionModel
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def _find_channel_id_by_name(
|
|
12
|
+
connection: ConnectionModel, channel_name: str
|
|
13
|
+
) -> Optional[str]:
|
|
14
|
+
"""
|
|
15
|
+
Find a Slack channel ID by channel name using the conversations.list API.
|
|
16
|
+
|
|
17
|
+
Based on: https://docs.databricks.com/aws/en/generative-ai/agent-framework/slack-agent
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
connection: ConnectionModel with workspace_client
|
|
21
|
+
channel_name: Name of the Slack channel (with or without '#' prefix)
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
Channel ID if found, None otherwise
|
|
25
|
+
"""
|
|
26
|
+
# Remove '#' prefix if present
|
|
27
|
+
clean_name = channel_name.lstrip("#")
|
|
28
|
+
|
|
29
|
+
logger.debug(f"Looking up Slack channel ID for channel name: {clean_name}")
|
|
30
|
+
|
|
31
|
+
try:
|
|
32
|
+
# Call Slack API to list conversations
|
|
33
|
+
response: Response = connection.workspace_client.serving_endpoints.http_request(
|
|
34
|
+
conn=connection.name,
|
|
35
|
+
method=ExternalFunctionRequestHttpMethod.GET,
|
|
36
|
+
path="/api/conversations.list",
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
if response.status_code != 200:
|
|
40
|
+
logger.error(f"Failed to list Slack channels: {response.text}")
|
|
41
|
+
return None
|
|
42
|
+
|
|
43
|
+
# Parse response
|
|
44
|
+
data = response.json()
|
|
45
|
+
|
|
46
|
+
if not data.get("ok"):
|
|
47
|
+
logger.error(f"Slack API returned error: {data.get('error')}")
|
|
48
|
+
return None
|
|
49
|
+
|
|
50
|
+
# Search for channel by name
|
|
51
|
+
channels = data.get("channels", [])
|
|
52
|
+
for channel in channels:
|
|
53
|
+
if channel.get("name") == clean_name:
|
|
54
|
+
channel_id = channel.get("id")
|
|
55
|
+
logger.debug(
|
|
56
|
+
f"Found channel ID '{channel_id}' for channel name '{clean_name}'"
|
|
57
|
+
)
|
|
58
|
+
return channel_id
|
|
59
|
+
|
|
60
|
+
logger.warning(f"Channel '{clean_name}' not found in Slack workspace")
|
|
61
|
+
return None
|
|
62
|
+
|
|
63
|
+
except Exception as e:
|
|
64
|
+
logger.error(f"Error looking up Slack channel: {e}")
|
|
65
|
+
return None
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def create_send_slack_message_tool(
|
|
69
|
+
connection: ConnectionModel | dict[str, Any],
|
|
70
|
+
channel_id: Optional[str] = None,
|
|
71
|
+
channel_name: Optional[str] = None,
|
|
72
|
+
name: Optional[str] = None,
|
|
73
|
+
description: Optional[str] = None,
|
|
74
|
+
) -> Callable[[str], Any]:
|
|
75
|
+
"""
|
|
76
|
+
Create a tool that sends a message to a Slack channel.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
connection: Unity Catalog connection to Slack (ConnectionModel or dict)
|
|
80
|
+
channel_id: Slack channel ID (e.g., 'C1234567890'). If not provided, channel_name is used.
|
|
81
|
+
channel_name: Slack channel name (e.g., 'general' or '#general'). Used to lookup channel_id if not provided.
|
|
82
|
+
name: Custom tool name (default: 'send_slack_message')
|
|
83
|
+
description: Custom tool description
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
A tool function that sends messages to the specified Slack channel
|
|
87
|
+
|
|
88
|
+
Based on: https://docs.databricks.com/aws/en/generative-ai/agent-framework/slack-agent
|
|
89
|
+
"""
|
|
90
|
+
logger.debug("create_send_slack_message_tool")
|
|
91
|
+
|
|
92
|
+
# Validate inputs
|
|
93
|
+
if channel_id is None and channel_name is None:
|
|
94
|
+
raise ValueError("Either channel_id or channel_name must be provided")
|
|
95
|
+
|
|
96
|
+
# Convert connection dict to ConnectionModel if needed
|
|
97
|
+
if isinstance(connection, dict):
|
|
98
|
+
connection = ConnectionModel(**connection)
|
|
99
|
+
|
|
100
|
+
# Look up channel_id from channel_name if needed
|
|
101
|
+
if channel_id is None and channel_name is not None:
|
|
102
|
+
logger.debug(f"Looking up channel_id for channel_name: {channel_name}")
|
|
103
|
+
channel_id = _find_channel_id_by_name(connection, channel_name)
|
|
104
|
+
if channel_id is None:
|
|
105
|
+
raise ValueError(f"Could not find Slack channel with name '{channel_name}'")
|
|
106
|
+
logger.debug(
|
|
107
|
+
f"Resolved channel_name '{channel_name}' to channel_id '{channel_id}'"
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
if name is None:
|
|
111
|
+
name = "send_slack_message"
|
|
112
|
+
|
|
113
|
+
if description is None:
|
|
114
|
+
description = "Send a message to a Slack channel"
|
|
115
|
+
|
|
116
|
+
@tool(
|
|
117
|
+
name_or_callable=name,
|
|
118
|
+
description=description,
|
|
119
|
+
)
|
|
120
|
+
def send_slack_message(text: str) -> str:
|
|
121
|
+
response: Response = connection.workspace_client.serving_endpoints.http_request(
|
|
122
|
+
conn=connection.name,
|
|
123
|
+
method=ExternalFunctionRequestHttpMethod.POST,
|
|
124
|
+
path="/api/chat.postMessage",
|
|
125
|
+
json={"channel": channel_id, "text": text},
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
if response.status_code == 200:
|
|
129
|
+
return "Successful request sent to Slack: " + response.text
|
|
130
|
+
else:
|
|
131
|
+
return (
|
|
132
|
+
"Encountered failure when executing request. Message from Call: "
|
|
133
|
+
+ response.text
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
return send_slack_message
|
dao_ai/utils.py
CHANGED
|
@@ -43,6 +43,7 @@ def get_installed_packages() -> dict[str, str]:
|
|
|
43
43
|
packages: Sequence[str] = [
|
|
44
44
|
f"databricks-agents=={version('databricks-agents')}",
|
|
45
45
|
f"databricks-langchain=={version('databricks-langchain')}",
|
|
46
|
+
f"databricks-mcp=={version('databricks-mcp')}",
|
|
46
47
|
f"databricks-sdk[openai]=={version('databricks-sdk')}",
|
|
47
48
|
f"duckduckgo-search=={version('duckduckgo-search')}",
|
|
48
49
|
f"langchain=={version('langchain')}",
|
|
@@ -56,11 +57,14 @@ def get_installed_packages() -> dict[str, str]:
|
|
|
56
57
|
f"langgraph-swarm=={version('langgraph-swarm')}",
|
|
57
58
|
f"langmem=={version('langmem')}",
|
|
58
59
|
f"loguru=={version('loguru')}",
|
|
60
|
+
f"mcp=={version('mcp')}",
|
|
59
61
|
f"mlflow=={version('mlflow')}",
|
|
62
|
+
f"nest-asyncio=={version('nest-asyncio')}",
|
|
60
63
|
f"openevals=={version('openevals')}",
|
|
61
64
|
f"openpyxl=={version('openpyxl')}",
|
|
62
65
|
f"psycopg[binary,pool]=={version('psycopg')}",
|
|
63
66
|
f"pydantic=={version('pydantic')}",
|
|
67
|
+
f"pyyaml=={version('pyyaml')}",
|
|
64
68
|
f"unitycatalog-ai[databricks]=={version('unitycatalog-ai')}",
|
|
65
69
|
f"unitycatalog-langchain[databricks]=={version('unitycatalog-langchain')}",
|
|
66
70
|
]
|
|
@@ -112,3 +116,7 @@ def load_function(function_name: str) -> Callable[..., Any]:
|
|
|
112
116
|
except (ImportError, AttributeError, TypeError) as e:
|
|
113
117
|
# Provide a detailed error message that includes the original exception
|
|
114
118
|
raise ImportError(f"Failed to import {function_name}: {e}")
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def is_in_model_serving() -> bool:
|
|
122
|
+
return os.environ.get("IS_IN_DB_MODEL_SERVING_ENV", "false").lower() == "true"
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dao-ai
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.23
|
|
4
4
|
Summary: DAO AI: A modular, multi-agent orchestration framework for complex AI workflows. Supports agent handoff, tool integration, and dynamic configuration via YAML.
|
|
5
5
|
Project-URL: Homepage, https://github.com/natefleming/dao-ai
|
|
6
6
|
Project-URL: Documentation, https://natefleming.github.io/dao-ai
|
|
@@ -24,9 +24,10 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
24
24
|
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
25
25
|
Classifier: Topic :: System :: Distributed Computing
|
|
26
26
|
Requires-Python: >=3.12
|
|
27
|
-
Requires-Dist: databricks-agents>=1.6.
|
|
27
|
+
Requires-Dist: databricks-agents>=1.6.1
|
|
28
28
|
Requires-Dist: databricks-langchain>=0.8.0
|
|
29
|
-
Requires-Dist: databricks-
|
|
29
|
+
Requires-Dist: databricks-mcp>=0.3.0
|
|
30
|
+
Requires-Dist: databricks-sdk[openai]>=0.67.0
|
|
30
31
|
Requires-Dist: duckduckgo-search>=8.0.2
|
|
31
32
|
Requires-Dist: grandalf>=0.8
|
|
32
33
|
Requires-Dist: langchain-mcp-adapters>=0.1.10
|
|
@@ -653,7 +654,7 @@ test:
|
|
|
653
654
|
#### 4. MCP (Model Context Protocol) Tools (`type: mcp`)
|
|
654
655
|
MCP tools allow interaction with external services that implement the Model Context Protocol, supporting both HTTP and stdio transports.
|
|
655
656
|
|
|
656
|
-
**Configuration Example:**
|
|
657
|
+
**Configuration Example (Direct URL):**
|
|
657
658
|
```yaml
|
|
658
659
|
tools:
|
|
659
660
|
weather_tool_mcp:
|
|
@@ -664,8 +665,30 @@ test:
|
|
|
664
665
|
transport: streamable_http
|
|
665
666
|
url: http://localhost:8000/mcp
|
|
666
667
|
```
|
|
668
|
+
|
|
669
|
+
**Configuration Example (Unity Catalog Connection):**
|
|
670
|
+
MCP tools can also use Unity Catalog Connections for secure, governed access with on-behalf-of-user capabilities. The connection provides OAuth authentication, while the URL specifies the endpoint:
|
|
671
|
+
```yaml
|
|
672
|
+
resources:
|
|
673
|
+
connections:
|
|
674
|
+
github_connection:
|
|
675
|
+
name: github_u2m_connection # UC Connection name
|
|
676
|
+
|
|
677
|
+
tools:
|
|
678
|
+
github_mcp:
|
|
679
|
+
name: github_mcp
|
|
680
|
+
function:
|
|
681
|
+
type: mcp
|
|
682
|
+
name: github_mcp
|
|
683
|
+
transport: streamable_http
|
|
684
|
+
url: https://workspace.databricks.com/api/2.0/mcp/external/github_u2m_connection # MCP endpoint URL
|
|
685
|
+
connection: *github_connection # UC Connection provides OAuth authentication
|
|
686
|
+
```
|
|
687
|
+
|
|
667
688
|
**Development:**
|
|
668
|
-
Ensure the MCP service is running and accessible at the specified URL or command.
|
|
689
|
+
- **For direct URL connections**: Ensure the MCP service is running and accessible at the specified URL or command. Provide OAuth credentials (client_id, client_secret) or PAT for authentication.
|
|
690
|
+
- **For UC Connection**: URL is required to specify the endpoint. The connection provides OAuth authentication via the workspace client. Ensure the connection is configured in Unity Catalog with appropriate MCP scopes (`mcp.genie`, `mcp.functions`, `mcp.vectorsearch`, `mcp.external`).
|
|
691
|
+
- The framework will handle the MCP protocol communication automatically, including session management and authentication.
|
|
669
692
|
|
|
670
693
|
### Configuring New Agents
|
|
671
694
|
|
|
@@ -3,8 +3,8 @@ dao_ai/agent_as_code.py,sha256=kPSeDz2-1jRaed1TMs4LA3VECoyqe9_Ed2beRLB9gXQ,472
|
|
|
3
3
|
dao_ai/catalog.py,sha256=sPZpHTD3lPx4EZUtIWeQV7VQM89WJ6YH__wluk1v2lE,4947
|
|
4
4
|
dao_ai/chat_models.py,sha256=uhwwOTeLyHWqoTTgHrs4n5iSyTwe4EQcLKnh3jRxPWI,8626
|
|
5
5
|
dao_ai/cli.py,sha256=Aez2TQW3Q8Ho1IaIkRggt0NevDxAAVPjXkePC5GPJF0,20429
|
|
6
|
-
dao_ai/config.py,sha256=
|
|
7
|
-
dao_ai/graph.py,sha256=
|
|
6
|
+
dao_ai/config.py,sha256=XHU6xkRAoTeiZYH5ns_fLwcR6EaxRAGeMwRSoW3n0S8,55431
|
|
7
|
+
dao_ai/graph.py,sha256=APYc2y3cig4P52X4sOHSFSZNK8j5EtEPJLFwWeJ3KQQ,7956
|
|
8
8
|
dao_ai/guardrails.py,sha256=4TKArDONRy8RwHzOT1plZ1rhy3x9GF_aeGpPCRl6wYA,4016
|
|
9
9
|
dao_ai/messages.py,sha256=xl_3-WcFqZKCFCiov8sZOPljTdM3gX3fCHhxq-xFg2U,7005
|
|
10
10
|
dao_ai/models.py,sha256=8r8GIG3EGxtVyWsRNI56lVaBjiNrPkzh4HdwMZRq8iw,31689
|
|
@@ -12,29 +12,30 @@ dao_ai/nodes.py,sha256=SSuFNTXOdFaKg_aX-yUkQO7fM9wvNGu14lPXKDapU1U,8461
|
|
|
12
12
|
dao_ai/prompts.py,sha256=vpmIbWs_szXUgNNDs5Gh2LcxKZti5pHDKSfoClUcgX0,1289
|
|
13
13
|
dao_ai/state.py,sha256=_lF9krAYYjvFDMUwZzVKOn0ZnXKcOrbjWKdre0C5B54,1137
|
|
14
14
|
dao_ai/types.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
15
|
-
dao_ai/utils.py,sha256=
|
|
15
|
+
dao_ai/utils.py,sha256=yXgqHrYdO5qDxgxUs2G5XJeLFgwg8D0BIJvbFkqSbhs,4519
|
|
16
16
|
dao_ai/vector_search.py,sha256=jlaFS_iizJ55wblgzZmswMM3UOL-qOp2BGJc0JqXYSg,2839
|
|
17
17
|
dao_ai/hooks/__init__.py,sha256=LlHGIuiZt6vGW8K5AQo1XJEkBP5vDVtMhq0IdjcLrD4,417
|
|
18
18
|
dao_ai/hooks/core.py,sha256=ZShHctUSoauhBgdf1cecy9-D7J6-sGn-pKjuRMumW5U,6663
|
|
19
19
|
dao_ai/memory/__init__.py,sha256=1kHx_p9abKYFQ6EYD05nuc1GS5HXVEpufmjBGw_7Uho,260
|
|
20
20
|
dao_ai/memory/base.py,sha256=99nfr2UZJ4jmfTL_KrqUlRSCoRxzkZyWyx5WqeUoMdQ,338
|
|
21
|
-
dao_ai/memory/core.py,sha256=
|
|
22
|
-
dao_ai/memory/postgres.py,sha256=
|
|
21
|
+
dao_ai/memory/core.py,sha256=DnEjQO3S7hXr3CDDd7C2eE7fQUmcCS_8q9BXEgjPH3U,4271
|
|
22
|
+
dao_ai/memory/postgres.py,sha256=vvI3osjx1EoU5GBA6SCUstTBKillcmLl12hVgDMjfJY,15346
|
|
23
23
|
dao_ai/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
24
24
|
dao_ai/providers/base.py,sha256=-fjKypCOk28h6vioPfMj9YZSw_3Kcbi2nMuAyY7vX9k,1383
|
|
25
|
-
dao_ai/providers/databricks.py,sha256=
|
|
26
|
-
dao_ai/tools/__init__.py,sha256=
|
|
25
|
+
dao_ai/providers/databricks.py,sha256=CFZ2RojcTjiJ1aGwNI3_0qCGf339w2o5h9CRDKNesLs,39825
|
|
26
|
+
dao_ai/tools/__init__.py,sha256=G5-5Yi6zpQOH53b5IzLdtsC6g0Ep6leI5GxgxOmgw7Q,1203
|
|
27
27
|
dao_ai/tools/agent.py,sha256=WbQnyziiT12TLMrA7xK0VuOU029tdmUBXbUl-R1VZ0Q,1886
|
|
28
28
|
dao_ai/tools/core.py,sha256=Kei33S8vrmvPOAyrFNekaWmV2jqZ-IPS1QDSvU7RZF0,1984
|
|
29
|
-
dao_ai/tools/genie.py,sha256=
|
|
29
|
+
dao_ai/tools/genie.py,sha256=8HSOCzSg6PlBzBYXMmNfUnl-LO03p3Ki3fxLPm_dhPg,15051
|
|
30
30
|
dao_ai/tools/human_in_the_loop.py,sha256=yk35MO9eNETnYFH-sqlgR-G24TrEgXpJlnZUustsLkI,3681
|
|
31
|
-
dao_ai/tools/mcp.py,sha256=
|
|
31
|
+
dao_ai/tools/mcp.py,sha256=RAAG97boEDJKlX7X_XUz-l-nH5DdqtHUG_I2zw1lWNk,6844
|
|
32
32
|
dao_ai/tools/python.py,sha256=XcQiTMshZyLUTVR5peB3vqsoUoAAy8gol9_pcrhddfI,1831
|
|
33
|
+
dao_ai/tools/slack.py,sha256=SCvyVcD9Pv_XXPXePE_fSU1Pd8VLTEkKDLvoGTZWy2Y,4775
|
|
33
34
|
dao_ai/tools/time.py,sha256=Y-23qdnNHzwjvnfkWvYsE7PoWS1hfeKy44tA7sCnNac,8759
|
|
34
35
|
dao_ai/tools/unity_catalog.py,sha256=uX_h52BuBAr4c9UeqSMI7DNz3BPRLeai5tBVW4sJqRI,13113
|
|
35
36
|
dao_ai/tools/vector_search.py,sha256=EDYQs51zIPaAP0ma1D81wJT77GQ-v-cjb2XrFVWfWdg,2621
|
|
36
|
-
dao_ai-0.0.
|
|
37
|
-
dao_ai-0.0.
|
|
38
|
-
dao_ai-0.0.
|
|
39
|
-
dao_ai-0.0.
|
|
40
|
-
dao_ai-0.0.
|
|
37
|
+
dao_ai-0.0.23.dist-info/METADATA,sha256=6GfCnhhQN9t4x1LX8mUHsOTfr4mgdGR1Xx070pjIm_g,42638
|
|
38
|
+
dao_ai-0.0.23.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
39
|
+
dao_ai-0.0.23.dist-info/entry_points.txt,sha256=Xa-UFyc6gWGwMqMJOt06ZOog2vAfygV_DSwg1AiP46g,43
|
|
40
|
+
dao_ai-0.0.23.dist-info/licenses/LICENSE,sha256=YZt3W32LtPYruuvHE9lGk2bw6ZPMMJD8yLrjgHybyz4,1069
|
|
41
|
+
dao_ai-0.0.23.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|