dao-ai 0.0.20__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/config.py +21 -6
- dao_ai/memory/postgres.py +31 -3
- dao_ai/models.py +137 -2
- dao_ai/providers/databricks.py +2 -0
- dao_ai/state.py +3 -0
- dao_ai/tools/genie.py +346 -34
- {dao_ai-0.0.20.dist-info → dao_ai-0.0.21.dist-info}/METADATA +1 -1
- {dao_ai-0.0.20.dist-info → dao_ai-0.0.21.dist-info}/RECORD +11 -11
- {dao_ai-0.0.20.dist-info → dao_ai-0.0.21.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.20.dist-info → dao_ai-0.0.21.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.20.dist-info → dao_ai-0.0.21.dist-info}/licenses/LICENSE +0 -0
dao_ai/config.py
CHANGED
|
@@ -427,7 +427,8 @@ class GenieRoomModel(BaseModel, IsDatabricksResource):
|
|
|
427
427
|
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
428
428
|
return [
|
|
429
429
|
DatabricksGenieSpace(
|
|
430
|
-
genie_space_id=self.space_id,
|
|
430
|
+
genie_space_id=value_of(self.space_id),
|
|
431
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
431
432
|
)
|
|
432
433
|
]
|
|
433
434
|
|
|
@@ -437,7 +438,7 @@ class GenieRoomModel(BaseModel, IsDatabricksResource):
|
|
|
437
438
|
return self
|
|
438
439
|
|
|
439
440
|
|
|
440
|
-
class VolumeModel(BaseModel, HasFullName):
|
|
441
|
+
class VolumeModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
441
442
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
442
443
|
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
443
444
|
name: str
|
|
@@ -455,6 +456,13 @@ class VolumeModel(BaseModel, HasFullName):
|
|
|
455
456
|
provider: ServiceProvider = DatabricksProvider(w=w)
|
|
456
457
|
provider.create_volume(self)
|
|
457
458
|
|
|
459
|
+
@property
|
|
460
|
+
def api_scopes(self) -> Sequence[str]:
|
|
461
|
+
return ["files.files", "catalog.volumes"]
|
|
462
|
+
|
|
463
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
464
|
+
return []
|
|
465
|
+
|
|
458
466
|
|
|
459
467
|
class VolumePathModel(BaseModel, HasFullName):
|
|
460
468
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
@@ -683,7 +691,8 @@ class WarehouseModel(BaseModel, IsDatabricksResource):
|
|
|
683
691
|
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
684
692
|
return [
|
|
685
693
|
DatabricksSQLWarehouse(
|
|
686
|
-
warehouse_id=self.warehouse_id,
|
|
694
|
+
warehouse_id=value_of(self.warehouse_id),
|
|
695
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
687
696
|
)
|
|
688
697
|
]
|
|
689
698
|
|
|
@@ -867,7 +876,7 @@ class PythonFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
867
876
|
|
|
868
877
|
class FactoryFunctionModel(BaseFunctionModel, HasFullName):
|
|
869
878
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
870
|
-
args: Optional[dict[str,
|
|
879
|
+
args: Optional[dict[str, Any]] = Field(default_factory=dict)
|
|
871
880
|
type: Literal[FunctionType.FACTORY] = FunctionType.FACTORY
|
|
872
881
|
|
|
873
882
|
@property
|
|
@@ -879,6 +888,12 @@ class FactoryFunctionModel(BaseFunctionModel, HasFullName):
|
|
|
879
888
|
|
|
880
889
|
return [create_factory_tool(self, **kwargs)]
|
|
881
890
|
|
|
891
|
+
@model_validator(mode="after")
|
|
892
|
+
def update_args(self):
|
|
893
|
+
for key, value in self.args.items():
|
|
894
|
+
self.args[key] = value_of(value)
|
|
895
|
+
return self
|
|
896
|
+
|
|
882
897
|
|
|
883
898
|
class TransportType(str, Enum):
|
|
884
899
|
STREAMABLE_HTTP = "streamable_http"
|
|
@@ -1262,12 +1277,12 @@ class AppModel(BaseModel):
|
|
|
1262
1277
|
if len(self.agents) > 1:
|
|
1263
1278
|
default_agent: AgentModel = self.agents[0]
|
|
1264
1279
|
self.orchestration = OrchestrationModel(
|
|
1265
|
-
|
|
1280
|
+
supervisor=SupervisorModel(model=default_agent.model)
|
|
1266
1281
|
)
|
|
1267
1282
|
elif len(self.agents) == 1:
|
|
1268
1283
|
default_agent: AgentModel = self.agents[0]
|
|
1269
1284
|
self.orchestration = OrchestrationModel(
|
|
1270
|
-
|
|
1285
|
+
swarm=SwarmModel(
|
|
1271
1286
|
model=default_agent.model, default_agent=default_agent
|
|
1272
1287
|
)
|
|
1273
1288
|
)
|
dao_ai/memory/postgres.py
CHANGED
|
@@ -74,8 +74,17 @@ class AsyncPostgresPoolManager:
|
|
|
74
74
|
async with cls._lock:
|
|
75
75
|
for connection_key, pool in cls._pools.items():
|
|
76
76
|
try:
|
|
77
|
-
|
|
77
|
+
# Use a short timeout to avoid blocking on pool closure
|
|
78
|
+
await asyncio.wait_for(pool.close(), timeout=2.0)
|
|
78
79
|
logger.debug(f"Closed PostgreSQL pool: {connection_key}")
|
|
80
|
+
except asyncio.TimeoutError:
|
|
81
|
+
logger.warning(
|
|
82
|
+
f"Timeout closing pool {connection_key}, forcing closure"
|
|
83
|
+
)
|
|
84
|
+
except asyncio.CancelledError:
|
|
85
|
+
logger.warning(
|
|
86
|
+
f"Pool closure cancelled for {connection_key} (shutdown in progress)"
|
|
87
|
+
)
|
|
79
88
|
except Exception as e:
|
|
80
89
|
logger.error(f"Error closing pool {connection_key}: {e}")
|
|
81
90
|
cls._pools.clear()
|
|
@@ -369,8 +378,27 @@ def _shutdown_pools():
|
|
|
369
378
|
|
|
370
379
|
def _shutdown_async_pools():
|
|
371
380
|
try:
|
|
372
|
-
|
|
373
|
-
|
|
381
|
+
# Try to get the current event loop first
|
|
382
|
+
try:
|
|
383
|
+
loop = asyncio.get_running_loop()
|
|
384
|
+
# If we're already in an event loop, create a task
|
|
385
|
+
loop.create_task(AsyncPostgresPoolManager.close_all_pools())
|
|
386
|
+
logger.debug("Scheduled async pool closure in running event loop")
|
|
387
|
+
except RuntimeError:
|
|
388
|
+
# No running loop, try to get or create one
|
|
389
|
+
try:
|
|
390
|
+
loop = asyncio.get_event_loop()
|
|
391
|
+
if loop.is_closed():
|
|
392
|
+
# Loop is closed, create a new one
|
|
393
|
+
loop = asyncio.new_event_loop()
|
|
394
|
+
asyncio.set_event_loop(loop)
|
|
395
|
+
loop.run_until_complete(AsyncPostgresPoolManager.close_all_pools())
|
|
396
|
+
logger.debug("Successfully closed all asynchronous PostgreSQL pools")
|
|
397
|
+
except Exception as inner_e:
|
|
398
|
+
# If all else fails, just log the error
|
|
399
|
+
logger.warning(
|
|
400
|
+
f"Could not close async pools cleanly during shutdown: {inner_e}"
|
|
401
|
+
)
|
|
374
402
|
except Exception as e:
|
|
375
403
|
logger.error(
|
|
376
404
|
f"Error closing asynchronous PostgreSQL pools during shutdown: {e}"
|
dao_ai/models.py
CHANGED
|
@@ -5,6 +5,7 @@ from typing import Any, Generator, Optional, Sequence, Union
|
|
|
5
5
|
|
|
6
6
|
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
|
7
7
|
from langgraph.graph.state import CompiledStateGraph
|
|
8
|
+
from langgraph.types import StateSnapshot
|
|
8
9
|
from loguru import logger
|
|
9
10
|
from mlflow import MlflowClient
|
|
10
11
|
from mlflow.pyfunc import ChatAgent, ChatModel, ResponsesAgent
|
|
@@ -59,6 +60,113 @@ def get_latest_model_version(model_name: str) -> int:
|
|
|
59
60
|
return latest_version
|
|
60
61
|
|
|
61
62
|
|
|
63
|
+
async def get_state_snapshot_async(
|
|
64
|
+
graph: CompiledStateGraph, thread_id: str
|
|
65
|
+
) -> Optional[StateSnapshot]:
|
|
66
|
+
"""
|
|
67
|
+
Retrieve the state snapshot from the graph for a given thread_id asynchronously.
|
|
68
|
+
|
|
69
|
+
This utility function accesses the graph's checkpointer to retrieve the current
|
|
70
|
+
state snapshot, which contains the full state values and metadata.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
graph: The compiled LangGraph state machine
|
|
74
|
+
thread_id: The thread/conversation ID to retrieve state for
|
|
75
|
+
|
|
76
|
+
Returns:
|
|
77
|
+
StateSnapshot if found, None otherwise
|
|
78
|
+
"""
|
|
79
|
+
logger.debug(f"Retrieving state snapshot for thread_id: {thread_id}")
|
|
80
|
+
try:
|
|
81
|
+
# Check if graph has a checkpointer
|
|
82
|
+
if graph.checkpointer is None:
|
|
83
|
+
logger.debug("No checkpointer available in graph")
|
|
84
|
+
return None
|
|
85
|
+
|
|
86
|
+
# Get the current state from the checkpointer (use async version)
|
|
87
|
+
config: dict[str, Any] = {"configurable": {"thread_id": thread_id}}
|
|
88
|
+
state_snapshot: Optional[StateSnapshot] = await graph.aget_state(config)
|
|
89
|
+
|
|
90
|
+
if state_snapshot is None:
|
|
91
|
+
logger.debug(f"No state found for thread_id: {thread_id}")
|
|
92
|
+
return None
|
|
93
|
+
|
|
94
|
+
return state_snapshot
|
|
95
|
+
|
|
96
|
+
except Exception as e:
|
|
97
|
+
logger.warning(f"Error retrieving state snapshot for thread {thread_id}: {e}")
|
|
98
|
+
return None
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def get_state_snapshot(
|
|
102
|
+
graph: CompiledStateGraph, thread_id: str
|
|
103
|
+
) -> Optional[StateSnapshot]:
|
|
104
|
+
"""
|
|
105
|
+
Retrieve the state snapshot from the graph for a given thread_id.
|
|
106
|
+
|
|
107
|
+
This is a synchronous wrapper around get_state_snapshot_async.
|
|
108
|
+
Use this for backward compatibility in synchronous contexts.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
graph: The compiled LangGraph state machine
|
|
112
|
+
thread_id: The thread/conversation ID to retrieve state for
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
StateSnapshot if found, None otherwise
|
|
116
|
+
"""
|
|
117
|
+
import asyncio
|
|
118
|
+
|
|
119
|
+
try:
|
|
120
|
+
loop = asyncio.get_event_loop()
|
|
121
|
+
except RuntimeError:
|
|
122
|
+
loop = asyncio.new_event_loop()
|
|
123
|
+
asyncio.set_event_loop(loop)
|
|
124
|
+
|
|
125
|
+
try:
|
|
126
|
+
return loop.run_until_complete(get_state_snapshot_async(graph, thread_id))
|
|
127
|
+
except Exception as e:
|
|
128
|
+
logger.warning(f"Error in synchronous state snapshot retrieval: {e}")
|
|
129
|
+
return None
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def get_genie_conversation_ids_from_state(
|
|
133
|
+
state_snapshot: Optional[StateSnapshot],
|
|
134
|
+
) -> dict[str, str]:
|
|
135
|
+
"""
|
|
136
|
+
Extract genie_conversation_ids from a state snapshot.
|
|
137
|
+
|
|
138
|
+
This function extracts the genie_conversation_ids dictionary from the state
|
|
139
|
+
snapshot values if present.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
state_snapshot: The state snapshot to extract conversation IDs from
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
A dictionary mapping genie space_id to conversation_id, or empty dict if not found
|
|
146
|
+
"""
|
|
147
|
+
if state_snapshot is None:
|
|
148
|
+
return {}
|
|
149
|
+
|
|
150
|
+
try:
|
|
151
|
+
# Extract state values - these contain the actual state data
|
|
152
|
+
state_values: dict[str, Any] = state_snapshot.values
|
|
153
|
+
|
|
154
|
+
# Extract genie_conversation_ids from state values
|
|
155
|
+
genie_conversation_ids: dict[str, str] = state_values.get(
|
|
156
|
+
"genie_conversation_ids", {}
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
if genie_conversation_ids:
|
|
160
|
+
logger.debug(f"Retrieved genie_conversation_ids: {genie_conversation_ids}")
|
|
161
|
+
return genie_conversation_ids
|
|
162
|
+
|
|
163
|
+
return {}
|
|
164
|
+
|
|
165
|
+
except Exception as e:
|
|
166
|
+
logger.warning(f"Error extracting genie_conversation_ids from state: {e}")
|
|
167
|
+
return {}
|
|
168
|
+
|
|
169
|
+
|
|
62
170
|
class LanggraphChatModel(ChatModel):
|
|
63
171
|
"""
|
|
64
172
|
ChatModel that delegates requests to a LangGraph CompiledStateGraph.
|
|
@@ -257,7 +365,19 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
257
365
|
text=last_message.content, id=f"msg_{uuid.uuid4().hex[:8]}"
|
|
258
366
|
)
|
|
259
367
|
|
|
260
|
-
|
|
368
|
+
# Retrieve genie_conversation_ids from state if available
|
|
369
|
+
custom_outputs: dict[str, Any] = custom_inputs.copy()
|
|
370
|
+
thread_id: Optional[str] = context.thread_id
|
|
371
|
+
if thread_id:
|
|
372
|
+
state_snapshot: Optional[StateSnapshot] = loop.run_until_complete(
|
|
373
|
+
get_state_snapshot_async(self.graph, thread_id)
|
|
374
|
+
)
|
|
375
|
+
genie_conversation_ids: dict[str, str] = (
|
|
376
|
+
get_genie_conversation_ids_from_state(state_snapshot)
|
|
377
|
+
)
|
|
378
|
+
if genie_conversation_ids:
|
|
379
|
+
custom_outputs["genie_conversation_ids"] = genie_conversation_ids
|
|
380
|
+
|
|
261
381
|
return ResponsesAgentResponse(
|
|
262
382
|
output=[output_item], custom_outputs=custom_outputs
|
|
263
383
|
)
|
|
@@ -318,7 +438,22 @@ class LanggraphResponsesAgent(ResponsesAgent):
|
|
|
318
438
|
**self.create_text_delta(delta=content, item_id=item_id)
|
|
319
439
|
)
|
|
320
440
|
|
|
321
|
-
|
|
441
|
+
# Retrieve genie_conversation_ids from state if available
|
|
442
|
+
custom_outputs: dict[str, Any] = custom_inputs.copy()
|
|
443
|
+
thread_id: Optional[str] = context.thread_id
|
|
444
|
+
|
|
445
|
+
if thread_id:
|
|
446
|
+
state_snapshot: Optional[
|
|
447
|
+
StateSnapshot
|
|
448
|
+
] = await get_state_snapshot_async(self.graph, thread_id)
|
|
449
|
+
genie_conversation_ids: dict[str, str] = (
|
|
450
|
+
get_genie_conversation_ids_from_state(state_snapshot)
|
|
451
|
+
)
|
|
452
|
+
if genie_conversation_ids:
|
|
453
|
+
custom_outputs["genie_conversation_ids"] = (
|
|
454
|
+
genie_conversation_ids
|
|
455
|
+
)
|
|
456
|
+
|
|
322
457
|
# Yield final output item
|
|
323
458
|
yield ResponsesAgentStreamEvent(
|
|
324
459
|
type="response.output_item.done",
|
dao_ai/providers/databricks.py
CHANGED
|
@@ -226,6 +226,7 @@ class DatabricksProvider(ServiceProvider):
|
|
|
226
226
|
config.resources.connections.values()
|
|
227
227
|
)
|
|
228
228
|
databases: Sequence[DatabaseModel] = list(config.resources.databases.values())
|
|
229
|
+
volumes: Sequence[VolumeModel] = list(config.resources.volumes.values())
|
|
229
230
|
|
|
230
231
|
resources: Sequence[IsDatabricksResource] = (
|
|
231
232
|
llms
|
|
@@ -236,6 +237,7 @@ class DatabricksProvider(ServiceProvider):
|
|
|
236
237
|
+ tables
|
|
237
238
|
+ connections
|
|
238
239
|
+ databases
|
|
240
|
+
+ volumes
|
|
239
241
|
)
|
|
240
242
|
|
|
241
243
|
# Flatten all resources from all models into a single list
|
dao_ai/state.py
CHANGED
|
@@ -31,6 +31,9 @@ class SharedState(MessagesState):
|
|
|
31
31
|
is_valid: bool # message validation node
|
|
32
32
|
message_error: str
|
|
33
33
|
|
|
34
|
+
# A mapping of genie space_id to conversation_id
|
|
35
|
+
genie_conversation_ids: dict[str, str] # Genie
|
|
36
|
+
|
|
34
37
|
|
|
35
38
|
class Context(BaseModel):
|
|
36
39
|
user_id: str | None = None
|
dao_ai/tools/genie.py
CHANGED
|
@@ -1,20 +1,284 @@
|
|
|
1
|
+
import bisect
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
1
4
|
import os
|
|
5
|
+
import time
|
|
6
|
+
from dataclasses import asdict, dataclass
|
|
7
|
+
from datetime import datetime
|
|
2
8
|
from textwrap import dedent
|
|
3
|
-
from typing import Any, Callable, Optional
|
|
9
|
+
from typing import Annotated, Any, Callable, Optional, Union
|
|
4
10
|
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
from
|
|
11
|
+
import mlflow
|
|
12
|
+
import pandas as pd
|
|
13
|
+
from databricks.sdk import WorkspaceClient
|
|
14
|
+
from langchain_core.messages import ToolMessage
|
|
15
|
+
from langchain_core.tools import InjectedToolCallId, tool
|
|
16
|
+
from langgraph.prebuilt import InjectedState
|
|
17
|
+
from langgraph.types import Command
|
|
18
|
+
from loguru import logger
|
|
19
|
+
from pydantic import BaseModel, Field
|
|
8
20
|
|
|
9
|
-
from dao_ai.config import
|
|
10
|
-
|
|
11
|
-
|
|
21
|
+
from dao_ai.config import AnyVariable, CompositeVariableModel, GenieRoomModel, value_of
|
|
22
|
+
|
|
23
|
+
MAX_TOKENS_OF_DATA: int = 20000
|
|
24
|
+
MAX_ITERATIONS: int = 50
|
|
25
|
+
DEFAULT_POLLING_INTERVAL_SECS: int = 2
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def _count_tokens(text):
|
|
29
|
+
import tiktoken
|
|
30
|
+
|
|
31
|
+
encoding = tiktoken.encoding_for_model("gpt-4o")
|
|
32
|
+
return len(encoding.encode(text))
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass
|
|
36
|
+
class GenieResponse:
|
|
37
|
+
conversation_id: str
|
|
38
|
+
result: Union[str, pd.DataFrame]
|
|
39
|
+
query: Optional[str] = ""
|
|
40
|
+
description: Optional[str] = ""
|
|
41
|
+
|
|
42
|
+
def to_json(self):
|
|
43
|
+
return json.dumps(asdict(self))
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class GenieToolInput(BaseModel):
|
|
47
|
+
"""Input schema for the Genie tool."""
|
|
48
|
+
|
|
49
|
+
question: str = Field(
|
|
50
|
+
description="The question to ask Genie about your data. Ask simple, clear questions about your tabular data. For complex analysis, ask multiple simple questions rather than one complex question."
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _truncate_result(dataframe: pd.DataFrame) -> str:
|
|
55
|
+
query_result = dataframe.to_markdown()
|
|
56
|
+
tokens_used = _count_tokens(query_result)
|
|
57
|
+
|
|
58
|
+
# If the full result fits, return it
|
|
59
|
+
if tokens_used <= MAX_TOKENS_OF_DATA:
|
|
60
|
+
return query_result.strip()
|
|
61
|
+
|
|
62
|
+
def is_too_big(n):
|
|
63
|
+
return _count_tokens(dataframe.iloc[:n].to_markdown()) > MAX_TOKENS_OF_DATA
|
|
64
|
+
|
|
65
|
+
# Use bisect_left to find the cutoff point of rows within the max token data limit in a O(log n) complexity
|
|
66
|
+
# Passing True, as this is the target value we are looking for when _is_too_big returns
|
|
67
|
+
cutoff = bisect.bisect_left(range(len(dataframe) + 1), True, key=is_too_big)
|
|
68
|
+
|
|
69
|
+
# Slice to the found limit
|
|
70
|
+
truncated_df = dataframe.iloc[:cutoff]
|
|
71
|
+
|
|
72
|
+
# Edge case: Cannot return any rows because of tokens so return an empty string
|
|
73
|
+
if len(truncated_df) == 0:
|
|
74
|
+
return ""
|
|
75
|
+
|
|
76
|
+
truncated_result = truncated_df.to_markdown()
|
|
77
|
+
|
|
78
|
+
# Double-check edge case if we overshot by one
|
|
79
|
+
if _count_tokens(truncated_result) > MAX_TOKENS_OF_DATA:
|
|
80
|
+
truncated_result = truncated_df.iloc[:-1].to_markdown()
|
|
81
|
+
return truncated_result
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@mlflow.trace(span_type="PARSER")
|
|
85
|
+
def _parse_query_result(resp, truncate_results) -> Union[str, pd.DataFrame]:
|
|
86
|
+
output = resp["result"]
|
|
87
|
+
if not output:
|
|
88
|
+
return "EMPTY"
|
|
89
|
+
|
|
90
|
+
columns = resp["manifest"]["schema"]["columns"]
|
|
91
|
+
header = [str(col["name"]) for col in columns]
|
|
92
|
+
rows = []
|
|
93
|
+
|
|
94
|
+
for item in output["data_array"]:
|
|
95
|
+
row = []
|
|
96
|
+
for column, value in zip(columns, item):
|
|
97
|
+
type_name = column["type_name"]
|
|
98
|
+
if value is None:
|
|
99
|
+
row.append(None)
|
|
100
|
+
continue
|
|
101
|
+
|
|
102
|
+
if type_name in ["INT", "LONG", "SHORT", "BYTE"]:
|
|
103
|
+
row.append(int(value))
|
|
104
|
+
elif type_name in ["FLOAT", "DOUBLE", "DECIMAL"]:
|
|
105
|
+
row.append(float(value))
|
|
106
|
+
elif type_name == "BOOLEAN":
|
|
107
|
+
row.append(value.lower() == "true")
|
|
108
|
+
elif type_name == "DATE" or type_name == "TIMESTAMP":
|
|
109
|
+
row.append(datetime.strptime(value[:10], "%Y-%m-%d").date())
|
|
110
|
+
elif type_name == "BINARY":
|
|
111
|
+
row.append(bytes(value, "utf-8"))
|
|
112
|
+
else:
|
|
113
|
+
row.append(value)
|
|
114
|
+
|
|
115
|
+
rows.append(row)
|
|
116
|
+
|
|
117
|
+
dataframe = pd.DataFrame(rows, columns=header)
|
|
118
|
+
|
|
119
|
+
if truncate_results:
|
|
120
|
+
query_result = _truncate_result(dataframe)
|
|
121
|
+
else:
|
|
122
|
+
query_result = dataframe.to_markdown()
|
|
123
|
+
|
|
124
|
+
return query_result.strip()
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class Genie:
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
space_id,
|
|
131
|
+
client: WorkspaceClient | None = None,
|
|
132
|
+
truncate_results: bool = False,
|
|
133
|
+
polling_interval: int = DEFAULT_POLLING_INTERVAL_SECS,
|
|
134
|
+
):
|
|
135
|
+
self.space_id = space_id
|
|
136
|
+
workspace_client = client or WorkspaceClient()
|
|
137
|
+
self.genie = workspace_client.genie
|
|
138
|
+
self.description = self.genie.get_space(space_id).description
|
|
139
|
+
self.headers = {
|
|
140
|
+
"Accept": "application/json",
|
|
141
|
+
"Content-Type": "application/json",
|
|
142
|
+
}
|
|
143
|
+
self.truncate_results = truncate_results
|
|
144
|
+
if polling_interval < 1 or polling_interval > 30:
|
|
145
|
+
raise ValueError("poll_interval must be between 1 and 30 seconds")
|
|
146
|
+
self.poll_interval = polling_interval
|
|
147
|
+
|
|
148
|
+
@mlflow.trace()
|
|
149
|
+
def start_conversation(self, content):
|
|
150
|
+
resp = self.genie._api.do(
|
|
151
|
+
"POST",
|
|
152
|
+
f"/api/2.0/genie/spaces/{self.space_id}/start-conversation",
|
|
153
|
+
body={"content": content},
|
|
154
|
+
headers=self.headers,
|
|
155
|
+
)
|
|
156
|
+
return resp
|
|
157
|
+
|
|
158
|
+
@mlflow.trace()
|
|
159
|
+
def create_message(self, conversation_id, content):
|
|
160
|
+
resp = self.genie._api.do(
|
|
161
|
+
"POST",
|
|
162
|
+
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages",
|
|
163
|
+
body={"content": content},
|
|
164
|
+
headers=self.headers,
|
|
165
|
+
)
|
|
166
|
+
return resp
|
|
167
|
+
|
|
168
|
+
@mlflow.trace()
|
|
169
|
+
def poll_for_result(self, conversation_id, message_id):
|
|
170
|
+
@mlflow.trace()
|
|
171
|
+
def poll_query_results(attachment_id, query_str, description):
|
|
172
|
+
iteration_count = 0
|
|
173
|
+
while iteration_count < MAX_ITERATIONS:
|
|
174
|
+
iteration_count += 1
|
|
175
|
+
resp = self.genie._api.do(
|
|
176
|
+
"GET",
|
|
177
|
+
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}/attachments/{attachment_id}/query-result",
|
|
178
|
+
headers=self.headers,
|
|
179
|
+
)["statement_response"]
|
|
180
|
+
state = resp["status"]["state"]
|
|
181
|
+
if state == "SUCCEEDED":
|
|
182
|
+
result = _parse_query_result(resp, self.truncate_results)
|
|
183
|
+
return GenieResponse(
|
|
184
|
+
conversation_id, result, query_str, description
|
|
185
|
+
)
|
|
186
|
+
elif state in ["RUNNING", "PENDING"]:
|
|
187
|
+
logging.debug("Waiting for query result...")
|
|
188
|
+
time.sleep(self.poll_interval)
|
|
189
|
+
else:
|
|
190
|
+
return GenieResponse(
|
|
191
|
+
conversation_id,
|
|
192
|
+
f"No query result: {resp['state']}",
|
|
193
|
+
query_str,
|
|
194
|
+
description,
|
|
195
|
+
)
|
|
196
|
+
return GenieResponse(
|
|
197
|
+
conversation_id,
|
|
198
|
+
f"Genie query for result timed out after {MAX_ITERATIONS} iterations of {self.poll_interval} seconds",
|
|
199
|
+
query_str,
|
|
200
|
+
description,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
@mlflow.trace()
|
|
204
|
+
def poll_result():
|
|
205
|
+
iteration_count = 0
|
|
206
|
+
while iteration_count < MAX_ITERATIONS:
|
|
207
|
+
iteration_count += 1
|
|
208
|
+
resp = self.genie._api.do(
|
|
209
|
+
"GET",
|
|
210
|
+
f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}",
|
|
211
|
+
headers=self.headers,
|
|
212
|
+
)
|
|
213
|
+
if resp["status"] == "COMPLETED":
|
|
214
|
+
# Check if attachments key exists in response
|
|
215
|
+
attachments = resp.get("attachments", [])
|
|
216
|
+
if not attachments:
|
|
217
|
+
# Handle case where response has no attachments
|
|
218
|
+
return GenieResponse(
|
|
219
|
+
conversation_id,
|
|
220
|
+
result=f"Genie query completed but no attachments found. Response: {resp}",
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
attachment = next((r for r in attachments if "query" in r), None)
|
|
224
|
+
if attachment:
|
|
225
|
+
query_obj = attachment["query"]
|
|
226
|
+
description = query_obj.get("description", "")
|
|
227
|
+
query_str = query_obj.get("query", "")
|
|
228
|
+
attachment_id = attachment["attachment_id"]
|
|
229
|
+
return poll_query_results(attachment_id, query_str, description)
|
|
230
|
+
if resp["status"] == "COMPLETED":
|
|
231
|
+
text_content = next(
|
|
232
|
+
(r for r in attachments if "text" in r), None
|
|
233
|
+
)
|
|
234
|
+
if text_content:
|
|
235
|
+
return GenieResponse(
|
|
236
|
+
conversation_id, result=text_content["text"]["content"]
|
|
237
|
+
)
|
|
238
|
+
return GenieResponse(
|
|
239
|
+
conversation_id,
|
|
240
|
+
result="Genie query completed but no text content found in attachments.",
|
|
241
|
+
)
|
|
242
|
+
elif resp["status"] in {"CANCELLED", "QUERY_RESULT_EXPIRED"}:
|
|
243
|
+
return GenieResponse(
|
|
244
|
+
conversation_id, result=f"Genie query {resp['status'].lower()}."
|
|
245
|
+
)
|
|
246
|
+
elif resp["status"] == "FAILED":
|
|
247
|
+
return GenieResponse(
|
|
248
|
+
conversation_id,
|
|
249
|
+
result=f"Genie query failed with error: {resp.get('error', 'Unknown error')}",
|
|
250
|
+
)
|
|
251
|
+
# includes EXECUTING_QUERY, Genie can retry after this status
|
|
252
|
+
else:
|
|
253
|
+
logging.debug(f"Waiting...: {resp['status']}")
|
|
254
|
+
time.sleep(self.poll_interval)
|
|
255
|
+
return GenieResponse(
|
|
256
|
+
conversation_id,
|
|
257
|
+
f"Genie query timed out after {MAX_ITERATIONS} iterations of {self.poll_interval} seconds",
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
return poll_result()
|
|
261
|
+
|
|
262
|
+
@mlflow.trace()
|
|
263
|
+
def ask_question(self, question: str, conversation_id: str | None = None):
|
|
264
|
+
logger.debug(
|
|
265
|
+
f"ask_question called with question: {question}, conversation_id: {conversation_id}"
|
|
266
|
+
)
|
|
267
|
+
if conversation_id:
|
|
268
|
+
resp = self.create_message(conversation_id, question)
|
|
269
|
+
else:
|
|
270
|
+
resp = self.start_conversation(question)
|
|
271
|
+
logger.debug(f"ask_question response: {resp}")
|
|
272
|
+
return self.poll_for_result(resp["conversation_id"], resp["message_id"])
|
|
12
273
|
|
|
13
274
|
|
|
14
275
|
def create_genie_tool(
|
|
15
276
|
genie_room: GenieRoomModel | dict[str, Any],
|
|
16
277
|
name: Optional[str] = None,
|
|
17
278
|
description: Optional[str] = None,
|
|
279
|
+
persist_conversation: bool = True,
|
|
280
|
+
truncate_results: bool = False,
|
|
281
|
+
poll_interval: int = DEFAULT_POLLING_INTERVAL_SECS,
|
|
18
282
|
) -> Callable[[str], GenieResponse]:
|
|
19
283
|
"""
|
|
20
284
|
Create a tool for interacting with Databricks Genie for natural language queries to databases.
|
|
@@ -24,22 +288,33 @@ def create_genie_tool(
|
|
|
24
288
|
answering questions about inventory, sales, and other structured retail data.
|
|
25
289
|
|
|
26
290
|
Args:
|
|
27
|
-
|
|
28
|
-
|
|
291
|
+
genie_room: GenieRoomModel or dict containing Genie configuration
|
|
292
|
+
name: Optional custom name for the tool. If None, uses default "genie_tool"
|
|
293
|
+
description: Optional custom description for the tool. If None, uses default description
|
|
29
294
|
|
|
30
295
|
Returns:
|
|
31
|
-
A
|
|
296
|
+
A LangGraph tool that processes natural language queries through Genie
|
|
32
297
|
"""
|
|
33
298
|
|
|
34
299
|
if isinstance(genie_room, dict):
|
|
35
300
|
genie_room = GenieRoomModel(**genie_room)
|
|
36
301
|
|
|
37
|
-
space_id:
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
302
|
+
space_id: AnyVariable = genie_room.space_id or os.environ.get(
|
|
303
|
+
"DATABRICKS_GENIE_SPACE_ID"
|
|
304
|
+
)
|
|
305
|
+
space_id: AnyVariable = genie_room.space_id or os.environ.get(
|
|
306
|
+
"DATABRICKS_GENIE_SPACE_ID"
|
|
42
307
|
)
|
|
308
|
+
if isinstance(space_id, dict):
|
|
309
|
+
space_id = CompositeVariableModel(**space_id)
|
|
310
|
+
space_id = value_of(space_id)
|
|
311
|
+
|
|
312
|
+
# genie: Genie = Genie(
|
|
313
|
+
# space_id=space_id,
|
|
314
|
+
# client=genie_room.workspace_client,
|
|
315
|
+
# truncate_results=truncate_results,
|
|
316
|
+
# polling_interval=poll_interval,
|
|
317
|
+
# )
|
|
43
318
|
|
|
44
319
|
default_description: str = dedent("""
|
|
45
320
|
This tool lets you have a conversation and chat with tabular data about <topic>. You should ask
|
|
@@ -49,29 +324,66 @@ def create_genie_tool(
|
|
|
49
324
|
Prefer to call this tool multiple times rather than asking a complex question.
|
|
50
325
|
""")
|
|
51
326
|
|
|
52
|
-
|
|
53
|
-
description
|
|
327
|
+
tool_description: str = (
|
|
328
|
+
description if description is not None else default_description
|
|
329
|
+
)
|
|
330
|
+
tool_name: str = name if name is not None else "genie_tool"
|
|
54
331
|
|
|
55
|
-
|
|
56
|
-
Args:
|
|
57
|
-
question (str): The question to ask to ask Genie
|
|
332
|
+
function_docs = """
|
|
58
333
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
""")
|
|
334
|
+
Args:
|
|
335
|
+
question (str): The question to ask to ask Genie about your data. Ask simple, clear questions about your tabular data. For complex analysis, ask multiple simple questions rather than one complex question.
|
|
62
336
|
|
|
63
|
-
|
|
337
|
+
Returns:
|
|
338
|
+
GenieResponse: A response object containing the conversation ID and result from Genie."""
|
|
339
|
+
tool_description = tool_description + function_docs
|
|
64
340
|
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
341
|
+
@tool(
|
|
342
|
+
name_or_callable=tool_name,
|
|
343
|
+
description=tool_description,
|
|
344
|
+
)
|
|
345
|
+
def genie_tool(
|
|
346
|
+
question: Annotated[str, "The question to ask Genie about your data"],
|
|
347
|
+
state: Annotated[dict, InjectedState],
|
|
348
|
+
tool_call_id: Annotated[str, InjectedToolCallId],
|
|
349
|
+
) -> Command:
|
|
350
|
+
genie: Genie = Genie(
|
|
351
|
+
space_id=space_id,
|
|
352
|
+
client=genie_room.workspace_client,
|
|
353
|
+
truncate_results=truncate_results,
|
|
354
|
+
polling_interval=poll_interval,
|
|
355
|
+
)
|
|
70
356
|
|
|
71
|
-
|
|
357
|
+
"""Process a natural language question through Databricks Genie."""
|
|
358
|
+
# Get existing conversation mapping and retrieve conversation ID for this space
|
|
359
|
+
conversation_ids: dict[str, str] = state.get("genie_conversation_ids", {})
|
|
360
|
+
existing_conversation_id: str | None = conversation_ids.get(space_id)
|
|
361
|
+
logger.debug(
|
|
362
|
+
f"Existing conversation ID for space {space_id}: {existing_conversation_id}"
|
|
363
|
+
)
|
|
72
364
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
365
|
+
response: GenieResponse = genie.ask_question(
|
|
366
|
+
question, conversation_id=existing_conversation_id
|
|
367
|
+
)
|
|
368
|
+
|
|
369
|
+
current_conversation_id: str = response.conversation_id
|
|
370
|
+
logger.debug(
|
|
371
|
+
f"Current conversation ID for space {space_id}: {current_conversation_id}"
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
# Update the conversation mapping with the new conversation ID for this space
|
|
375
|
+
|
|
376
|
+
update: dict[str, Any] = {
|
|
377
|
+
"messages": [ToolMessage(response.to_json(), tool_call_id=tool_call_id)],
|
|
378
|
+
}
|
|
379
|
+
|
|
380
|
+
if persist_conversation:
|
|
381
|
+
updated_conversation_ids: dict[str, str] = conversation_ids.copy()
|
|
382
|
+
updated_conversation_ids[space_id] = current_conversation_id
|
|
383
|
+
update["genie_conversation_ids"] = updated_conversation_ids
|
|
384
|
+
|
|
385
|
+
logger.debug(f"State update: {update}")
|
|
386
|
+
|
|
387
|
+
return Command(update=update)
|
|
76
388
|
|
|
77
|
-
return
|
|
389
|
+
return genie_tool
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dao-ai
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.21
|
|
4
4
|
Summary: DAO AI: A modular, multi-agent orchestration framework for complex AI workflows. Supports agent handoff, tool integration, and dynamic configuration via YAML.
|
|
5
5
|
Project-URL: Homepage, https://github.com/natefleming/dao-ai
|
|
6
6
|
Project-URL: Documentation, https://natefleming.github.io/dao-ai
|
|
@@ -3,14 +3,14 @@ dao_ai/agent_as_code.py,sha256=kPSeDz2-1jRaed1TMs4LA3VECoyqe9_Ed2beRLB9gXQ,472
|
|
|
3
3
|
dao_ai/catalog.py,sha256=sPZpHTD3lPx4EZUtIWeQV7VQM89WJ6YH__wluk1v2lE,4947
|
|
4
4
|
dao_ai/chat_models.py,sha256=uhwwOTeLyHWqoTTgHrs4n5iSyTwe4EQcLKnh3jRxPWI,8626
|
|
5
5
|
dao_ai/cli.py,sha256=Aez2TQW3Q8Ho1IaIkRggt0NevDxAAVPjXkePC5GPJF0,20429
|
|
6
|
-
dao_ai/config.py,sha256=
|
|
6
|
+
dao_ai/config.py,sha256=GeaM00wNlYecwe3HhqeG88Hprt0SvGg4HtC7g_m-v98,52386
|
|
7
7
|
dao_ai/graph.py,sha256=gmD9mxODfXuvn9xWeBfewm1FiuVAWMLEdnZz7DNmSH0,7859
|
|
8
8
|
dao_ai/guardrails.py,sha256=4TKArDONRy8RwHzOT1plZ1rhy3x9GF_aeGpPCRl6wYA,4016
|
|
9
9
|
dao_ai/messages.py,sha256=xl_3-WcFqZKCFCiov8sZOPljTdM3gX3fCHhxq-xFg2U,7005
|
|
10
|
-
dao_ai/models.py,sha256=
|
|
10
|
+
dao_ai/models.py,sha256=8r8GIG3EGxtVyWsRNI56lVaBjiNrPkzh4HdwMZRq8iw,31689
|
|
11
11
|
dao_ai/nodes.py,sha256=SSuFNTXOdFaKg_aX-yUkQO7fM9wvNGu14lPXKDapU1U,8461
|
|
12
12
|
dao_ai/prompts.py,sha256=vpmIbWs_szXUgNNDs5Gh2LcxKZti5pHDKSfoClUcgX0,1289
|
|
13
|
-
dao_ai/state.py,sha256=
|
|
13
|
+
dao_ai/state.py,sha256=_lF9krAYYjvFDMUwZzVKOn0ZnXKcOrbjWKdre0C5B54,1137
|
|
14
14
|
dao_ai/types.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
15
15
|
dao_ai/utils.py,sha256=dkZTXNN6q0xwkrvSWdNq8937W2xGuLCRWRb6hRQM6kA,4217
|
|
16
16
|
dao_ai/vector_search.py,sha256=jlaFS_iizJ55wblgzZmswMM3UOL-qOp2BGJc0JqXYSg,2839
|
|
@@ -19,22 +19,22 @@ dao_ai/hooks/core.py,sha256=ZShHctUSoauhBgdf1cecy9-D7J6-sGn-pKjuRMumW5U,6663
|
|
|
19
19
|
dao_ai/memory/__init__.py,sha256=1kHx_p9abKYFQ6EYD05nuc1GS5HXVEpufmjBGw_7Uho,260
|
|
20
20
|
dao_ai/memory/base.py,sha256=99nfr2UZJ4jmfTL_KrqUlRSCoRxzkZyWyx5WqeUoMdQ,338
|
|
21
21
|
dao_ai/memory/core.py,sha256=g7chjBgVgx3iKjR2hghl0QL1j3802uIM_e7mgszur9M,4151
|
|
22
|
-
dao_ai/memory/postgres.py,sha256=
|
|
22
|
+
dao_ai/memory/postgres.py,sha256=aWHRLhPm-9ywjlQe2B4XSdLbeaiuVV88p4PiQJFNEWo,13924
|
|
23
23
|
dao_ai/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
24
24
|
dao_ai/providers/base.py,sha256=-fjKypCOk28h6vioPfMj9YZSw_3Kcbi2nMuAyY7vX9k,1383
|
|
25
|
-
dao_ai/providers/databricks.py,sha256=
|
|
25
|
+
dao_ai/providers/databricks.py,sha256=PX5mBvZaIxSJIAHWVnPXsho1XvxcoR3Qs3I9UavFRsY,28306
|
|
26
26
|
dao_ai/tools/__init__.py,sha256=ye6MHaJY7tUnJ8336YJiLxuZr55zDPNdOw6gm7j5jlc,1103
|
|
27
27
|
dao_ai/tools/agent.py,sha256=WbQnyziiT12TLMrA7xK0VuOU029tdmUBXbUl-R1VZ0Q,1886
|
|
28
28
|
dao_ai/tools/core.py,sha256=Kei33S8vrmvPOAyrFNekaWmV2jqZ-IPS1QDSvU7RZF0,1984
|
|
29
|
-
dao_ai/tools/genie.py,sha256=
|
|
29
|
+
dao_ai/tools/genie.py,sha256=1CbLViNQ3KnmDtHXuwqCPug7rEhCGvuHP1NgsY-AJZ0,15050
|
|
30
30
|
dao_ai/tools/human_in_the_loop.py,sha256=yk35MO9eNETnYFH-sqlgR-G24TrEgXpJlnZUustsLkI,3681
|
|
31
31
|
dao_ai/tools/mcp.py,sha256=auEt_dwv4J26fr5AgLmwmnAsI894-cyuvkvjItzAUxs,4419
|
|
32
32
|
dao_ai/tools/python.py,sha256=XcQiTMshZyLUTVR5peB3vqsoUoAAy8gol9_pcrhddfI,1831
|
|
33
33
|
dao_ai/tools/time.py,sha256=Y-23qdnNHzwjvnfkWvYsE7PoWS1hfeKy44tA7sCnNac,8759
|
|
34
34
|
dao_ai/tools/unity_catalog.py,sha256=uX_h52BuBAr4c9UeqSMI7DNz3BPRLeai5tBVW4sJqRI,13113
|
|
35
35
|
dao_ai/tools/vector_search.py,sha256=EDYQs51zIPaAP0ma1D81wJT77GQ-v-cjb2XrFVWfWdg,2621
|
|
36
|
-
dao_ai-0.0.
|
|
37
|
-
dao_ai-0.0.
|
|
38
|
-
dao_ai-0.0.
|
|
39
|
-
dao_ai-0.0.
|
|
40
|
-
dao_ai-0.0.
|
|
36
|
+
dao_ai-0.0.21.dist-info/METADATA,sha256=PG-eOltuUpaJf4lYEw-DoVy5BFT9LbMCfe8GanIV7zQ,41380
|
|
37
|
+
dao_ai-0.0.21.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
38
|
+
dao_ai-0.0.21.dist-info/entry_points.txt,sha256=Xa-UFyc6gWGwMqMJOt06ZOog2vAfygV_DSwg1AiP46g,43
|
|
39
|
+
dao_ai-0.0.21.dist-info/licenses/LICENSE,sha256=YZt3W32LtPYruuvHE9lGk2bw6ZPMMJD8yLrjgHybyz4,1069
|
|
40
|
+
dao_ai-0.0.21.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|