dao-ai 0.0.25__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/agent_as_code.py +5 -5
- dao_ai/cli.py +245 -40
- dao_ai/config.py +1863 -338
- dao_ai/genie/__init__.py +38 -0
- dao_ai/genie/cache/__init__.py +43 -0
- dao_ai/genie/cache/base.py +72 -0
- dao_ai/genie/cache/core.py +79 -0
- dao_ai/genie/cache/lru.py +347 -0
- dao_ai/genie/cache/semantic.py +970 -0
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -228
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +27 -195
- dao_ai/logging.py +56 -0
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +65 -30
- dao_ai/memory/databricks.py +402 -0
- dao_ai/memory/postgres.py +79 -38
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +125 -0
- dao_ai/middleware/assertions.py +806 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/core.py +67 -0
- dao_ai/middleware/guardrails.py +420 -0
- dao_ai/middleware/human_in_the_loop.py +232 -0
- dao_ai/middleware/message_validation.py +586 -0
- dao_ai/middleware/summarization.py +197 -0
- dao_ai/models.py +1306 -114
- dao_ai/nodes.py +261 -166
- dao_ai/optimization.py +674 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +294 -0
- dao_ai/orchestration/supervisor.py +278 -0
- dao_ai/orchestration/swarm.py +271 -0
- dao_ai/prompts.py +128 -31
- dao_ai/providers/databricks.py +645 -172
- dao_ai/state.py +157 -21
- dao_ai/tools/__init__.py +13 -5
- dao_ai/tools/agent.py +1 -3
- dao_ai/tools/core.py +64 -11
- dao_ai/tools/email.py +232 -0
- dao_ai/tools/genie.py +144 -295
- dao_ai/tools/mcp.py +220 -133
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +9 -14
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +22 -10
- dao_ai/tools/sql.py +202 -0
- dao_ai/tools/time.py +30 -7
- dao_ai/tools/unity_catalog.py +165 -88
- dao_ai/tools/vector_search.py +360 -40
- dao_ai/utils.py +218 -16
- dao_ai-0.1.2.dist-info/METADATA +455 -0
- dao_ai-0.1.2.dist-info/RECORD +64 -0
- {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +1 -1
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.25.dist-info/METADATA +0 -1165
- dao_ai-0.0.25.dist-info/RECORD +0 -41
- {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
dao_ai/utils.py
CHANGED
|
@@ -3,9 +3,11 @@ import importlib.metadata
|
|
|
3
3
|
import os
|
|
4
4
|
import re
|
|
5
5
|
import site
|
|
6
|
-
from importlib.metadata import version
|
|
6
|
+
from importlib.metadata import PackageNotFoundError, version
|
|
7
|
+
from pathlib import Path
|
|
7
8
|
from typing import Any, Callable, Sequence
|
|
8
9
|
|
|
10
|
+
from langchain_core.tools import BaseTool
|
|
9
11
|
from loguru import logger
|
|
10
12
|
|
|
11
13
|
import dao_ai
|
|
@@ -18,15 +20,15 @@ def is_lib_provided(lib_name: str, pip_requirements: Sequence[str]) -> bool:
|
|
|
18
20
|
)
|
|
19
21
|
|
|
20
22
|
|
|
21
|
-
def is_installed():
|
|
23
|
+
def is_installed() -> bool:
|
|
22
24
|
current_file = os.path.abspath(dao_ai.__file__)
|
|
23
25
|
site_packages = [os.path.abspath(path) for path in site.getsitepackages()]
|
|
24
26
|
if site.getusersitepackages():
|
|
25
27
|
site_packages.append(os.path.abspath(site.getusersitepackages()))
|
|
26
28
|
|
|
27
29
|
found: bool = any(current_file.startswith(pkg_path) for pkg_path in site_packages)
|
|
28
|
-
logger.
|
|
29
|
-
|
|
30
|
+
logger.trace(
|
|
31
|
+
"Checking if dao_ai is installed", is_installed=found, current_file=current_file
|
|
30
32
|
)
|
|
31
33
|
return found
|
|
32
34
|
|
|
@@ -37,6 +39,114 @@ def normalize_name(name: str) -> str:
|
|
|
37
39
|
return normalized.strip("_")
|
|
38
40
|
|
|
39
41
|
|
|
42
|
+
def normalize_host(host: str | None) -> str | None:
|
|
43
|
+
"""Ensure host URL has https:// scheme.
|
|
44
|
+
|
|
45
|
+
The DATABRICKS_HOST environment variable should always include the https://
|
|
46
|
+
scheme, but some environments (e.g., Databricks Apps infrastructure) may
|
|
47
|
+
provide the host without it. This function normalizes the host to ensure
|
|
48
|
+
it has the proper scheme.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
host: The host URL, with or without scheme
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
The host URL with https:// scheme, or None if host is None/empty
|
|
55
|
+
"""
|
|
56
|
+
if not host:
|
|
57
|
+
return None
|
|
58
|
+
host = host.strip()
|
|
59
|
+
if not host:
|
|
60
|
+
return None
|
|
61
|
+
if not host.startswith("http://") and not host.startswith("https://"):
|
|
62
|
+
return f"https://{host}"
|
|
63
|
+
return host
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def get_default_databricks_host() -> str | None:
|
|
67
|
+
"""Get the default Databricks workspace host.
|
|
68
|
+
|
|
69
|
+
Attempts to get the host from:
|
|
70
|
+
1. DATABRICKS_HOST environment variable
|
|
71
|
+
2. WorkspaceClient ambient authentication (e.g., from ~/.databrickscfg)
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
The Databricks workspace host URL (with https:// scheme), or None if not available.
|
|
75
|
+
"""
|
|
76
|
+
# Try environment variable first
|
|
77
|
+
host: str | None = os.environ.get("DATABRICKS_HOST")
|
|
78
|
+
if host:
|
|
79
|
+
return normalize_host(host)
|
|
80
|
+
|
|
81
|
+
# Fall back to WorkspaceClient
|
|
82
|
+
try:
|
|
83
|
+
from databricks.sdk import WorkspaceClient
|
|
84
|
+
|
|
85
|
+
w: WorkspaceClient = WorkspaceClient()
|
|
86
|
+
return normalize_host(w.config.host)
|
|
87
|
+
except Exception:
|
|
88
|
+
logger.trace("Could not get default Databricks host from WorkspaceClient")
|
|
89
|
+
return None
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def dao_ai_version() -> str:
|
|
93
|
+
"""
|
|
94
|
+
Get the dao-ai package version, with fallback for source installations.
|
|
95
|
+
|
|
96
|
+
Tries to get the version from installed package metadata first. If the package
|
|
97
|
+
is not installed (e.g., running from source), falls back to reading from
|
|
98
|
+
pyproject.toml. Returns "dev" if neither method works.
|
|
99
|
+
|
|
100
|
+
Returns:
|
|
101
|
+
str: The version string, or "dev" if version cannot be determined
|
|
102
|
+
"""
|
|
103
|
+
try:
|
|
104
|
+
# Try to get version from installed package metadata
|
|
105
|
+
return version("dao-ai")
|
|
106
|
+
except PackageNotFoundError:
|
|
107
|
+
# Package not installed, try reading from pyproject.toml
|
|
108
|
+
logger.trace(
|
|
109
|
+
"dao-ai package not installed, attempting to read version from pyproject.toml"
|
|
110
|
+
)
|
|
111
|
+
try:
|
|
112
|
+
import tomllib # Python 3.11+
|
|
113
|
+
except ImportError:
|
|
114
|
+
try:
|
|
115
|
+
import tomli as tomllib # Fallback for Python < 3.11
|
|
116
|
+
except ImportError:
|
|
117
|
+
logger.warning(
|
|
118
|
+
"Cannot determine dao-ai version: package not installed and tomllib/tomli not available"
|
|
119
|
+
)
|
|
120
|
+
return "dev"
|
|
121
|
+
|
|
122
|
+
try:
|
|
123
|
+
# Find pyproject.toml relative to this file
|
|
124
|
+
project_root = Path(__file__).parents[2]
|
|
125
|
+
pyproject_path = project_root / "pyproject.toml"
|
|
126
|
+
|
|
127
|
+
if not pyproject_path.exists():
|
|
128
|
+
logger.warning(
|
|
129
|
+
"Cannot determine dao-ai version: pyproject.toml not found",
|
|
130
|
+
path=str(pyproject_path),
|
|
131
|
+
)
|
|
132
|
+
return "dev"
|
|
133
|
+
|
|
134
|
+
with open(pyproject_path, "rb") as f:
|
|
135
|
+
pyproject_data = tomllib.load(f)
|
|
136
|
+
pkg_version = pyproject_data.get("project", {}).get("version", "dev")
|
|
137
|
+
logger.trace(
|
|
138
|
+
"Read version from pyproject.toml",
|
|
139
|
+
version=pkg_version,
|
|
140
|
+
path=str(pyproject_path),
|
|
141
|
+
)
|
|
142
|
+
return pkg_version
|
|
143
|
+
except Exception as e:
|
|
144
|
+
logger.warning(
|
|
145
|
+
"Cannot determine dao-ai version from pyproject.toml", error=str(e)
|
|
146
|
+
)
|
|
147
|
+
return "dev"
|
|
148
|
+
|
|
149
|
+
|
|
40
150
|
def get_installed_packages() -> dict[str, str]:
|
|
41
151
|
"""Get all installed packages with versions"""
|
|
42
152
|
|
|
@@ -45,16 +155,14 @@ def get_installed_packages() -> dict[str, str]:
|
|
|
45
155
|
f"databricks-langchain=={version('databricks-langchain')}",
|
|
46
156
|
f"databricks-mcp=={version('databricks-mcp')}",
|
|
47
157
|
f"databricks-sdk[openai]=={version('databricks-sdk')}",
|
|
48
|
-
f"
|
|
158
|
+
f"ddgs=={version('ddgs')}",
|
|
159
|
+
f"flashrank=={version('flashrank')}",
|
|
49
160
|
f"langchain=={version('langchain')}",
|
|
50
161
|
f"langchain-mcp-adapters=={version('langchain-mcp-adapters')}",
|
|
51
162
|
f"langchain-openai=={version('langchain-openai')}",
|
|
52
163
|
f"langchain-tavily=={version('langchain-tavily')}",
|
|
53
164
|
f"langgraph=={version('langgraph')}",
|
|
54
165
|
f"langgraph-checkpoint-postgres=={version('langgraph-checkpoint-postgres')}",
|
|
55
|
-
f"langgraph-prebuilt=={version('langgraph-prebuilt')}",
|
|
56
|
-
f"langgraph-supervisor=={version('langgraph-supervisor')}",
|
|
57
|
-
f"langgraph-swarm=={version('langgraph-swarm')}",
|
|
58
166
|
f"langmem=={version('langmem')}",
|
|
59
167
|
f"loguru=={version('loguru')}",
|
|
60
168
|
f"mcp=={version('mcp')}",
|
|
@@ -65,6 +173,7 @@ def get_installed_packages() -> dict[str, str]:
|
|
|
65
173
|
f"psycopg[binary,pool]=={version('psycopg')}",
|
|
66
174
|
f"pydantic=={version('pydantic')}",
|
|
67
175
|
f"pyyaml=={version('pyyaml')}",
|
|
176
|
+
f"tomli=={version('tomli')}",
|
|
68
177
|
f"unitycatalog-ai[databricks]=={version('unitycatalog-ai')}",
|
|
69
178
|
f"unitycatalog-langchain[databricks]=={version('unitycatalog-langchain')}",
|
|
70
179
|
]
|
|
@@ -85,18 +194,18 @@ def load_function(function_name: str) -> Callable[..., Any]:
|
|
|
85
194
|
"module.submodule.function_name"
|
|
86
195
|
|
|
87
196
|
Returns:
|
|
88
|
-
The imported callable function
|
|
197
|
+
The imported callable function or langchain tool
|
|
89
198
|
|
|
90
199
|
Raises:
|
|
91
200
|
ImportError: If the module cannot be imported
|
|
92
201
|
AttributeError: If the function doesn't exist in the module
|
|
93
|
-
TypeError: If the resolved object is not callable
|
|
202
|
+
TypeError: If the resolved object is not callable or invocable
|
|
94
203
|
|
|
95
204
|
Example:
|
|
96
205
|
>>> func = callable_from_fqn("dao_ai.models.get_latest_model_version")
|
|
97
206
|
>>> version = func("my_model")
|
|
98
207
|
"""
|
|
99
|
-
logger.
|
|
208
|
+
logger.trace("Loading function", function_name=function_name)
|
|
100
209
|
|
|
101
210
|
try:
|
|
102
211
|
# Split the FQN into module path and function name
|
|
@@ -106,11 +215,16 @@ def load_function(function_name: str) -> Callable[..., Any]:
|
|
|
106
215
|
module = importlib.import_module(module_path)
|
|
107
216
|
|
|
108
217
|
# Get the function from the module
|
|
109
|
-
func = getattr(module, func_name)
|
|
218
|
+
func: Any = getattr(module, func_name)
|
|
219
|
+
|
|
220
|
+
# Verify that the resolved object is callable or is a LangChain tool
|
|
221
|
+
# In langchain 1.x, StructuredTool objects are not directly callable
|
|
222
|
+
# but have an invoke() method
|
|
223
|
+
is_callable: bool = callable(func)
|
|
224
|
+
is_langchain_tool: bool = isinstance(func, BaseTool)
|
|
110
225
|
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
raise TypeError(f"Function {func_name} is not callable.")
|
|
226
|
+
if not is_callable and not is_langchain_tool:
|
|
227
|
+
raise TypeError(f"Function {func_name} is not callable or invocable.")
|
|
114
228
|
|
|
115
229
|
return func
|
|
116
230
|
except (ImportError, AttributeError, TypeError) as e:
|
|
@@ -118,5 +232,93 @@ def load_function(function_name: str) -> Callable[..., Any]:
|
|
|
118
232
|
raise ImportError(f"Failed to import {function_name}: {e}")
|
|
119
233
|
|
|
120
234
|
|
|
235
|
+
def type_from_fqn(type_name: str) -> type:
|
|
236
|
+
"""
|
|
237
|
+
Load a type from a fully qualified name (FQN).
|
|
238
|
+
|
|
239
|
+
Dynamically imports and returns a type (class) from a module using its
|
|
240
|
+
fully qualified name. Useful for loading Pydantic models, dataclasses,
|
|
241
|
+
or any Python type specified as a string in configuration files.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
type_name: Fully qualified type name in format "module.path.ClassName"
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
The imported type/class
|
|
248
|
+
|
|
249
|
+
Raises:
|
|
250
|
+
ValueError: If the FQN format is invalid
|
|
251
|
+
ImportError: If the module cannot be imported
|
|
252
|
+
AttributeError: If the type doesn't exist in the module
|
|
253
|
+
TypeError: If the resolved object is not a type
|
|
254
|
+
|
|
255
|
+
Example:
|
|
256
|
+
>>> ProductModel = type_from_fqn("my_models.ProductInfo")
|
|
257
|
+
>>> instance = ProductModel(name="Widget", price=9.99)
|
|
258
|
+
"""
|
|
259
|
+
logger.trace("Loading type", type_name=type_name)
|
|
260
|
+
|
|
261
|
+
try:
|
|
262
|
+
# Split the FQN into module path and class name
|
|
263
|
+
parts = type_name.rsplit(".", 1)
|
|
264
|
+
if len(parts) != 2:
|
|
265
|
+
raise ValueError(
|
|
266
|
+
f"Invalid type name '{type_name}'. "
|
|
267
|
+
"Expected format: 'module.path.ClassName'"
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
module_path, class_name = parts
|
|
271
|
+
|
|
272
|
+
# Dynamically import the module
|
|
273
|
+
try:
|
|
274
|
+
module = importlib.import_module(module_path)
|
|
275
|
+
except ModuleNotFoundError as e:
|
|
276
|
+
raise ImportError(
|
|
277
|
+
f"Could not import module '{module_path}' for type '{type_name}': {e}"
|
|
278
|
+
) from e
|
|
279
|
+
|
|
280
|
+
# Get the class from the module
|
|
281
|
+
if not hasattr(module, class_name):
|
|
282
|
+
raise AttributeError(
|
|
283
|
+
f"Module '{module_path}' does not have attribute '{class_name}'"
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
resolved_type = getattr(module, class_name)
|
|
287
|
+
|
|
288
|
+
# Verify it's actually a type
|
|
289
|
+
if not isinstance(resolved_type, type):
|
|
290
|
+
raise TypeError(
|
|
291
|
+
f"'{type_name}' resolved to {resolved_type}, which is not a type"
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
return resolved_type
|
|
295
|
+
|
|
296
|
+
except (ValueError, ImportError, AttributeError, TypeError) as e:
|
|
297
|
+
# Provide a detailed error message that includes the original exception
|
|
298
|
+
raise type(e)(f"Failed to load type '{type_name}': {e}") from e
|
|
299
|
+
|
|
300
|
+
|
|
121
301
|
def is_in_model_serving() -> bool:
|
|
122
|
-
|
|
302
|
+
"""Check if running in Databricks Model Serving environment.
|
|
303
|
+
|
|
304
|
+
Detects Model Serving by checking for environment variables that are
|
|
305
|
+
typically set in that environment.
|
|
306
|
+
"""
|
|
307
|
+
# Primary check - explicit Databricks Model Serving env var
|
|
308
|
+
if os.environ.get("IS_IN_DB_MODEL_SERVING_ENV", "false").lower() == "true":
|
|
309
|
+
return True
|
|
310
|
+
|
|
311
|
+
# Secondary check - Model Serving sets these environment variables
|
|
312
|
+
if os.environ.get("DATABRICKS_MODEL_SERVING_ENV"):
|
|
313
|
+
return True
|
|
314
|
+
|
|
315
|
+
# Check for cluster type indicator
|
|
316
|
+
cluster_type = os.environ.get("DATABRICKS_CLUSTER_TYPE", "")
|
|
317
|
+
if "model-serving" in cluster_type.lower():
|
|
318
|
+
return True
|
|
319
|
+
|
|
320
|
+
# Check for model serving specific paths
|
|
321
|
+
if os.path.exists("/opt/conda/envs/mlflow-env"):
|
|
322
|
+
return True
|
|
323
|
+
|
|
324
|
+
return False
|