dao-ai 0.0.25__py3-none-any.whl → 0.0.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dao_ai/config.py CHANGED
@@ -30,12 +30,15 @@ from databricks_langchain import (
30
30
  DatabricksFunctionClient,
31
31
  )
32
32
  from langchain_core.language_models import LanguageModelLike
33
+ from langchain_core.messages import BaseMessage, messages_from_dict
33
34
  from langchain_core.runnables.base import RunnableLike
34
35
  from langchain_openai import ChatOpenAI
35
36
  from langgraph.checkpoint.base import BaseCheckpointSaver
36
37
  from langgraph.graph.state import CompiledStateGraph
37
38
  from langgraph.store.base import BaseStore
38
39
  from loguru import logger
40
+ from mlflow.genai.datasets import EvaluationDataset, create_dataset, get_dataset
41
+ from mlflow.genai.prompts import PromptVersion, load_prompt
39
42
  from mlflow.models import ModelConfig
40
43
  from mlflow.models.resources import (
41
44
  DatabricksFunction,
@@ -49,6 +52,9 @@ from mlflow.models.resources import (
49
52
  DatabricksVectorSearchIndex,
50
53
  )
51
54
  from mlflow.pyfunc import ChatModel, ResponsesAgent
55
+ from mlflow.types.responses import (
56
+ ResponsesAgentRequest,
57
+ )
52
58
  from pydantic import (
53
59
  BaseModel,
54
60
  ConfigDict,
@@ -324,6 +330,10 @@ class LLMModel(BaseModel, IsDatabricksResource):
324
330
  "serving.serving-endpoints",
325
331
  ]
326
332
 
333
+ @property
334
+ def uri(self) -> str:
335
+ return f"databricks:/{self.name}"
336
+
327
337
  def as_resources(self) -> Sequence[DatabricksResource]:
328
338
  return [
329
339
  DatabricksServingEndpoint(
@@ -1181,17 +1191,32 @@ class PromptModel(BaseModel, HasFullName):
1181
1191
  from dao_ai.providers.databricks import DatabricksProvider
1182
1192
 
1183
1193
  provider: DatabricksProvider = DatabricksProvider()
1184
- prompt: str = provider.get_prompt(self)
1185
- return prompt
1194
+ prompt_version = provider.get_prompt(self)
1195
+ return prompt_version.to_single_brace_format()
1186
1196
 
1187
1197
  @property
1188
1198
  def full_name(self) -> str:
1199
+ prompt_name: str = self.name
1189
1200
  if self.schema_model:
1190
- name: str = ""
1191
- if self.name:
1192
- name = f".{self.name}"
1193
- return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}{name}"
1194
- return self.name
1201
+ prompt_name = f"{self.schema_model.full_name}.{prompt_name}"
1202
+ return prompt_name
1203
+
1204
+ @property
1205
+ def uri(self) -> str:
1206
+ prompt_uri: str = f"prompts:/{self.full_name}"
1207
+
1208
+ if self.alias:
1209
+ prompt_uri = f"prompts:/{self.full_name}@{self.alias}"
1210
+ elif self.version:
1211
+ prompt_uri = f"prompts:/{self.full_name}/{self.version}"
1212
+ else:
1213
+ prompt_uri = f"prompts:/{self.full_name}@latest"
1214
+
1215
+ return prompt_uri
1216
+
1217
+ def as_prompt(self) -> PromptVersion:
1218
+ prompt_version: PromptVersion = load_prompt(self.uri)
1219
+ return prompt_version
1195
1220
 
1196
1221
  @model_validator(mode="after")
1197
1222
  def validate_mutually_exclusive(self):
@@ -1213,6 +1238,17 @@ class AgentModel(BaseModel):
1213
1238
  pre_agent_hook: Optional[FunctionHook] = None
1214
1239
  post_agent_hook: Optional[FunctionHook] = None
1215
1240
 
1241
+ def as_runnable(self) -> RunnableLike:
1242
+ from dao_ai.nodes import create_agent_node
1243
+
1244
+ return create_agent_node(self)
1245
+
1246
+ def as_responses_agent(self) -> ResponsesAgent:
1247
+ from dao_ai.models import create_responses_agent
1248
+
1249
+ graph: CompiledStateGraph = self.as_runnable()
1250
+ return create_responses_agent(graph)
1251
+
1216
1252
 
1217
1253
  class SupervisorModel(BaseModel):
1218
1254
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
@@ -1330,6 +1366,19 @@ class ChatPayload(BaseModel):
1330
1366
 
1331
1367
  return self
1332
1368
 
1369
+ def as_messages(self) -> Sequence[BaseMessage]:
1370
+ return messages_from_dict(
1371
+ [{"type": m.role, "content": m.content} for m in self.messages]
1372
+ )
1373
+
1374
+ def as_agent_request(self) -> ResponsesAgentRequest:
1375
+ from mlflow.types.responses_helpers import Message as _Message
1376
+
1377
+ return ResponsesAgentRequest(
1378
+ input=[_Message(role=m.role, content=m.content) for m in self.messages],
1379
+ custom_inputs=self.custom_inputs,
1380
+ )
1381
+
1333
1382
 
1334
1383
  class ChatHistoryModel(BaseModel):
1335
1384
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
@@ -1459,6 +1508,174 @@ class EvaluationModel(BaseModel):
1459
1508
  guidelines: list[GuidelineModel] = Field(default_factory=list)
1460
1509
 
1461
1510
 
1511
+ class EvaluationDatasetExpectationsModel(BaseModel):
1512
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1513
+ expected_response: Optional[str] = None
1514
+ expected_facts: Optional[list[str]] = None
1515
+
1516
+ @model_validator(mode="after")
1517
+ def validate_mutually_exclusive(self):
1518
+ if self.expected_response is not None and self.expected_facts is not None:
1519
+ raise ValueError("Cannot specify both expected_response and expected_facts")
1520
+ return self
1521
+
1522
+
1523
+ class EvaluationDatasetEntryModel(BaseModel):
1524
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1525
+ inputs: ChatPayload
1526
+ expectations: EvaluationDatasetExpectationsModel
1527
+
1528
+ def to_mlflow_format(self) -> dict[str, Any]:
1529
+ """
1530
+ Convert to MLflow evaluation dataset format.
1531
+
1532
+ Flattens the expectations fields to the top level alongside inputs,
1533
+ which is the format expected by MLflow's Correctness scorer.
1534
+
1535
+ Returns:
1536
+ dict: Flattened dictionary with inputs and expectation fields at top level
1537
+ """
1538
+ result: dict[str, Any] = {"inputs": self.inputs.model_dump()}
1539
+
1540
+ # Flatten expectations to top level for MLflow compatibility
1541
+ if self.expectations.expected_response is not None:
1542
+ result["expected_response"] = self.expectations.expected_response
1543
+ if self.expectations.expected_facts is not None:
1544
+ result["expected_facts"] = self.expectations.expected_facts
1545
+
1546
+ return result
1547
+
1548
+
1549
+ class EvaluationDatasetModel(BaseModel, HasFullName):
1550
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1551
+ schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
1552
+ name: str
1553
+ data: Optional[list[EvaluationDatasetEntryModel]] = Field(default_factory=list)
1554
+ overwrite: Optional[bool] = False
1555
+
1556
+ def as_dataset(self, w: WorkspaceClient | None = None) -> EvaluationDataset:
1557
+ evaluation_dataset: EvaluationDataset
1558
+ needs_creation: bool = False
1559
+
1560
+ try:
1561
+ evaluation_dataset = get_dataset(name=self.full_name)
1562
+ if self.overwrite:
1563
+ logger.warning(f"Overwriting dataset {self.full_name}")
1564
+ workspace_client: WorkspaceClient = w if w else WorkspaceClient()
1565
+ logger.debug(f"Dropping table: {self.full_name}")
1566
+ workspace_client.tables.delete(full_name=self.full_name)
1567
+ needs_creation = True
1568
+ except Exception:
1569
+ logger.warning(
1570
+ f"Dataset {self.full_name} not found, will create new dataset"
1571
+ )
1572
+ needs_creation = True
1573
+
1574
+ # Create dataset if needed (either new or after overwrite)
1575
+ if needs_creation:
1576
+ evaluation_dataset = create_dataset(name=self.full_name)
1577
+ if self.data:
1578
+ logger.debug(
1579
+ f"Merging {len(self.data)} entries into dataset {self.full_name}"
1580
+ )
1581
+ # Use to_mlflow_format() to flatten expectations for MLflow compatibility
1582
+ evaluation_dataset.merge_records(
1583
+ [e.to_mlflow_format() for e in self.data]
1584
+ )
1585
+
1586
+ return evaluation_dataset
1587
+
1588
+ @property
1589
+ def full_name(self) -> str:
1590
+ if self.schema_model:
1591
+ return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}.{self.name}"
1592
+ return self.name
1593
+
1594
+
1595
+ class PromptOptimizationModel(BaseModel):
1596
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1597
+ name: str
1598
+ prompt: Optional[PromptModel] = None
1599
+ agent: AgentModel
1600
+ dataset: (
1601
+ EvaluationDatasetModel | str
1602
+ ) # Reference to dataset name (looked up in OptimizationsModel.training_datasets or MLflow)
1603
+ reflection_model: Optional[LLMModel | str] = None
1604
+ num_candidates: Optional[int] = 50
1605
+ scorer_model: Optional[LLMModel | str] = None
1606
+
1607
+ def optimize(self, w: WorkspaceClient | None = None) -> PromptModel:
1608
+ """
1609
+ Optimize the prompt using MLflow's prompt optimization.
1610
+
1611
+ Args:
1612
+ w: Optional WorkspaceClient for Databricks operations
1613
+
1614
+ Returns:
1615
+ PromptModel: The optimized prompt model with new URI
1616
+ """
1617
+ from dao_ai.providers.base import ServiceProvider
1618
+ from dao_ai.providers.databricks import DatabricksProvider
1619
+
1620
+ provider: ServiceProvider = DatabricksProvider(w=w)
1621
+ optimized_prompt: PromptModel = provider.optimize_prompt(self)
1622
+ return optimized_prompt
1623
+
1624
+ @model_validator(mode="after")
1625
+ def set_defaults(self):
1626
+ # If no prompt is specified, try to use the agent's prompt
1627
+ if self.prompt is None:
1628
+ if isinstance(self.agent.prompt, PromptModel):
1629
+ self.prompt = self.agent.prompt
1630
+ else:
1631
+ raise ValueError(
1632
+ f"Prompt optimization '{self.name}' requires either an explicit prompt "
1633
+ f"or an agent with a prompt configured"
1634
+ )
1635
+
1636
+ if self.reflection_model is None:
1637
+ self.reflection_model = self.agent.model
1638
+
1639
+ if self.scorer_model is None:
1640
+ self.scorer_model = self.agent.model
1641
+
1642
+ return self
1643
+
1644
+
1645
+ class OptimizationsModel(BaseModel):
1646
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1647
+ training_datasets: dict[str, EvaluationDatasetModel] = Field(default_factory=dict)
1648
+ prompt_optimizations: dict[str, PromptOptimizationModel] = Field(
1649
+ default_factory=dict
1650
+ )
1651
+
1652
+ def optimize(self, w: WorkspaceClient | None = None) -> dict[str, PromptModel]:
1653
+ """
1654
+ Optimize all prompts in this configuration.
1655
+
1656
+ This method:
1657
+ 1. Ensures all training datasets are created/registered in MLflow
1658
+ 2. Runs each prompt optimization
1659
+
1660
+ Args:
1661
+ w: Optional WorkspaceClient for Databricks operations
1662
+
1663
+ Returns:
1664
+ dict[str, PromptModel]: Dictionary mapping optimization names to optimized prompts
1665
+ """
1666
+ # First, ensure all training datasets are created/registered in MLflow
1667
+ logger.info(f"Ensuring {len(self.training_datasets)} training datasets exist")
1668
+ for dataset_name, dataset_model in self.training_datasets.items():
1669
+ logger.debug(f"Creating/updating dataset: {dataset_name}")
1670
+ dataset_model.as_dataset()
1671
+
1672
+ # Run optimizations
1673
+ results: dict[str, PromptModel] = {}
1674
+ for name, optimization in self.prompt_optimizations.items():
1675
+ results[name] = optimization.optimize(w)
1676
+ return results
1677
+
1678
+
1462
1679
  class DatasetFormat(str, Enum):
1463
1680
  CSV = "csv"
1464
1681
  DELTA = "delta"
@@ -1537,6 +1754,7 @@ class AppConfig(BaseModel):
1537
1754
  agents: dict[str, AgentModel] = Field(default_factory=dict)
1538
1755
  app: Optional[AppModel] = None
1539
1756
  evaluation: Optional[EvaluationModel] = None
1757
+ optimizations: Optional[OptimizationsModel] = None
1540
1758
  datasets: Optional[list[DatasetModel]] = Field(default_factory=list)
1541
1759
  unity_catalog_functions: Optional[list[UnityCatalogFunctionSqlModel]] = Field(
1542
1760
  default_factory=list
dao_ai/graph.py CHANGED
@@ -79,7 +79,12 @@ def _create_supervisor_graph(config: AppConfig) -> CompiledStateGraph:
79
79
  for registered_agent in config.app.agents:
80
80
  agents.append(
81
81
  create_agent_node(
82
- app=config.app, agent=registered_agent, additional_tools=[]
82
+ agent=registered_agent,
83
+ memory=config.app.orchestration.memory
84
+ if config.app.orchestration
85
+ else None,
86
+ chat_history=config.app.chat_history,
87
+ additional_tools=[],
83
88
  )
84
89
  )
85
90
  tools.append(
@@ -169,7 +174,12 @@ def _create_swarm_graph(config: AppConfig) -> CompiledStateGraph:
169
174
  )
170
175
  agents.append(
171
176
  create_agent_node(
172
- app=config.app, agent=registered_agent, additional_tools=handoff_tools
177
+ agent=registered_agent,
178
+ memory=config.app.orchestration.memory
179
+ if config.app.orchestration
180
+ else None,
181
+ chat_history=config.app.chat_history,
182
+ additional_tools=handoff_tools,
173
183
  )
174
184
  )
175
185
 
dao_ai/nodes.py CHANGED
@@ -19,9 +19,9 @@ from loguru import logger
19
19
  from dao_ai.config import (
20
20
  AgentModel,
21
21
  AppConfig,
22
- AppModel,
23
22
  ChatHistoryModel,
24
23
  FunctionHook,
24
+ MemoryModel,
25
25
  ToolModel,
26
26
  )
27
27
  from dao_ai.guardrails import reflection_guardrail, with_guardrails
@@ -31,12 +31,18 @@ from dao_ai.state import Context, IncomingState, SharedState
31
31
  from dao_ai.tools import create_tools
32
32
 
33
33
 
34
- def summarization_node(app_model: AppModel) -> RunnableLike:
35
- chat_history: ChatHistoryModel | None = app_model.chat_history
34
+ def summarization_node(chat_history: ChatHistoryModel) -> RunnableLike:
35
+ """
36
+ Create a summarization node for managing chat history.
37
+
38
+ Args:
39
+ chat_history: ChatHistoryModel configuration for summarization
40
+
41
+ Returns:
42
+ RunnableLike: A summarization node that processes messages
43
+ """
36
44
  if chat_history is None:
37
- raise ValueError(
38
- "AppModel must have chat_history configured to use summarization"
39
- )
45
+ raise ValueError("chat_history must be provided to use summarization")
40
46
 
41
47
  max_tokens: int = chat_history.max_tokens
42
48
  max_tokens_before_summary: int | None = chat_history.max_tokens_before_summary
@@ -93,23 +99,26 @@ def call_agent_with_summarized_messages(agent: CompiledStateGraph) -> RunnableLi
93
99
 
94
100
 
95
101
  def create_agent_node(
96
- app: AppModel,
97
102
  agent: AgentModel,
103
+ memory: Optional[MemoryModel] = None,
104
+ chat_history: Optional[ChatHistoryModel] = None,
98
105
  additional_tools: Optional[Sequence[BaseTool]] = None,
99
106
  ) -> RunnableLike:
100
107
  """
101
108
  Factory function that creates a LangGraph node for a specialized agent.
102
109
 
103
- This creates a node function that handles user requests using a specialized agent
104
- based on the provided agent_type. The function configures the agent with the
105
- appropriate model, prompt, tools, and guardrails from the model_config.
110
+ This creates a node function that handles user requests using a specialized agent.
111
+ The function configures the agent with the appropriate model, prompt, tools, and guardrails.
112
+ If chat_history is provided, it creates a workflow with summarization node.
106
113
 
107
114
  Args:
108
- model_config: Configuration containing models, prompts, tools, and guardrails
109
- agent_type: Type of agent to create (e.g., "general", "product", "inventory")
115
+ agent: AgentModel configuration for the agent
116
+ memory: Optional MemoryModel for memory store configuration
117
+ chat_history: Optional ChatHistoryModel for chat history summarization
118
+ additional_tools: Optional sequence of additional tools to add to the agent
110
119
 
111
120
  Returns:
112
- An agent callable function that processes state and returns responses
121
+ RunnableLike: An agent node that processes state and returns responses
113
122
  """
114
123
  logger.debug(f"Creating agent node for {agent.name}")
115
124
 
@@ -124,10 +133,10 @@ def create_agent_node(
124
133
  additional_tools = []
125
134
  tools: Sequence[BaseTool] = create_tools(tool_models) + additional_tools
126
135
 
127
- if app.orchestration.memory and app.orchestration.memory.store:
136
+ if memory and memory.store:
128
137
  namespace: tuple[str, ...] = ("memory",)
129
- if app.orchestration.memory.store.namespace:
130
- namespace = namespace + (app.orchestration.memory.store.namespace,)
138
+ if memory.store.namespace:
139
+ namespace = namespace + (memory.store.namespace,)
131
140
  logger.debug(f"Memory store namespace: {namespace}")
132
141
 
133
142
  tools += [
@@ -145,13 +154,15 @@ def create_agent_node(
145
154
  )
146
155
  logger.debug(f"post_agent_hook: {post_agent_hook}")
147
156
 
157
+ checkpointer: bool = memory and memory.checkpointer is not None
158
+
148
159
  compiled_agent: CompiledStateGraph = create_react_agent(
149
160
  name=agent.name,
150
161
  model=llm,
151
162
  prompt=make_prompt(agent.prompt),
152
163
  tools=tools,
153
164
  store=True,
154
- checkpointer=True,
165
+ checkpointer=checkpointer,
155
166
  state_schema=SharedState,
156
167
  context_schema=Context,
157
168
  pre_model_hook=pre_agent_hook,
@@ -166,8 +177,6 @@ def create_agent_node(
166
177
 
167
178
  agent_node: CompiledStateGraph
168
179
 
169
- chat_history: ChatHistoryModel = app.chat_history
170
-
171
180
  if chat_history is None:
172
181
  logger.debug("No chat history configured, using compiled agent directly")
173
182
  agent_node = compiled_agent
@@ -179,7 +188,7 @@ def create_agent_node(
179
188
  input=SharedState,
180
189
  output=SharedState,
181
190
  )
182
- workflow.add_node("summarization", summarization_node(app))
191
+ workflow.add_node("summarization", summarization_node(chat_history))
183
192
  workflow.add_node(
184
193
  "agent",
185
194
  call_agent_with_summarized_messages(agent=compiled_agent),