dao-ai 0.0.25__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/agent_as_code.py +5 -5
  3. dao_ai/cli.py +245 -40
  4. dao_ai/config.py +1863 -338
  5. dao_ai/genie/__init__.py +38 -0
  6. dao_ai/genie/cache/__init__.py +43 -0
  7. dao_ai/genie/cache/base.py +72 -0
  8. dao_ai/genie/cache/core.py +79 -0
  9. dao_ai/genie/cache/lru.py +347 -0
  10. dao_ai/genie/cache/semantic.py +970 -0
  11. dao_ai/genie/core.py +35 -0
  12. dao_ai/graph.py +27 -228
  13. dao_ai/hooks/__init__.py +9 -6
  14. dao_ai/hooks/core.py +27 -195
  15. dao_ai/logging.py +56 -0
  16. dao_ai/memory/__init__.py +10 -0
  17. dao_ai/memory/core.py +65 -30
  18. dao_ai/memory/databricks.py +402 -0
  19. dao_ai/memory/postgres.py +79 -38
  20. dao_ai/messages.py +6 -4
  21. dao_ai/middleware/__init__.py +125 -0
  22. dao_ai/middleware/assertions.py +806 -0
  23. dao_ai/middleware/base.py +50 -0
  24. dao_ai/middleware/core.py +67 -0
  25. dao_ai/middleware/guardrails.py +420 -0
  26. dao_ai/middleware/human_in_the_loop.py +232 -0
  27. dao_ai/middleware/message_validation.py +586 -0
  28. dao_ai/middleware/summarization.py +197 -0
  29. dao_ai/models.py +1306 -114
  30. dao_ai/nodes.py +261 -166
  31. dao_ai/optimization.py +674 -0
  32. dao_ai/orchestration/__init__.py +52 -0
  33. dao_ai/orchestration/core.py +294 -0
  34. dao_ai/orchestration/supervisor.py +278 -0
  35. dao_ai/orchestration/swarm.py +271 -0
  36. dao_ai/prompts.py +128 -31
  37. dao_ai/providers/databricks.py +645 -172
  38. dao_ai/state.py +157 -21
  39. dao_ai/tools/__init__.py +13 -5
  40. dao_ai/tools/agent.py +1 -3
  41. dao_ai/tools/core.py +64 -11
  42. dao_ai/tools/email.py +232 -0
  43. dao_ai/tools/genie.py +144 -295
  44. dao_ai/tools/mcp.py +220 -133
  45. dao_ai/tools/memory.py +50 -0
  46. dao_ai/tools/python.py +9 -14
  47. dao_ai/tools/search.py +14 -0
  48. dao_ai/tools/slack.py +22 -10
  49. dao_ai/tools/sql.py +202 -0
  50. dao_ai/tools/time.py +30 -7
  51. dao_ai/tools/unity_catalog.py +165 -88
  52. dao_ai/tools/vector_search.py +360 -40
  53. dao_ai/utils.py +218 -16
  54. dao_ai-0.1.2.dist-info/METADATA +455 -0
  55. dao_ai-0.1.2.dist-info/RECORD +64 -0
  56. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +1 -1
  57. dao_ai/chat_models.py +0 -204
  58. dao_ai/guardrails.py +0 -112
  59. dao_ai/tools/human_in_the_loop.py +0 -100
  60. dao_ai-0.0.25.dist-info/METADATA +0 -1165
  61. dao_ai-0.0.25.dist-info/RECORD +0 -41
  62. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
  63. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
@@ -1,20 +1,21 @@
1
- from typing import Any, Dict, Optional, Sequence, Union
1
+ from typing import Any, Dict, Optional, Sequence, Set
2
2
 
3
3
  from databricks.sdk import WorkspaceClient
4
- from databricks.sdk.service.catalog import PermissionsChange, Privilege
4
+ from databricks.sdk.service.catalog import FunctionInfo, PermissionsChange, Privilege
5
5
  from databricks_langchain import DatabricksFunctionClient, UCFunctionToolkit
6
6
  from langchain_core.runnables.base import RunnableLike
7
7
  from langchain_core.tools import StructuredTool
8
8
  from loguru import logger
9
+ from pydantic import BaseModel
10
+ from unitycatalog.ai.core.base import FunctionExecutionResult
9
11
 
10
12
  from dao_ai.config import (
11
13
  AnyVariable,
12
14
  CompositeVariableModel,
13
- ToolModel,
14
15
  UnityCatalogFunctionModel,
15
16
  value_of,
16
17
  )
17
- from dao_ai.tools.human_in_the_loop import as_human_in_the_loop
18
+ from dao_ai.utils import normalize_host
18
19
 
19
20
 
20
21
  def create_uc_tools(
@@ -33,45 +34,48 @@ def create_uc_tools(
33
34
  Returns:
34
35
  A sequence of BaseTool objects that wrap the specified UC functions
35
36
  """
37
+ original_function_model: UnityCatalogFunctionModel | None = None
38
+ workspace_client: WorkspaceClient | None = None
39
+ function_name: str
36
40
 
37
- logger.debug(f"create_uc_tools: {function}")
38
-
39
- original_function_model = None
40
41
  if isinstance(function, UnityCatalogFunctionModel):
41
42
  original_function_model = function
42
- function_name = function.full_name
43
+ function_name = function.resource.full_name
44
+ workspace_client = function.resource.workspace_client
43
45
  else:
44
46
  function_name = function
45
47
 
48
+ logger.trace("Creating UC tools", function_name=function_name)
49
+
46
50
  # Determine which tools to create
51
+ tools: list[RunnableLike]
47
52
  if original_function_model and original_function_model.partial_args:
48
- logger.debug("Found partial_args, creating custom tool with partial arguments")
49
- # Create a ToolModel wrapper for the with_partial_args function
50
- tool_model = ToolModel(
51
- name=original_function_model.name, function=original_function_model
53
+ logger.debug(
54
+ "Creating custom tool with partial arguments", function_name=function_name
52
55
  )
53
-
54
- # Use with_partial_args to create the authenticated tool
55
- tools = [with_partial_args(tool_model, original_function_model.partial_args)]
56
+ # Use with_partial_args directly with UnityCatalogFunctionModel
57
+ tools = [with_partial_args(original_function_model)]
56
58
  else:
57
59
  # Fallback to standard UC toolkit approach
58
- client: DatabricksFunctionClient = DatabricksFunctionClient()
60
+ client: DatabricksFunctionClient = DatabricksFunctionClient(
61
+ client=workspace_client
62
+ )
59
63
 
60
64
  toolkit: UCFunctionToolkit = UCFunctionToolkit(
61
65
  function_names=[function_name], client=client
62
66
  )
63
67
 
64
68
  tools = toolkit.tools or []
65
- logger.debug(f"Retrieved tools: {tools}")
69
+ logger.trace("Retrieved tools", tools_count=len(tools))
66
70
 
67
- # Apply human-in-the-loop wrapper to all tools and return
68
- return [as_human_in_the_loop(tool=tool, function=function_name) for tool in tools]
71
+ # HITL is now handled at middleware level via HumanInTheLoopMiddleware
72
+ return list(tools)
69
73
 
70
74
 
71
75
  def _execute_uc_function(
72
76
  client: DatabricksFunctionClient,
73
77
  function_name: str,
74
- partial_args: Dict[str, str] = None,
78
+ partial_args: Optional[Dict[str, str]] = None,
75
79
  **kwargs: Any,
76
80
  ) -> str:
77
81
  """Execute Unity Catalog function with partial args and provided parameters."""
@@ -83,18 +87,30 @@ def _execute_uc_function(
83
87
  all_params.update(kwargs)
84
88
 
85
89
  logger.debug(
86
- f"Calling UC function {function_name} with parameters: {list(all_params.keys())}"
90
+ "Calling UC function",
91
+ function_name=function_name,
92
+ parameters=list(all_params.keys()),
87
93
  )
88
94
 
89
- result = client.execute_function(function_name=function_name, parameters=all_params)
95
+ result: FunctionExecutionResult = client.execute_function(
96
+ function_name=function_name, parameters=all_params
97
+ )
90
98
 
91
99
  # Handle errors and extract result
92
- if hasattr(result, "error") and result.error:
93
- logger.error(f"Unity Catalog function error: {result.error}")
100
+ if result.error:
101
+ logger.error(
102
+ "Unity Catalog function error",
103
+ function_name=function_name,
104
+ error=result.error,
105
+ )
94
106
  raise RuntimeError(f"Function execution failed: {result.error}")
95
107
 
96
- result_value: str = result.value if hasattr(result, "value") else str(result)
97
- logger.debug(f"UC function result: {result_value}")
108
+ result_value: str = result.value if result.value is not None else str(result)
109
+ logger.trace(
110
+ "UC function result",
111
+ function_name=function_name,
112
+ result_length=len(str(result_value)),
113
+ )
98
114
  return result_value
99
115
 
100
116
 
@@ -113,21 +129,29 @@ def _grant_function_permissions(
113
129
  """
114
130
  try:
115
131
  # Initialize workspace client
116
- workspace_client = WorkspaceClient(host=host) if host else WorkspaceClient()
132
+ workspace_client: WorkspaceClient = (
133
+ WorkspaceClient(host=host) if host else WorkspaceClient()
134
+ )
117
135
 
118
136
  # Parse the function name to get catalog and schema
119
- parts = function_name.split(".")
137
+ parts: list[str] = function_name.split(".")
120
138
  if len(parts) != 3:
121
139
  logger.warning(
122
- f"Invalid function name format: {function_name}. Expected catalog.schema.function"
140
+ "Invalid function name format, expected catalog.schema.function",
141
+ function_name=function_name,
123
142
  )
124
143
  return
125
144
 
145
+ catalog_name: str
146
+ schema_name: str
147
+ func_name: str
126
148
  catalog_name, schema_name, func_name = parts
127
- schema_full_name = f"{catalog_name}.{schema_name}"
149
+ schema_full_name: str = f"{catalog_name}.{schema_name}"
128
150
 
129
151
  logger.debug(
130
- f"Granting comprehensive permissions on function {function_name} to principal {client_id}"
152
+ "Granting comprehensive permissions",
153
+ function_name=function_name,
154
+ principal=client_id,
131
155
  )
132
156
 
133
157
  # 1. Grant EXECUTE permission on the function
@@ -139,9 +163,13 @@ def _grant_function_permissions(
139
163
  PermissionsChange(principal=client_id, add=[Privilege.EXECUTE])
140
164
  ],
141
165
  )
142
- logger.debug(f"Granted EXECUTE on function {function_name}")
166
+ logger.trace("Granted EXECUTE permission", function_name=function_name)
143
167
  except Exception as e:
144
- logger.warning(f"Failed to grant EXECUTE on function {function_name}: {e}")
168
+ logger.warning(
169
+ "Failed to grant EXECUTE permission",
170
+ function_name=function_name,
171
+ error=str(e),
172
+ )
145
173
 
146
174
  # 2. Grant USE_SCHEMA permission on the schema
147
175
  try:
@@ -155,10 +183,12 @@ def _grant_function_permissions(
155
183
  )
156
184
  ],
157
185
  )
158
- logger.debug(f"Granted USE_SCHEMA on schema {schema_full_name}")
186
+ logger.trace("Granted USE_SCHEMA permission", schema=schema_full_name)
159
187
  except Exception as e:
160
188
  logger.warning(
161
- f"Failed to grant USE_SCHEMA on schema {schema_full_name}: {e}"
189
+ "Failed to grant USE_SCHEMA permission",
190
+ schema=schema_full_name,
191
+ error=str(e),
162
192
  )
163
193
 
164
194
  # 3. Grant USE_CATALOG and BROWSE permissions on the catalog
@@ -173,25 +203,34 @@ def _grant_function_permissions(
173
203
  )
174
204
  ],
175
205
  )
176
- logger.debug(f"Granted USE_CATALOG and BROWSE on catalog {catalog_name}")
206
+ logger.trace(
207
+ "Granted USE_CATALOG and BROWSE permissions", catalog=catalog_name
208
+ )
177
209
  except Exception as e:
178
210
  logger.warning(
179
- f"Failed to grant catalog permissions on {catalog_name}: {e}"
211
+ "Failed to grant catalog permissions",
212
+ catalog=catalog_name,
213
+ error=str(e),
180
214
  )
181
215
 
182
216
  logger.debug(
183
- f"Successfully granted comprehensive permissions on {function_name} to {client_id}"
217
+ "Successfully granted comprehensive permissions",
218
+ function_name=function_name,
219
+ principal=client_id,
184
220
  )
185
221
 
186
222
  except Exception as e:
187
223
  logger.warning(
188
- f"Failed to grant permissions on function {function_name} to {client_id}: {e}"
224
+ "Failed to grant permissions",
225
+ function_name=function_name,
226
+ principal=client_id,
227
+ error=str(e),
189
228
  )
190
229
  # Don't fail the tool creation if permission granting fails
191
230
  pass
192
231
 
193
232
 
194
- def _create_filtered_schema(original_schema: type, exclude_fields: set[str]) -> type:
233
+ def _create_filtered_schema(original_schema: type, exclude_fields: Set[str]) -> type:
195
234
  """
196
235
  Create a new Pydantic model that excludes specified fields from the original schema.
197
236
 
@@ -202,23 +241,27 @@ def _create_filtered_schema(original_schema: type, exclude_fields: set[str]) ->
202
241
  Returns:
203
242
  A new Pydantic model class with the specified fields removed
204
243
  """
205
- from pydantic import BaseModel, Field, create_model
206
- from pydantic.fields import PydanticUndefined
244
+ from pydantic import Field, create_model
245
+ from pydantic.fields import FieldInfo, PydanticUndefined
207
246
 
208
247
  try:
209
248
  # Get the original model's fields (Pydantic v2)
210
- original_fields = original_schema.model_fields
211
- filtered_field_definitions = {}
249
+ original_fields: dict[str, FieldInfo] = original_schema.model_fields
250
+ filtered_field_definitions: dict[str, tuple[type, FieldInfo]] = {}
212
251
 
213
- for name, field in original_fields.items():
214
- if name not in exclude_fields:
252
+ field_name: str
253
+ field: FieldInfo
254
+ for field_name, field in original_fields.items():
255
+ if field_name not in exclude_fields:
215
256
  # Reconstruct the field definition for create_model
216
- field_type = field.annotation
217
- field_default = (
257
+ field_type: type = field.annotation
258
+ field_default: Any = (
218
259
  field.default if field.default is not PydanticUndefined else ...
219
260
  )
220
- field_info = Field(default=field_default, description=field.description)
221
- filtered_field_definitions[name] = (field_type, field_info)
261
+ field_info: FieldInfo = Field(
262
+ default=field_default, description=field.description
263
+ )
264
+ filtered_field_definitions[field_name] = (field_type, field_info)
222
265
 
223
266
  # If no fields remain after filtering, return a generic empty schema
224
267
  if not filtered_field_definitions:
@@ -231,18 +274,18 @@ def _create_filtered_schema(original_schema: type, exclude_fields: set[str]) ->
231
274
  return EmptySchema
232
275
 
233
276
  # Create the new model dynamically
234
- model_name = f"Filtered{original_schema.__name__}"
235
- docstring = getattr(
277
+ model_name: str = f"Filtered{original_schema.__name__}"
278
+ docstring: str = getattr(
236
279
  original_schema, "__doc__", "Filtered Unity Catalog function parameters."
237
280
  )
238
281
 
239
- filtered_model = create_model(
282
+ filtered_model: type[BaseModel] = create_model(
240
283
  model_name, __doc__=docstring, **filtered_field_definitions
241
284
  )
242
285
  return filtered_model
243
286
 
244
287
  except Exception as e:
245
- logger.warning(f"Failed to create filtered schema: {e}")
288
+ logger.warning("Failed to create filtered schema", error=str(e))
246
289
 
247
290
  # Fallback to generic schema
248
291
  class GenericFilteredSchema(BaseModel):
@@ -254,8 +297,7 @@ def _create_filtered_schema(original_schema: type, exclude_fields: set[str]) ->
254
297
 
255
298
 
256
299
  def with_partial_args(
257
- tool: Union[ToolModel, Dict[str, Any]],
258
- partial_args: dict[str, AnyVariable] = {},
300
+ uc_function: UnityCatalogFunctionModel,
259
301
  ) -> StructuredTool:
260
302
  """
261
303
  Create a Unity Catalog tool with partial arguments pre-filled.
@@ -264,35 +306,64 @@ def with_partial_args(
264
306
  already resolved, so the caller only needs to provide the remaining parameters.
265
307
 
266
308
  Args:
267
- tool: ToolModel containing the Unity Catalog function configuration
268
- partial_args: Dictionary of arguments to pre-fill in the tool
309
+ uc_function: UnityCatalogFunctionModel containing the function configuration
310
+ and partial_args to pre-fill.
269
311
 
270
312
  Returns:
271
313
  StructuredTool: A LangChain tool with partial arguments pre-filled
272
314
  """
273
315
  from unitycatalog.ai.langchain.toolkit import generate_function_input_params_schema
274
316
 
275
- logger.debug(f"with_partial_args: {tool}")
317
+ from dao_ai.config import ServicePrincipalModel
318
+
319
+ partial_args: dict[str, AnyVariable] = uc_function.partial_args or {}
276
320
 
277
321
  # Convert dict-based variables to CompositeVariableModel and resolve their values
278
- resolved_args = {}
322
+ resolved_args: dict[str, Any] = {}
323
+ k: str
324
+ v: AnyVariable
279
325
  for k, v in partial_args.items():
280
326
  if isinstance(v, dict):
281
327
  resolved_args[k] = value_of(CompositeVariableModel(**v))
282
328
  else:
283
329
  resolved_args[k] = value_of(v)
284
330
 
285
- logger.debug(f"Resolved partial args: {resolved_args.keys()}")
286
-
287
- if isinstance(tool, dict):
288
- tool = ToolModel(**tool)
331
+ # Handle service_principal - expand into client_id and client_secret
332
+ if "service_principal" in resolved_args:
333
+ sp: Any = resolved_args.pop("service_principal")
334
+ if isinstance(sp, dict):
335
+ sp = ServicePrincipalModel(**sp)
336
+ if isinstance(sp, ServicePrincipalModel):
337
+ if "client_id" not in resolved_args:
338
+ resolved_args["client_id"] = value_of(sp.client_id)
339
+ if "client_secret" not in resolved_args:
340
+ resolved_args["client_secret"] = value_of(sp.client_secret)
341
+
342
+ # Normalize host/workspace_host - accept either key, ensure https:// scheme
343
+ if "workspace_host" in resolved_args and "host" not in resolved_args:
344
+ resolved_args["host"] = normalize_host(resolved_args.pop("workspace_host"))
345
+ elif "host" in resolved_args:
346
+ resolved_args["host"] = normalize_host(resolved_args["host"])
347
+
348
+ # Default host from WorkspaceClient if not provided
349
+ if "host" not in resolved_args:
350
+ from dao_ai.utils import get_default_databricks_host
351
+
352
+ host: str | None = get_default_databricks_host()
353
+ if host:
354
+ resolved_args["host"] = host
355
+
356
+ # Get function info from the resource
357
+ function_name: str = uc_function.resource.full_name
358
+ tool_name: str = uc_function.resource.name or function_name.replace(".", "_")
359
+ workspace_client: WorkspaceClient = uc_function.resource.workspace_client
289
360
 
290
- unity_catalog_function = tool.function
291
- if isinstance(unity_catalog_function, dict):
292
- unity_catalog_function = UnityCatalogFunctionModel(**unity_catalog_function)
293
-
294
- function_name: str = unity_catalog_function.full_name
295
- logger.debug(f"Creating UC tool with partial args for: {function_name}")
361
+ logger.debug(
362
+ "Creating UC tool with partial args",
363
+ function_name=function_name,
364
+ tool_name=tool_name,
365
+ partial_args=list(resolved_args.keys()),
366
+ )
296
367
 
297
368
  # Grant permissions if we have credentials
298
369
  if "client_id" in resolved_args:
@@ -301,36 +372,44 @@ def with_partial_args(
301
372
  try:
302
373
  _grant_function_permissions(function_name, client_id, host)
303
374
  except Exception as e:
304
- logger.warning(f"Failed to grant permissions: {e}")
375
+ logger.warning(
376
+ "Failed to grant permissions", function_name=function_name, error=str(e)
377
+ )
305
378
 
306
- # Create the client for function execution
307
- client: DatabricksFunctionClient = DatabricksFunctionClient()
379
+ # Create the client for function execution using the resource's workspace client
380
+ client: DatabricksFunctionClient = DatabricksFunctionClient(client=workspace_client)
308
381
 
309
382
  # Try to get the function schema for better tool definition
383
+ schema_model: type[BaseModel]
384
+ tool_description: str
310
385
  try:
311
- function_info = client.get_function(function_name)
386
+ function_info: FunctionInfo = client.get_function(function_name)
312
387
  schema_info = generate_function_input_params_schema(function_info)
313
388
  tool_description = (
314
389
  function_info.comment or f"Unity Catalog function: {function_name}"
315
390
  )
316
391
 
317
- logger.debug(
318
- f"Generated schema for function {function_name}: {schema_info.pydantic_model}"
392
+ logger.trace(
393
+ "Generated function schema",
394
+ function_name=function_name,
395
+ schema=schema_info.pydantic_model.__name__,
319
396
  )
320
- logger.debug(f"Tool description: {tool_description}")
321
397
 
322
398
  # Create a modified schema that excludes partial args
323
- original_schema = schema_info.pydantic_model
399
+ original_schema: type = schema_info.pydantic_model
324
400
  schema_model = _create_filtered_schema(original_schema, resolved_args.keys())
325
- logger.debug(
326
- f"Filtered schema excludes partial args: {list(resolved_args.keys())}"
401
+ logger.trace(
402
+ "Filtered schema to exclude partial args",
403
+ function_name=function_name,
404
+ excluded_args=list(resolved_args.keys()),
327
405
  )
328
406
 
329
407
  except Exception as e:
330
- logger.warning(f"Could not introspect function {function_name}: {e}")
331
- # Fallback to a generic schema
332
- from pydantic import BaseModel
408
+ logger.warning(
409
+ "Could not introspect function", function_name=function_name, error=str(e)
410
+ )
333
411
 
412
+ # Fallback to a generic schema
334
413
  class GenericUCParams(BaseModel):
335
414
  """Generic parameters for Unity Catalog function."""
336
415
 
@@ -350,14 +429,12 @@ def with_partial_args(
350
429
  )
351
430
 
352
431
  # Set the function name for the decorator
353
- uc_function_wrapper.__name__ = tool.name or function_name.replace(".", "_")
432
+ uc_function_wrapper.__name__ = tool_name
354
433
 
355
434
  # Create the tool using LangChain's StructuredTool
356
- from langchain_core.tools import StructuredTool
357
-
358
- partial_tool = StructuredTool.from_function(
435
+ partial_tool: StructuredTool = StructuredTool.from_function(
359
436
  func=uc_function_wrapper,
360
- name=tool.name or function_name.replace(".", "_"),
437
+ name=tool_name,
361
438
  description=tool_description,
362
439
  args_schema=schema_model,
363
440
  )