dao-ai 0.0.25__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/agent_as_code.py +5 -5
  3. dao_ai/cli.py +245 -40
  4. dao_ai/config.py +1863 -338
  5. dao_ai/genie/__init__.py +38 -0
  6. dao_ai/genie/cache/__init__.py +43 -0
  7. dao_ai/genie/cache/base.py +72 -0
  8. dao_ai/genie/cache/core.py +79 -0
  9. dao_ai/genie/cache/lru.py +347 -0
  10. dao_ai/genie/cache/semantic.py +970 -0
  11. dao_ai/genie/core.py +35 -0
  12. dao_ai/graph.py +27 -228
  13. dao_ai/hooks/__init__.py +9 -6
  14. dao_ai/hooks/core.py +27 -195
  15. dao_ai/logging.py +56 -0
  16. dao_ai/memory/__init__.py +10 -0
  17. dao_ai/memory/core.py +65 -30
  18. dao_ai/memory/databricks.py +402 -0
  19. dao_ai/memory/postgres.py +79 -38
  20. dao_ai/messages.py +6 -4
  21. dao_ai/middleware/__init__.py +125 -0
  22. dao_ai/middleware/assertions.py +806 -0
  23. dao_ai/middleware/base.py +50 -0
  24. dao_ai/middleware/core.py +67 -0
  25. dao_ai/middleware/guardrails.py +420 -0
  26. dao_ai/middleware/human_in_the_loop.py +232 -0
  27. dao_ai/middleware/message_validation.py +586 -0
  28. dao_ai/middleware/summarization.py +197 -0
  29. dao_ai/models.py +1306 -114
  30. dao_ai/nodes.py +261 -166
  31. dao_ai/optimization.py +674 -0
  32. dao_ai/orchestration/__init__.py +52 -0
  33. dao_ai/orchestration/core.py +294 -0
  34. dao_ai/orchestration/supervisor.py +278 -0
  35. dao_ai/orchestration/swarm.py +271 -0
  36. dao_ai/prompts.py +128 -31
  37. dao_ai/providers/databricks.py +645 -172
  38. dao_ai/state.py +157 -21
  39. dao_ai/tools/__init__.py +13 -5
  40. dao_ai/tools/agent.py +1 -3
  41. dao_ai/tools/core.py +64 -11
  42. dao_ai/tools/email.py +232 -0
  43. dao_ai/tools/genie.py +144 -295
  44. dao_ai/tools/mcp.py +220 -133
  45. dao_ai/tools/memory.py +50 -0
  46. dao_ai/tools/python.py +9 -14
  47. dao_ai/tools/search.py +14 -0
  48. dao_ai/tools/slack.py +22 -10
  49. dao_ai/tools/sql.py +202 -0
  50. dao_ai/tools/time.py +30 -7
  51. dao_ai/tools/unity_catalog.py +165 -88
  52. dao_ai/tools/vector_search.py +360 -40
  53. dao_ai/utils.py +218 -16
  54. dao_ai-0.1.2.dist-info/METADATA +455 -0
  55. dao_ai-0.1.2.dist-info/RECORD +64 -0
  56. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +1 -1
  57. dao_ai/chat_models.py +0 -204
  58. dao_ai/guardrails.py +0 -112
  59. dao_ai/tools/human_in_the_loop.py +0 -100
  60. dao_ai-0.0.25.dist-info/METADATA +0 -1165
  61. dao_ai-0.0.25.dist-info/RECORD +0 -41
  62. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
  63. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
@@ -1,72 +1,392 @@
1
- from typing import Any, Optional, Sequence
1
+ """
2
+ Vector search tool for retrieving documents from Databricks Vector Search.
3
+
4
+ This module provides a tool factory for creating semantic search tools
5
+ with dynamic filter schemas based on table columns and FlashRank reranking support.
6
+ """
7
+
8
+ import json
9
+ import os
10
+ from typing import Any, Optional
2
11
 
3
12
  import mlflow
4
- from databricks_langchain.vector_search_retriever_tool import VectorSearchRetrieverTool
5
- from langchain_core.tools import BaseTool
13
+ from databricks.sdk import WorkspaceClient
14
+ from databricks.vector_search.reranker import DatabricksReranker
15
+ from databricks_langchain import DatabricksVectorSearch
16
+ from flashrank import Ranker, RerankRequest
17
+ from langchain_core.documents import Document
18
+ from langchain_core.tools import StructuredTool
19
+ from loguru import logger
20
+ from mlflow.entities import SpanType
21
+ from pydantic import BaseModel, ConfigDict, Field, create_model
6
22
 
7
23
  from dao_ai.config import (
24
+ RerankParametersModel,
8
25
  RetrieverModel,
26
+ SearchParametersModel,
9
27
  VectorStoreModel,
28
+ value_of,
29
+ )
30
+ from dao_ai.utils import normalize_host
31
+
32
+ # Create FilterItem model at module level so it can be used in type hints
33
+ FilterItem = create_model(
34
+ "FilterItem",
35
+ key=(
36
+ str,
37
+ Field(
38
+ description="The filter key, which includes the column name and can include operators like 'NOT', '<', '>=', 'LIKE', 'OR'"
39
+ ),
40
+ ),
41
+ value=(
42
+ Any,
43
+ Field(
44
+ description="The filter value, which can be a single value or an array of values"
45
+ ),
46
+ ),
47
+ __config__=ConfigDict(extra="forbid"),
10
48
  )
11
49
 
12
50
 
51
+ def _create_dynamic_input_schema(
52
+ index_name: str, workspace_client: WorkspaceClient
53
+ ) -> type[BaseModel]:
54
+ """
55
+ Create dynamic input schema with column information from the table.
56
+
57
+ Args:
58
+ index_name: Full name of the vector search index
59
+ workspace_client: Workspace client to query table metadata
60
+
61
+ Returns:
62
+ Pydantic model class for tool input
63
+ """
64
+
65
+ # Try to get column information
66
+ column_descriptions = []
67
+ try:
68
+ table_info = workspace_client.tables.get(full_name=index_name)
69
+ for column_info in table_info.columns:
70
+ name = column_info.name
71
+ col_type = column_info.type_name.name
72
+ if not name.startswith("__"):
73
+ column_descriptions.append(f"{name} ({col_type})")
74
+ except Exception:
75
+ logger.debug(
76
+ "Could not retrieve column information for dynamic schema",
77
+ index=index_name,
78
+ )
79
+
80
+ # Build filter description matching VectorSearchRetrieverTool format
81
+ filter_description = (
82
+ "Optional filters to refine vector search results as an array of key-value pairs. "
83
+ "IMPORTANT: If unsure about filter values, try searching WITHOUT filters first to get broad results, "
84
+ "then optionally add filters to narrow down if needed. This ensures you don't miss relevant results due to incorrect filter values. "
85
+ )
86
+
87
+ if column_descriptions:
88
+ filter_description += (
89
+ f"Available columns for filtering: {', '.join(column_descriptions)}. "
90
+ )
91
+
92
+ filter_description += (
93
+ "Supports the following operators:\n\n"
94
+ '- Inclusion: [{"key": "column", "value": value}] or [{"key": "column", "value": [value1, value2]}] (matches if the column equals any of the provided values)\n'
95
+ '- Exclusion: [{"key": "column NOT", "value": value}]\n'
96
+ '- Comparisons: [{"key": "column <", "value": value}], [{"key": "column >=", "value": value}], etc.\n'
97
+ '- Pattern match: [{"key": "column LIKE", "value": "word"}] (matches full tokens separated by whitespace)\n'
98
+ '- OR logic: [{"key": "column1 OR column2", "value": [value1, value2]}] '
99
+ "(matches if column1 equals value1 or column2 equals value2; matches are position-specific)\n\n"
100
+ "Examples:\n"
101
+ '- Filter by category: [{"key": "category", "value": "electronics"}]\n'
102
+ '- Filter by price range: [{"key": "price >=", "value": 100}, {"key": "price <", "value": 500}]\n'
103
+ '- Exclude specific status: [{"key": "status NOT", "value": "archived"}]\n'
104
+ '- Pattern matching: [{"key": "description LIKE", "value": "wireless"}]'
105
+ )
106
+
107
+ # Create the input model
108
+ VectorSearchInput = create_model(
109
+ "VectorSearchInput",
110
+ query=(
111
+ str,
112
+ Field(description="The search query string to find relevant documents"),
113
+ ),
114
+ filters=(
115
+ Optional[list[FilterItem]],
116
+ Field(default=None, description=filter_description),
117
+ ),
118
+ __config__=ConfigDict(extra="forbid"),
119
+ )
120
+
121
+ return VectorSearchInput
122
+
123
+
124
+ @mlflow.trace(name="rerank_documents", span_type=SpanType.RETRIEVER)
125
+ def _rerank_documents(
126
+ query: str,
127
+ documents: list[Document],
128
+ ranker: Ranker,
129
+ rerank_config: RerankParametersModel,
130
+ ) -> list[Document]:
131
+ """
132
+ Rerank documents using FlashRank cross-encoder model.
133
+
134
+ Args:
135
+ query: The search query string
136
+ documents: List of documents to rerank
137
+ ranker: The FlashRank Ranker instance
138
+ rerank_config: Reranking configuration
139
+
140
+ Returns:
141
+ Reranked list of documents with reranker_score in metadata
142
+ """
143
+ logger.trace(
144
+ "Starting reranking",
145
+ documents_count=len(documents),
146
+ model=rerank_config.model,
147
+ )
148
+
149
+ # Prepare passages for reranking
150
+ passages: list[dict[str, Any]] = [
151
+ {"text": doc.page_content, "meta": doc.metadata} for doc in documents
152
+ ]
153
+
154
+ # Create reranking request
155
+ rerank_request: RerankRequest = RerankRequest(query=query, passages=passages)
156
+
157
+ # Perform reranking
158
+ results: list[dict[str, Any]] = ranker.rerank(rerank_request)
159
+
160
+ # Apply top_n filtering
161
+ top_n: int = rerank_config.top_n or len(documents)
162
+ results = results[:top_n]
163
+ logger.debug("Reranking complete", top_n=top_n, candidates_count=len(documents))
164
+
165
+ # Convert back to Document objects with reranking scores
166
+ reranked_docs: list[Document] = []
167
+ for result in results:
168
+ orig_doc: Optional[Document] = next(
169
+ (doc for doc in documents if doc.page_content == result["text"]), None
170
+ )
171
+ if orig_doc:
172
+ reranked_doc: Document = Document(
173
+ page_content=orig_doc.page_content,
174
+ metadata={
175
+ **orig_doc.metadata,
176
+ "reranker_score": result["score"],
177
+ },
178
+ )
179
+ reranked_docs.append(reranked_doc)
180
+
181
+ logger.debug(
182
+ "Documents reranked",
183
+ input_count=len(documents),
184
+ output_count=len(reranked_docs),
185
+ model=rerank_config.model,
186
+ )
187
+
188
+ return reranked_docs
189
+
190
+
13
191
  def create_vector_search_tool(
14
- retriever: RetrieverModel | dict[str, Any],
192
+ retriever: Optional[RetrieverModel | dict[str, Any]] = None,
193
+ vector_store: Optional[VectorStoreModel | dict[str, Any]] = None,
15
194
  name: Optional[str] = None,
16
195
  description: Optional[str] = None,
17
- ) -> BaseTool:
196
+ ) -> StructuredTool:
18
197
  """
19
- Create a Vector Search tool for retrieving documents from a Databricks Vector Search index.
20
-
21
- This function creates a tool that enables semantic search over product information,
22
- documentation, or other content. It also registers the retriever schema with MLflow
23
- for proper integration with the model serving infrastructure.
198
+ Create a Vector Search tool with dynamic schema and optional reranking.
24
199
 
25
200
  Args:
26
- retriever: Configuration details for the vector search retriever, including:
27
- - name: Name of the tool
28
- - description: Description of the tool's purpose
29
- - primary_key: Primary key column for the vector store
30
- - text_column: Text column used for vector search
31
- - doc_uri: URI for documentation or additional context
32
- - vector_store: Dictionary with 'endpoint_name' and 'index' for vector search
33
- - columns: List of columns to retrieve from the vector store
34
- - search_parameters: Additional parameters for customizing the search behavior
201
+ retriever: Full retriever configuration with search parameters and reranking
202
+ vector_store: Direct vector store reference (uses default search parameters)
203
+ name: Optional custom name for the tool
204
+ description: Optional custom description for the tool
35
205
 
36
206
  Returns:
37
- A BaseTool instance that can perform vector search operations
207
+ A LangChain StructuredTool with proper schema (additionalProperties: false)
38
208
  """
39
209
 
40
- if isinstance(retriever, dict):
41
- retriever = RetrieverModel(**retriever)
210
+ # Validate mutually exclusive parameters
211
+ if retriever is None and vector_store is None:
212
+ raise ValueError("Must provide either 'retriever' or 'vector_store' parameter")
213
+ if retriever is not None and vector_store is not None:
214
+ raise ValueError(
215
+ "Cannot provide both 'retriever' and 'vector_store' parameters"
216
+ )
217
+
218
+ # Handle vector_store parameter
219
+ if vector_store is not None:
220
+ if isinstance(vector_store, dict):
221
+ vector_store = VectorStoreModel(**vector_store)
222
+ retriever = RetrieverModel(vector_store=vector_store)
223
+ else:
224
+ if isinstance(retriever, dict):
225
+ retriever = RetrieverModel(**retriever)
42
226
 
43
227
  vector_store: VectorStoreModel = retriever.vector_store
44
228
 
229
+ # Index is required
230
+ if vector_store.index is None:
231
+ raise ValueError("vector_store.index is required for vector search")
232
+
45
233
  index_name: str = vector_store.index.full_name
46
- columns: Sequence[str] = retriever.columns
47
- search_parameters: dict[str, Any] = retriever.search_parameters.model_dump()
48
- primary_key: str = vector_store.primary_key
49
- doc_uri: str = vector_store.doc_uri
50
- text_column: str = vector_store.embedding_source_column
234
+ columns: list[str] = list(retriever.columns or [])
235
+ search_parameters: SearchParametersModel = retriever.search_parameters
236
+ rerank_config: Optional[RerankParametersModel] = retriever.rerank
237
+
238
+ # Initialize FlashRank ranker if configured
239
+ ranker: Optional[Ranker] = None
240
+ if rerank_config and rerank_config.model:
241
+ logger.debug(
242
+ "Initializing FlashRank ranker",
243
+ model=rerank_config.model,
244
+ top_n=rerank_config.top_n or "auto",
245
+ )
246
+ try:
247
+ cache_dir = os.path.expanduser(rerank_config.cache_dir)
248
+ ranker = Ranker(model_name=rerank_config.model, cache_dir=cache_dir)
249
+ logger.success("FlashRank ranker initialized", model=rerank_config.model)
250
+ except Exception as e:
251
+ logger.warning("Failed to initialize FlashRank ranker", error=str(e))
252
+ rerank_config = None
253
+
254
+ # Build client_args for VectorSearchClient
255
+ # Use getattr to safely access attributes that may not exist (e.g., in mocks)
256
+ client_args: dict[str, Any] = {}
257
+ has_explicit_auth = any(
258
+ [
259
+ os.environ.get("DATABRICKS_TOKEN"),
260
+ os.environ.get("DATABRICKS_CLIENT_ID"),
261
+ getattr(vector_store, "pat", None),
262
+ getattr(vector_store, "client_id", None),
263
+ getattr(vector_store, "on_behalf_of_user", None),
264
+ ]
265
+ )
266
+
267
+ if has_explicit_auth:
268
+ databricks_host = os.environ.get("DATABRICKS_HOST")
269
+ if (
270
+ not databricks_host
271
+ and getattr(vector_store, "_workspace_client", None) is not None
272
+ ):
273
+ databricks_host = vector_store.workspace_client.config.host
274
+ if databricks_host:
275
+ client_args["workspace_url"] = normalize_host(databricks_host)
51
276
 
52
- vector_search_tool: BaseTool = VectorSearchRetrieverTool(
277
+ token = os.environ.get("DATABRICKS_TOKEN")
278
+ if not token and getattr(vector_store, "pat", None):
279
+ token = value_of(vector_store.pat)
280
+ if token:
281
+ client_args["personal_access_token"] = token
282
+
283
+ client_id = os.environ.get("DATABRICKS_CLIENT_ID")
284
+ if not client_id and getattr(vector_store, "client_id", None):
285
+ client_id = value_of(vector_store.client_id)
286
+ if client_id:
287
+ client_args["service_principal_client_id"] = client_id
288
+
289
+ client_secret = os.environ.get("DATABRICKS_CLIENT_SECRET")
290
+ if not client_secret and getattr(vector_store, "client_secret", None):
291
+ client_secret = value_of(vector_store.client_secret)
292
+ if client_secret:
293
+ client_args["service_principal_client_secret"] = client_secret
294
+
295
+ logger.debug(
296
+ "Creating vector search tool",
53
297
  name=name,
54
- tool_name=name,
55
- description=description,
56
- tool_description=description,
298
+ index=index_name,
299
+ client_args_keys=list(client_args.keys()) if client_args else [],
300
+ )
301
+
302
+ # Create DatabricksVectorSearch
303
+ # Note: text_column should be None for Databricks-managed embeddings
304
+ # (it's automatically determined from the index)
305
+ vector_search: DatabricksVectorSearch = DatabricksVectorSearch(
57
306
  index_name=index_name,
307
+ text_column=None,
58
308
  columns=columns,
59
- **search_parameters,
60
309
  workspace_client=vector_store.workspace_client,
310
+ client_args=client_args if client_args else None,
311
+ primary_key=vector_store.primary_key,
312
+ doc_uri=vector_store.doc_uri,
313
+ include_score=True,
314
+ reranker=(
315
+ DatabricksReranker(columns_to_rerank=rerank_config.columns)
316
+ if rerank_config and rerank_config.columns
317
+ else None
318
+ ),
61
319
  )
62
320
 
63
- # Register the retriever schema with MLflow for model serving integration
64
- mlflow.models.set_retriever_schema(
65
- name=name or "retriever",
66
- primary_key=primary_key,
67
- text_column=text_column,
68
- doc_uri=doc_uri,
69
- other_columns=columns,
321
+ # Create dynamic input schema
322
+ input_schema: type[BaseModel] = _create_dynamic_input_schema(
323
+ index_name, vector_store.workspace_client
70
324
  )
71
325
 
72
- return vector_search_tool
326
+ # Define the tool function
327
+ def vector_search_func(
328
+ query: str, filters: Optional[list[FilterItem]] = None
329
+ ) -> str:
330
+ """Search for relevant documents using vector similarity."""
331
+ # Convert FilterItem Pydantic models to dict format for DatabricksVectorSearch
332
+ filters_dict: dict[str, Any] = {}
333
+ if filters:
334
+ for item in filters:
335
+ filters_dict[item.key] = item.value
336
+
337
+ # Merge with configured filters
338
+ combined_filters: dict[str, Any] = {
339
+ **filters_dict,
340
+ **(search_parameters.filters or {}),
341
+ }
342
+
343
+ # Perform vector search
344
+ logger.trace("Performing vector search", query_preview=query[:50])
345
+ documents: list[Document] = vector_search.similarity_search(
346
+ query=query,
347
+ k=search_parameters.num_results or 5,
348
+ filter=combined_filters if combined_filters else None,
349
+ query_type=search_parameters.query_type or "ANN",
350
+ )
351
+
352
+ # Apply FlashRank reranking if configured
353
+ if ranker and rerank_config:
354
+ logger.debug("Applying FlashRank reranking")
355
+ documents = _rerank_documents(query, documents, ranker, rerank_config)
356
+
357
+ # Serialize documents to JSON format for LLM consumption
358
+ # Convert Document objects to dicts with page_content and metadata
359
+ # Need to handle numpy types in metadata (e.g., float32, int64)
360
+ serialized_docs: list[dict[str, Any]] = []
361
+ for doc in documents:
362
+ doc: Document
363
+ # Convert metadata values to JSON-serializable types
364
+ metadata_serializable: dict[str, Any] = {}
365
+ for key, value in doc.metadata.items():
366
+ # Handle numpy types
367
+ if hasattr(value, "item"): # numpy scalar
368
+ metadata_serializable[key] = value.item()
369
+ else:
370
+ metadata_serializable[key] = value
371
+
372
+ serialized_docs.append(
373
+ {
374
+ "page_content": doc.page_content,
375
+ "metadata": metadata_serializable,
376
+ }
377
+ )
378
+
379
+ # Return as JSON string
380
+ return json.dumps(serialized_docs)
381
+
382
+ # Create the StructuredTool
383
+ tool: StructuredTool = StructuredTool.from_function(
384
+ func=vector_search_func,
385
+ name=name or f"vector_search_{vector_store.index.name}",
386
+ description=description or f"Search documents in {index_name}",
387
+ args_schema=input_schema,
388
+ )
389
+
390
+ logger.success("Vector search tool created", name=tool.name, index=index_name)
391
+
392
+ return tool