dao-ai 0.0.28__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/agent_as_code.py +2 -5
  3. dao_ai/cli.py +245 -40
  4. dao_ai/config.py +1491 -370
  5. dao_ai/genie/__init__.py +38 -0
  6. dao_ai/genie/cache/__init__.py +43 -0
  7. dao_ai/genie/cache/base.py +72 -0
  8. dao_ai/genie/cache/core.py +79 -0
  9. dao_ai/genie/cache/lru.py +347 -0
  10. dao_ai/genie/cache/semantic.py +970 -0
  11. dao_ai/genie/core.py +35 -0
  12. dao_ai/graph.py +27 -253
  13. dao_ai/hooks/__init__.py +9 -6
  14. dao_ai/hooks/core.py +27 -195
  15. dao_ai/logging.py +56 -0
  16. dao_ai/memory/__init__.py +10 -0
  17. dao_ai/memory/core.py +65 -30
  18. dao_ai/memory/databricks.py +402 -0
  19. dao_ai/memory/postgres.py +79 -38
  20. dao_ai/messages.py +6 -4
  21. dao_ai/middleware/__init__.py +125 -0
  22. dao_ai/middleware/assertions.py +806 -0
  23. dao_ai/middleware/base.py +50 -0
  24. dao_ai/middleware/core.py +67 -0
  25. dao_ai/middleware/guardrails.py +420 -0
  26. dao_ai/middleware/human_in_the_loop.py +232 -0
  27. dao_ai/middleware/message_validation.py +586 -0
  28. dao_ai/middleware/summarization.py +197 -0
  29. dao_ai/models.py +1306 -114
  30. dao_ai/nodes.py +245 -159
  31. dao_ai/optimization.py +674 -0
  32. dao_ai/orchestration/__init__.py +52 -0
  33. dao_ai/orchestration/core.py +294 -0
  34. dao_ai/orchestration/supervisor.py +278 -0
  35. dao_ai/orchestration/swarm.py +271 -0
  36. dao_ai/prompts.py +128 -31
  37. dao_ai/providers/databricks.py +573 -601
  38. dao_ai/state.py +157 -21
  39. dao_ai/tools/__init__.py +13 -5
  40. dao_ai/tools/agent.py +1 -3
  41. dao_ai/tools/core.py +64 -11
  42. dao_ai/tools/email.py +232 -0
  43. dao_ai/tools/genie.py +144 -294
  44. dao_ai/tools/mcp.py +223 -155
  45. dao_ai/tools/memory.py +50 -0
  46. dao_ai/tools/python.py +9 -14
  47. dao_ai/tools/search.py +14 -0
  48. dao_ai/tools/slack.py +22 -10
  49. dao_ai/tools/sql.py +202 -0
  50. dao_ai/tools/time.py +30 -7
  51. dao_ai/tools/unity_catalog.py +165 -88
  52. dao_ai/tools/vector_search.py +331 -221
  53. dao_ai/utils.py +166 -20
  54. dao_ai-0.1.2.dist-info/METADATA +455 -0
  55. dao_ai-0.1.2.dist-info/RECORD +64 -0
  56. dao_ai/chat_models.py +0 -204
  57. dao_ai/guardrails.py +0 -112
  58. dao_ai/tools/human_in_the_loop.py +0 -100
  59. dao_ai-0.0.28.dist-info/METADATA +0 -1168
  60. dao_ai-0.0.28.dist-info/RECORD +0 -41
  61. {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +0 -0
  62. {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
  63. {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
dao_ai/tools/genie.py CHANGED
@@ -1,284 +1,73 @@
1
- import bisect
1
+ """
2
+ Genie tool for natural language queries to databases.
3
+
4
+ This module provides the tool factory for creating LangGraph tools that
5
+ interact with Databricks Genie.
6
+
7
+ For the core Genie service and cache implementations, see:
8
+ - dao_ai.genie: GenieService, GenieServiceBase
9
+ - dao_ai.genie.cache: LRUCacheService, SemanticCacheService
10
+ """
11
+
2
12
  import json
3
13
  import os
4
- import time
5
- from dataclasses import asdict, dataclass
6
- from datetime import datetime
7
14
  from textwrap import dedent
8
- from typing import Annotated, Any, Callable, Optional, Union
15
+ from typing import Annotated, Any, Callable
9
16
 
10
- import mlflow
11
17
  import pandas as pd
12
- from databricks.sdk import WorkspaceClient
18
+ from databricks_ai_bridge.genie import Genie, GenieResponse
19
+ from langchain.tools import ToolRuntime, tool
13
20
  from langchain_core.messages import ToolMessage
14
- from langchain_core.tools import InjectedToolCallId, tool
15
- from langgraph.prebuilt import InjectedState
16
21
  from langgraph.types import Command
17
22
  from loguru import logger
18
- from pydantic import BaseModel, Field
19
-
20
- from dao_ai.config import AnyVariable, CompositeVariableModel, GenieRoomModel, value_of
21
-
22
- MAX_TOKENS_OF_DATA: int = 20000
23
- MAX_ITERATIONS: int = 50
24
- DEFAULT_POLLING_INTERVAL_SECS: int = 2
25
-
23
+ from pydantic import BaseModel
26
24
 
27
- def _count_tokens(text):
28
- import tiktoken
29
-
30
- encoding = tiktoken.encoding_for_model("gpt-4o")
31
- return len(encoding.encode(text))
32
-
33
-
34
- @dataclass
35
- class GenieResponse:
36
- conversation_id: str
37
- result: Union[str, pd.DataFrame]
38
- query: Optional[str] = ""
39
- description: Optional[str] = ""
40
-
41
- def to_json(self):
42
- return json.dumps(asdict(self))
25
+ from dao_ai.config import (
26
+ AnyVariable,
27
+ CompositeVariableModel,
28
+ GenieLRUCacheParametersModel,
29
+ GenieRoomModel,
30
+ GenieSemanticCacheParametersModel,
31
+ value_of,
32
+ )
33
+ from dao_ai.genie import GenieService, GenieServiceBase
34
+ from dao_ai.genie.cache import CacheResult, LRUCacheService, SemanticCacheService
35
+ from dao_ai.state import AgentState, Context, SessionState
43
36
 
44
37
 
45
38
  class GenieToolInput(BaseModel):
46
- """Input schema for the Genie tool."""
47
-
48
- question: str = Field(
49
- description="The question to ask Genie about your data. Ask simple, clear questions about your tabular data. For complex analysis, ask multiple simple questions rather than one complex question."
50
- )
39
+ """Input schema for Genie tool - only includes user-facing parameters."""
51
40
 
41
+ question: str
52
42
 
53
- def _truncate_result(dataframe: pd.DataFrame) -> str:
54
- query_result = dataframe.to_markdown()
55
- tokens_used = _count_tokens(query_result)
56
-
57
- # If the full result fits, return it
58
- if tokens_used <= MAX_TOKENS_OF_DATA:
59
- return query_result.strip()
60
-
61
- def is_too_big(n):
62
- return _count_tokens(dataframe.iloc[:n].to_markdown()) > MAX_TOKENS_OF_DATA
63
-
64
- # Use bisect_left to find the cutoff point of rows within the max token data limit in a O(log n) complexity
65
- # Passing True, as this is the target value we are looking for when _is_too_big returns
66
- cutoff = bisect.bisect_left(range(len(dataframe) + 1), True, key=is_too_big)
67
-
68
- # Slice to the found limit
69
- truncated_df = dataframe.iloc[:cutoff]
70
-
71
- # Edge case: Cannot return any rows because of tokens so return an empty string
72
- if len(truncated_df) == 0:
73
- return ""
74
-
75
- truncated_result = truncated_df.to_markdown()
76
-
77
- # Double-check edge case if we overshot by one
78
- if _count_tokens(truncated_result) > MAX_TOKENS_OF_DATA:
79
- truncated_result = truncated_df.iloc[:-1].to_markdown()
80
- return truncated_result
81
-
82
-
83
- @mlflow.trace(span_type="PARSER")
84
- def _parse_query_result(resp, truncate_results) -> Union[str, pd.DataFrame]:
85
- output = resp["result"]
86
- if not output:
87
- return "EMPTY"
88
-
89
- columns = resp["manifest"]["schema"]["columns"]
90
- header = [str(col["name"]) for col in columns]
91
- rows = []
92
-
93
- for item in output["data_array"]:
94
- row = []
95
- for column, value in zip(columns, item):
96
- type_name = column["type_name"]
97
- if value is None:
98
- row.append(None)
99
- continue
100
-
101
- if type_name in ["INT", "LONG", "SHORT", "BYTE"]:
102
- row.append(int(value))
103
- elif type_name in ["FLOAT", "DOUBLE", "DECIMAL"]:
104
- row.append(float(value))
105
- elif type_name == "BOOLEAN":
106
- row.append(value.lower() == "true")
107
- elif type_name == "DATE" or type_name == "TIMESTAMP":
108
- row.append(datetime.strptime(value[:10], "%Y-%m-%d").date())
109
- elif type_name == "BINARY":
110
- row.append(bytes(value, "utf-8"))
111
- else:
112
- row.append(value)
113
-
114
- rows.append(row)
115
-
116
- dataframe = pd.DataFrame(rows, columns=header)
117
-
118
- if truncate_results:
119
- query_result = _truncate_result(dataframe)
120
- else:
121
- query_result = dataframe.to_markdown()
122
-
123
- return query_result.strip()
124
-
125
-
126
- class Genie:
127
- def __init__(
128
- self,
129
- space_id,
130
- client: WorkspaceClient | None = None,
131
- truncate_results: bool = False,
132
- polling_interval: int = DEFAULT_POLLING_INTERVAL_SECS,
133
- ):
134
- self.space_id = space_id
135
- workspace_client = client or WorkspaceClient()
136
- self.genie = workspace_client.genie
137
- self.description = self.genie.get_space(space_id).description
138
- self.headers = {
139
- "Accept": "application/json",
140
- "Content-Type": "application/json",
141
- }
142
- self.truncate_results = truncate_results
143
- if polling_interval < 1 or polling_interval > 30:
144
- raise ValueError("poll_interval must be between 1 and 30 seconds")
145
- self.poll_interval = polling_interval
146
-
147
- @mlflow.trace()
148
- def start_conversation(self, content):
149
- resp = self.genie._api.do(
150
- "POST",
151
- f"/api/2.0/genie/spaces/{self.space_id}/start-conversation",
152
- body={"content": content},
153
- headers=self.headers,
154
- )
155
- return resp
156
-
157
- @mlflow.trace()
158
- def create_message(self, conversation_id, content):
159
- resp = self.genie._api.do(
160
- "POST",
161
- f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages",
162
- body={"content": content},
163
- headers=self.headers,
164
- )
165
- return resp
166
-
167
- @mlflow.trace()
168
- def poll_for_result(self, conversation_id, message_id):
169
- @mlflow.trace()
170
- def poll_query_results(attachment_id, query_str, description):
171
- iteration_count = 0
172
- while iteration_count < MAX_ITERATIONS:
173
- iteration_count += 1
174
- resp = self.genie._api.do(
175
- "GET",
176
- f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}/attachments/{attachment_id}/query-result",
177
- headers=self.headers,
178
- )["statement_response"]
179
- state = resp["status"]["state"]
180
- if state == "SUCCEEDED":
181
- result = _parse_query_result(resp, self.truncate_results)
182
- return GenieResponse(
183
- conversation_id, result, query_str, description
184
- )
185
- elif state in ["RUNNING", "PENDING"]:
186
- logger.debug("Waiting for query result...")
187
- time.sleep(self.poll_interval)
188
- else:
189
- return GenieResponse(
190
- conversation_id,
191
- f"No query result: {resp['state']}",
192
- query_str,
193
- description,
194
- )
195
- return GenieResponse(
196
- conversation_id,
197
- f"Genie query for result timed out after {MAX_ITERATIONS} iterations of {self.poll_interval} seconds",
198
- query_str,
199
- description,
200
- )
201
43
 
202
- @mlflow.trace()
203
- def poll_result():
204
- iteration_count = 0
205
- while iteration_count < MAX_ITERATIONS:
206
- iteration_count += 1
207
- resp = self.genie._api.do(
208
- "GET",
209
- f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}",
210
- headers=self.headers,
211
- )
212
- if resp["status"] == "COMPLETED":
213
- # Check if attachments key exists in response
214
- attachments = resp.get("attachments", [])
215
- if not attachments:
216
- # Handle case where response has no attachments
217
- return GenieResponse(
218
- conversation_id,
219
- result=f"Genie query completed but no attachments found. Response: {resp}",
220
- )
221
-
222
- attachment = next((r for r in attachments if "query" in r), None)
223
- if attachment:
224
- query_obj = attachment["query"]
225
- description = query_obj.get("description", "")
226
- query_str = query_obj.get("query", "")
227
- attachment_id = attachment["attachment_id"]
228
- return poll_query_results(attachment_id, query_str, description)
229
- if resp["status"] == "COMPLETED":
230
- text_content = next(
231
- (r for r in attachments if "text" in r), None
232
- )
233
- if text_content:
234
- return GenieResponse(
235
- conversation_id, result=text_content["text"]["content"]
236
- )
237
- return GenieResponse(
238
- conversation_id,
239
- result="Genie query completed but no text content found in attachments.",
240
- )
241
- elif resp["status"] in {"CANCELLED", "QUERY_RESULT_EXPIRED"}:
242
- return GenieResponse(
243
- conversation_id, result=f"Genie query {resp['status'].lower()}."
244
- )
245
- elif resp["status"] == "FAILED":
246
- return GenieResponse(
247
- conversation_id,
248
- result=f"Genie query failed with error: {resp.get('error', 'Unknown error')}",
249
- )
250
- # includes EXECUTING_QUERY, Genie can retry after this status
251
- else:
252
- logger.debug(f"Waiting...: {resp['status']}")
253
- time.sleep(self.poll_interval)
254
- return GenieResponse(
255
- conversation_id,
256
- f"Genie query timed out after {MAX_ITERATIONS} iterations of {self.poll_interval} seconds",
257
- )
44
+ def _response_to_json(response: GenieResponse) -> str:
45
+ """Convert GenieResponse to JSON string, handling DataFrame results."""
46
+ # Convert result to string if it's a DataFrame
47
+ result: str | pd.DataFrame = response.result
48
+ if isinstance(result, pd.DataFrame):
49
+ result = result.to_markdown()
258
50
 
259
- return poll_result()
260
-
261
- @mlflow.trace()
262
- def ask_question(self, question: str, conversation_id: str | None = None):
263
- logger.debug(
264
- f"ask_question called with question: {question}, conversation_id: {conversation_id}"
265
- )
266
- if conversation_id:
267
- resp = self.create_message(conversation_id, question)
268
- else:
269
- resp = self.start_conversation(question)
270
- logger.debug(f"ask_question response: {resp}")
271
- return self.poll_for_result(resp["conversation_id"], resp["message_id"])
51
+ data: dict[str, Any] = {
52
+ "result": result,
53
+ "query": response.query,
54
+ "description": response.description,
55
+ "conversation_id": response.conversation_id,
56
+ }
57
+ return json.dumps(data)
272
58
 
273
59
 
274
60
  def create_genie_tool(
275
61
  genie_room: GenieRoomModel | dict[str, Any],
276
- name: Optional[str] = None,
277
- description: Optional[str] = None,
278
- persist_conversation: bool = False,
62
+ name: str | None = None,
63
+ description: str | None = None,
64
+ persist_conversation: bool = True,
279
65
  truncate_results: bool = False,
280
- poll_interval: int = DEFAULT_POLLING_INTERVAL_SECS,
281
- ) -> Callable[[str], GenieResponse]:
66
+ lru_cache_parameters: GenieLRUCacheParametersModel | dict[str, Any] | None = None,
67
+ semantic_cache_parameters: GenieSemanticCacheParametersModel
68
+ | dict[str, Any]
69
+ | None = None,
70
+ ) -> Callable[..., Command]:
282
71
  """
283
72
  Create a tool for interacting with Databricks Genie for natural language queries to databases.
284
73
 
@@ -290,17 +79,37 @@ def create_genie_tool(
290
79
  genie_room: GenieRoomModel or dict containing Genie configuration
291
80
  name: Optional custom name for the tool. If None, uses default "genie_tool"
292
81
  description: Optional custom description for the tool. If None, uses default description
82
+ persist_conversation: Whether to persist conversation IDs across tool calls for
83
+ multi-turn conversations within the same Genie space
84
+ truncate_results: Whether to truncate large query results to fit token limits
85
+ lru_cache_parameters: Optional LRU cache configuration for SQL query caching
86
+ semantic_cache_parameters: Optional semantic cache configuration using pg_vector
87
+ for similarity-based query matching
293
88
 
294
89
  Returns:
295
90
  A LangGraph tool that processes natural language queries through Genie
296
91
  """
92
+ logger.debug(
93
+ "Creating Genie tool",
94
+ genie_room_type=type(genie_room).__name__,
95
+ persist_conversation=persist_conversation,
96
+ truncate_results=truncate_results,
97
+ name=name,
98
+ has_lru_cache=lru_cache_parameters is not None,
99
+ has_semantic_cache=semantic_cache_parameters is not None,
100
+ )
297
101
 
298
102
  if isinstance(genie_room, dict):
299
103
  genie_room = GenieRoomModel(**genie_room)
300
104
 
301
- space_id: AnyVariable = genie_room.space_id or os.environ.get(
302
- "DATABRICKS_GENIE_SPACE_ID"
303
- )
105
+ if isinstance(lru_cache_parameters, dict):
106
+ lru_cache_parameters = GenieLRUCacheParametersModel(**lru_cache_parameters)
107
+
108
+ if isinstance(semantic_cache_parameters, dict):
109
+ semantic_cache_parameters = GenieSemanticCacheParametersModel(
110
+ **semantic_cache_parameters
111
+ )
112
+
304
113
  space_id: AnyVariable = genie_room.space_id or os.environ.get(
305
114
  "DATABRICKS_GENIE_SPACE_ID"
306
115
  )
@@ -308,13 +117,6 @@ def create_genie_tool(
308
117
  space_id = CompositeVariableModel(**space_id)
309
118
  space_id = value_of(space_id)
310
119
 
311
- # genie: Genie = Genie(
312
- # space_id=space_id,
313
- # client=genie_room.workspace_client,
314
- # truncate_results=truncate_results,
315
- # polling_interval=poll_interval,
316
- # )
317
-
318
120
  default_description: str = dedent("""
319
121
  This tool lets you have a conversation and chat with tabular data about <topic>. You should ask
320
122
  questions about the data and the tool will try to answer them.
@@ -337,51 +139,99 @@ Returns:
337
139
  GenieResponse: A response object containing the conversation ID and result from Genie."""
338
140
  tool_description = tool_description + function_docs
339
141
 
142
+ genie: Genie = Genie(
143
+ space_id=space_id,
144
+ client=genie_room.workspace_client,
145
+ truncate_results=truncate_results,
146
+ )
147
+
148
+ genie_service: GenieServiceBase = GenieService(genie)
149
+
150
+ # Wrap with semantic cache first (checked second due to decorator pattern)
151
+ if semantic_cache_parameters is not None:
152
+ genie_service = SemanticCacheService(
153
+ impl=genie_service,
154
+ parameters=semantic_cache_parameters,
155
+ workspace_client=genie_room.workspace_client, # Pass workspace client for conversation history
156
+ ).initialize() # Eagerly initialize to fail fast and create table
157
+
158
+ # Wrap with LRU cache last (checked first - fast O(1) exact match)
159
+ if lru_cache_parameters is not None:
160
+ genie_service = LRUCacheService(
161
+ impl=genie_service,
162
+ parameters=lru_cache_parameters,
163
+ )
164
+
340
165
  @tool(
341
166
  name_or_callable=tool_name,
342
167
  description=tool_description,
343
168
  )
344
169
  def genie_tool(
345
170
  question: Annotated[str, "The question to ask Genie about your data"],
346
- state: Annotated[dict, InjectedState],
347
- tool_call_id: Annotated[str, InjectedToolCallId],
171
+ runtime: ToolRuntime[Context, AgentState],
348
172
  ) -> Command:
349
- genie: Genie = Genie(
350
- space_id=space_id,
351
- client=genie_room.workspace_client,
352
- truncate_results=truncate_results,
353
- polling_interval=poll_interval,
354
- )
173
+ """Process a natural language question through Databricks Genie.
355
174
 
356
- """Process a natural language question through Databricks Genie."""
357
- # Get existing conversation mapping and retrieve conversation ID for this space
358
- conversation_ids: dict[str, str] = state.get("genie_conversation_ids", {})
359
- existing_conversation_id: str | None = conversation_ids.get(space_id)
360
- logger.debug(
361
- f"Existing conversation ID for space {space_id}: {existing_conversation_id}"
175
+ Uses ToolRuntime to access state and context in a type-safe way.
176
+ """
177
+ # Access state through runtime
178
+ state: AgentState = runtime.state
179
+ tool_call_id: str = runtime.tool_call_id
180
+
181
+ # Ensure space_id is a string for state keys
182
+ space_id_str: str = str(space_id)
183
+
184
+ # Get session state (or create new one)
185
+ session: SessionState = state.get("session", SessionState())
186
+
187
+ # Get existing conversation ID from session
188
+ existing_conversation_id: str | None = session.genie.get_conversation_id(
189
+ space_id_str
190
+ )
191
+ logger.trace(
192
+ "Using existing conversation ID",
193
+ space_id=space_id_str,
194
+ conversation_id=existing_conversation_id,
362
195
  )
363
196
 
364
- response: GenieResponse = genie.ask_question(
197
+ # Call ask_question which always returns CacheResult with cache metadata
198
+ cache_result: CacheResult = genie_service.ask_question(
365
199
  question, conversation_id=existing_conversation_id
366
200
  )
201
+ genie_response: GenieResponse = cache_result.response
202
+ cache_hit: bool = cache_result.cache_hit
203
+ cache_key: str | None = cache_result.served_by
367
204
 
368
- current_conversation_id: str = response.conversation_id
205
+ current_conversation_id: str = genie_response.conversation_id
369
206
  logger.debug(
370
- f"Current conversation ID for space {space_id}: {current_conversation_id}"
207
+ "Genie question answered",
208
+ space_id=space_id_str,
209
+ conversation_id=current_conversation_id,
210
+ cache_hit=cache_hit,
211
+ cache_key=cache_key,
371
212
  )
372
213
 
373
- # Update the conversation mapping with the new conversation ID for this space
214
+ # Update session state with cache information
215
+ if persist_conversation:
216
+ session.genie.update_space(
217
+ space_id=space_id_str,
218
+ conversation_id=current_conversation_id,
219
+ cache_hit=cache_hit,
220
+ cache_key=cache_key,
221
+ last_query=question,
222
+ )
374
223
 
224
+ # Build update dict with response and session
375
225
  update: dict[str, Any] = {
376
- "messages": [ToolMessage(response.to_json(), tool_call_id=tool_call_id)],
226
+ "messages": [
227
+ ToolMessage(
228
+ _response_to_json(genie_response), tool_call_id=tool_call_id
229
+ )
230
+ ],
377
231
  }
378
232
 
379
233
  if persist_conversation:
380
- updated_conversation_ids: dict[str, str] = conversation_ids.copy()
381
- updated_conversation_ids[space_id] = current_conversation_id
382
- update["genie_conversation_ids"] = updated_conversation_ids
383
-
384
- logger.debug(f"State update: {update}")
234
+ update["session"] = session
385
235
 
386
236
  return Command(update=update)
387
237