dao-ai 0.0.25__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/agent_as_code.py +5 -5
  3. dao_ai/cli.py +245 -40
  4. dao_ai/config.py +1863 -338
  5. dao_ai/genie/__init__.py +38 -0
  6. dao_ai/genie/cache/__init__.py +43 -0
  7. dao_ai/genie/cache/base.py +72 -0
  8. dao_ai/genie/cache/core.py +79 -0
  9. dao_ai/genie/cache/lru.py +347 -0
  10. dao_ai/genie/cache/semantic.py +970 -0
  11. dao_ai/genie/core.py +35 -0
  12. dao_ai/graph.py +27 -228
  13. dao_ai/hooks/__init__.py +9 -6
  14. dao_ai/hooks/core.py +27 -195
  15. dao_ai/logging.py +56 -0
  16. dao_ai/memory/__init__.py +10 -0
  17. dao_ai/memory/core.py +65 -30
  18. dao_ai/memory/databricks.py +402 -0
  19. dao_ai/memory/postgres.py +79 -38
  20. dao_ai/messages.py +6 -4
  21. dao_ai/middleware/__init__.py +125 -0
  22. dao_ai/middleware/assertions.py +806 -0
  23. dao_ai/middleware/base.py +50 -0
  24. dao_ai/middleware/core.py +67 -0
  25. dao_ai/middleware/guardrails.py +420 -0
  26. dao_ai/middleware/human_in_the_loop.py +232 -0
  27. dao_ai/middleware/message_validation.py +586 -0
  28. dao_ai/middleware/summarization.py +197 -0
  29. dao_ai/models.py +1306 -114
  30. dao_ai/nodes.py +261 -166
  31. dao_ai/optimization.py +674 -0
  32. dao_ai/orchestration/__init__.py +52 -0
  33. dao_ai/orchestration/core.py +294 -0
  34. dao_ai/orchestration/supervisor.py +278 -0
  35. dao_ai/orchestration/swarm.py +271 -0
  36. dao_ai/prompts.py +128 -31
  37. dao_ai/providers/databricks.py +645 -172
  38. dao_ai/state.py +157 -21
  39. dao_ai/tools/__init__.py +13 -5
  40. dao_ai/tools/agent.py +1 -3
  41. dao_ai/tools/core.py +64 -11
  42. dao_ai/tools/email.py +232 -0
  43. dao_ai/tools/genie.py +144 -295
  44. dao_ai/tools/mcp.py +220 -133
  45. dao_ai/tools/memory.py +50 -0
  46. dao_ai/tools/python.py +9 -14
  47. dao_ai/tools/search.py +14 -0
  48. dao_ai/tools/slack.py +22 -10
  49. dao_ai/tools/sql.py +202 -0
  50. dao_ai/tools/time.py +30 -7
  51. dao_ai/tools/unity_catalog.py +165 -88
  52. dao_ai/tools/vector_search.py +360 -40
  53. dao_ai/utils.py +218 -16
  54. dao_ai-0.1.2.dist-info/METADATA +455 -0
  55. dao_ai-0.1.2.dist-info/RECORD +64 -0
  56. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +1 -1
  57. dao_ai/chat_models.py +0 -204
  58. dao_ai/guardrails.py +0 -112
  59. dao_ai/tools/human_in_the_loop.py +0 -100
  60. dao_ai-0.0.25.dist-info/METADATA +0 -1165
  61. dao_ai-0.0.25.dist-info/RECORD +0 -41
  62. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
  63. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
dao_ai/utils.py CHANGED
@@ -3,9 +3,11 @@ import importlib.metadata
3
3
  import os
4
4
  import re
5
5
  import site
6
- from importlib.metadata import version
6
+ from importlib.metadata import PackageNotFoundError, version
7
+ from pathlib import Path
7
8
  from typing import Any, Callable, Sequence
8
9
 
10
+ from langchain_core.tools import BaseTool
9
11
  from loguru import logger
10
12
 
11
13
  import dao_ai
@@ -18,15 +20,15 @@ def is_lib_provided(lib_name: str, pip_requirements: Sequence[str]) -> bool:
18
20
  )
19
21
 
20
22
 
21
- def is_installed():
23
+ def is_installed() -> bool:
22
24
  current_file = os.path.abspath(dao_ai.__file__)
23
25
  site_packages = [os.path.abspath(path) for path in site.getsitepackages()]
24
26
  if site.getusersitepackages():
25
27
  site_packages.append(os.path.abspath(site.getusersitepackages()))
26
28
 
27
29
  found: bool = any(current_file.startswith(pkg_path) for pkg_path in site_packages)
28
- logger.debug(
29
- f"Checking if dao_ai is installed: {found} (current file: {current_file}"
30
+ logger.trace(
31
+ "Checking if dao_ai is installed", is_installed=found, current_file=current_file
30
32
  )
31
33
  return found
32
34
 
@@ -37,6 +39,114 @@ def normalize_name(name: str) -> str:
37
39
  return normalized.strip("_")
38
40
 
39
41
 
42
+ def normalize_host(host: str | None) -> str | None:
43
+ """Ensure host URL has https:// scheme.
44
+
45
+ The DATABRICKS_HOST environment variable should always include the https://
46
+ scheme, but some environments (e.g., Databricks Apps infrastructure) may
47
+ provide the host without it. This function normalizes the host to ensure
48
+ it has the proper scheme.
49
+
50
+ Args:
51
+ host: The host URL, with or without scheme
52
+
53
+ Returns:
54
+ The host URL with https:// scheme, or None if host is None/empty
55
+ """
56
+ if not host:
57
+ return None
58
+ host = host.strip()
59
+ if not host:
60
+ return None
61
+ if not host.startswith("http://") and not host.startswith("https://"):
62
+ return f"https://{host}"
63
+ return host
64
+
65
+
66
+ def get_default_databricks_host() -> str | None:
67
+ """Get the default Databricks workspace host.
68
+
69
+ Attempts to get the host from:
70
+ 1. DATABRICKS_HOST environment variable
71
+ 2. WorkspaceClient ambient authentication (e.g., from ~/.databrickscfg)
72
+
73
+ Returns:
74
+ The Databricks workspace host URL (with https:// scheme), or None if not available.
75
+ """
76
+ # Try environment variable first
77
+ host: str | None = os.environ.get("DATABRICKS_HOST")
78
+ if host:
79
+ return normalize_host(host)
80
+
81
+ # Fall back to WorkspaceClient
82
+ try:
83
+ from databricks.sdk import WorkspaceClient
84
+
85
+ w: WorkspaceClient = WorkspaceClient()
86
+ return normalize_host(w.config.host)
87
+ except Exception:
88
+ logger.trace("Could not get default Databricks host from WorkspaceClient")
89
+ return None
90
+
91
+
92
+ def dao_ai_version() -> str:
93
+ """
94
+ Get the dao-ai package version, with fallback for source installations.
95
+
96
+ Tries to get the version from installed package metadata first. If the package
97
+ is not installed (e.g., running from source), falls back to reading from
98
+ pyproject.toml. Returns "dev" if neither method works.
99
+
100
+ Returns:
101
+ str: The version string, or "dev" if version cannot be determined
102
+ """
103
+ try:
104
+ # Try to get version from installed package metadata
105
+ return version("dao-ai")
106
+ except PackageNotFoundError:
107
+ # Package not installed, try reading from pyproject.toml
108
+ logger.trace(
109
+ "dao-ai package not installed, attempting to read version from pyproject.toml"
110
+ )
111
+ try:
112
+ import tomllib # Python 3.11+
113
+ except ImportError:
114
+ try:
115
+ import tomli as tomllib # Fallback for Python < 3.11
116
+ except ImportError:
117
+ logger.warning(
118
+ "Cannot determine dao-ai version: package not installed and tomllib/tomli not available"
119
+ )
120
+ return "dev"
121
+
122
+ try:
123
+ # Find pyproject.toml relative to this file
124
+ project_root = Path(__file__).parents[2]
125
+ pyproject_path = project_root / "pyproject.toml"
126
+
127
+ if not pyproject_path.exists():
128
+ logger.warning(
129
+ "Cannot determine dao-ai version: pyproject.toml not found",
130
+ path=str(pyproject_path),
131
+ )
132
+ return "dev"
133
+
134
+ with open(pyproject_path, "rb") as f:
135
+ pyproject_data = tomllib.load(f)
136
+ pkg_version = pyproject_data.get("project", {}).get("version", "dev")
137
+ logger.trace(
138
+ "Read version from pyproject.toml",
139
+ version=pkg_version,
140
+ path=str(pyproject_path),
141
+ )
142
+ return pkg_version
143
+ except Exception as e:
144
+ logger.warning(
145
+ "Cannot determine dao-ai version from pyproject.toml", error=str(e)
146
+ )
147
+ return "dev"
148
+
149
+
40
150
  def get_installed_packages() -> dict[str, str]:
41
151
  """Get all installed packages with versions"""
42
152
 
@@ -45,16 +155,14 @@ def get_installed_packages() -> dict[str, str]:
45
155
  f"databricks-langchain=={version('databricks-langchain')}",
46
156
  f"databricks-mcp=={version('databricks-mcp')}",
47
157
  f"databricks-sdk[openai]=={version('databricks-sdk')}",
48
- f"duckduckgo-search=={version('duckduckgo-search')}",
158
+ f"ddgs=={version('ddgs')}",
159
+ f"flashrank=={version('flashrank')}",
49
160
  f"langchain=={version('langchain')}",
50
161
  f"langchain-mcp-adapters=={version('langchain-mcp-adapters')}",
51
162
  f"langchain-openai=={version('langchain-openai')}",
52
163
  f"langchain-tavily=={version('langchain-tavily')}",
53
164
  f"langgraph=={version('langgraph')}",
54
165
  f"langgraph-checkpoint-postgres=={version('langgraph-checkpoint-postgres')}",
55
- f"langgraph-prebuilt=={version('langgraph-prebuilt')}",
56
- f"langgraph-supervisor=={version('langgraph-supervisor')}",
57
- f"langgraph-swarm=={version('langgraph-swarm')}",
58
166
  f"langmem=={version('langmem')}",
59
167
  f"loguru=={version('loguru')}",
60
168
  f"mcp=={version('mcp')}",
@@ -65,6 +173,7 @@ def get_installed_packages() -> dict[str, str]:
65
173
  f"psycopg[binary,pool]=={version('psycopg')}",
66
174
  f"pydantic=={version('pydantic')}",
67
175
  f"pyyaml=={version('pyyaml')}",
176
+ f"tomli=={version('tomli')}",
68
177
  f"unitycatalog-ai[databricks]=={version('unitycatalog-ai')}",
69
178
  f"unitycatalog-langchain[databricks]=={version('unitycatalog-langchain')}",
70
179
  ]
@@ -85,18 +194,18 @@ def load_function(function_name: str) -> Callable[..., Any]:
85
194
  "module.submodule.function_name"
86
195
 
87
196
  Returns:
88
- The imported callable function
197
+ The imported callable function or langchain tool
89
198
 
90
199
  Raises:
91
200
  ImportError: If the module cannot be imported
92
201
  AttributeError: If the function doesn't exist in the module
93
- TypeError: If the resolved object is not callable
202
+ TypeError: If the resolved object is not callable or invocable
94
203
 
95
204
  Example:
96
205
  >>> func = callable_from_fqn("dao_ai.models.get_latest_model_version")
97
206
  >>> version = func("my_model")
98
207
  """
99
- logger.debug(f"Loading function: {function_name}")
208
+ logger.trace("Loading function", function_name=function_name)
100
209
 
101
210
  try:
102
211
  # Split the FQN into module path and function name
@@ -106,11 +215,16 @@ def load_function(function_name: str) -> Callable[..., Any]:
106
215
  module = importlib.import_module(module_path)
107
216
 
108
217
  # Get the function from the module
109
- func = getattr(module, func_name)
218
+ func: Any = getattr(module, func_name)
219
+
220
+ # Verify that the resolved object is callable or is a LangChain tool
221
+ # In langchain 1.x, StructuredTool objects are not directly callable
222
+ # but have an invoke() method
223
+ is_callable: bool = callable(func)
224
+ is_langchain_tool: bool = isinstance(func, BaseTool)
110
225
 
111
- # Verify that the resolved object is callable
112
- if not callable(func):
113
- raise TypeError(f"Function {func_name} is not callable.")
226
+ if not is_callable and not is_langchain_tool:
227
+ raise TypeError(f"Function {func_name} is not callable or invocable.")
114
228
 
115
229
  return func
116
230
  except (ImportError, AttributeError, TypeError) as e:
@@ -118,5 +232,93 @@ def load_function(function_name: str) -> Callable[..., Any]:
118
232
  raise ImportError(f"Failed to import {function_name}: {e}")
119
233
 
120
234
 
235
+ def type_from_fqn(type_name: str) -> type:
236
+ """
237
+ Load a type from a fully qualified name (FQN).
238
+
239
+ Dynamically imports and returns a type (class) from a module using its
240
+ fully qualified name. Useful for loading Pydantic models, dataclasses,
241
+ or any Python type specified as a string in configuration files.
242
+
243
+ Args:
244
+ type_name: Fully qualified type name in format "module.path.ClassName"
245
+
246
+ Returns:
247
+ The imported type/class
248
+
249
+ Raises:
250
+ ValueError: If the FQN format is invalid
251
+ ImportError: If the module cannot be imported
252
+ AttributeError: If the type doesn't exist in the module
253
+ TypeError: If the resolved object is not a type
254
+
255
+ Example:
256
+ >>> ProductModel = type_from_fqn("my_models.ProductInfo")
257
+ >>> instance = ProductModel(name="Widget", price=9.99)
258
+ """
259
+ logger.trace("Loading type", type_name=type_name)
260
+
261
+ try:
262
+ # Split the FQN into module path and class name
263
+ parts = type_name.rsplit(".", 1)
264
+ if len(parts) != 2:
265
+ raise ValueError(
266
+ f"Invalid type name '{type_name}'. "
267
+ "Expected format: 'module.path.ClassName'"
268
+ )
269
+
270
+ module_path, class_name = parts
271
+
272
+ # Dynamically import the module
273
+ try:
274
+ module = importlib.import_module(module_path)
275
+ except ModuleNotFoundError as e:
276
+ raise ImportError(
277
+ f"Could not import module '{module_path}' for type '{type_name}': {e}"
278
+ ) from e
279
+
280
+ # Get the class from the module
281
+ if not hasattr(module, class_name):
282
+ raise AttributeError(
283
+ f"Module '{module_path}' does not have attribute '{class_name}'"
284
+ )
285
+
286
+ resolved_type = getattr(module, class_name)
287
+
288
+ # Verify it's actually a type
289
+ if not isinstance(resolved_type, type):
290
+ raise TypeError(
291
+ f"'{type_name}' resolved to {resolved_type}, which is not a type"
292
+ )
293
+
294
+ return resolved_type
295
+
296
+ except (ValueError, ImportError, AttributeError, TypeError) as e:
297
+ # Provide a detailed error message that includes the original exception
298
+ raise type(e)(f"Failed to load type '{type_name}': {e}") from e
299
+
300
+
121
301
  def is_in_model_serving() -> bool:
122
- return os.environ.get("IS_IN_DB_MODEL_SERVING_ENV", "false").lower() == "true"
302
+ """Check if running in Databricks Model Serving environment.
303
+
304
+ Detects Model Serving by checking for environment variables that are
305
+ typically set in that environment.
306
+ """
307
+ # Primary check - explicit Databricks Model Serving env var
308
+ if os.environ.get("IS_IN_DB_MODEL_SERVING_ENV", "false").lower() == "true":
309
+ return True
310
+
311
+ # Secondary check - Model Serving sets these environment variables
312
+ if os.environ.get("DATABRICKS_MODEL_SERVING_ENV"):
313
+ return True
314
+
315
+ # Check for cluster type indicator
316
+ cluster_type = os.environ.get("DATABRICKS_CLUSTER_TYPE", "")
317
+ if "model-serving" in cluster_type.lower():
318
+ return True
319
+
320
+ # Check for model serving specific paths
321
+ if os.path.exists("/opt/conda/envs/mlflow-env"):
322
+ return True
323
+
324
+ return False