dao-ai 0.0.20__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dao_ai/config.py CHANGED
@@ -427,7 +427,8 @@ class GenieRoomModel(BaseModel, IsDatabricksResource):
427
427
  def as_resources(self) -> Sequence[DatabricksResource]:
428
428
  return [
429
429
  DatabricksGenieSpace(
430
- genie_space_id=self.space_id, on_behalf_of_user=self.on_behalf_of_user
430
+ genie_space_id=value_of(self.space_id),
431
+ on_behalf_of_user=self.on_behalf_of_user,
431
432
  )
432
433
  ]
433
434
 
@@ -437,7 +438,7 @@ class GenieRoomModel(BaseModel, IsDatabricksResource):
437
438
  return self
438
439
 
439
440
 
440
- class VolumeModel(BaseModel, HasFullName):
441
+ class VolumeModel(BaseModel, HasFullName, IsDatabricksResource):
441
442
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
442
443
  schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
443
444
  name: str
@@ -455,6 +456,13 @@ class VolumeModel(BaseModel, HasFullName):
455
456
  provider: ServiceProvider = DatabricksProvider(w=w)
456
457
  provider.create_volume(self)
457
458
 
459
+ @property
460
+ def api_scopes(self) -> Sequence[str]:
461
+ return ["files.files", "catalog.volumes"]
462
+
463
+ def as_resources(self) -> Sequence[DatabricksResource]:
464
+ return []
465
+
458
466
 
459
467
  class VolumePathModel(BaseModel, HasFullName):
460
468
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
@@ -683,7 +691,8 @@ class WarehouseModel(BaseModel, IsDatabricksResource):
683
691
  def as_resources(self) -> Sequence[DatabricksResource]:
684
692
  return [
685
693
  DatabricksSQLWarehouse(
686
- warehouse_id=self.warehouse_id, on_behalf_of_user=self.on_behalf_of_user
694
+ warehouse_id=value_of(self.warehouse_id),
695
+ on_behalf_of_user=self.on_behalf_of_user,
687
696
  )
688
697
  ]
689
698
 
@@ -867,7 +876,7 @@ class PythonFunctionModel(BaseFunctionModel, HasFullName):
867
876
 
868
877
  class FactoryFunctionModel(BaseFunctionModel, HasFullName):
869
878
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
870
- args: Optional[dict[str, AnyVariable]] = Field(default_factory=dict)
879
+ args: Optional[dict[str, Any]] = Field(default_factory=dict)
871
880
  type: Literal[FunctionType.FACTORY] = FunctionType.FACTORY
872
881
 
873
882
  @property
@@ -879,6 +888,12 @@ class FactoryFunctionModel(BaseFunctionModel, HasFullName):
879
888
 
880
889
  return [create_factory_tool(self, **kwargs)]
881
890
 
891
+ @model_validator(mode="after")
892
+ def update_args(self):
893
+ for key, value in self.args.items():
894
+ self.args[key] = value_of(value)
895
+ return self
896
+
882
897
 
883
898
  class TransportType(str, Enum):
884
899
  STREAMABLE_HTTP = "streamable_http"
@@ -1262,12 +1277,12 @@ class AppModel(BaseModel):
1262
1277
  if len(self.agents) > 1:
1263
1278
  default_agent: AgentModel = self.agents[0]
1264
1279
  self.orchestration = OrchestrationModel(
1265
- swarm=SupervisorModel(model=default_agent.model)
1280
+ supervisor=SupervisorModel(model=default_agent.model)
1266
1281
  )
1267
1282
  elif len(self.agents) == 1:
1268
1283
  default_agent: AgentModel = self.agents[0]
1269
1284
  self.orchestration = OrchestrationModel(
1270
- supervisor=SwarmModel(
1285
+ swarm=SwarmModel(
1271
1286
  model=default_agent.model, default_agent=default_agent
1272
1287
  )
1273
1288
  )
dao_ai/memory/postgres.py CHANGED
@@ -74,8 +74,17 @@ class AsyncPostgresPoolManager:
74
74
  async with cls._lock:
75
75
  for connection_key, pool in cls._pools.items():
76
76
  try:
77
- await pool.close()
77
+ # Use a short timeout to avoid blocking on pool closure
78
+ await asyncio.wait_for(pool.close(), timeout=2.0)
78
79
  logger.debug(f"Closed PostgreSQL pool: {connection_key}")
80
+ except asyncio.TimeoutError:
81
+ logger.warning(
82
+ f"Timeout closing pool {connection_key}, forcing closure"
83
+ )
84
+ except asyncio.CancelledError:
85
+ logger.warning(
86
+ f"Pool closure cancelled for {connection_key} (shutdown in progress)"
87
+ )
79
88
  except Exception as e:
80
89
  logger.error(f"Error closing pool {connection_key}: {e}")
81
90
  cls._pools.clear()
@@ -369,8 +378,27 @@ def _shutdown_pools():
369
378
 
370
379
  def _shutdown_async_pools():
371
380
  try:
372
- asyncio.run(AsyncPostgresPoolManager.close_all_pools())
373
- logger.debug("Successfully closed all asynchronous PostgreSQL pools")
381
+ # Try to get the current event loop first
382
+ try:
383
+ loop = asyncio.get_running_loop()
384
+ # If we're already in an event loop, create a task
385
+ loop.create_task(AsyncPostgresPoolManager.close_all_pools())
386
+ logger.debug("Scheduled async pool closure in running event loop")
387
+ except RuntimeError:
388
+ # No running loop, try to get or create one
389
+ try:
390
+ loop = asyncio.get_event_loop()
391
+ if loop.is_closed():
392
+ # Loop is closed, create a new one
393
+ loop = asyncio.new_event_loop()
394
+ asyncio.set_event_loop(loop)
395
+ loop.run_until_complete(AsyncPostgresPoolManager.close_all_pools())
396
+ logger.debug("Successfully closed all asynchronous PostgreSQL pools")
397
+ except Exception as inner_e:
398
+ # If all else fails, just log the error
399
+ logger.warning(
400
+ f"Could not close async pools cleanly during shutdown: {inner_e}"
401
+ )
374
402
  except Exception as e:
375
403
  logger.error(
376
404
  f"Error closing asynchronous PostgreSQL pools during shutdown: {e}"
dao_ai/models.py CHANGED
@@ -5,6 +5,7 @@ from typing import Any, Generator, Optional, Sequence, Union
5
5
 
6
6
  from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
7
7
  from langgraph.graph.state import CompiledStateGraph
8
+ from langgraph.types import StateSnapshot
8
9
  from loguru import logger
9
10
  from mlflow import MlflowClient
10
11
  from mlflow.pyfunc import ChatAgent, ChatModel, ResponsesAgent
@@ -59,6 +60,113 @@ def get_latest_model_version(model_name: str) -> int:
59
60
  return latest_version
60
61
 
61
62
 
63
+ async def get_state_snapshot_async(
64
+ graph: CompiledStateGraph, thread_id: str
65
+ ) -> Optional[StateSnapshot]:
66
+ """
67
+ Retrieve the state snapshot from the graph for a given thread_id asynchronously.
68
+
69
+ This utility function accesses the graph's checkpointer to retrieve the current
70
+ state snapshot, which contains the full state values and metadata.
71
+
72
+ Args:
73
+ graph: The compiled LangGraph state machine
74
+ thread_id: The thread/conversation ID to retrieve state for
75
+
76
+ Returns:
77
+ StateSnapshot if found, None otherwise
78
+ """
79
+ logger.debug(f"Retrieving state snapshot for thread_id: {thread_id}")
80
+ try:
81
+ # Check if graph has a checkpointer
82
+ if graph.checkpointer is None:
83
+ logger.debug("No checkpointer available in graph")
84
+ return None
85
+
86
+ # Get the current state from the checkpointer (use async version)
87
+ config: dict[str, Any] = {"configurable": {"thread_id": thread_id}}
88
+ state_snapshot: Optional[StateSnapshot] = await graph.aget_state(config)
89
+
90
+ if state_snapshot is None:
91
+ logger.debug(f"No state found for thread_id: {thread_id}")
92
+ return None
93
+
94
+ return state_snapshot
95
+
96
+ except Exception as e:
97
+ logger.warning(f"Error retrieving state snapshot for thread {thread_id}: {e}")
98
+ return None
99
+
100
+
101
+ def get_state_snapshot(
102
+ graph: CompiledStateGraph, thread_id: str
103
+ ) -> Optional[StateSnapshot]:
104
+ """
105
+ Retrieve the state snapshot from the graph for a given thread_id.
106
+
107
+ This is a synchronous wrapper around get_state_snapshot_async.
108
+ Use this for backward compatibility in synchronous contexts.
109
+
110
+ Args:
111
+ graph: The compiled LangGraph state machine
112
+ thread_id: The thread/conversation ID to retrieve state for
113
+
114
+ Returns:
115
+ StateSnapshot if found, None otherwise
116
+ """
117
+ import asyncio
118
+
119
+ try:
120
+ loop = asyncio.get_event_loop()
121
+ except RuntimeError:
122
+ loop = asyncio.new_event_loop()
123
+ asyncio.set_event_loop(loop)
124
+
125
+ try:
126
+ return loop.run_until_complete(get_state_snapshot_async(graph, thread_id))
127
+ except Exception as e:
128
+ logger.warning(f"Error in synchronous state snapshot retrieval: {e}")
129
+ return None
130
+
131
+
132
+ def get_genie_conversation_ids_from_state(
133
+ state_snapshot: Optional[StateSnapshot],
134
+ ) -> dict[str, str]:
135
+ """
136
+ Extract genie_conversation_ids from a state snapshot.
137
+
138
+ This function extracts the genie_conversation_ids dictionary from the state
139
+ snapshot values if present.
140
+
141
+ Args:
142
+ state_snapshot: The state snapshot to extract conversation IDs from
143
+
144
+ Returns:
145
+ A dictionary mapping genie space_id to conversation_id, or empty dict if not found
146
+ """
147
+ if state_snapshot is None:
148
+ return {}
149
+
150
+ try:
151
+ # Extract state values - these contain the actual state data
152
+ state_values: dict[str, Any] = state_snapshot.values
153
+
154
+ # Extract genie_conversation_ids from state values
155
+ genie_conversation_ids: dict[str, str] = state_values.get(
156
+ "genie_conversation_ids", {}
157
+ )
158
+
159
+ if genie_conversation_ids:
160
+ logger.debug(f"Retrieved genie_conversation_ids: {genie_conversation_ids}")
161
+ return genie_conversation_ids
162
+
163
+ return {}
164
+
165
+ except Exception as e:
166
+ logger.warning(f"Error extracting genie_conversation_ids from state: {e}")
167
+ return {}
168
+
169
+
62
170
  class LanggraphChatModel(ChatModel):
63
171
  """
64
172
  ChatModel that delegates requests to a LangGraph CompiledStateGraph.
@@ -257,7 +365,19 @@ class LanggraphResponsesAgent(ResponsesAgent):
257
365
  text=last_message.content, id=f"msg_{uuid.uuid4().hex[:8]}"
258
366
  )
259
367
 
260
- custom_outputs = custom_inputs
368
+ # Retrieve genie_conversation_ids from state if available
369
+ custom_outputs: dict[str, Any] = custom_inputs.copy()
370
+ thread_id: Optional[str] = context.thread_id
371
+ if thread_id:
372
+ state_snapshot: Optional[StateSnapshot] = loop.run_until_complete(
373
+ get_state_snapshot_async(self.graph, thread_id)
374
+ )
375
+ genie_conversation_ids: dict[str, str] = (
376
+ get_genie_conversation_ids_from_state(state_snapshot)
377
+ )
378
+ if genie_conversation_ids:
379
+ custom_outputs["genie_conversation_ids"] = genie_conversation_ids
380
+
261
381
  return ResponsesAgentResponse(
262
382
  output=[output_item], custom_outputs=custom_outputs
263
383
  )
@@ -318,7 +438,22 @@ class LanggraphResponsesAgent(ResponsesAgent):
318
438
  **self.create_text_delta(delta=content, item_id=item_id)
319
439
  )
320
440
 
321
- custom_outputs = custom_inputs
441
+ # Retrieve genie_conversation_ids from state if available
442
+ custom_outputs: dict[str, Any] = custom_inputs.copy()
443
+ thread_id: Optional[str] = context.thread_id
444
+
445
+ if thread_id:
446
+ state_snapshot: Optional[
447
+ StateSnapshot
448
+ ] = await get_state_snapshot_async(self.graph, thread_id)
449
+ genie_conversation_ids: dict[str, str] = (
450
+ get_genie_conversation_ids_from_state(state_snapshot)
451
+ )
452
+ if genie_conversation_ids:
453
+ custom_outputs["genie_conversation_ids"] = (
454
+ genie_conversation_ids
455
+ )
456
+
322
457
  # Yield final output item
323
458
  yield ResponsesAgentStreamEvent(
324
459
  type="response.output_item.done",
@@ -226,6 +226,7 @@ class DatabricksProvider(ServiceProvider):
226
226
  config.resources.connections.values()
227
227
  )
228
228
  databases: Sequence[DatabaseModel] = list(config.resources.databases.values())
229
+ volumes: Sequence[VolumeModel] = list(config.resources.volumes.values())
229
230
 
230
231
  resources: Sequence[IsDatabricksResource] = (
231
232
  llms
@@ -236,6 +237,7 @@ class DatabricksProvider(ServiceProvider):
236
237
  + tables
237
238
  + connections
238
239
  + databases
240
+ + volumes
239
241
  )
240
242
 
241
243
  # Flatten all resources from all models into a single list
dao_ai/state.py CHANGED
@@ -31,6 +31,9 @@ class SharedState(MessagesState):
31
31
  is_valid: bool # message validation node
32
32
  message_error: str
33
33
 
34
+ # A mapping of genie space_id to conversation_id
35
+ genie_conversation_ids: dict[str, str] # Genie
36
+
34
37
 
35
38
  class Context(BaseModel):
36
39
  user_id: str | None = None
dao_ai/tools/genie.py CHANGED
@@ -1,20 +1,284 @@
1
+ import bisect
2
+ import json
3
+ import logging
1
4
  import os
5
+ import time
6
+ from dataclasses import asdict, dataclass
7
+ from datetime import datetime
2
8
  from textwrap import dedent
3
- from typing import Any, Callable, Optional
9
+ from typing import Annotated, Any, Callable, Optional, Union
4
10
 
5
- from databricks_ai_bridge.genie import GenieResponse
6
- from databricks_langchain.genie import Genie
7
- from langchain_core.tools import StructuredTool
11
+ import mlflow
12
+ import pandas as pd
13
+ from databricks.sdk import WorkspaceClient
14
+ from langchain_core.messages import ToolMessage
15
+ from langchain_core.tools import InjectedToolCallId, tool
16
+ from langgraph.prebuilt import InjectedState
17
+ from langgraph.types import Command
18
+ from loguru import logger
19
+ from pydantic import BaseModel, Field
8
20
 
9
- from dao_ai.config import (
10
- GenieRoomModel,
11
- )
21
+ from dao_ai.config import AnyVariable, CompositeVariableModel, GenieRoomModel, value_of
22
+
23
+ MAX_TOKENS_OF_DATA: int = 20000
24
+ MAX_ITERATIONS: int = 50
25
+ DEFAULT_POLLING_INTERVAL_SECS: int = 2
26
+
27
+
28
+ def _count_tokens(text):
29
+ import tiktoken
30
+
31
+ encoding = tiktoken.encoding_for_model("gpt-4o")
32
+ return len(encoding.encode(text))
33
+
34
+
35
+ @dataclass
36
+ class GenieResponse:
37
+ conversation_id: str
38
+ result: Union[str, pd.DataFrame]
39
+ query: Optional[str] = ""
40
+ description: Optional[str] = ""
41
+
42
+ def to_json(self):
43
+ return json.dumps(asdict(self))
44
+
45
+
46
+ class GenieToolInput(BaseModel):
47
+ """Input schema for the Genie tool."""
48
+
49
+ question: str = Field(
50
+ description="The question to ask Genie about your data. Ask simple, clear questions about your tabular data. For complex analysis, ask multiple simple questions rather than one complex question."
51
+ )
52
+
53
+
54
+ def _truncate_result(dataframe: pd.DataFrame) -> str:
55
+ query_result = dataframe.to_markdown()
56
+ tokens_used = _count_tokens(query_result)
57
+
58
+ # If the full result fits, return it
59
+ if tokens_used <= MAX_TOKENS_OF_DATA:
60
+ return query_result.strip()
61
+
62
+ def is_too_big(n):
63
+ return _count_tokens(dataframe.iloc[:n].to_markdown()) > MAX_TOKENS_OF_DATA
64
+
65
+ # Use bisect_left to find the cutoff point of rows within the max token data limit in a O(log n) complexity
66
+ # Passing True, as this is the target value we are looking for when _is_too_big returns
67
+ cutoff = bisect.bisect_left(range(len(dataframe) + 1), True, key=is_too_big)
68
+
69
+ # Slice to the found limit
70
+ truncated_df = dataframe.iloc[:cutoff]
71
+
72
+ # Edge case: Cannot return any rows because of tokens so return an empty string
73
+ if len(truncated_df) == 0:
74
+ return ""
75
+
76
+ truncated_result = truncated_df.to_markdown()
77
+
78
+ # Double-check edge case if we overshot by one
79
+ if _count_tokens(truncated_result) > MAX_TOKENS_OF_DATA:
80
+ truncated_result = truncated_df.iloc[:-1].to_markdown()
81
+ return truncated_result
82
+
83
+
84
+ @mlflow.trace(span_type="PARSER")
85
+ def _parse_query_result(resp, truncate_results) -> Union[str, pd.DataFrame]:
86
+ output = resp["result"]
87
+ if not output:
88
+ return "EMPTY"
89
+
90
+ columns = resp["manifest"]["schema"]["columns"]
91
+ header = [str(col["name"]) for col in columns]
92
+ rows = []
93
+
94
+ for item in output["data_array"]:
95
+ row = []
96
+ for column, value in zip(columns, item):
97
+ type_name = column["type_name"]
98
+ if value is None:
99
+ row.append(None)
100
+ continue
101
+
102
+ if type_name in ["INT", "LONG", "SHORT", "BYTE"]:
103
+ row.append(int(value))
104
+ elif type_name in ["FLOAT", "DOUBLE", "DECIMAL"]:
105
+ row.append(float(value))
106
+ elif type_name == "BOOLEAN":
107
+ row.append(value.lower() == "true")
108
+ elif type_name == "DATE" or type_name == "TIMESTAMP":
109
+ row.append(datetime.strptime(value[:10], "%Y-%m-%d").date())
110
+ elif type_name == "BINARY":
111
+ row.append(bytes(value, "utf-8"))
112
+ else:
113
+ row.append(value)
114
+
115
+ rows.append(row)
116
+
117
+ dataframe = pd.DataFrame(rows, columns=header)
118
+
119
+ if truncate_results:
120
+ query_result = _truncate_result(dataframe)
121
+ else:
122
+ query_result = dataframe.to_markdown()
123
+
124
+ return query_result.strip()
125
+
126
+
127
+ class Genie:
128
+ def __init__(
129
+ self,
130
+ space_id,
131
+ client: WorkspaceClient | None = None,
132
+ truncate_results: bool = False,
133
+ polling_interval: int = DEFAULT_POLLING_INTERVAL_SECS,
134
+ ):
135
+ self.space_id = space_id
136
+ workspace_client = client or WorkspaceClient()
137
+ self.genie = workspace_client.genie
138
+ self.description = self.genie.get_space(space_id).description
139
+ self.headers = {
140
+ "Accept": "application/json",
141
+ "Content-Type": "application/json",
142
+ }
143
+ self.truncate_results = truncate_results
144
+ if polling_interval < 1 or polling_interval > 30:
145
+ raise ValueError("poll_interval must be between 1 and 30 seconds")
146
+ self.poll_interval = polling_interval
147
+
148
+ @mlflow.trace()
149
+ def start_conversation(self, content):
150
+ resp = self.genie._api.do(
151
+ "POST",
152
+ f"/api/2.0/genie/spaces/{self.space_id}/start-conversation",
153
+ body={"content": content},
154
+ headers=self.headers,
155
+ )
156
+ return resp
157
+
158
+ @mlflow.trace()
159
+ def create_message(self, conversation_id, content):
160
+ resp = self.genie._api.do(
161
+ "POST",
162
+ f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages",
163
+ body={"content": content},
164
+ headers=self.headers,
165
+ )
166
+ return resp
167
+
168
+ @mlflow.trace()
169
+ def poll_for_result(self, conversation_id, message_id):
170
+ @mlflow.trace()
171
+ def poll_query_results(attachment_id, query_str, description):
172
+ iteration_count = 0
173
+ while iteration_count < MAX_ITERATIONS:
174
+ iteration_count += 1
175
+ resp = self.genie._api.do(
176
+ "GET",
177
+ f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}/attachments/{attachment_id}/query-result",
178
+ headers=self.headers,
179
+ )["statement_response"]
180
+ state = resp["status"]["state"]
181
+ if state == "SUCCEEDED":
182
+ result = _parse_query_result(resp, self.truncate_results)
183
+ return GenieResponse(
184
+ conversation_id, result, query_str, description
185
+ )
186
+ elif state in ["RUNNING", "PENDING"]:
187
+ logging.debug("Waiting for query result...")
188
+ time.sleep(self.poll_interval)
189
+ else:
190
+ return GenieResponse(
191
+ conversation_id,
192
+ f"No query result: {resp['state']}",
193
+ query_str,
194
+ description,
195
+ )
196
+ return GenieResponse(
197
+ conversation_id,
198
+ f"Genie query for result timed out after {MAX_ITERATIONS} iterations of {self.poll_interval} seconds",
199
+ query_str,
200
+ description,
201
+ )
202
+
203
+ @mlflow.trace()
204
+ def poll_result():
205
+ iteration_count = 0
206
+ while iteration_count < MAX_ITERATIONS:
207
+ iteration_count += 1
208
+ resp = self.genie._api.do(
209
+ "GET",
210
+ f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}",
211
+ headers=self.headers,
212
+ )
213
+ if resp["status"] == "COMPLETED":
214
+ # Check if attachments key exists in response
215
+ attachments = resp.get("attachments", [])
216
+ if not attachments:
217
+ # Handle case where response has no attachments
218
+ return GenieResponse(
219
+ conversation_id,
220
+ result=f"Genie query completed but no attachments found. Response: {resp}",
221
+ )
222
+
223
+ attachment = next((r for r in attachments if "query" in r), None)
224
+ if attachment:
225
+ query_obj = attachment["query"]
226
+ description = query_obj.get("description", "")
227
+ query_str = query_obj.get("query", "")
228
+ attachment_id = attachment["attachment_id"]
229
+ return poll_query_results(attachment_id, query_str, description)
230
+ if resp["status"] == "COMPLETED":
231
+ text_content = next(
232
+ (r for r in attachments if "text" in r), None
233
+ )
234
+ if text_content:
235
+ return GenieResponse(
236
+ conversation_id, result=text_content["text"]["content"]
237
+ )
238
+ return GenieResponse(
239
+ conversation_id,
240
+ result="Genie query completed but no text content found in attachments.",
241
+ )
242
+ elif resp["status"] in {"CANCELLED", "QUERY_RESULT_EXPIRED"}:
243
+ return GenieResponse(
244
+ conversation_id, result=f"Genie query {resp['status'].lower()}."
245
+ )
246
+ elif resp["status"] == "FAILED":
247
+ return GenieResponse(
248
+ conversation_id,
249
+ result=f"Genie query failed with error: {resp.get('error', 'Unknown error')}",
250
+ )
251
+ # includes EXECUTING_QUERY, Genie can retry after this status
252
+ else:
253
+ logging.debug(f"Waiting...: {resp['status']}")
254
+ time.sleep(self.poll_interval)
255
+ return GenieResponse(
256
+ conversation_id,
257
+ f"Genie query timed out after {MAX_ITERATIONS} iterations of {self.poll_interval} seconds",
258
+ )
259
+
260
+ return poll_result()
261
+
262
+ @mlflow.trace()
263
+ def ask_question(self, question: str, conversation_id: str | None = None):
264
+ logger.debug(
265
+ f"ask_question called with question: {question}, conversation_id: {conversation_id}"
266
+ )
267
+ if conversation_id:
268
+ resp = self.create_message(conversation_id, question)
269
+ else:
270
+ resp = self.start_conversation(question)
271
+ logger.debug(f"ask_question response: {resp}")
272
+ return self.poll_for_result(resp["conversation_id"], resp["message_id"])
12
273
 
13
274
 
14
275
  def create_genie_tool(
15
276
  genie_room: GenieRoomModel | dict[str, Any],
16
277
  name: Optional[str] = None,
17
278
  description: Optional[str] = None,
279
+ persist_conversation: bool = True,
280
+ truncate_results: bool = False,
281
+ poll_interval: int = DEFAULT_POLLING_INTERVAL_SECS,
18
282
  ) -> Callable[[str], GenieResponse]:
19
283
  """
20
284
  Create a tool for interacting with Databricks Genie for natural language queries to databases.
@@ -24,22 +288,33 @@ def create_genie_tool(
24
288
  answering questions about inventory, sales, and other structured retail data.
25
289
 
26
290
  Args:
27
- space_id: Databricks workspace ID where Genie is configured. If None, tries to
28
- get it from DATABRICKS_GENIE_SPACE_ID environment variable.
291
+ genie_room: GenieRoomModel or dict containing Genie configuration
292
+ name: Optional custom name for the tool. If None, uses default "genie_tool"
293
+ description: Optional custom description for the tool. If None, uses default description
29
294
 
30
295
  Returns:
31
- A callable tool function that processes natural language queries through Genie
296
+ A LangGraph tool that processes natural language queries through Genie
32
297
  """
33
298
 
34
299
  if isinstance(genie_room, dict):
35
300
  genie_room = GenieRoomModel(**genie_room)
36
301
 
37
- space_id: str = genie_room.space_id or os.environ.get("DATABRICKS_GENIE_SPACE_ID")
38
-
39
- genie: Genie = Genie(
40
- space_id=space_id,
41
- client=genie_room.workspace_client,
302
+ space_id: AnyVariable = genie_room.space_id or os.environ.get(
303
+ "DATABRICKS_GENIE_SPACE_ID"
304
+ )
305
+ space_id: AnyVariable = genie_room.space_id or os.environ.get(
306
+ "DATABRICKS_GENIE_SPACE_ID"
42
307
  )
308
+ if isinstance(space_id, dict):
309
+ space_id = CompositeVariableModel(**space_id)
310
+ space_id = value_of(space_id)
311
+
312
+ # genie: Genie = Genie(
313
+ # space_id=space_id,
314
+ # client=genie_room.workspace_client,
315
+ # truncate_results=truncate_results,
316
+ # polling_interval=poll_interval,
317
+ # )
43
318
 
44
319
  default_description: str = dedent("""
45
320
  This tool lets you have a conversation and chat with tabular data about <topic>. You should ask
@@ -49,29 +324,66 @@ def create_genie_tool(
49
324
  Prefer to call this tool multiple times rather than asking a complex question.
50
325
  """)
51
326
 
52
- if description is None:
53
- description = default_description
327
+ tool_description: str = (
328
+ description if description is not None else default_description
329
+ )
330
+ tool_name: str = name if name is not None else "genie_tool"
54
331
 
55
- doc_signature: str = dedent("""
56
- Args:
57
- question (str): The question to ask to ask Genie
332
+ function_docs = """
58
333
 
59
- Returns:
60
- response (GenieResponse): An object containing the Genie response
61
- """)
334
+ Args:
335
+ question (str): The question to ask to ask Genie about your data. Ask simple, clear questions about your tabular data. For complex analysis, ask multiple simple questions rather than one complex question.
62
336
 
63
- doc: str = description + "\n" + doc_signature
337
+ Returns:
338
+ GenieResponse: A response object containing the conversation ID and result from Genie."""
339
+ tool_description = tool_description + function_docs
64
340
 
65
- async def genie_tool(question: str) -> GenieResponse:
66
- # Use sync API for now since Genie doesn't support async yet
67
- # Can be easily updated to await when Genie gets async support
68
- response: GenieResponse = genie.ask_question(question)
69
- return response
341
+ @tool(
342
+ name_or_callable=tool_name,
343
+ description=tool_description,
344
+ )
345
+ def genie_tool(
346
+ question: Annotated[str, "The question to ask Genie about your data"],
347
+ state: Annotated[dict, InjectedState],
348
+ tool_call_id: Annotated[str, InjectedToolCallId],
349
+ ) -> Command:
350
+ genie: Genie = Genie(
351
+ space_id=space_id,
352
+ client=genie_room.workspace_client,
353
+ truncate_results=truncate_results,
354
+ polling_interval=poll_interval,
355
+ )
70
356
 
71
- name: str = name if name else genie_tool.__name__
357
+ """Process a natural language question through Databricks Genie."""
358
+ # Get existing conversation mapping and retrieve conversation ID for this space
359
+ conversation_ids: dict[str, str] = state.get("genie_conversation_ids", {})
360
+ existing_conversation_id: str | None = conversation_ids.get(space_id)
361
+ logger.debug(
362
+ f"Existing conversation ID for space {space_id}: {existing_conversation_id}"
363
+ )
72
364
 
73
- structured_tool: StructuredTool = StructuredTool.from_function(
74
- coroutine=genie_tool, name=name, description=doc, parse_docstring=False
75
- )
365
+ response: GenieResponse = genie.ask_question(
366
+ question, conversation_id=existing_conversation_id
367
+ )
368
+
369
+ current_conversation_id: str = response.conversation_id
370
+ logger.debug(
371
+ f"Current conversation ID for space {space_id}: {current_conversation_id}"
372
+ )
373
+
374
+ # Update the conversation mapping with the new conversation ID for this space
375
+
376
+ update: dict[str, Any] = {
377
+ "messages": [ToolMessage(response.to_json(), tool_call_id=tool_call_id)],
378
+ }
379
+
380
+ if persist_conversation:
381
+ updated_conversation_ids: dict[str, str] = conversation_ids.copy()
382
+ updated_conversation_ids[space_id] = current_conversation_id
383
+ update["genie_conversation_ids"] = updated_conversation_ids
384
+
385
+ logger.debug(f"State update: {update}")
386
+
387
+ return Command(update=update)
76
388
 
77
- return structured_tool
389
+ return genie_tool
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dao-ai
3
- Version: 0.0.20
3
+ Version: 0.0.21
4
4
  Summary: DAO AI: A modular, multi-agent orchestration framework for complex AI workflows. Supports agent handoff, tool integration, and dynamic configuration via YAML.
5
5
  Project-URL: Homepage, https://github.com/natefleming/dao-ai
6
6
  Project-URL: Documentation, https://natefleming.github.io/dao-ai
@@ -3,14 +3,14 @@ dao_ai/agent_as_code.py,sha256=kPSeDz2-1jRaed1TMs4LA3VECoyqe9_Ed2beRLB9gXQ,472
3
3
  dao_ai/catalog.py,sha256=sPZpHTD3lPx4EZUtIWeQV7VQM89WJ6YH__wluk1v2lE,4947
4
4
  dao_ai/chat_models.py,sha256=uhwwOTeLyHWqoTTgHrs4n5iSyTwe4EQcLKnh3jRxPWI,8626
5
5
  dao_ai/cli.py,sha256=Aez2TQW3Q8Ho1IaIkRggt0NevDxAAVPjXkePC5GPJF0,20429
6
- dao_ai/config.py,sha256=ZO5ei45gnhqg1BtD0R9aekJz4ClmiTw2GHhOk4Idil4,51958
6
+ dao_ai/config.py,sha256=GeaM00wNlYecwe3HhqeG88Hprt0SvGg4HtC7g_m-v98,52386
7
7
  dao_ai/graph.py,sha256=gmD9mxODfXuvn9xWeBfewm1FiuVAWMLEdnZz7DNmSH0,7859
8
8
  dao_ai/guardrails.py,sha256=4TKArDONRy8RwHzOT1plZ1rhy3x9GF_aeGpPCRl6wYA,4016
9
9
  dao_ai/messages.py,sha256=xl_3-WcFqZKCFCiov8sZOPljTdM3gX3fCHhxq-xFg2U,7005
10
- dao_ai/models.py,sha256=Xb23U-lhDG8KyNRIijcJ4InluadlaGNy4rrYx7Cjgfg,26939
10
+ dao_ai/models.py,sha256=8r8GIG3EGxtVyWsRNI56lVaBjiNrPkzh4HdwMZRq8iw,31689
11
11
  dao_ai/nodes.py,sha256=SSuFNTXOdFaKg_aX-yUkQO7fM9wvNGu14lPXKDapU1U,8461
12
12
  dao_ai/prompts.py,sha256=vpmIbWs_szXUgNNDs5Gh2LcxKZti5pHDKSfoClUcgX0,1289
13
- dao_ai/state.py,sha256=GwbMbd1TWZx1T5iQrEOX6_rpxOitlmyeJ8dMr2o_pag,1031
13
+ dao_ai/state.py,sha256=_lF9krAYYjvFDMUwZzVKOn0ZnXKcOrbjWKdre0C5B54,1137
14
14
  dao_ai/types.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
15
  dao_ai/utils.py,sha256=dkZTXNN6q0xwkrvSWdNq8937W2xGuLCRWRb6hRQM6kA,4217
16
16
  dao_ai/vector_search.py,sha256=jlaFS_iizJ55wblgzZmswMM3UOL-qOp2BGJc0JqXYSg,2839
@@ -19,22 +19,22 @@ dao_ai/hooks/core.py,sha256=ZShHctUSoauhBgdf1cecy9-D7J6-sGn-pKjuRMumW5U,6663
19
19
  dao_ai/memory/__init__.py,sha256=1kHx_p9abKYFQ6EYD05nuc1GS5HXVEpufmjBGw_7Uho,260
20
20
  dao_ai/memory/base.py,sha256=99nfr2UZJ4jmfTL_KrqUlRSCoRxzkZyWyx5WqeUoMdQ,338
21
21
  dao_ai/memory/core.py,sha256=g7chjBgVgx3iKjR2hghl0QL1j3802uIM_e7mgszur9M,4151
22
- dao_ai/memory/postgres.py,sha256=pxxMjGotgqjrKhx0lVR3EAjSZTQgBpiPZOB0-cyjprc,12505
22
+ dao_ai/memory/postgres.py,sha256=aWHRLhPm-9ywjlQe2B4XSdLbeaiuVV88p4PiQJFNEWo,13924
23
23
  dao_ai/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
24
24
  dao_ai/providers/base.py,sha256=-fjKypCOk28h6vioPfMj9YZSw_3Kcbi2nMuAyY7vX9k,1383
25
- dao_ai/providers/databricks.py,sha256=fZ8mGotfA3W3t5yUej2xGmGHSybjBFYr895mOctT418,28203
25
+ dao_ai/providers/databricks.py,sha256=PX5mBvZaIxSJIAHWVnPXsho1XvxcoR3Qs3I9UavFRsY,28306
26
26
  dao_ai/tools/__init__.py,sha256=ye6MHaJY7tUnJ8336YJiLxuZr55zDPNdOw6gm7j5jlc,1103
27
27
  dao_ai/tools/agent.py,sha256=WbQnyziiT12TLMrA7xK0VuOU029tdmUBXbUl-R1VZ0Q,1886
28
28
  dao_ai/tools/core.py,sha256=Kei33S8vrmvPOAyrFNekaWmV2jqZ-IPS1QDSvU7RZF0,1984
29
- dao_ai/tools/genie.py,sha256=GzV5lfDYKmzW_lSLxAsPaTwnzX6GxQOB1UcLaTDqpfY,2787
29
+ dao_ai/tools/genie.py,sha256=1CbLViNQ3KnmDtHXuwqCPug7rEhCGvuHP1NgsY-AJZ0,15050
30
30
  dao_ai/tools/human_in_the_loop.py,sha256=yk35MO9eNETnYFH-sqlgR-G24TrEgXpJlnZUustsLkI,3681
31
31
  dao_ai/tools/mcp.py,sha256=auEt_dwv4J26fr5AgLmwmnAsI894-cyuvkvjItzAUxs,4419
32
32
  dao_ai/tools/python.py,sha256=XcQiTMshZyLUTVR5peB3vqsoUoAAy8gol9_pcrhddfI,1831
33
33
  dao_ai/tools/time.py,sha256=Y-23qdnNHzwjvnfkWvYsE7PoWS1hfeKy44tA7sCnNac,8759
34
34
  dao_ai/tools/unity_catalog.py,sha256=uX_h52BuBAr4c9UeqSMI7DNz3BPRLeai5tBVW4sJqRI,13113
35
35
  dao_ai/tools/vector_search.py,sha256=EDYQs51zIPaAP0ma1D81wJT77GQ-v-cjb2XrFVWfWdg,2621
36
- dao_ai-0.0.20.dist-info/METADATA,sha256=gWNRLhswz5sCe1vxbBQ6dGlgiObI9nI829Q5DQRqRRY,41380
37
- dao_ai-0.0.20.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
38
- dao_ai-0.0.20.dist-info/entry_points.txt,sha256=Xa-UFyc6gWGwMqMJOt06ZOog2vAfygV_DSwg1AiP46g,43
39
- dao_ai-0.0.20.dist-info/licenses/LICENSE,sha256=YZt3W32LtPYruuvHE9lGk2bw6ZPMMJD8yLrjgHybyz4,1069
40
- dao_ai-0.0.20.dist-info/RECORD,,
36
+ dao_ai-0.0.21.dist-info/METADATA,sha256=PG-eOltuUpaJf4lYEw-DoVy5BFT9LbMCfe8GanIV7zQ,41380
37
+ dao_ai-0.0.21.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
38
+ dao_ai-0.0.21.dist-info/entry_points.txt,sha256=Xa-UFyc6gWGwMqMJOt06ZOog2vAfygV_DSwg1AiP46g,43
39
+ dao_ai-0.0.21.dist-info/licenses/LICENSE,sha256=YZt3W32LtPYruuvHE9lGk2bw6ZPMMJD8yLrjgHybyz4,1069
40
+ dao_ai-0.0.21.dist-info/RECORD,,