dao-ai 0.0.25__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/agent_as_code.py +5 -5
  3. dao_ai/cli.py +245 -40
  4. dao_ai/config.py +1863 -338
  5. dao_ai/genie/__init__.py +38 -0
  6. dao_ai/genie/cache/__init__.py +43 -0
  7. dao_ai/genie/cache/base.py +72 -0
  8. dao_ai/genie/cache/core.py +79 -0
  9. dao_ai/genie/cache/lru.py +347 -0
  10. dao_ai/genie/cache/semantic.py +970 -0
  11. dao_ai/genie/core.py +35 -0
  12. dao_ai/graph.py +27 -228
  13. dao_ai/hooks/__init__.py +9 -6
  14. dao_ai/hooks/core.py +27 -195
  15. dao_ai/logging.py +56 -0
  16. dao_ai/memory/__init__.py +10 -0
  17. dao_ai/memory/core.py +65 -30
  18. dao_ai/memory/databricks.py +402 -0
  19. dao_ai/memory/postgres.py +79 -38
  20. dao_ai/messages.py +6 -4
  21. dao_ai/middleware/__init__.py +125 -0
  22. dao_ai/middleware/assertions.py +806 -0
  23. dao_ai/middleware/base.py +50 -0
  24. dao_ai/middleware/core.py +67 -0
  25. dao_ai/middleware/guardrails.py +420 -0
  26. dao_ai/middleware/human_in_the_loop.py +232 -0
  27. dao_ai/middleware/message_validation.py +586 -0
  28. dao_ai/middleware/summarization.py +197 -0
  29. dao_ai/models.py +1306 -114
  30. dao_ai/nodes.py +261 -166
  31. dao_ai/optimization.py +674 -0
  32. dao_ai/orchestration/__init__.py +52 -0
  33. dao_ai/orchestration/core.py +294 -0
  34. dao_ai/orchestration/supervisor.py +278 -0
  35. dao_ai/orchestration/swarm.py +271 -0
  36. dao_ai/prompts.py +128 -31
  37. dao_ai/providers/databricks.py +645 -172
  38. dao_ai/state.py +157 -21
  39. dao_ai/tools/__init__.py +13 -5
  40. dao_ai/tools/agent.py +1 -3
  41. dao_ai/tools/core.py +64 -11
  42. dao_ai/tools/email.py +232 -0
  43. dao_ai/tools/genie.py +144 -295
  44. dao_ai/tools/mcp.py +220 -133
  45. dao_ai/tools/memory.py +50 -0
  46. dao_ai/tools/python.py +9 -14
  47. dao_ai/tools/search.py +14 -0
  48. dao_ai/tools/slack.py +22 -10
  49. dao_ai/tools/sql.py +202 -0
  50. dao_ai/tools/time.py +30 -7
  51. dao_ai/tools/unity_catalog.py +165 -88
  52. dao_ai/tools/vector_search.py +360 -40
  53. dao_ai/utils.py +218 -16
  54. dao_ai-0.1.2.dist-info/METADATA +455 -0
  55. dao_ai-0.1.2.dist-info/RECORD +64 -0
  56. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +1 -1
  57. dao_ai/chat_models.py +0 -204
  58. dao_ai/guardrails.py +0 -112
  59. dao_ai/tools/human_in_the_loop.py +0 -100
  60. dao_ai-0.0.25.dist-info/METADATA +0 -1165
  61. dao_ai-0.0.25.dist-info/RECORD +0 -41
  62. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
  63. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
dao_ai/tools/genie.py CHANGED
@@ -1,285 +1,73 @@
1
- import bisect
1
+ """
2
+ Genie tool for natural language queries to databases.
3
+
4
+ This module provides the tool factory for creating LangGraph tools that
5
+ interact with Databricks Genie.
6
+
7
+ For the core Genie service and cache implementations, see:
8
+ - dao_ai.genie: GenieService, GenieServiceBase
9
+ - dao_ai.genie.cache: LRUCacheService, SemanticCacheService
10
+ """
11
+
2
12
  import json
3
- import logging
4
13
  import os
5
- import time
6
- from dataclasses import asdict, dataclass
7
- from datetime import datetime
8
14
  from textwrap import dedent
9
- from typing import Annotated, Any, Callable, Optional, Union
15
+ from typing import Annotated, Any, Callable
10
16
 
11
- import mlflow
12
17
  import pandas as pd
13
- from databricks.sdk import WorkspaceClient
18
+ from databricks_ai_bridge.genie import Genie, GenieResponse
19
+ from langchain.tools import ToolRuntime, tool
14
20
  from langchain_core.messages import ToolMessage
15
- from langchain_core.tools import InjectedToolCallId, tool
16
- from langgraph.prebuilt import InjectedState
17
21
  from langgraph.types import Command
18
22
  from loguru import logger
19
- from pydantic import BaseModel, Field
20
-
21
- from dao_ai.config import AnyVariable, CompositeVariableModel, GenieRoomModel, value_of
22
-
23
- MAX_TOKENS_OF_DATA: int = 20000
24
- MAX_ITERATIONS: int = 50
25
- DEFAULT_POLLING_INTERVAL_SECS: int = 2
26
-
23
+ from pydantic import BaseModel
27
24
 
28
- def _count_tokens(text):
29
- import tiktoken
30
-
31
- encoding = tiktoken.encoding_for_model("gpt-4o")
32
- return len(encoding.encode(text))
33
-
34
-
35
- @dataclass
36
- class GenieResponse:
37
- conversation_id: str
38
- result: Union[str, pd.DataFrame]
39
- query: Optional[str] = ""
40
- description: Optional[str] = ""
41
-
42
- def to_json(self):
43
- return json.dumps(asdict(self))
25
+ from dao_ai.config import (
26
+ AnyVariable,
27
+ CompositeVariableModel,
28
+ GenieLRUCacheParametersModel,
29
+ GenieRoomModel,
30
+ GenieSemanticCacheParametersModel,
31
+ value_of,
32
+ )
33
+ from dao_ai.genie import GenieService, GenieServiceBase
34
+ from dao_ai.genie.cache import CacheResult, LRUCacheService, SemanticCacheService
35
+ from dao_ai.state import AgentState, Context, SessionState
44
36
 
45
37
 
46
38
  class GenieToolInput(BaseModel):
47
- """Input schema for the Genie tool."""
48
-
49
- question: str = Field(
50
- description="The question to ask Genie about your data. Ask simple, clear questions about your tabular data. For complex analysis, ask multiple simple questions rather than one complex question."
51
- )
39
+ """Input schema for Genie tool - only includes user-facing parameters."""
52
40
 
41
+ question: str
53
42
 
54
- def _truncate_result(dataframe: pd.DataFrame) -> str:
55
- query_result = dataframe.to_markdown()
56
- tokens_used = _count_tokens(query_result)
57
-
58
- # If the full result fits, return it
59
- if tokens_used <= MAX_TOKENS_OF_DATA:
60
- return query_result.strip()
61
-
62
- def is_too_big(n):
63
- return _count_tokens(dataframe.iloc[:n].to_markdown()) > MAX_TOKENS_OF_DATA
64
-
65
- # Use bisect_left to find the cutoff point of rows within the max token data limit in a O(log n) complexity
66
- # Passing True, as this is the target value we are looking for when _is_too_big returns
67
- cutoff = bisect.bisect_left(range(len(dataframe) + 1), True, key=is_too_big)
68
-
69
- # Slice to the found limit
70
- truncated_df = dataframe.iloc[:cutoff]
71
-
72
- # Edge case: Cannot return any rows because of tokens so return an empty string
73
- if len(truncated_df) == 0:
74
- return ""
75
-
76
- truncated_result = truncated_df.to_markdown()
77
-
78
- # Double-check edge case if we overshot by one
79
- if _count_tokens(truncated_result) > MAX_TOKENS_OF_DATA:
80
- truncated_result = truncated_df.iloc[:-1].to_markdown()
81
- return truncated_result
82
-
83
-
84
- @mlflow.trace(span_type="PARSER")
85
- def _parse_query_result(resp, truncate_results) -> Union[str, pd.DataFrame]:
86
- output = resp["result"]
87
- if not output:
88
- return "EMPTY"
89
-
90
- columns = resp["manifest"]["schema"]["columns"]
91
- header = [str(col["name"]) for col in columns]
92
- rows = []
93
-
94
- for item in output["data_array"]:
95
- row = []
96
- for column, value in zip(columns, item):
97
- type_name = column["type_name"]
98
- if value is None:
99
- row.append(None)
100
- continue
101
-
102
- if type_name in ["INT", "LONG", "SHORT", "BYTE"]:
103
- row.append(int(value))
104
- elif type_name in ["FLOAT", "DOUBLE", "DECIMAL"]:
105
- row.append(float(value))
106
- elif type_name == "BOOLEAN":
107
- row.append(value.lower() == "true")
108
- elif type_name == "DATE" or type_name == "TIMESTAMP":
109
- row.append(datetime.strptime(value[:10], "%Y-%m-%d").date())
110
- elif type_name == "BINARY":
111
- row.append(bytes(value, "utf-8"))
112
- else:
113
- row.append(value)
114
-
115
- rows.append(row)
116
-
117
- dataframe = pd.DataFrame(rows, columns=header)
118
-
119
- if truncate_results:
120
- query_result = _truncate_result(dataframe)
121
- else:
122
- query_result = dataframe.to_markdown()
123
-
124
- return query_result.strip()
125
-
126
-
127
- class Genie:
128
- def __init__(
129
- self,
130
- space_id,
131
- client: WorkspaceClient | None = None,
132
- truncate_results: bool = False,
133
- polling_interval: int = DEFAULT_POLLING_INTERVAL_SECS,
134
- ):
135
- self.space_id = space_id
136
- workspace_client = client or WorkspaceClient()
137
- self.genie = workspace_client.genie
138
- self.description = self.genie.get_space(space_id).description
139
- self.headers = {
140
- "Accept": "application/json",
141
- "Content-Type": "application/json",
142
- }
143
- self.truncate_results = truncate_results
144
- if polling_interval < 1 or polling_interval > 30:
145
- raise ValueError("poll_interval must be between 1 and 30 seconds")
146
- self.poll_interval = polling_interval
147
-
148
- @mlflow.trace()
149
- def start_conversation(self, content):
150
- resp = self.genie._api.do(
151
- "POST",
152
- f"/api/2.0/genie/spaces/{self.space_id}/start-conversation",
153
- body={"content": content},
154
- headers=self.headers,
155
- )
156
- return resp
157
-
158
- @mlflow.trace()
159
- def create_message(self, conversation_id, content):
160
- resp = self.genie._api.do(
161
- "POST",
162
- f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages",
163
- body={"content": content},
164
- headers=self.headers,
165
- )
166
- return resp
167
-
168
- @mlflow.trace()
169
- def poll_for_result(self, conversation_id, message_id):
170
- @mlflow.trace()
171
- def poll_query_results(attachment_id, query_str, description):
172
- iteration_count = 0
173
- while iteration_count < MAX_ITERATIONS:
174
- iteration_count += 1
175
- resp = self.genie._api.do(
176
- "GET",
177
- f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}/attachments/{attachment_id}/query-result",
178
- headers=self.headers,
179
- )["statement_response"]
180
- state = resp["status"]["state"]
181
- if state == "SUCCEEDED":
182
- result = _parse_query_result(resp, self.truncate_results)
183
- return GenieResponse(
184
- conversation_id, result, query_str, description
185
- )
186
- elif state in ["RUNNING", "PENDING"]:
187
- logging.debug("Waiting for query result...")
188
- time.sleep(self.poll_interval)
189
- else:
190
- return GenieResponse(
191
- conversation_id,
192
- f"No query result: {resp['state']}",
193
- query_str,
194
- description,
195
- )
196
- return GenieResponse(
197
- conversation_id,
198
- f"Genie query for result timed out after {MAX_ITERATIONS} iterations of {self.poll_interval} seconds",
199
- query_str,
200
- description,
201
- )
202
43
 
203
- @mlflow.trace()
204
- def poll_result():
205
- iteration_count = 0
206
- while iteration_count < MAX_ITERATIONS:
207
- iteration_count += 1
208
- resp = self.genie._api.do(
209
- "GET",
210
- f"/api/2.0/genie/spaces/{self.space_id}/conversations/{conversation_id}/messages/{message_id}",
211
- headers=self.headers,
212
- )
213
- if resp["status"] == "COMPLETED":
214
- # Check if attachments key exists in response
215
- attachments = resp.get("attachments", [])
216
- if not attachments:
217
- # Handle case where response has no attachments
218
- return GenieResponse(
219
- conversation_id,
220
- result=f"Genie query completed but no attachments found. Response: {resp}",
221
- )
222
-
223
- attachment = next((r for r in attachments if "query" in r), None)
224
- if attachment:
225
- query_obj = attachment["query"]
226
- description = query_obj.get("description", "")
227
- query_str = query_obj.get("query", "")
228
- attachment_id = attachment["attachment_id"]
229
- return poll_query_results(attachment_id, query_str, description)
230
- if resp["status"] == "COMPLETED":
231
- text_content = next(
232
- (r for r in attachments if "text" in r), None
233
- )
234
- if text_content:
235
- return GenieResponse(
236
- conversation_id, result=text_content["text"]["content"]
237
- )
238
- return GenieResponse(
239
- conversation_id,
240
- result="Genie query completed but no text content found in attachments.",
241
- )
242
- elif resp["status"] in {"CANCELLED", "QUERY_RESULT_EXPIRED"}:
243
- return GenieResponse(
244
- conversation_id, result=f"Genie query {resp['status'].lower()}."
245
- )
246
- elif resp["status"] == "FAILED":
247
- return GenieResponse(
248
- conversation_id,
249
- result=f"Genie query failed with error: {resp.get('error', 'Unknown error')}",
250
- )
251
- # includes EXECUTING_QUERY, Genie can retry after this status
252
- else:
253
- logging.debug(f"Waiting...: {resp['status']}")
254
- time.sleep(self.poll_interval)
255
- return GenieResponse(
256
- conversation_id,
257
- f"Genie query timed out after {MAX_ITERATIONS} iterations of {self.poll_interval} seconds",
258
- )
44
+ def _response_to_json(response: GenieResponse) -> str:
45
+ """Convert GenieResponse to JSON string, handling DataFrame results."""
46
+ # Convert result to string if it's a DataFrame
47
+ result: str | pd.DataFrame = response.result
48
+ if isinstance(result, pd.DataFrame):
49
+ result = result.to_markdown()
259
50
 
260
- return poll_result()
261
-
262
- @mlflow.trace()
263
- def ask_question(self, question: str, conversation_id: str | None = None):
264
- logger.debug(
265
- f"ask_question called with question: {question}, conversation_id: {conversation_id}"
266
- )
267
- if conversation_id:
268
- resp = self.create_message(conversation_id, question)
269
- else:
270
- resp = self.start_conversation(question)
271
- logger.debug(f"ask_question response: {resp}")
272
- return self.poll_for_result(resp["conversation_id"], resp["message_id"])
51
+ data: dict[str, Any] = {
52
+ "result": result,
53
+ "query": response.query,
54
+ "description": response.description,
55
+ "conversation_id": response.conversation_id,
56
+ }
57
+ return json.dumps(data)
273
58
 
274
59
 
275
60
  def create_genie_tool(
276
61
  genie_room: GenieRoomModel | dict[str, Any],
277
- name: Optional[str] = None,
278
- description: Optional[str] = None,
279
- persist_conversation: bool = False,
62
+ name: str | None = None,
63
+ description: str | None = None,
64
+ persist_conversation: bool = True,
280
65
  truncate_results: bool = False,
281
- poll_interval: int = DEFAULT_POLLING_INTERVAL_SECS,
282
- ) -> Callable[[str], GenieResponse]:
66
+ lru_cache_parameters: GenieLRUCacheParametersModel | dict[str, Any] | None = None,
67
+ semantic_cache_parameters: GenieSemanticCacheParametersModel
68
+ | dict[str, Any]
69
+ | None = None,
70
+ ) -> Callable[..., Command]:
283
71
  """
284
72
  Create a tool for interacting with Databricks Genie for natural language queries to databases.
285
73
 
@@ -291,17 +79,37 @@ def create_genie_tool(
291
79
  genie_room: GenieRoomModel or dict containing Genie configuration
292
80
  name: Optional custom name for the tool. If None, uses default "genie_tool"
293
81
  description: Optional custom description for the tool. If None, uses default description
82
+ persist_conversation: Whether to persist conversation IDs across tool calls for
83
+ multi-turn conversations within the same Genie space
84
+ truncate_results: Whether to truncate large query results to fit token limits
85
+ lru_cache_parameters: Optional LRU cache configuration for SQL query caching
86
+ semantic_cache_parameters: Optional semantic cache configuration using pg_vector
87
+ for similarity-based query matching
294
88
 
295
89
  Returns:
296
90
  A LangGraph tool that processes natural language queries through Genie
297
91
  """
92
+ logger.debug(
93
+ "Creating Genie tool",
94
+ genie_room_type=type(genie_room).__name__,
95
+ persist_conversation=persist_conversation,
96
+ truncate_results=truncate_results,
97
+ name=name,
98
+ has_lru_cache=lru_cache_parameters is not None,
99
+ has_semantic_cache=semantic_cache_parameters is not None,
100
+ )
298
101
 
299
102
  if isinstance(genie_room, dict):
300
103
  genie_room = GenieRoomModel(**genie_room)
301
104
 
302
- space_id: AnyVariable = genie_room.space_id or os.environ.get(
303
- "DATABRICKS_GENIE_SPACE_ID"
304
- )
105
+ if isinstance(lru_cache_parameters, dict):
106
+ lru_cache_parameters = GenieLRUCacheParametersModel(**lru_cache_parameters)
107
+
108
+ if isinstance(semantic_cache_parameters, dict):
109
+ semantic_cache_parameters = GenieSemanticCacheParametersModel(
110
+ **semantic_cache_parameters
111
+ )
112
+
305
113
  space_id: AnyVariable = genie_room.space_id or os.environ.get(
306
114
  "DATABRICKS_GENIE_SPACE_ID"
307
115
  )
@@ -309,13 +117,6 @@ def create_genie_tool(
309
117
  space_id = CompositeVariableModel(**space_id)
310
118
  space_id = value_of(space_id)
311
119
 
312
- # genie: Genie = Genie(
313
- # space_id=space_id,
314
- # client=genie_room.workspace_client,
315
- # truncate_results=truncate_results,
316
- # polling_interval=poll_interval,
317
- # )
318
-
319
120
  default_description: str = dedent("""
320
121
  This tool lets you have a conversation and chat with tabular data about <topic>. You should ask
321
122
  questions about the data and the tool will try to answer them.
@@ -338,51 +139,99 @@ Returns:
338
139
  GenieResponse: A response object containing the conversation ID and result from Genie."""
339
140
  tool_description = tool_description + function_docs
340
141
 
142
+ genie: Genie = Genie(
143
+ space_id=space_id,
144
+ client=genie_room.workspace_client,
145
+ truncate_results=truncate_results,
146
+ )
147
+
148
+ genie_service: GenieServiceBase = GenieService(genie)
149
+
150
+ # Wrap with semantic cache first (checked second due to decorator pattern)
151
+ if semantic_cache_parameters is not None:
152
+ genie_service = SemanticCacheService(
153
+ impl=genie_service,
154
+ parameters=semantic_cache_parameters,
155
+ workspace_client=genie_room.workspace_client, # Pass workspace client for conversation history
156
+ ).initialize() # Eagerly initialize to fail fast and create table
157
+
158
+ # Wrap with LRU cache last (checked first - fast O(1) exact match)
159
+ if lru_cache_parameters is not None:
160
+ genie_service = LRUCacheService(
161
+ impl=genie_service,
162
+ parameters=lru_cache_parameters,
163
+ )
164
+
341
165
  @tool(
342
166
  name_or_callable=tool_name,
343
167
  description=tool_description,
344
168
  )
345
169
  def genie_tool(
346
170
  question: Annotated[str, "The question to ask Genie about your data"],
347
- state: Annotated[dict, InjectedState],
348
- tool_call_id: Annotated[str, InjectedToolCallId],
171
+ runtime: ToolRuntime[Context, AgentState],
349
172
  ) -> Command:
350
- genie: Genie = Genie(
351
- space_id=space_id,
352
- client=genie_room.workspace_client,
353
- truncate_results=truncate_results,
354
- polling_interval=poll_interval,
355
- )
173
+ """Process a natural language question through Databricks Genie.
356
174
 
357
- """Process a natural language question through Databricks Genie."""
358
- # Get existing conversation mapping and retrieve conversation ID for this space
359
- conversation_ids: dict[str, str] = state.get("genie_conversation_ids", {})
360
- existing_conversation_id: str | None = conversation_ids.get(space_id)
361
- logger.debug(
362
- f"Existing conversation ID for space {space_id}: {existing_conversation_id}"
175
+ Uses ToolRuntime to access state and context in a type-safe way.
176
+ """
177
+ # Access state through runtime
178
+ state: AgentState = runtime.state
179
+ tool_call_id: str = runtime.tool_call_id
180
+
181
+ # Ensure space_id is a string for state keys
182
+ space_id_str: str = str(space_id)
183
+
184
+ # Get session state (or create new one)
185
+ session: SessionState = state.get("session", SessionState())
186
+
187
+ # Get existing conversation ID from session
188
+ existing_conversation_id: str | None = session.genie.get_conversation_id(
189
+ space_id_str
190
+ )
191
+ logger.trace(
192
+ "Using existing conversation ID",
193
+ space_id=space_id_str,
194
+ conversation_id=existing_conversation_id,
363
195
  )
364
196
 
365
- response: GenieResponse = genie.ask_question(
197
+ # Call ask_question which always returns CacheResult with cache metadata
198
+ cache_result: CacheResult = genie_service.ask_question(
366
199
  question, conversation_id=existing_conversation_id
367
200
  )
201
+ genie_response: GenieResponse = cache_result.response
202
+ cache_hit: bool = cache_result.cache_hit
203
+ cache_key: str | None = cache_result.served_by
368
204
 
369
- current_conversation_id: str = response.conversation_id
205
+ current_conversation_id: str = genie_response.conversation_id
370
206
  logger.debug(
371
- f"Current conversation ID for space {space_id}: {current_conversation_id}"
207
+ "Genie question answered",
208
+ space_id=space_id_str,
209
+ conversation_id=current_conversation_id,
210
+ cache_hit=cache_hit,
211
+ cache_key=cache_key,
372
212
  )
373
213
 
374
- # Update the conversation mapping with the new conversation ID for this space
214
+ # Update session state with cache information
215
+ if persist_conversation:
216
+ session.genie.update_space(
217
+ space_id=space_id_str,
218
+ conversation_id=current_conversation_id,
219
+ cache_hit=cache_hit,
220
+ cache_key=cache_key,
221
+ last_query=question,
222
+ )
375
223
 
224
+ # Build update dict with response and session
376
225
  update: dict[str, Any] = {
377
- "messages": [ToolMessage(response.to_json(), tool_call_id=tool_call_id)],
226
+ "messages": [
227
+ ToolMessage(
228
+ _response_to_json(genie_response), tool_call_id=tool_call_id
229
+ )
230
+ ],
378
231
  }
379
232
 
380
233
  if persist_conversation:
381
- updated_conversation_ids: dict[str, str] = conversation_ids.copy()
382
- updated_conversation_ids[space_id] = current_conversation_id
383
- update["genie_conversation_ids"] = updated_conversation_ids
384
-
385
- logger.debug(f"State update: {update}")
234
+ update["session"] = session
386
235
 
387
236
  return Command(update=update)
388
237