dao-ai 0.0.24__py3-none-any.whl → 0.0.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dao_ai/cli.py CHANGED
@@ -460,6 +460,49 @@ def setup_logging(verbosity: int) -> None:
460
460
  logger.add(sys.stderr, level=level)
461
461
 
462
462
 
463
+ def generate_bundle_from_template(config_path: Path, app_name: str) -> Path:
464
+ """
465
+ Generate an app-specific databricks.yaml from databricks.yaml.template.
466
+
467
+ This function:
468
+ 1. Reads databricks.yaml.template (permanent template file)
469
+ 2. Replaces __APP_NAME__ with the actual app name
470
+ 3. Writes to databricks.yaml (overwrites if exists)
471
+ 4. Returns the path to the generated file
472
+
473
+ The generated databricks.yaml is overwritten on each deployment and is not tracked in git.
474
+ Schema reference remains pointing to ./schemas/bundle_config_schema.json.
475
+
476
+ Args:
477
+ config_path: Path to the app config file
478
+ app_name: Normalized app name
479
+
480
+ Returns:
481
+ Path to the generated databricks.yaml file
482
+ """
483
+ cwd = Path.cwd()
484
+ template_path = cwd / "databricks.yaml.template"
485
+ output_path = cwd / "databricks.yaml"
486
+
487
+ if not template_path.exists():
488
+ logger.error(f"Template file {template_path} does not exist.")
489
+ sys.exit(1)
490
+
491
+ # Read template
492
+ with open(template_path, "r") as f:
493
+ template_content = f.read()
494
+
495
+ # Replace template variables
496
+ bundle_content = template_content.replace("__APP_NAME__", app_name)
497
+
498
+ # Write generated databricks.yaml (overwrite if exists)
499
+ with open(output_path, "w") as f:
500
+ f.write(bundle_content)
501
+
502
+ logger.info(f"Generated bundle configuration at {output_path} from template")
503
+ return output_path
504
+
505
+
463
506
  def run_databricks_command(
464
507
  command: list[str],
465
508
  profile: Optional[str] = None,
@@ -467,44 +510,55 @@ def run_databricks_command(
467
510
  target: Optional[str] = None,
468
511
  dry_run: bool = False,
469
512
  ) -> None:
470
- """Execute a databricks CLI command with optional profile."""
513
+ """Execute a databricks CLI command with optional profile and target."""
514
+ config_path = Path(config) if config else None
515
+
516
+ if config_path and not config_path.exists():
517
+ logger.error(f"Configuration file {config_path} does not exist.")
518
+ sys.exit(1)
519
+
520
+ # Load app config and generate bundle from template
521
+ app_config: AppConfig = AppConfig.from_file(config_path) if config_path else None
522
+ normalized_name: str = normalize_name(app_config.app.name) if app_config else None
523
+
524
+ # Generate app-specific bundle from template (overwrites databricks.yaml temporarily)
525
+ if config_path and app_config:
526
+ generate_bundle_from_template(config_path, normalized_name)
527
+
528
+ # Use app name as target if not explicitly provided
529
+ # This ensures each app gets its own Terraform state in .databricks/bundle/<app-name>/
530
+ if not target and normalized_name:
531
+ target = normalized_name
532
+ logger.debug(f"Using app-specific target: {target}")
533
+
534
+ # Build databricks command (no -c flag needed, uses databricks.yaml in current dir)
471
535
  cmd = ["databricks"]
472
536
  if profile:
473
537
  cmd.extend(["--profile", profile])
538
+
474
539
  if target:
475
540
  cmd.extend(["--target", target])
476
- cmd.extend(command)
477
- if config:
478
- config_path = Path(config)
479
541
 
480
- if not config_path.exists():
481
- logger.error(f"Configuration file {config_path} does not exist.")
482
- sys.exit(1)
483
-
484
- app_config: AppConfig = AppConfig.from_file(config_path)
542
+ cmd.extend(command)
485
543
 
486
- # Always convert to path relative to notebooks directory
487
- # Get absolute path of config file and current working directory
544
+ # Add config_path variable for notebooks
545
+ if config_path and app_config:
546
+ # Calculate relative path from notebooks directory to config file
488
547
  config_abs = config_path.resolve()
489
548
  cwd = Path.cwd()
490
549
  notebooks_dir = cwd / "notebooks"
491
550
 
492
- # Calculate relative path from notebooks directory to config file
493
551
  try:
494
552
  relative_config = config_abs.relative_to(notebooks_dir)
495
553
  except ValueError:
496
- # Config file is outside notebooks directory, calculate relative path
497
- # Use os.path.relpath to get the relative path from notebooks_dir to config file
498
554
  relative_config = Path(os.path.relpath(config_abs, notebooks_dir))
499
555
 
500
556
  cmd.append(f'--var="config_path={relative_config}"')
501
557
 
502
- normalized_name: str = normalize_name(app_config.app.name)
503
- cmd.append(f'--var="app_name={normalized_name}"')
504
-
505
558
  logger.debug(f"Executing command: {' '.join(cmd)}")
506
559
 
507
560
  if dry_run:
561
+ logger.info(f"[DRY RUN] Would execute: {' '.join(cmd)}")
508
562
  return
509
563
 
510
564
  try:
@@ -531,6 +585,9 @@ def run_databricks_command(
531
585
  except FileNotFoundError:
532
586
  logger.error("databricks CLI not found. Please install the Databricks CLI.")
533
587
  sys.exit(1)
588
+ except Exception as e:
589
+ logger.error(f"Command execution failed: {e}")
590
+ sys.exit(1)
534
591
 
535
592
 
536
593
  def handle_bundle_command(options: Namespace) -> None:
@@ -539,6 +596,7 @@ def handle_bundle_command(options: Namespace) -> None:
539
596
  config: Optional[str] = options.config
540
597
  target: Optional[str] = options.target
541
598
  dry_run: bool = options.dry_run
599
+
542
600
  if options.deploy:
543
601
  logger.info("Deploying DAO AI asset bundle...")
544
602
  run_databricks_command(
@@ -546,8 +604,9 @@ def handle_bundle_command(options: Namespace) -> None:
546
604
  )
547
605
  if options.run:
548
606
  logger.info("Running DAO AI system with current configuration...")
607
+ # Use static job resource key that matches databricks.yaml (resources.jobs.deploy_job)
549
608
  run_databricks_command(
550
- ["bundle", "run", "deploy-end-to-end"],
609
+ ["bundle", "run", "deploy_job"],
551
610
  profile,
552
611
  config,
553
612
  target,
dao_ai/config.py CHANGED
@@ -30,12 +30,15 @@ from databricks_langchain import (
30
30
  DatabricksFunctionClient,
31
31
  )
32
32
  from langchain_core.language_models import LanguageModelLike
33
+ from langchain_core.messages import BaseMessage, messages_from_dict
33
34
  from langchain_core.runnables.base import RunnableLike
34
35
  from langchain_openai import ChatOpenAI
35
36
  from langgraph.checkpoint.base import BaseCheckpointSaver
36
37
  from langgraph.graph.state import CompiledStateGraph
37
38
  from langgraph.store.base import BaseStore
38
39
  from loguru import logger
40
+ from mlflow.genai.datasets import EvaluationDataset, create_dataset, get_dataset
41
+ from mlflow.genai.prompts import PromptVersion, load_prompt
39
42
  from mlflow.models import ModelConfig
40
43
  from mlflow.models.resources import (
41
44
  DatabricksFunction,
@@ -49,6 +52,9 @@ from mlflow.models.resources import (
49
52
  DatabricksVectorSearchIndex,
50
53
  )
51
54
  from mlflow.pyfunc import ChatModel, ResponsesAgent
55
+ from mlflow.types.responses import (
56
+ ResponsesAgentRequest,
57
+ )
52
58
  from pydantic import (
53
59
  BaseModel,
54
60
  ConfigDict,
@@ -324,6 +330,10 @@ class LLMModel(BaseModel, IsDatabricksResource):
324
330
  "serving.serving-endpoints",
325
331
  ]
326
332
 
333
+ @property
334
+ def uri(self) -> str:
335
+ return f"databricks:/{self.name}"
336
+
327
337
  def as_resources(self) -> Sequence[DatabricksResource]:
328
338
  return [
329
339
  DatabricksServingEndpoint(
@@ -1181,17 +1191,32 @@ class PromptModel(BaseModel, HasFullName):
1181
1191
  from dao_ai.providers.databricks import DatabricksProvider
1182
1192
 
1183
1193
  provider: DatabricksProvider = DatabricksProvider()
1184
- prompt: str = provider.get_prompt(self)
1185
- return prompt
1194
+ prompt_version = provider.get_prompt(self)
1195
+ return prompt_version.to_single_brace_format()
1186
1196
 
1187
1197
  @property
1188
1198
  def full_name(self) -> str:
1199
+ prompt_name: str = self.name
1189
1200
  if self.schema_model:
1190
- name: str = ""
1191
- if self.name:
1192
- name = f".{self.name}"
1193
- return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}{name}"
1194
- return self.name
1201
+ prompt_name = f"{self.schema_model.full_name}.{prompt_name}"
1202
+ return prompt_name
1203
+
1204
+ @property
1205
+ def uri(self) -> str:
1206
+ prompt_uri: str = f"prompts:/{self.full_name}"
1207
+
1208
+ if self.alias:
1209
+ prompt_uri = f"prompts:/{self.full_name}@{self.alias}"
1210
+ elif self.version:
1211
+ prompt_uri = f"prompts:/{self.full_name}/{self.version}"
1212
+ else:
1213
+ prompt_uri = f"prompts:/{self.full_name}@latest"
1214
+
1215
+ return prompt_uri
1216
+
1217
+ def as_prompt(self) -> PromptVersion:
1218
+ prompt_version: PromptVersion = load_prompt(self.uri)
1219
+ return prompt_version
1195
1220
 
1196
1221
  @model_validator(mode="after")
1197
1222
  def validate_mutually_exclusive(self):
@@ -1213,6 +1238,17 @@ class AgentModel(BaseModel):
1213
1238
  pre_agent_hook: Optional[FunctionHook] = None
1214
1239
  post_agent_hook: Optional[FunctionHook] = None
1215
1240
 
1241
+ def as_runnable(self) -> RunnableLike:
1242
+ from dao_ai.nodes import create_agent_node
1243
+
1244
+ return create_agent_node(self)
1245
+
1246
+ def as_responses_agent(self) -> ResponsesAgent:
1247
+ from dao_ai.models import create_responses_agent
1248
+
1249
+ graph: CompiledStateGraph = self.as_runnable()
1250
+ return create_responses_agent(graph)
1251
+
1216
1252
 
1217
1253
  class SupervisorModel(BaseModel):
1218
1254
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
@@ -1301,7 +1337,7 @@ class ChatPayload(BaseModel):
1301
1337
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1302
1338
  input: Optional[list[Message]] = None
1303
1339
  messages: Optional[list[Message]] = None
1304
- custom_inputs: dict
1340
+ custom_inputs: Optional[dict] = Field(default_factory=dict)
1305
1341
 
1306
1342
  @model_validator(mode="after")
1307
1343
  def validate_mutual_exclusion_and_alias(self) -> "ChatPayload":
@@ -1330,6 +1366,19 @@ class ChatPayload(BaseModel):
1330
1366
 
1331
1367
  return self
1332
1368
 
1369
+ def as_messages(self) -> Sequence[BaseMessage]:
1370
+ return messages_from_dict(
1371
+ [{"type": m.role, "content": m.content} for m in self.messages]
1372
+ )
1373
+
1374
+ def as_agent_request(self) -> ResponsesAgentRequest:
1375
+ from mlflow.types.responses_helpers import Message as _Message
1376
+
1377
+ return ResponsesAgentRequest(
1378
+ input=[_Message(role=m.role, content=m.content) for m in self.messages],
1379
+ custom_inputs=self.custom_inputs,
1380
+ )
1381
+
1333
1382
 
1334
1383
  class ChatHistoryModel(BaseModel):
1335
1384
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
@@ -1459,6 +1508,174 @@ class EvaluationModel(BaseModel):
1459
1508
  guidelines: list[GuidelineModel] = Field(default_factory=list)
1460
1509
 
1461
1510
 
1511
+ class EvaluationDatasetExpectationsModel(BaseModel):
1512
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1513
+ expected_response: Optional[str] = None
1514
+ expected_facts: Optional[list[str]] = None
1515
+
1516
+ @model_validator(mode="after")
1517
+ def validate_mutually_exclusive(self):
1518
+ if self.expected_response is not None and self.expected_facts is not None:
1519
+ raise ValueError("Cannot specify both expected_response and expected_facts")
1520
+ return self
1521
+
1522
+
1523
+ class EvaluationDatasetEntryModel(BaseModel):
1524
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1525
+ inputs: ChatPayload
1526
+ expectations: EvaluationDatasetExpectationsModel
1527
+
1528
+ def to_mlflow_format(self) -> dict[str, Any]:
1529
+ """
1530
+ Convert to MLflow evaluation dataset format.
1531
+
1532
+ Flattens the expectations fields to the top level alongside inputs,
1533
+ which is the format expected by MLflow's Correctness scorer.
1534
+
1535
+ Returns:
1536
+ dict: Flattened dictionary with inputs and expectation fields at top level
1537
+ """
1538
+ result: dict[str, Any] = {"inputs": self.inputs.model_dump()}
1539
+
1540
+ # Flatten expectations to top level for MLflow compatibility
1541
+ if self.expectations.expected_response is not None:
1542
+ result["expected_response"] = self.expectations.expected_response
1543
+ if self.expectations.expected_facts is not None:
1544
+ result["expected_facts"] = self.expectations.expected_facts
1545
+
1546
+ return result
1547
+
1548
+
1549
+ class EvaluationDatasetModel(BaseModel, HasFullName):
1550
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1551
+ schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
1552
+ name: str
1553
+ data: Optional[list[EvaluationDatasetEntryModel]] = Field(default_factory=list)
1554
+ overwrite: Optional[bool] = False
1555
+
1556
+ def as_dataset(self, w: WorkspaceClient | None = None) -> EvaluationDataset:
1557
+ evaluation_dataset: EvaluationDataset
1558
+ needs_creation: bool = False
1559
+
1560
+ try:
1561
+ evaluation_dataset = get_dataset(name=self.full_name)
1562
+ if self.overwrite:
1563
+ logger.warning(f"Overwriting dataset {self.full_name}")
1564
+ workspace_client: WorkspaceClient = w if w else WorkspaceClient()
1565
+ logger.debug(f"Dropping table: {self.full_name}")
1566
+ workspace_client.tables.delete(full_name=self.full_name)
1567
+ needs_creation = True
1568
+ except Exception:
1569
+ logger.warning(
1570
+ f"Dataset {self.full_name} not found, will create new dataset"
1571
+ )
1572
+ needs_creation = True
1573
+
1574
+ # Create dataset if needed (either new or after overwrite)
1575
+ if needs_creation:
1576
+ evaluation_dataset = create_dataset(name=self.full_name)
1577
+ if self.data:
1578
+ logger.debug(
1579
+ f"Merging {len(self.data)} entries into dataset {self.full_name}"
1580
+ )
1581
+ # Use to_mlflow_format() to flatten expectations for MLflow compatibility
1582
+ evaluation_dataset.merge_records(
1583
+ [e.to_mlflow_format() for e in self.data]
1584
+ )
1585
+
1586
+ return evaluation_dataset
1587
+
1588
+ @property
1589
+ def full_name(self) -> str:
1590
+ if self.schema_model:
1591
+ return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}.{self.name}"
1592
+ return self.name
1593
+
1594
+
1595
+ class PromptOptimizationModel(BaseModel):
1596
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1597
+ name: str
1598
+ prompt: Optional[PromptModel] = None
1599
+ agent: AgentModel
1600
+ dataset: (
1601
+ EvaluationDatasetModel | str
1602
+ ) # Reference to dataset name (looked up in OptimizationsModel.training_datasets or MLflow)
1603
+ reflection_model: Optional[LLMModel | str] = None
1604
+ num_candidates: Optional[int] = 50
1605
+ scorer_model: Optional[LLMModel | str] = None
1606
+
1607
+ def optimize(self, w: WorkspaceClient | None = None) -> PromptModel:
1608
+ """
1609
+ Optimize the prompt using MLflow's prompt optimization.
1610
+
1611
+ Args:
1612
+ w: Optional WorkspaceClient for Databricks operations
1613
+
1614
+ Returns:
1615
+ PromptModel: The optimized prompt model with new URI
1616
+ """
1617
+ from dao_ai.providers.base import ServiceProvider
1618
+ from dao_ai.providers.databricks import DatabricksProvider
1619
+
1620
+ provider: ServiceProvider = DatabricksProvider(w=w)
1621
+ optimized_prompt: PromptModel = provider.optimize_prompt(self)
1622
+ return optimized_prompt
1623
+
1624
+ @model_validator(mode="after")
1625
+ def set_defaults(self):
1626
+ # If no prompt is specified, try to use the agent's prompt
1627
+ if self.prompt is None:
1628
+ if isinstance(self.agent.prompt, PromptModel):
1629
+ self.prompt = self.agent.prompt
1630
+ else:
1631
+ raise ValueError(
1632
+ f"Prompt optimization '{self.name}' requires either an explicit prompt "
1633
+ f"or an agent with a prompt configured"
1634
+ )
1635
+
1636
+ if self.reflection_model is None:
1637
+ self.reflection_model = self.agent.model
1638
+
1639
+ if self.scorer_model is None:
1640
+ self.scorer_model = self.agent.model
1641
+
1642
+ return self
1643
+
1644
+
1645
+ class OptimizationsModel(BaseModel):
1646
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
1647
+ training_datasets: dict[str, EvaluationDatasetModel] = Field(default_factory=dict)
1648
+ prompt_optimizations: dict[str, PromptOptimizationModel] = Field(
1649
+ default_factory=dict
1650
+ )
1651
+
1652
+ def optimize(self, w: WorkspaceClient | None = None) -> dict[str, PromptModel]:
1653
+ """
1654
+ Optimize all prompts in this configuration.
1655
+
1656
+ This method:
1657
+ 1. Ensures all training datasets are created/registered in MLflow
1658
+ 2. Runs each prompt optimization
1659
+
1660
+ Args:
1661
+ w: Optional WorkspaceClient for Databricks operations
1662
+
1663
+ Returns:
1664
+ dict[str, PromptModel]: Dictionary mapping optimization names to optimized prompts
1665
+ """
1666
+ # First, ensure all training datasets are created/registered in MLflow
1667
+ logger.info(f"Ensuring {len(self.training_datasets)} training datasets exist")
1668
+ for dataset_name, dataset_model in self.training_datasets.items():
1669
+ logger.debug(f"Creating/updating dataset: {dataset_name}")
1670
+ dataset_model.as_dataset()
1671
+
1672
+ # Run optimizations
1673
+ results: dict[str, PromptModel] = {}
1674
+ for name, optimization in self.prompt_optimizations.items():
1675
+ results[name] = optimization.optimize(w)
1676
+ return results
1677
+
1678
+
1462
1679
  class DatasetFormat(str, Enum):
1463
1680
  CSV = "csv"
1464
1681
  DELTA = "delta"
@@ -1537,6 +1754,7 @@ class AppConfig(BaseModel):
1537
1754
  agents: dict[str, AgentModel] = Field(default_factory=dict)
1538
1755
  app: Optional[AppModel] = None
1539
1756
  evaluation: Optional[EvaluationModel] = None
1757
+ optimizations: Optional[OptimizationsModel] = None
1540
1758
  datasets: Optional[list[DatasetModel]] = Field(default_factory=list)
1541
1759
  unity_catalog_functions: Optional[list[UnityCatalogFunctionSqlModel]] = Field(
1542
1760
  default_factory=list
dao_ai/graph.py CHANGED
@@ -79,7 +79,12 @@ def _create_supervisor_graph(config: AppConfig) -> CompiledStateGraph:
79
79
  for registered_agent in config.app.agents:
80
80
  agents.append(
81
81
  create_agent_node(
82
- app=config.app, agent=registered_agent, additional_tools=[]
82
+ agent=registered_agent,
83
+ memory=config.app.orchestration.memory
84
+ if config.app.orchestration
85
+ else None,
86
+ chat_history=config.app.chat_history,
87
+ additional_tools=[],
83
88
  )
84
89
  )
85
90
  tools.append(
@@ -169,7 +174,12 @@ def _create_swarm_graph(config: AppConfig) -> CompiledStateGraph:
169
174
  )
170
175
  agents.append(
171
176
  create_agent_node(
172
- app=config.app, agent=registered_agent, additional_tools=handoff_tools
177
+ agent=registered_agent,
178
+ memory=config.app.orchestration.memory
179
+ if config.app.orchestration
180
+ else None,
181
+ chat_history=config.app.chat_history,
182
+ additional_tools=handoff_tools,
173
183
  )
174
184
  )
175
185
 
dao_ai/nodes.py CHANGED
@@ -19,9 +19,9 @@ from loguru import logger
19
19
  from dao_ai.config import (
20
20
  AgentModel,
21
21
  AppConfig,
22
- AppModel,
23
22
  ChatHistoryModel,
24
23
  FunctionHook,
24
+ MemoryModel,
25
25
  ToolModel,
26
26
  )
27
27
  from dao_ai.guardrails import reflection_guardrail, with_guardrails
@@ -31,12 +31,18 @@ from dao_ai.state import Context, IncomingState, SharedState
31
31
  from dao_ai.tools import create_tools
32
32
 
33
33
 
34
- def summarization_node(app_model: AppModel) -> RunnableLike:
35
- chat_history: ChatHistoryModel | None = app_model.chat_history
34
+ def summarization_node(chat_history: ChatHistoryModel) -> RunnableLike:
35
+ """
36
+ Create a summarization node for managing chat history.
37
+
38
+ Args:
39
+ chat_history: ChatHistoryModel configuration for summarization
40
+
41
+ Returns:
42
+ RunnableLike: A summarization node that processes messages
43
+ """
36
44
  if chat_history is None:
37
- raise ValueError(
38
- "AppModel must have chat_history configured to use summarization"
39
- )
45
+ raise ValueError("chat_history must be provided to use summarization")
40
46
 
41
47
  max_tokens: int = chat_history.max_tokens
42
48
  max_tokens_before_summary: int | None = chat_history.max_tokens_before_summary
@@ -93,23 +99,26 @@ def call_agent_with_summarized_messages(agent: CompiledStateGraph) -> RunnableLi
93
99
 
94
100
 
95
101
  def create_agent_node(
96
- app: AppModel,
97
102
  agent: AgentModel,
103
+ memory: Optional[MemoryModel] = None,
104
+ chat_history: Optional[ChatHistoryModel] = None,
98
105
  additional_tools: Optional[Sequence[BaseTool]] = None,
99
106
  ) -> RunnableLike:
100
107
  """
101
108
  Factory function that creates a LangGraph node for a specialized agent.
102
109
 
103
- This creates a node function that handles user requests using a specialized agent
104
- based on the provided agent_type. The function configures the agent with the
105
- appropriate model, prompt, tools, and guardrails from the model_config.
110
+ This creates a node function that handles user requests using a specialized agent.
111
+ The function configures the agent with the appropriate model, prompt, tools, and guardrails.
112
+ If chat_history is provided, it creates a workflow with summarization node.
106
113
 
107
114
  Args:
108
- model_config: Configuration containing models, prompts, tools, and guardrails
109
- agent_type: Type of agent to create (e.g., "general", "product", "inventory")
115
+ agent: AgentModel configuration for the agent
116
+ memory: Optional MemoryModel for memory store configuration
117
+ chat_history: Optional ChatHistoryModel for chat history summarization
118
+ additional_tools: Optional sequence of additional tools to add to the agent
110
119
 
111
120
  Returns:
112
- An agent callable function that processes state and returns responses
121
+ RunnableLike: An agent node that processes state and returns responses
113
122
  """
114
123
  logger.debug(f"Creating agent node for {agent.name}")
115
124
 
@@ -124,10 +133,10 @@ def create_agent_node(
124
133
  additional_tools = []
125
134
  tools: Sequence[BaseTool] = create_tools(tool_models) + additional_tools
126
135
 
127
- if app.orchestration.memory and app.orchestration.memory.store:
136
+ if memory and memory.store:
128
137
  namespace: tuple[str, ...] = ("memory",)
129
- if app.orchestration.memory.store.namespace:
130
- namespace = namespace + (app.orchestration.memory.store.namespace,)
138
+ if memory.store.namespace:
139
+ namespace = namespace + (memory.store.namespace,)
131
140
  logger.debug(f"Memory store namespace: {namespace}")
132
141
 
133
142
  tools += [
@@ -145,13 +154,15 @@ def create_agent_node(
145
154
  )
146
155
  logger.debug(f"post_agent_hook: {post_agent_hook}")
147
156
 
157
+ checkpointer: bool = memory and memory.checkpointer is not None
158
+
148
159
  compiled_agent: CompiledStateGraph = create_react_agent(
149
160
  name=agent.name,
150
161
  model=llm,
151
162
  prompt=make_prompt(agent.prompt),
152
163
  tools=tools,
153
164
  store=True,
154
- checkpointer=True,
165
+ checkpointer=checkpointer,
155
166
  state_schema=SharedState,
156
167
  context_schema=Context,
157
168
  pre_model_hook=pre_agent_hook,
@@ -166,8 +177,6 @@ def create_agent_node(
166
177
 
167
178
  agent_node: CompiledStateGraph
168
179
 
169
- chat_history: ChatHistoryModel = app.chat_history
170
-
171
180
  if chat_history is None:
172
181
  logger.debug("No chat history configured, using compiled agent directly")
173
182
  agent_node = compiled_agent
@@ -179,7 +188,7 @@ def create_agent_node(
179
188
  input=SharedState,
180
189
  output=SharedState,
181
190
  )
182
- workflow.add_node("summarization", summarization_node(app))
191
+ workflow.add_node("summarization", summarization_node(chat_history))
183
192
  workflow.add_node(
184
193
  "agent",
185
194
  call_agent_with_summarized_messages(agent=compiled_agent),