dao-ai 0.0.24__py3-none-any.whl → 0.0.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/cli.py +77 -18
- dao_ai/config.py +226 -8
- dao_ai/graph.py +12 -2
- dao_ai/nodes.py +29 -20
- dao_ai/providers/databricks.py +545 -34
- dao_ai/tools/mcp.py +41 -13
- dao_ai/utils.py +56 -1
- {dao_ai-0.0.24.dist-info → dao_ai-0.0.26.dist-info}/METADATA +4 -2
- {dao_ai-0.0.24.dist-info → dao_ai-0.0.26.dist-info}/RECORD +12 -12
- {dao_ai-0.0.24.dist-info → dao_ai-0.0.26.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.24.dist-info → dao_ai-0.0.26.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.24.dist-info → dao_ai-0.0.26.dist-info}/licenses/LICENSE +0 -0
dao_ai/cli.py
CHANGED
|
@@ -460,6 +460,49 @@ def setup_logging(verbosity: int) -> None:
|
|
|
460
460
|
logger.add(sys.stderr, level=level)
|
|
461
461
|
|
|
462
462
|
|
|
463
|
+
def generate_bundle_from_template(config_path: Path, app_name: str) -> Path:
|
|
464
|
+
"""
|
|
465
|
+
Generate an app-specific databricks.yaml from databricks.yaml.template.
|
|
466
|
+
|
|
467
|
+
This function:
|
|
468
|
+
1. Reads databricks.yaml.template (permanent template file)
|
|
469
|
+
2. Replaces __APP_NAME__ with the actual app name
|
|
470
|
+
3. Writes to databricks.yaml (overwrites if exists)
|
|
471
|
+
4. Returns the path to the generated file
|
|
472
|
+
|
|
473
|
+
The generated databricks.yaml is overwritten on each deployment and is not tracked in git.
|
|
474
|
+
Schema reference remains pointing to ./schemas/bundle_config_schema.json.
|
|
475
|
+
|
|
476
|
+
Args:
|
|
477
|
+
config_path: Path to the app config file
|
|
478
|
+
app_name: Normalized app name
|
|
479
|
+
|
|
480
|
+
Returns:
|
|
481
|
+
Path to the generated databricks.yaml file
|
|
482
|
+
"""
|
|
483
|
+
cwd = Path.cwd()
|
|
484
|
+
template_path = cwd / "databricks.yaml.template"
|
|
485
|
+
output_path = cwd / "databricks.yaml"
|
|
486
|
+
|
|
487
|
+
if not template_path.exists():
|
|
488
|
+
logger.error(f"Template file {template_path} does not exist.")
|
|
489
|
+
sys.exit(1)
|
|
490
|
+
|
|
491
|
+
# Read template
|
|
492
|
+
with open(template_path, "r") as f:
|
|
493
|
+
template_content = f.read()
|
|
494
|
+
|
|
495
|
+
# Replace template variables
|
|
496
|
+
bundle_content = template_content.replace("__APP_NAME__", app_name)
|
|
497
|
+
|
|
498
|
+
# Write generated databricks.yaml (overwrite if exists)
|
|
499
|
+
with open(output_path, "w") as f:
|
|
500
|
+
f.write(bundle_content)
|
|
501
|
+
|
|
502
|
+
logger.info(f"Generated bundle configuration at {output_path} from template")
|
|
503
|
+
return output_path
|
|
504
|
+
|
|
505
|
+
|
|
463
506
|
def run_databricks_command(
|
|
464
507
|
command: list[str],
|
|
465
508
|
profile: Optional[str] = None,
|
|
@@ -467,44 +510,55 @@ def run_databricks_command(
|
|
|
467
510
|
target: Optional[str] = None,
|
|
468
511
|
dry_run: bool = False,
|
|
469
512
|
) -> None:
|
|
470
|
-
"""Execute a databricks CLI command with optional profile."""
|
|
513
|
+
"""Execute a databricks CLI command with optional profile and target."""
|
|
514
|
+
config_path = Path(config) if config else None
|
|
515
|
+
|
|
516
|
+
if config_path and not config_path.exists():
|
|
517
|
+
logger.error(f"Configuration file {config_path} does not exist.")
|
|
518
|
+
sys.exit(1)
|
|
519
|
+
|
|
520
|
+
# Load app config and generate bundle from template
|
|
521
|
+
app_config: AppConfig = AppConfig.from_file(config_path) if config_path else None
|
|
522
|
+
normalized_name: str = normalize_name(app_config.app.name) if app_config else None
|
|
523
|
+
|
|
524
|
+
# Generate app-specific bundle from template (overwrites databricks.yaml temporarily)
|
|
525
|
+
if config_path and app_config:
|
|
526
|
+
generate_bundle_from_template(config_path, normalized_name)
|
|
527
|
+
|
|
528
|
+
# Use app name as target if not explicitly provided
|
|
529
|
+
# This ensures each app gets its own Terraform state in .databricks/bundle/<app-name>/
|
|
530
|
+
if not target and normalized_name:
|
|
531
|
+
target = normalized_name
|
|
532
|
+
logger.debug(f"Using app-specific target: {target}")
|
|
533
|
+
|
|
534
|
+
# Build databricks command (no -c flag needed, uses databricks.yaml in current dir)
|
|
471
535
|
cmd = ["databricks"]
|
|
472
536
|
if profile:
|
|
473
537
|
cmd.extend(["--profile", profile])
|
|
538
|
+
|
|
474
539
|
if target:
|
|
475
540
|
cmd.extend(["--target", target])
|
|
476
|
-
cmd.extend(command)
|
|
477
|
-
if config:
|
|
478
|
-
config_path = Path(config)
|
|
479
541
|
|
|
480
|
-
|
|
481
|
-
logger.error(f"Configuration file {config_path} does not exist.")
|
|
482
|
-
sys.exit(1)
|
|
483
|
-
|
|
484
|
-
app_config: AppConfig = AppConfig.from_file(config_path)
|
|
542
|
+
cmd.extend(command)
|
|
485
543
|
|
|
486
|
-
|
|
487
|
-
|
|
544
|
+
# Add config_path variable for notebooks
|
|
545
|
+
if config_path and app_config:
|
|
546
|
+
# Calculate relative path from notebooks directory to config file
|
|
488
547
|
config_abs = config_path.resolve()
|
|
489
548
|
cwd = Path.cwd()
|
|
490
549
|
notebooks_dir = cwd / "notebooks"
|
|
491
550
|
|
|
492
|
-
# Calculate relative path from notebooks directory to config file
|
|
493
551
|
try:
|
|
494
552
|
relative_config = config_abs.relative_to(notebooks_dir)
|
|
495
553
|
except ValueError:
|
|
496
|
-
# Config file is outside notebooks directory, calculate relative path
|
|
497
|
-
# Use os.path.relpath to get the relative path from notebooks_dir to config file
|
|
498
554
|
relative_config = Path(os.path.relpath(config_abs, notebooks_dir))
|
|
499
555
|
|
|
500
556
|
cmd.append(f'--var="config_path={relative_config}"')
|
|
501
557
|
|
|
502
|
-
normalized_name: str = normalize_name(app_config.app.name)
|
|
503
|
-
cmd.append(f'--var="app_name={normalized_name}"')
|
|
504
|
-
|
|
505
558
|
logger.debug(f"Executing command: {' '.join(cmd)}")
|
|
506
559
|
|
|
507
560
|
if dry_run:
|
|
561
|
+
logger.info(f"[DRY RUN] Would execute: {' '.join(cmd)}")
|
|
508
562
|
return
|
|
509
563
|
|
|
510
564
|
try:
|
|
@@ -531,6 +585,9 @@ def run_databricks_command(
|
|
|
531
585
|
except FileNotFoundError:
|
|
532
586
|
logger.error("databricks CLI not found. Please install the Databricks CLI.")
|
|
533
587
|
sys.exit(1)
|
|
588
|
+
except Exception as e:
|
|
589
|
+
logger.error(f"Command execution failed: {e}")
|
|
590
|
+
sys.exit(1)
|
|
534
591
|
|
|
535
592
|
|
|
536
593
|
def handle_bundle_command(options: Namespace) -> None:
|
|
@@ -539,6 +596,7 @@ def handle_bundle_command(options: Namespace) -> None:
|
|
|
539
596
|
config: Optional[str] = options.config
|
|
540
597
|
target: Optional[str] = options.target
|
|
541
598
|
dry_run: bool = options.dry_run
|
|
599
|
+
|
|
542
600
|
if options.deploy:
|
|
543
601
|
logger.info("Deploying DAO AI asset bundle...")
|
|
544
602
|
run_databricks_command(
|
|
@@ -546,8 +604,9 @@ def handle_bundle_command(options: Namespace) -> None:
|
|
|
546
604
|
)
|
|
547
605
|
if options.run:
|
|
548
606
|
logger.info("Running DAO AI system with current configuration...")
|
|
607
|
+
# Use static job resource key that matches databricks.yaml (resources.jobs.deploy_job)
|
|
549
608
|
run_databricks_command(
|
|
550
|
-
["bundle", "run", "
|
|
609
|
+
["bundle", "run", "deploy_job"],
|
|
551
610
|
profile,
|
|
552
611
|
config,
|
|
553
612
|
target,
|
dao_ai/config.py
CHANGED
|
@@ -30,12 +30,15 @@ from databricks_langchain import (
|
|
|
30
30
|
DatabricksFunctionClient,
|
|
31
31
|
)
|
|
32
32
|
from langchain_core.language_models import LanguageModelLike
|
|
33
|
+
from langchain_core.messages import BaseMessage, messages_from_dict
|
|
33
34
|
from langchain_core.runnables.base import RunnableLike
|
|
34
35
|
from langchain_openai import ChatOpenAI
|
|
35
36
|
from langgraph.checkpoint.base import BaseCheckpointSaver
|
|
36
37
|
from langgraph.graph.state import CompiledStateGraph
|
|
37
38
|
from langgraph.store.base import BaseStore
|
|
38
39
|
from loguru import logger
|
|
40
|
+
from mlflow.genai.datasets import EvaluationDataset, create_dataset, get_dataset
|
|
41
|
+
from mlflow.genai.prompts import PromptVersion, load_prompt
|
|
39
42
|
from mlflow.models import ModelConfig
|
|
40
43
|
from mlflow.models.resources import (
|
|
41
44
|
DatabricksFunction,
|
|
@@ -49,6 +52,9 @@ from mlflow.models.resources import (
|
|
|
49
52
|
DatabricksVectorSearchIndex,
|
|
50
53
|
)
|
|
51
54
|
from mlflow.pyfunc import ChatModel, ResponsesAgent
|
|
55
|
+
from mlflow.types.responses import (
|
|
56
|
+
ResponsesAgentRequest,
|
|
57
|
+
)
|
|
52
58
|
from pydantic import (
|
|
53
59
|
BaseModel,
|
|
54
60
|
ConfigDict,
|
|
@@ -324,6 +330,10 @@ class LLMModel(BaseModel, IsDatabricksResource):
|
|
|
324
330
|
"serving.serving-endpoints",
|
|
325
331
|
]
|
|
326
332
|
|
|
333
|
+
@property
|
|
334
|
+
def uri(self) -> str:
|
|
335
|
+
return f"databricks:/{self.name}"
|
|
336
|
+
|
|
327
337
|
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
328
338
|
return [
|
|
329
339
|
DatabricksServingEndpoint(
|
|
@@ -1181,17 +1191,32 @@ class PromptModel(BaseModel, HasFullName):
|
|
|
1181
1191
|
from dao_ai.providers.databricks import DatabricksProvider
|
|
1182
1192
|
|
|
1183
1193
|
provider: DatabricksProvider = DatabricksProvider()
|
|
1184
|
-
|
|
1185
|
-
return
|
|
1194
|
+
prompt_version = provider.get_prompt(self)
|
|
1195
|
+
return prompt_version.to_single_brace_format()
|
|
1186
1196
|
|
|
1187
1197
|
@property
|
|
1188
1198
|
def full_name(self) -> str:
|
|
1199
|
+
prompt_name: str = self.name
|
|
1189
1200
|
if self.schema_model:
|
|
1190
|
-
|
|
1191
|
-
|
|
1192
|
-
|
|
1193
|
-
|
|
1194
|
-
|
|
1201
|
+
prompt_name = f"{self.schema_model.full_name}.{prompt_name}"
|
|
1202
|
+
return prompt_name
|
|
1203
|
+
|
|
1204
|
+
@property
|
|
1205
|
+
def uri(self) -> str:
|
|
1206
|
+
prompt_uri: str = f"prompts:/{self.full_name}"
|
|
1207
|
+
|
|
1208
|
+
if self.alias:
|
|
1209
|
+
prompt_uri = f"prompts:/{self.full_name}@{self.alias}"
|
|
1210
|
+
elif self.version:
|
|
1211
|
+
prompt_uri = f"prompts:/{self.full_name}/{self.version}"
|
|
1212
|
+
else:
|
|
1213
|
+
prompt_uri = f"prompts:/{self.full_name}@latest"
|
|
1214
|
+
|
|
1215
|
+
return prompt_uri
|
|
1216
|
+
|
|
1217
|
+
def as_prompt(self) -> PromptVersion:
|
|
1218
|
+
prompt_version: PromptVersion = load_prompt(self.uri)
|
|
1219
|
+
return prompt_version
|
|
1195
1220
|
|
|
1196
1221
|
@model_validator(mode="after")
|
|
1197
1222
|
def validate_mutually_exclusive(self):
|
|
@@ -1213,6 +1238,17 @@ class AgentModel(BaseModel):
|
|
|
1213
1238
|
pre_agent_hook: Optional[FunctionHook] = None
|
|
1214
1239
|
post_agent_hook: Optional[FunctionHook] = None
|
|
1215
1240
|
|
|
1241
|
+
def as_runnable(self) -> RunnableLike:
|
|
1242
|
+
from dao_ai.nodes import create_agent_node
|
|
1243
|
+
|
|
1244
|
+
return create_agent_node(self)
|
|
1245
|
+
|
|
1246
|
+
def as_responses_agent(self) -> ResponsesAgent:
|
|
1247
|
+
from dao_ai.models import create_responses_agent
|
|
1248
|
+
|
|
1249
|
+
graph: CompiledStateGraph = self.as_runnable()
|
|
1250
|
+
return create_responses_agent(graph)
|
|
1251
|
+
|
|
1216
1252
|
|
|
1217
1253
|
class SupervisorModel(BaseModel):
|
|
1218
1254
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
@@ -1301,7 +1337,7 @@ class ChatPayload(BaseModel):
|
|
|
1301
1337
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1302
1338
|
input: Optional[list[Message]] = None
|
|
1303
1339
|
messages: Optional[list[Message]] = None
|
|
1304
|
-
custom_inputs: dict
|
|
1340
|
+
custom_inputs: Optional[dict] = Field(default_factory=dict)
|
|
1305
1341
|
|
|
1306
1342
|
@model_validator(mode="after")
|
|
1307
1343
|
def validate_mutual_exclusion_and_alias(self) -> "ChatPayload":
|
|
@@ -1330,6 +1366,19 @@ class ChatPayload(BaseModel):
|
|
|
1330
1366
|
|
|
1331
1367
|
return self
|
|
1332
1368
|
|
|
1369
|
+
def as_messages(self) -> Sequence[BaseMessage]:
|
|
1370
|
+
return messages_from_dict(
|
|
1371
|
+
[{"type": m.role, "content": m.content} for m in self.messages]
|
|
1372
|
+
)
|
|
1373
|
+
|
|
1374
|
+
def as_agent_request(self) -> ResponsesAgentRequest:
|
|
1375
|
+
from mlflow.types.responses_helpers import Message as _Message
|
|
1376
|
+
|
|
1377
|
+
return ResponsesAgentRequest(
|
|
1378
|
+
input=[_Message(role=m.role, content=m.content) for m in self.messages],
|
|
1379
|
+
custom_inputs=self.custom_inputs,
|
|
1380
|
+
)
|
|
1381
|
+
|
|
1333
1382
|
|
|
1334
1383
|
class ChatHistoryModel(BaseModel):
|
|
1335
1384
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
@@ -1459,6 +1508,174 @@ class EvaluationModel(BaseModel):
|
|
|
1459
1508
|
guidelines: list[GuidelineModel] = Field(default_factory=list)
|
|
1460
1509
|
|
|
1461
1510
|
|
|
1511
|
+
class EvaluationDatasetExpectationsModel(BaseModel):
|
|
1512
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1513
|
+
expected_response: Optional[str] = None
|
|
1514
|
+
expected_facts: Optional[list[str]] = None
|
|
1515
|
+
|
|
1516
|
+
@model_validator(mode="after")
|
|
1517
|
+
def validate_mutually_exclusive(self):
|
|
1518
|
+
if self.expected_response is not None and self.expected_facts is not None:
|
|
1519
|
+
raise ValueError("Cannot specify both expected_response and expected_facts")
|
|
1520
|
+
return self
|
|
1521
|
+
|
|
1522
|
+
|
|
1523
|
+
class EvaluationDatasetEntryModel(BaseModel):
|
|
1524
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1525
|
+
inputs: ChatPayload
|
|
1526
|
+
expectations: EvaluationDatasetExpectationsModel
|
|
1527
|
+
|
|
1528
|
+
def to_mlflow_format(self) -> dict[str, Any]:
|
|
1529
|
+
"""
|
|
1530
|
+
Convert to MLflow evaluation dataset format.
|
|
1531
|
+
|
|
1532
|
+
Flattens the expectations fields to the top level alongside inputs,
|
|
1533
|
+
which is the format expected by MLflow's Correctness scorer.
|
|
1534
|
+
|
|
1535
|
+
Returns:
|
|
1536
|
+
dict: Flattened dictionary with inputs and expectation fields at top level
|
|
1537
|
+
"""
|
|
1538
|
+
result: dict[str, Any] = {"inputs": self.inputs.model_dump()}
|
|
1539
|
+
|
|
1540
|
+
# Flatten expectations to top level for MLflow compatibility
|
|
1541
|
+
if self.expectations.expected_response is not None:
|
|
1542
|
+
result["expected_response"] = self.expectations.expected_response
|
|
1543
|
+
if self.expectations.expected_facts is not None:
|
|
1544
|
+
result["expected_facts"] = self.expectations.expected_facts
|
|
1545
|
+
|
|
1546
|
+
return result
|
|
1547
|
+
|
|
1548
|
+
|
|
1549
|
+
class EvaluationDatasetModel(BaseModel, HasFullName):
|
|
1550
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1551
|
+
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
1552
|
+
name: str
|
|
1553
|
+
data: Optional[list[EvaluationDatasetEntryModel]] = Field(default_factory=list)
|
|
1554
|
+
overwrite: Optional[bool] = False
|
|
1555
|
+
|
|
1556
|
+
def as_dataset(self, w: WorkspaceClient | None = None) -> EvaluationDataset:
|
|
1557
|
+
evaluation_dataset: EvaluationDataset
|
|
1558
|
+
needs_creation: bool = False
|
|
1559
|
+
|
|
1560
|
+
try:
|
|
1561
|
+
evaluation_dataset = get_dataset(name=self.full_name)
|
|
1562
|
+
if self.overwrite:
|
|
1563
|
+
logger.warning(f"Overwriting dataset {self.full_name}")
|
|
1564
|
+
workspace_client: WorkspaceClient = w if w else WorkspaceClient()
|
|
1565
|
+
logger.debug(f"Dropping table: {self.full_name}")
|
|
1566
|
+
workspace_client.tables.delete(full_name=self.full_name)
|
|
1567
|
+
needs_creation = True
|
|
1568
|
+
except Exception:
|
|
1569
|
+
logger.warning(
|
|
1570
|
+
f"Dataset {self.full_name} not found, will create new dataset"
|
|
1571
|
+
)
|
|
1572
|
+
needs_creation = True
|
|
1573
|
+
|
|
1574
|
+
# Create dataset if needed (either new or after overwrite)
|
|
1575
|
+
if needs_creation:
|
|
1576
|
+
evaluation_dataset = create_dataset(name=self.full_name)
|
|
1577
|
+
if self.data:
|
|
1578
|
+
logger.debug(
|
|
1579
|
+
f"Merging {len(self.data)} entries into dataset {self.full_name}"
|
|
1580
|
+
)
|
|
1581
|
+
# Use to_mlflow_format() to flatten expectations for MLflow compatibility
|
|
1582
|
+
evaluation_dataset.merge_records(
|
|
1583
|
+
[e.to_mlflow_format() for e in self.data]
|
|
1584
|
+
)
|
|
1585
|
+
|
|
1586
|
+
return evaluation_dataset
|
|
1587
|
+
|
|
1588
|
+
@property
|
|
1589
|
+
def full_name(self) -> str:
|
|
1590
|
+
if self.schema_model:
|
|
1591
|
+
return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}.{self.name}"
|
|
1592
|
+
return self.name
|
|
1593
|
+
|
|
1594
|
+
|
|
1595
|
+
class PromptOptimizationModel(BaseModel):
|
|
1596
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1597
|
+
name: str
|
|
1598
|
+
prompt: Optional[PromptModel] = None
|
|
1599
|
+
agent: AgentModel
|
|
1600
|
+
dataset: (
|
|
1601
|
+
EvaluationDatasetModel | str
|
|
1602
|
+
) # Reference to dataset name (looked up in OptimizationsModel.training_datasets or MLflow)
|
|
1603
|
+
reflection_model: Optional[LLMModel | str] = None
|
|
1604
|
+
num_candidates: Optional[int] = 50
|
|
1605
|
+
scorer_model: Optional[LLMModel | str] = None
|
|
1606
|
+
|
|
1607
|
+
def optimize(self, w: WorkspaceClient | None = None) -> PromptModel:
|
|
1608
|
+
"""
|
|
1609
|
+
Optimize the prompt using MLflow's prompt optimization.
|
|
1610
|
+
|
|
1611
|
+
Args:
|
|
1612
|
+
w: Optional WorkspaceClient for Databricks operations
|
|
1613
|
+
|
|
1614
|
+
Returns:
|
|
1615
|
+
PromptModel: The optimized prompt model with new URI
|
|
1616
|
+
"""
|
|
1617
|
+
from dao_ai.providers.base import ServiceProvider
|
|
1618
|
+
from dao_ai.providers.databricks import DatabricksProvider
|
|
1619
|
+
|
|
1620
|
+
provider: ServiceProvider = DatabricksProvider(w=w)
|
|
1621
|
+
optimized_prompt: PromptModel = provider.optimize_prompt(self)
|
|
1622
|
+
return optimized_prompt
|
|
1623
|
+
|
|
1624
|
+
@model_validator(mode="after")
|
|
1625
|
+
def set_defaults(self):
|
|
1626
|
+
# If no prompt is specified, try to use the agent's prompt
|
|
1627
|
+
if self.prompt is None:
|
|
1628
|
+
if isinstance(self.agent.prompt, PromptModel):
|
|
1629
|
+
self.prompt = self.agent.prompt
|
|
1630
|
+
else:
|
|
1631
|
+
raise ValueError(
|
|
1632
|
+
f"Prompt optimization '{self.name}' requires either an explicit prompt "
|
|
1633
|
+
f"or an agent with a prompt configured"
|
|
1634
|
+
)
|
|
1635
|
+
|
|
1636
|
+
if self.reflection_model is None:
|
|
1637
|
+
self.reflection_model = self.agent.model
|
|
1638
|
+
|
|
1639
|
+
if self.scorer_model is None:
|
|
1640
|
+
self.scorer_model = self.agent.model
|
|
1641
|
+
|
|
1642
|
+
return self
|
|
1643
|
+
|
|
1644
|
+
|
|
1645
|
+
class OptimizationsModel(BaseModel):
|
|
1646
|
+
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1647
|
+
training_datasets: dict[str, EvaluationDatasetModel] = Field(default_factory=dict)
|
|
1648
|
+
prompt_optimizations: dict[str, PromptOptimizationModel] = Field(
|
|
1649
|
+
default_factory=dict
|
|
1650
|
+
)
|
|
1651
|
+
|
|
1652
|
+
def optimize(self, w: WorkspaceClient | None = None) -> dict[str, PromptModel]:
|
|
1653
|
+
"""
|
|
1654
|
+
Optimize all prompts in this configuration.
|
|
1655
|
+
|
|
1656
|
+
This method:
|
|
1657
|
+
1. Ensures all training datasets are created/registered in MLflow
|
|
1658
|
+
2. Runs each prompt optimization
|
|
1659
|
+
|
|
1660
|
+
Args:
|
|
1661
|
+
w: Optional WorkspaceClient for Databricks operations
|
|
1662
|
+
|
|
1663
|
+
Returns:
|
|
1664
|
+
dict[str, PromptModel]: Dictionary mapping optimization names to optimized prompts
|
|
1665
|
+
"""
|
|
1666
|
+
# First, ensure all training datasets are created/registered in MLflow
|
|
1667
|
+
logger.info(f"Ensuring {len(self.training_datasets)} training datasets exist")
|
|
1668
|
+
for dataset_name, dataset_model in self.training_datasets.items():
|
|
1669
|
+
logger.debug(f"Creating/updating dataset: {dataset_name}")
|
|
1670
|
+
dataset_model.as_dataset()
|
|
1671
|
+
|
|
1672
|
+
# Run optimizations
|
|
1673
|
+
results: dict[str, PromptModel] = {}
|
|
1674
|
+
for name, optimization in self.prompt_optimizations.items():
|
|
1675
|
+
results[name] = optimization.optimize(w)
|
|
1676
|
+
return results
|
|
1677
|
+
|
|
1678
|
+
|
|
1462
1679
|
class DatasetFormat(str, Enum):
|
|
1463
1680
|
CSV = "csv"
|
|
1464
1681
|
DELTA = "delta"
|
|
@@ -1537,6 +1754,7 @@ class AppConfig(BaseModel):
|
|
|
1537
1754
|
agents: dict[str, AgentModel] = Field(default_factory=dict)
|
|
1538
1755
|
app: Optional[AppModel] = None
|
|
1539
1756
|
evaluation: Optional[EvaluationModel] = None
|
|
1757
|
+
optimizations: Optional[OptimizationsModel] = None
|
|
1540
1758
|
datasets: Optional[list[DatasetModel]] = Field(default_factory=list)
|
|
1541
1759
|
unity_catalog_functions: Optional[list[UnityCatalogFunctionSqlModel]] = Field(
|
|
1542
1760
|
default_factory=list
|
dao_ai/graph.py
CHANGED
|
@@ -79,7 +79,12 @@ def _create_supervisor_graph(config: AppConfig) -> CompiledStateGraph:
|
|
|
79
79
|
for registered_agent in config.app.agents:
|
|
80
80
|
agents.append(
|
|
81
81
|
create_agent_node(
|
|
82
|
-
|
|
82
|
+
agent=registered_agent,
|
|
83
|
+
memory=config.app.orchestration.memory
|
|
84
|
+
if config.app.orchestration
|
|
85
|
+
else None,
|
|
86
|
+
chat_history=config.app.chat_history,
|
|
87
|
+
additional_tools=[],
|
|
83
88
|
)
|
|
84
89
|
)
|
|
85
90
|
tools.append(
|
|
@@ -169,7 +174,12 @@ def _create_swarm_graph(config: AppConfig) -> CompiledStateGraph:
|
|
|
169
174
|
)
|
|
170
175
|
agents.append(
|
|
171
176
|
create_agent_node(
|
|
172
|
-
|
|
177
|
+
agent=registered_agent,
|
|
178
|
+
memory=config.app.orchestration.memory
|
|
179
|
+
if config.app.orchestration
|
|
180
|
+
else None,
|
|
181
|
+
chat_history=config.app.chat_history,
|
|
182
|
+
additional_tools=handoff_tools,
|
|
173
183
|
)
|
|
174
184
|
)
|
|
175
185
|
|
dao_ai/nodes.py
CHANGED
|
@@ -19,9 +19,9 @@ from loguru import logger
|
|
|
19
19
|
from dao_ai.config import (
|
|
20
20
|
AgentModel,
|
|
21
21
|
AppConfig,
|
|
22
|
-
AppModel,
|
|
23
22
|
ChatHistoryModel,
|
|
24
23
|
FunctionHook,
|
|
24
|
+
MemoryModel,
|
|
25
25
|
ToolModel,
|
|
26
26
|
)
|
|
27
27
|
from dao_ai.guardrails import reflection_guardrail, with_guardrails
|
|
@@ -31,12 +31,18 @@ from dao_ai.state import Context, IncomingState, SharedState
|
|
|
31
31
|
from dao_ai.tools import create_tools
|
|
32
32
|
|
|
33
33
|
|
|
34
|
-
def summarization_node(
|
|
35
|
-
|
|
34
|
+
def summarization_node(chat_history: ChatHistoryModel) -> RunnableLike:
|
|
35
|
+
"""
|
|
36
|
+
Create a summarization node for managing chat history.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
chat_history: ChatHistoryModel configuration for summarization
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
RunnableLike: A summarization node that processes messages
|
|
43
|
+
"""
|
|
36
44
|
if chat_history is None:
|
|
37
|
-
raise ValueError(
|
|
38
|
-
"AppModel must have chat_history configured to use summarization"
|
|
39
|
-
)
|
|
45
|
+
raise ValueError("chat_history must be provided to use summarization")
|
|
40
46
|
|
|
41
47
|
max_tokens: int = chat_history.max_tokens
|
|
42
48
|
max_tokens_before_summary: int | None = chat_history.max_tokens_before_summary
|
|
@@ -93,23 +99,26 @@ def call_agent_with_summarized_messages(agent: CompiledStateGraph) -> RunnableLi
|
|
|
93
99
|
|
|
94
100
|
|
|
95
101
|
def create_agent_node(
|
|
96
|
-
app: AppModel,
|
|
97
102
|
agent: AgentModel,
|
|
103
|
+
memory: Optional[MemoryModel] = None,
|
|
104
|
+
chat_history: Optional[ChatHistoryModel] = None,
|
|
98
105
|
additional_tools: Optional[Sequence[BaseTool]] = None,
|
|
99
106
|
) -> RunnableLike:
|
|
100
107
|
"""
|
|
101
108
|
Factory function that creates a LangGraph node for a specialized agent.
|
|
102
109
|
|
|
103
|
-
This creates a node function that handles user requests using a specialized agent
|
|
104
|
-
|
|
105
|
-
|
|
110
|
+
This creates a node function that handles user requests using a specialized agent.
|
|
111
|
+
The function configures the agent with the appropriate model, prompt, tools, and guardrails.
|
|
112
|
+
If chat_history is provided, it creates a workflow with summarization node.
|
|
106
113
|
|
|
107
114
|
Args:
|
|
108
|
-
|
|
109
|
-
|
|
115
|
+
agent: AgentModel configuration for the agent
|
|
116
|
+
memory: Optional MemoryModel for memory store configuration
|
|
117
|
+
chat_history: Optional ChatHistoryModel for chat history summarization
|
|
118
|
+
additional_tools: Optional sequence of additional tools to add to the agent
|
|
110
119
|
|
|
111
120
|
Returns:
|
|
112
|
-
An agent
|
|
121
|
+
RunnableLike: An agent node that processes state and returns responses
|
|
113
122
|
"""
|
|
114
123
|
logger.debug(f"Creating agent node for {agent.name}")
|
|
115
124
|
|
|
@@ -124,10 +133,10 @@ def create_agent_node(
|
|
|
124
133
|
additional_tools = []
|
|
125
134
|
tools: Sequence[BaseTool] = create_tools(tool_models) + additional_tools
|
|
126
135
|
|
|
127
|
-
if
|
|
136
|
+
if memory and memory.store:
|
|
128
137
|
namespace: tuple[str, ...] = ("memory",)
|
|
129
|
-
if
|
|
130
|
-
namespace = namespace + (
|
|
138
|
+
if memory.store.namespace:
|
|
139
|
+
namespace = namespace + (memory.store.namespace,)
|
|
131
140
|
logger.debug(f"Memory store namespace: {namespace}")
|
|
132
141
|
|
|
133
142
|
tools += [
|
|
@@ -145,13 +154,15 @@ def create_agent_node(
|
|
|
145
154
|
)
|
|
146
155
|
logger.debug(f"post_agent_hook: {post_agent_hook}")
|
|
147
156
|
|
|
157
|
+
checkpointer: bool = memory and memory.checkpointer is not None
|
|
158
|
+
|
|
148
159
|
compiled_agent: CompiledStateGraph = create_react_agent(
|
|
149
160
|
name=agent.name,
|
|
150
161
|
model=llm,
|
|
151
162
|
prompt=make_prompt(agent.prompt),
|
|
152
163
|
tools=tools,
|
|
153
164
|
store=True,
|
|
154
|
-
checkpointer=
|
|
165
|
+
checkpointer=checkpointer,
|
|
155
166
|
state_schema=SharedState,
|
|
156
167
|
context_schema=Context,
|
|
157
168
|
pre_model_hook=pre_agent_hook,
|
|
@@ -166,8 +177,6 @@ def create_agent_node(
|
|
|
166
177
|
|
|
167
178
|
agent_node: CompiledStateGraph
|
|
168
179
|
|
|
169
|
-
chat_history: ChatHistoryModel = app.chat_history
|
|
170
|
-
|
|
171
180
|
if chat_history is None:
|
|
172
181
|
logger.debug("No chat history configured, using compiled agent directly")
|
|
173
182
|
agent_node = compiled_agent
|
|
@@ -179,7 +188,7 @@ def create_agent_node(
|
|
|
179
188
|
input=SharedState,
|
|
180
189
|
output=SharedState,
|
|
181
190
|
)
|
|
182
|
-
workflow.add_node("summarization", summarization_node(
|
|
191
|
+
workflow.add_node("summarization", summarization_node(chat_history))
|
|
183
192
|
workflow.add_node(
|
|
184
193
|
"agent",
|
|
185
194
|
call_agent_with_summarized_messages(agent=compiled_agent),
|