dao-ai 0.0.28__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/agent_as_code.py +2 -5
- dao_ai/cli.py +245 -40
- dao_ai/config.py +1491 -370
- dao_ai/genie/__init__.py +38 -0
- dao_ai/genie/cache/__init__.py +43 -0
- dao_ai/genie/cache/base.py +72 -0
- dao_ai/genie/cache/core.py +79 -0
- dao_ai/genie/cache/lru.py +347 -0
- dao_ai/genie/cache/semantic.py +970 -0
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -253
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +27 -195
- dao_ai/logging.py +56 -0
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +65 -30
- dao_ai/memory/databricks.py +402 -0
- dao_ai/memory/postgres.py +79 -38
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +125 -0
- dao_ai/middleware/assertions.py +806 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/core.py +67 -0
- dao_ai/middleware/guardrails.py +420 -0
- dao_ai/middleware/human_in_the_loop.py +232 -0
- dao_ai/middleware/message_validation.py +586 -0
- dao_ai/middleware/summarization.py +197 -0
- dao_ai/models.py +1306 -114
- dao_ai/nodes.py +245 -159
- dao_ai/optimization.py +674 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +294 -0
- dao_ai/orchestration/supervisor.py +278 -0
- dao_ai/orchestration/swarm.py +271 -0
- dao_ai/prompts.py +128 -31
- dao_ai/providers/databricks.py +573 -601
- dao_ai/state.py +157 -21
- dao_ai/tools/__init__.py +13 -5
- dao_ai/tools/agent.py +1 -3
- dao_ai/tools/core.py +64 -11
- dao_ai/tools/email.py +232 -0
- dao_ai/tools/genie.py +144 -294
- dao_ai/tools/mcp.py +223 -155
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +9 -14
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +22 -10
- dao_ai/tools/sql.py +202 -0
- dao_ai/tools/time.py +30 -7
- dao_ai/tools/unity_catalog.py +165 -88
- dao_ai/tools/vector_search.py +331 -221
- dao_ai/utils.py +166 -20
- dao_ai-0.1.2.dist-info/METADATA +455 -0
- dao_ai-0.1.2.dist-info/RECORD +64 -0
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.28.dist-info/METADATA +0 -1168
- dao_ai-0.0.28.dist-info/RECORD +0 -41
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
dao_ai/utils.py
CHANGED
|
@@ -7,6 +7,7 @@ from importlib.metadata import PackageNotFoundError, version
|
|
|
7
7
|
from pathlib import Path
|
|
8
8
|
from typing import Any, Callable, Sequence
|
|
9
9
|
|
|
10
|
+
from langchain_core.tools import BaseTool
|
|
10
11
|
from loguru import logger
|
|
11
12
|
|
|
12
13
|
import dao_ai
|
|
@@ -19,15 +20,15 @@ def is_lib_provided(lib_name: str, pip_requirements: Sequence[str]) -> bool:
|
|
|
19
20
|
)
|
|
20
21
|
|
|
21
22
|
|
|
22
|
-
def is_installed():
|
|
23
|
+
def is_installed() -> bool:
|
|
23
24
|
current_file = os.path.abspath(dao_ai.__file__)
|
|
24
25
|
site_packages = [os.path.abspath(path) for path in site.getsitepackages()]
|
|
25
26
|
if site.getusersitepackages():
|
|
26
27
|
site_packages.append(os.path.abspath(site.getusersitepackages()))
|
|
27
28
|
|
|
28
29
|
found: bool = any(current_file.startswith(pkg_path) for pkg_path in site_packages)
|
|
29
|
-
logger.
|
|
30
|
-
|
|
30
|
+
logger.trace(
|
|
31
|
+
"Checking if dao_ai is installed", is_installed=found, current_file=current_file
|
|
31
32
|
)
|
|
32
33
|
return found
|
|
33
34
|
|
|
@@ -38,6 +39,56 @@ def normalize_name(name: str) -> str:
|
|
|
38
39
|
return normalized.strip("_")
|
|
39
40
|
|
|
40
41
|
|
|
42
|
+
def normalize_host(host: str | None) -> str | None:
|
|
43
|
+
"""Ensure host URL has https:// scheme.
|
|
44
|
+
|
|
45
|
+
The DATABRICKS_HOST environment variable should always include the https://
|
|
46
|
+
scheme, but some environments (e.g., Databricks Apps infrastructure) may
|
|
47
|
+
provide the host without it. This function normalizes the host to ensure
|
|
48
|
+
it has the proper scheme.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
host: The host URL, with or without scheme
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
The host URL with https:// scheme, or None if host is None/empty
|
|
55
|
+
"""
|
|
56
|
+
if not host:
|
|
57
|
+
return None
|
|
58
|
+
host = host.strip()
|
|
59
|
+
if not host:
|
|
60
|
+
return None
|
|
61
|
+
if not host.startswith("http://") and not host.startswith("https://"):
|
|
62
|
+
return f"https://{host}"
|
|
63
|
+
return host
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def get_default_databricks_host() -> str | None:
|
|
67
|
+
"""Get the default Databricks workspace host.
|
|
68
|
+
|
|
69
|
+
Attempts to get the host from:
|
|
70
|
+
1. DATABRICKS_HOST environment variable
|
|
71
|
+
2. WorkspaceClient ambient authentication (e.g., from ~/.databrickscfg)
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
The Databricks workspace host URL (with https:// scheme), or None if not available.
|
|
75
|
+
"""
|
|
76
|
+
# Try environment variable first
|
|
77
|
+
host: str | None = os.environ.get("DATABRICKS_HOST")
|
|
78
|
+
if host:
|
|
79
|
+
return normalize_host(host)
|
|
80
|
+
|
|
81
|
+
# Fall back to WorkspaceClient
|
|
82
|
+
try:
|
|
83
|
+
from databricks.sdk import WorkspaceClient
|
|
84
|
+
|
|
85
|
+
w: WorkspaceClient = WorkspaceClient()
|
|
86
|
+
return normalize_host(w.config.host)
|
|
87
|
+
except Exception:
|
|
88
|
+
logger.trace("Could not get default Databricks host from WorkspaceClient")
|
|
89
|
+
return None
|
|
90
|
+
|
|
91
|
+
|
|
41
92
|
def dao_ai_version() -> str:
|
|
42
93
|
"""
|
|
43
94
|
Get the dao-ai package version, with fallback for source installations.
|
|
@@ -54,7 +105,7 @@ def dao_ai_version() -> str:
|
|
|
54
105
|
return version("dao-ai")
|
|
55
106
|
except PackageNotFoundError:
|
|
56
107
|
# Package not installed, try reading from pyproject.toml
|
|
57
|
-
logger.
|
|
108
|
+
logger.trace(
|
|
58
109
|
"dao-ai package not installed, attempting to read version from pyproject.toml"
|
|
59
110
|
)
|
|
60
111
|
try:
|
|
@@ -75,19 +126,24 @@ def dao_ai_version() -> str:
|
|
|
75
126
|
|
|
76
127
|
if not pyproject_path.exists():
|
|
77
128
|
logger.warning(
|
|
78
|
-
|
|
129
|
+
"Cannot determine dao-ai version: pyproject.toml not found",
|
|
130
|
+
path=str(pyproject_path),
|
|
79
131
|
)
|
|
80
132
|
return "dev"
|
|
81
133
|
|
|
82
134
|
with open(pyproject_path, "rb") as f:
|
|
83
135
|
pyproject_data = tomllib.load(f)
|
|
84
136
|
pkg_version = pyproject_data.get("project", {}).get("version", "dev")
|
|
85
|
-
logger.
|
|
86
|
-
|
|
137
|
+
logger.trace(
|
|
138
|
+
"Read version from pyproject.toml",
|
|
139
|
+
version=pkg_version,
|
|
140
|
+
path=str(pyproject_path),
|
|
87
141
|
)
|
|
88
142
|
return pkg_version
|
|
89
143
|
except Exception as e:
|
|
90
|
-
logger.warning(
|
|
144
|
+
logger.warning(
|
|
145
|
+
"Cannot determine dao-ai version from pyproject.toml", error=str(e)
|
|
146
|
+
)
|
|
91
147
|
return "dev"
|
|
92
148
|
|
|
93
149
|
|
|
@@ -99,7 +155,7 @@ def get_installed_packages() -> dict[str, str]:
|
|
|
99
155
|
f"databricks-langchain=={version('databricks-langchain')}",
|
|
100
156
|
f"databricks-mcp=={version('databricks-mcp')}",
|
|
101
157
|
f"databricks-sdk[openai]=={version('databricks-sdk')}",
|
|
102
|
-
f"
|
|
158
|
+
f"ddgs=={version('ddgs')}",
|
|
103
159
|
f"flashrank=={version('flashrank')}",
|
|
104
160
|
f"langchain=={version('langchain')}",
|
|
105
161
|
f"langchain-mcp-adapters=={version('langchain-mcp-adapters')}",
|
|
@@ -107,9 +163,6 @@ def get_installed_packages() -> dict[str, str]:
|
|
|
107
163
|
f"langchain-tavily=={version('langchain-tavily')}",
|
|
108
164
|
f"langgraph=={version('langgraph')}",
|
|
109
165
|
f"langgraph-checkpoint-postgres=={version('langgraph-checkpoint-postgres')}",
|
|
110
|
-
f"langgraph-prebuilt=={version('langgraph-prebuilt')}",
|
|
111
|
-
f"langgraph-supervisor=={version('langgraph-supervisor')}",
|
|
112
|
-
f"langgraph-swarm=={version('langgraph-swarm')}",
|
|
113
166
|
f"langmem=={version('langmem')}",
|
|
114
167
|
f"loguru=={version('loguru')}",
|
|
115
168
|
f"mcp=={version('mcp')}",
|
|
@@ -141,18 +194,18 @@ def load_function(function_name: str) -> Callable[..., Any]:
|
|
|
141
194
|
"module.submodule.function_name"
|
|
142
195
|
|
|
143
196
|
Returns:
|
|
144
|
-
The imported callable function
|
|
197
|
+
The imported callable function or langchain tool
|
|
145
198
|
|
|
146
199
|
Raises:
|
|
147
200
|
ImportError: If the module cannot be imported
|
|
148
201
|
AttributeError: If the function doesn't exist in the module
|
|
149
|
-
TypeError: If the resolved object is not callable
|
|
202
|
+
TypeError: If the resolved object is not callable or invocable
|
|
150
203
|
|
|
151
204
|
Example:
|
|
152
205
|
>>> func = callable_from_fqn("dao_ai.models.get_latest_model_version")
|
|
153
206
|
>>> version = func("my_model")
|
|
154
207
|
"""
|
|
155
|
-
logger.
|
|
208
|
+
logger.trace("Loading function", function_name=function_name)
|
|
156
209
|
|
|
157
210
|
try:
|
|
158
211
|
# Split the FQN into module path and function name
|
|
@@ -162,11 +215,16 @@ def load_function(function_name: str) -> Callable[..., Any]:
|
|
|
162
215
|
module = importlib.import_module(module_path)
|
|
163
216
|
|
|
164
217
|
# Get the function from the module
|
|
165
|
-
func = getattr(module, func_name)
|
|
218
|
+
func: Any = getattr(module, func_name)
|
|
219
|
+
|
|
220
|
+
# Verify that the resolved object is callable or is a LangChain tool
|
|
221
|
+
# In langchain 1.x, StructuredTool objects are not directly callable
|
|
222
|
+
# but have an invoke() method
|
|
223
|
+
is_callable: bool = callable(func)
|
|
224
|
+
is_langchain_tool: bool = isinstance(func, BaseTool)
|
|
166
225
|
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
raise TypeError(f"Function {func_name} is not callable.")
|
|
226
|
+
if not is_callable and not is_langchain_tool:
|
|
227
|
+
raise TypeError(f"Function {func_name} is not callable or invocable.")
|
|
170
228
|
|
|
171
229
|
return func
|
|
172
230
|
except (ImportError, AttributeError, TypeError) as e:
|
|
@@ -174,5 +232,93 @@ def load_function(function_name: str) -> Callable[..., Any]:
|
|
|
174
232
|
raise ImportError(f"Failed to import {function_name}: {e}")
|
|
175
233
|
|
|
176
234
|
|
|
235
|
+
def type_from_fqn(type_name: str) -> type:
|
|
236
|
+
"""
|
|
237
|
+
Load a type from a fully qualified name (FQN).
|
|
238
|
+
|
|
239
|
+
Dynamically imports and returns a type (class) from a module using its
|
|
240
|
+
fully qualified name. Useful for loading Pydantic models, dataclasses,
|
|
241
|
+
or any Python type specified as a string in configuration files.
|
|
242
|
+
|
|
243
|
+
Args:
|
|
244
|
+
type_name: Fully qualified type name in format "module.path.ClassName"
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
The imported type/class
|
|
248
|
+
|
|
249
|
+
Raises:
|
|
250
|
+
ValueError: If the FQN format is invalid
|
|
251
|
+
ImportError: If the module cannot be imported
|
|
252
|
+
AttributeError: If the type doesn't exist in the module
|
|
253
|
+
TypeError: If the resolved object is not a type
|
|
254
|
+
|
|
255
|
+
Example:
|
|
256
|
+
>>> ProductModel = type_from_fqn("my_models.ProductInfo")
|
|
257
|
+
>>> instance = ProductModel(name="Widget", price=9.99)
|
|
258
|
+
"""
|
|
259
|
+
logger.trace("Loading type", type_name=type_name)
|
|
260
|
+
|
|
261
|
+
try:
|
|
262
|
+
# Split the FQN into module path and class name
|
|
263
|
+
parts = type_name.rsplit(".", 1)
|
|
264
|
+
if len(parts) != 2:
|
|
265
|
+
raise ValueError(
|
|
266
|
+
f"Invalid type name '{type_name}'. "
|
|
267
|
+
"Expected format: 'module.path.ClassName'"
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
module_path, class_name = parts
|
|
271
|
+
|
|
272
|
+
# Dynamically import the module
|
|
273
|
+
try:
|
|
274
|
+
module = importlib.import_module(module_path)
|
|
275
|
+
except ModuleNotFoundError as e:
|
|
276
|
+
raise ImportError(
|
|
277
|
+
f"Could not import module '{module_path}' for type '{type_name}': {e}"
|
|
278
|
+
) from e
|
|
279
|
+
|
|
280
|
+
# Get the class from the module
|
|
281
|
+
if not hasattr(module, class_name):
|
|
282
|
+
raise AttributeError(
|
|
283
|
+
f"Module '{module_path}' does not have attribute '{class_name}'"
|
|
284
|
+
)
|
|
285
|
+
|
|
286
|
+
resolved_type = getattr(module, class_name)
|
|
287
|
+
|
|
288
|
+
# Verify it's actually a type
|
|
289
|
+
if not isinstance(resolved_type, type):
|
|
290
|
+
raise TypeError(
|
|
291
|
+
f"'{type_name}' resolved to {resolved_type}, which is not a type"
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
return resolved_type
|
|
295
|
+
|
|
296
|
+
except (ValueError, ImportError, AttributeError, TypeError) as e:
|
|
297
|
+
# Provide a detailed error message that includes the original exception
|
|
298
|
+
raise type(e)(f"Failed to load type '{type_name}': {e}") from e
|
|
299
|
+
|
|
300
|
+
|
|
177
301
|
def is_in_model_serving() -> bool:
|
|
178
|
-
|
|
302
|
+
"""Check if running in Databricks Model Serving environment.
|
|
303
|
+
|
|
304
|
+
Detects Model Serving by checking for environment variables that are
|
|
305
|
+
typically set in that environment.
|
|
306
|
+
"""
|
|
307
|
+
# Primary check - explicit Databricks Model Serving env var
|
|
308
|
+
if os.environ.get("IS_IN_DB_MODEL_SERVING_ENV", "false").lower() == "true":
|
|
309
|
+
return True
|
|
310
|
+
|
|
311
|
+
# Secondary check - Model Serving sets these environment variables
|
|
312
|
+
if os.environ.get("DATABRICKS_MODEL_SERVING_ENV"):
|
|
313
|
+
return True
|
|
314
|
+
|
|
315
|
+
# Check for cluster type indicator
|
|
316
|
+
cluster_type = os.environ.get("DATABRICKS_CLUSTER_TYPE", "")
|
|
317
|
+
if "model-serving" in cluster_type.lower():
|
|
318
|
+
return True
|
|
319
|
+
|
|
320
|
+
# Check for model serving specific paths
|
|
321
|
+
if os.path.exists("/opt/conda/envs/mlflow-env"):
|
|
322
|
+
return True
|
|
323
|
+
|
|
324
|
+
return False
|