dao-ai 0.1.17__py3-none-any.whl → 0.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dao_ai/memory/postgres.py CHANGED
@@ -3,6 +3,7 @@ import atexit
3
3
  import threading
4
4
  from typing import Any, Optional
5
5
 
6
+ from databricks_ai_bridge.lakebase import AsyncLakebasePool, LakebasePool
6
7
  from langgraph.checkpoint.base import BaseCheckpointSaver
7
8
  from langgraph.checkpoint.postgres import ShallowPostgresSaver
8
9
  from langgraph.checkpoint.postgres.aio import AsyncShallowPostgresSaver
@@ -86,13 +87,22 @@ async def _create_async_pool(
86
87
 
87
88
 
88
89
  class AsyncPostgresPoolManager:
90
+ """
91
+ Asynchronous PostgreSQL connection pool manager that shares pools
92
+ based on database configuration.
93
+
94
+ For Lakebase connections (when instance_name is provided), uses AsyncLakebasePool
95
+ from databricks_ai_bridge which handles automatic token rotation and host resolution.
96
+ For standard PostgreSQL connections, uses psycopg_pool.AsyncConnectionPool.
97
+ """
98
+
89
99
  _pools: dict[str, AsyncConnectionPool] = {}
100
+ _lakebase_pools: dict[str, AsyncLakebasePool] = {}
90
101
  _lock: asyncio.Lock = asyncio.Lock()
91
102
 
92
103
  @classmethod
93
104
  async def get_pool(cls, database: DatabaseModel) -> AsyncConnectionPool:
94
105
  connection_key: str = database.name
95
- connection_params: dict[str, Any] = database.connection_params
96
106
 
97
107
  async with cls._lock:
98
108
  if connection_key in cls._pools:
@@ -103,19 +113,43 @@ class AsyncPostgresPoolManager:
103
113
 
104
114
  logger.debug("Creating new async PostgreSQL pool", database=database.name)
105
115
 
106
- kwargs: dict[str, Any] = {
107
- "row_factory": dict_row,
108
- "autocommit": True,
109
- } | database.connection_kwargs or {}
110
-
111
- # Create connection pool
112
- pool: AsyncConnectionPool = await _create_async_pool(
113
- connection_params=connection_params,
114
- database_name=database.name,
115
- max_pool_size=database.max_pool_size,
116
- timeout_seconds=database.timeout_seconds,
117
- kwargs=kwargs,
118
- )
116
+ if database.is_lakebase:
117
+ # Use AsyncLakebasePool for Lakebase connections
118
+ # AsyncLakebasePool handles automatic token rotation and host resolution
119
+ lakebase_pool = AsyncLakebasePool(
120
+ instance_name=database.instance_name,
121
+ workspace_client=database.workspace_client,
122
+ min_size=1,
123
+ max_size=database.max_pool_size,
124
+ timeout=float(database.timeout_seconds),
125
+ )
126
+ # Open the async pool
127
+ await lakebase_pool.open()
128
+ # Store the AsyncLakebasePool for proper cleanup
129
+ cls._lakebase_pools[connection_key] = lakebase_pool
130
+ # Get the underlying AsyncConnectionPool
131
+ pool = lakebase_pool.pool
132
+ logger.success(
133
+ "Async Lakebase connection pool created",
134
+ database=database.name,
135
+ instance_name=database.instance_name,
136
+ pool_size=database.max_pool_size,
137
+ )
138
+ else:
139
+ # Use standard async PostgreSQL pool for non-Lakebase connections
140
+ connection_params: dict[str, Any] = database.connection_params
141
+ kwargs: dict[str, Any] = {
142
+ "row_factory": dict_row,
143
+ "autocommit": True,
144
+ } | database.connection_kwargs or {}
145
+
146
+ pool = await _create_async_pool(
147
+ connection_params=connection_params,
148
+ database_name=database.name,
149
+ max_pool_size=database.max_pool_size,
150
+ timeout_seconds=database.timeout_seconds,
151
+ kwargs=kwargs,
152
+ )
119
153
 
120
154
  cls._pools[connection_key] = pool
121
155
  return pool
@@ -125,7 +159,13 @@ class AsyncPostgresPoolManager:
125
159
  connection_key: str = database.name
126
160
 
127
161
  async with cls._lock:
128
- if connection_key in cls._pools:
162
+ # Close AsyncLakebasePool if it exists (handles underlying pool cleanup)
163
+ if connection_key in cls._lakebase_pools:
164
+ lakebase_pool = cls._lakebase_pools.pop(connection_key)
165
+ await lakebase_pool.close()
166
+ cls._pools.pop(connection_key, None)
167
+ logger.debug("Async Lakebase pool closed", database=database.name)
168
+ elif connection_key in cls._pools:
129
169
  pool = cls._pools.pop(connection_key)
130
170
  await pool.close()
131
171
  logger.debug("Async PostgreSQL pool closed", database=database.name)
@@ -133,9 +173,32 @@ class AsyncPostgresPoolManager:
133
173
  @classmethod
134
174
  async def close_all_pools(cls):
135
175
  async with cls._lock:
176
+ # Close all AsyncLakebasePool instances first
177
+ for connection_key, lakebase_pool in cls._lakebase_pools.items():
178
+ try:
179
+ await asyncio.wait_for(lakebase_pool.close(), timeout=2.0)
180
+ logger.debug("Async Lakebase pool closed", pool=connection_key)
181
+ except asyncio.TimeoutError:
182
+ logger.warning(
183
+ "Timeout closing async Lakebase pool, forcing closure",
184
+ pool=connection_key,
185
+ )
186
+ except asyncio.CancelledError:
187
+ logger.warning(
188
+ "Async Lakebase pool closure cancelled (shutdown in progress)",
189
+ pool=connection_key,
190
+ )
191
+ except Exception as e:
192
+ logger.error(
193
+ "Error closing async Lakebase pool",
194
+ pool=connection_key,
195
+ error=str(e),
196
+ )
197
+ cls._lakebase_pools.clear()
198
+
199
+ # Close any remaining standard async PostgreSQL pools
136
200
  for connection_key, pool in cls._pools.items():
137
201
  try:
138
- # Use a short timeout to avoid blocking on pool closure
139
202
  await asyncio.wait_for(pool.close(), timeout=2.0)
140
203
  logger.debug("Async PostgreSQL pool closed", pool=connection_key)
141
204
  except asyncio.TimeoutError:
@@ -309,15 +372,19 @@ class PostgresPoolManager:
309
372
  """
310
373
  Synchronous PostgreSQL connection pool manager that shares pools
311
374
  based on database configuration.
375
+
376
+ For Lakebase connections (when instance_name is provided), uses LakebasePool
377
+ from databricks_ai_bridge which handles automatic token rotation and host resolution.
378
+ For standard PostgreSQL connections, uses psycopg_pool.ConnectionPool.
312
379
  """
313
380
 
314
381
  _pools: dict[str, ConnectionPool] = {}
382
+ _lakebase_pools: dict[str, LakebasePool] = {}
315
383
  _lock: threading.Lock = threading.Lock()
316
384
 
317
385
  @classmethod
318
386
  def get_pool(cls, database: DatabaseModel) -> ConnectionPool:
319
387
  connection_key: str = str(database.name)
320
- connection_params: dict[str, Any] = database.connection_params
321
388
 
322
389
  with cls._lock:
323
390
  if connection_key in cls._pools:
@@ -326,19 +393,41 @@ class PostgresPoolManager:
326
393
 
327
394
  logger.debug("Creating new PostgreSQL pool", database=database.name)
328
395
 
329
- kwargs: dict[str, Any] = {
330
- "row_factory": dict_row,
331
- "autocommit": True,
332
- } | database.connection_kwargs or {}
333
-
334
- # Create connection pool
335
- pool: ConnectionPool = _create_pool(
336
- connection_params=connection_params,
337
- database_name=database.name,
338
- max_pool_size=database.max_pool_size,
339
- timeout_seconds=database.timeout_seconds,
340
- kwargs=kwargs,
341
- )
396
+ if database.is_lakebase:
397
+ # Use LakebasePool for Lakebase connections
398
+ # LakebasePool handles automatic token rotation and host resolution
399
+ lakebase_pool = LakebasePool(
400
+ instance_name=database.instance_name,
401
+ workspace_client=database.workspace_client,
402
+ min_size=1,
403
+ max_size=database.max_pool_size,
404
+ timeout=float(database.timeout_seconds),
405
+ )
406
+ # Store the LakebasePool for proper cleanup
407
+ cls._lakebase_pools[connection_key] = lakebase_pool
408
+ # Get the underlying ConnectionPool
409
+ pool = lakebase_pool.pool
410
+ logger.success(
411
+ "Lakebase connection pool created",
412
+ database=database.name,
413
+ instance_name=database.instance_name,
414
+ pool_size=database.max_pool_size,
415
+ )
416
+ else:
417
+ # Use standard PostgreSQL pool for non-Lakebase connections
418
+ connection_params: dict[str, Any] = database.connection_params
419
+ kwargs: dict[str, Any] = {
420
+ "row_factory": dict_row,
421
+ "autocommit": True,
422
+ } | database.connection_kwargs or {}
423
+
424
+ pool = _create_pool(
425
+ connection_params=connection_params,
426
+ database_name=database.name,
427
+ max_pool_size=database.max_pool_size,
428
+ timeout_seconds=database.timeout_seconds,
429
+ kwargs=kwargs,
430
+ )
342
431
 
343
432
  cls._pools[connection_key] = pool
344
433
  return pool
@@ -348,7 +437,13 @@ class PostgresPoolManager:
348
437
  connection_key: str = database.name
349
438
 
350
439
  with cls._lock:
351
- if connection_key in cls._pools:
440
+ # Close LakebasePool if it exists (handles underlying pool cleanup)
441
+ if connection_key in cls._lakebase_pools:
442
+ lakebase_pool = cls._lakebase_pools.pop(connection_key)
443
+ lakebase_pool.close()
444
+ cls._pools.pop(connection_key, None)
445
+ logger.debug("Lakebase pool closed", database=database.name)
446
+ elif connection_key in cls._pools:
352
447
  pool = cls._pools.pop(connection_key)
353
448
  pool.close()
354
449
  logger.debug("PostgreSQL pool closed", database=database.name)
@@ -356,16 +451,32 @@ class PostgresPoolManager:
356
451
  @classmethod
357
452
  def close_all_pools(cls):
358
453
  with cls._lock:
359
- for connection_key, pool in cls._pools.items():
454
+ # Close all LakebasePool instances first
455
+ for connection_key, lakebase_pool in cls._lakebase_pools.items():
360
456
  try:
361
- pool.close()
362
- logger.debug("PostgreSQL pool closed", pool=connection_key)
457
+ lakebase_pool.close()
458
+ logger.debug("Lakebase pool closed", pool=connection_key)
363
459
  except Exception as e:
364
460
  logger.error(
365
- "Error closing PostgreSQL pool",
461
+ "Error closing Lakebase pool",
366
462
  pool=connection_key,
367
463
  error=str(e),
368
464
  )
465
+ cls._lakebase_pools.clear()
466
+
467
+ # Close any remaining standard PostgreSQL pools
468
+ for connection_key, pool in cls._pools.items():
469
+ # Skip if already closed via LakebasePool
470
+ if connection_key not in cls._lakebase_pools:
471
+ try:
472
+ pool.close()
473
+ logger.debug("PostgreSQL pool closed", pool=connection_key)
474
+ except Exception as e:
475
+ logger.error(
476
+ "Error closing PostgreSQL pool",
477
+ pool=connection_key,
478
+ error=str(e),
479
+ )
369
480
  cls._pools.clear()
370
481
 
371
482
 
@@ -9,7 +9,7 @@ This module provides the foundational utilities for multi-agent orchestration:
9
9
  - Main orchestration graph factory
10
10
  """
11
11
 
12
- from typing import Awaitable, Callable, Literal
12
+ from typing import Any, Awaitable, Callable, Literal
13
13
 
14
14
  from langchain.tools import ToolRuntime, tool
15
15
  from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
@@ -179,8 +179,16 @@ def create_agent_node_handler(
179
179
  "messages": filtered_messages,
180
180
  }
181
181
 
182
- # Invoke the agent
183
- result: AgentState = await agent.ainvoke(agent_state, context=runtime.context)
182
+ # Build config with configurable from context for langmem compatibility
183
+ # langmem tools expect user_id to be in config.configurable
184
+ config: dict[str, Any] = {}
185
+ if runtime.context:
186
+ config = {"configurable": runtime.context.model_dump()}
187
+
188
+ # Invoke the agent with both context and config
189
+ result: AgentState = await agent.ainvoke(
190
+ agent_state, context=runtime.context, config=config
191
+ )
184
192
 
185
193
  # Extract agent response based on output mode
186
194
  result_messages = result.get("messages", [])
@@ -227,15 +235,31 @@ def create_handoff_tool(
227
235
  tool_call_id: str = runtime.tool_call_id
228
236
  logger.debug("Handoff to agent", target_agent=target_agent_name)
229
237
 
238
+ # Get the AIMessage that triggered this handoff (required for tool_use/tool_result pairing)
239
+ # LLMs expect tool calls to be paired with their responses, so we must include both
240
+ # the AIMessage containing the tool call and the ToolMessage acknowledging it.
241
+ messages: list[BaseMessage] = runtime.state.get("messages", [])
242
+ last_ai_message: AIMessage | None = None
243
+ for msg in reversed(messages):
244
+ if isinstance(msg, AIMessage) and msg.tool_calls:
245
+ last_ai_message = msg
246
+ break
247
+
248
+ # Build message list with proper pairing
249
+ update_messages: list[BaseMessage] = []
250
+ if last_ai_message:
251
+ update_messages.append(last_ai_message)
252
+ update_messages.append(
253
+ ToolMessage(
254
+ content=f"Transferred to {target_agent_name}",
255
+ tool_call_id=tool_call_id,
256
+ )
257
+ )
258
+
230
259
  return Command(
231
260
  update={
232
261
  "active_agent": target_agent_name,
233
- "messages": [
234
- ToolMessage(
235
- content=f"Transferred to {target_agent_name}",
236
- tool_call_id=tool_call_id,
237
- )
238
- ],
262
+ "messages": update_messages,
239
263
  },
240
264
  goto=target_agent_name,
241
265
  graph=Command.PARENT,
@@ -13,7 +13,7 @@ from langchain.agents import create_agent
13
13
  from langchain.agents.middleware import AgentMiddleware as LangchainAgentMiddleware
14
14
  from langchain.tools import ToolRuntime, tool
15
15
  from langchain_core.language_models import LanguageModelLike
16
- from langchain_core.messages import ToolMessage
16
+ from langchain_core.messages import AIMessage, BaseMessage, ToolMessage
17
17
  from langchain_core.tools import BaseTool
18
18
  from langgraph.checkpoint.base import BaseCheckpointSaver
19
19
  from langgraph.graph import StateGraph
@@ -75,15 +75,30 @@ def _create_handoff_back_to_supervisor_tool() -> BaseTool:
75
75
  tool_call_id: str = runtime.tool_call_id
76
76
  logger.debug("Agent handing back to supervisor", summary_preview=summary[:100])
77
77
 
78
+ # Get the AIMessage that triggered this handoff (required for tool_use/tool_result pairing)
79
+ # LLMs expect tool calls to be paired with their responses, so we must include both
80
+ # the AIMessage containing the tool call and the ToolMessage acknowledging it.
81
+ messages: list[BaseMessage] = runtime.state.get("messages", [])
82
+ last_ai_message: AIMessage | None = None
83
+ for msg in reversed(messages):
84
+ if isinstance(msg, AIMessage) and msg.tool_calls:
85
+ last_ai_message = msg
86
+ break
87
+
88
+ # Build message list with proper pairing
89
+ update_messages: list[BaseMessage] = []
90
+ if last_ai_message:
91
+ update_messages.append(last_ai_message)
92
+ update_messages.append(
93
+ ToolMessage(
94
+ content=f"Task completed: {summary}",
95
+ tool_call_id=tool_call_id,
96
+ )
97
+ )
98
+
78
99
  return Command(
79
100
  update={
80
- "active_agent": None,
81
- "messages": [
82
- ToolMessage(
83
- content=f"Task completed: {summary}",
84
- tool_call_id=tool_call_id,
85
- )
86
- ],
101
+ "messages": update_messages,
87
102
  },
88
103
  goto=SUPERVISOR_NODE,
89
104
  graph=Command.PARENT,
@@ -2,9 +2,11 @@
2
2
  Prompt utilities for DAO AI agents.
3
3
 
4
4
  This module provides utilities for creating dynamic prompts using
5
- LangChain v1's @dynamic_prompt middleware decorator pattern.
5
+ LangChain v1's @dynamic_prompt middleware decorator pattern, as well as
6
+ paths to prompt template files.
6
7
  """
7
8
 
9
+ from pathlib import Path
8
10
  from typing import Any, Optional
9
11
 
10
12
  from langchain.agents.middleware import (
@@ -18,6 +20,13 @@ from loguru import logger
18
20
  from dao_ai.config import PromptModel
19
21
  from dao_ai.state import Context
20
22
 
23
+ PROMPTS_DIR = Path(__file__).parent
24
+
25
+
26
+ def get_prompt_path(name: str) -> Path:
27
+ """Get the path to a prompt template file."""
28
+ return PROMPTS_DIR / name
29
+
21
30
 
22
31
  def make_prompt(
23
32
  base_system_prompt: Optional[str | PromptModel],
@@ -0,0 +1,58 @@
1
+ name: instructed_retriever_decomposition
2
+ description: Decomposes user queries into multiple search queries with metadata filters
3
+
4
+ template: |
5
+ You are a search query decomposition expert. Your task is to break down a user query into one or more focused search queries with appropriate metadata filters. Respond with a JSON object.
6
+
7
+ ## Current Time
8
+ {current_time}
9
+
10
+ ## Database Schema
11
+ {schema_description}
12
+
13
+ ## Constraints
14
+ {constraints}
15
+
16
+ ## Few-Shot Examples
17
+ {examples}
18
+
19
+ ## Instructions
20
+ 1. Analyze the user query and identify distinct search intents
21
+ 2. For each intent, create a focused search query text
22
+ 3. Extract metadata filters from the query using the exact filter syntax above
23
+ 4. Resolve relative time references (e.g., "last month", "past year") using the current time
24
+ 5. Generate at most {max_subqueries} search queries
25
+ 6. If no filters apply, set filters to null
26
+
27
+ ## User Query
28
+ {query}
29
+
30
+ Generate search queries that together capture all aspects of the user's information need.
31
+
32
+ variables:
33
+ - current_time
34
+ - schema_description
35
+ - constraints
36
+ - examples
37
+ - max_subqueries
38
+ - query
39
+
40
+ output_format: |
41
+ The output must be a JSON object with a "queries" field containing an array of search query objects.
42
+ Each search query object has:
43
+ - "text": The search query string
44
+ - "filters": An array of filter objects, each with "key" (column + optional operator) and "value", or null if no filters
45
+
46
+ Supported filter operators (append to column name):
47
+ - Equality: {"key": "column", "value": "val"} or {"key": "column", "value": ["val1", "val2"]}
48
+ - Exclusion: {"key": "column NOT", "value": "val"}
49
+ - Comparison: {"key": "column <", "value": 100}, also <=, >, >=
50
+ - Token match: {"key": "column LIKE", "value": "word"}
51
+ - Exclude token: {"key": "column NOT LIKE", "value": "word"}
52
+
53
+ Examples:
54
+ - [{"key": "brand_name", "value": "MILWAUKEE"}]
55
+ - [{"key": "price <", "value": 100}]
56
+ - [{"key": "brand_name NOT", "value": "DEWALT"}]
57
+ - [{"key": "brand_name", "value": ["MILWAUKEE", "DEWALT"]}]
58
+ - [{"key": "description LIKE", "value": "cordless"}]
@@ -0,0 +1,14 @@
1
+ name: instruction_aware_reranking
2
+ version: "1.1"
3
+ description: Rerank documents based on user instructions and constraints
4
+
5
+ template: |
6
+ Rerank these search results for the query "{query}".
7
+
8
+ {instructions}
9
+
10
+ ## Documents
11
+
12
+ {documents}
13
+
14
+ Score each document 0.0-1.0 based on relevance to the query and instructions. Return results sorted by score (highest first). Only include documents scoring > 0.1.
@@ -0,0 +1,37 @@
1
+ name: router_query_classification
2
+ version: "1.0"
3
+ description: Classify query to determine execution mode (standard vs instructed)
4
+
5
+ template: |
6
+ You are a query classification system. Your task is to determine the best execution mode for a search query.
7
+
8
+ ## Execution Modes
9
+
10
+ **standard**: Use for simple keyword or product searches without specific constraints.
11
+ - General questions about products
12
+ - Simple keyword searches
13
+ - Broad category browsing
14
+
15
+ **instructed**: Use for queries with explicit constraints that require metadata filtering.
16
+ - Price constraints ("under $100", "between $50 and $200")
17
+ - Brand preferences ("Milwaukee", "not DeWalt", "excluding Makita")
18
+ - Category filters ("power tools", "paint supplies")
19
+ - Time/recency constraints ("recent", "from last month", "updated this year")
20
+ - Comparison queries ("compare X and Y")
21
+ - Multiple combined constraints
22
+
23
+ ## Available Schema for Filtering
24
+
25
+ {schema_description}
26
+
27
+ ## Query to Classify
28
+
29
+ "{query}"
30
+
31
+ ## Instructions
32
+
33
+ Analyze the query and determine:
34
+ 1. Does it contain explicit constraints that can be translated to metadata filters?
35
+ 2. Would the query benefit from being decomposed into subqueries?
36
+
37
+ Return your classification as a JSON object with a single field "mode" set to either "standard" or "instructed".
@@ -0,0 +1,46 @@
1
+ name: result_verification
2
+ version: "1.0"
3
+ description: Verify search results satisfy user constraints
4
+
5
+ template: |
6
+ You are a result verification system. Your task is to determine whether search results satisfy the user's query constraints.
7
+
8
+ ## User Query
9
+
10
+ "{query}"
11
+
12
+ ## Schema Information
13
+
14
+ {schema_description}
15
+
16
+ ## Constraints to Verify
17
+
18
+ {constraints}
19
+
20
+ ## Retrieved Results (Top {num_results})
21
+
22
+ {results_summary}
23
+
24
+ ## Previous Attempt Feedback (if retry)
25
+
26
+ {previous_feedback}
27
+
28
+ ## Instructions
29
+
30
+ Analyze whether the results satisfy the user's explicit and implicit constraints:
31
+
32
+ 1. **Intent Match**: Do the results address what the user is looking for?
33
+ 2. **Explicit Constraints**: Are price, brand, category, date constraints met?
34
+ 3. **Relevance**: Are the results actually useful for the user's needs?
35
+
36
+ If results do NOT satisfy constraints, suggest specific filter relaxations:
37
+ - Use "REMOVE" to drop a filter entirely
38
+ - Use "BROADEN" to widen a range (e.g., price < 100 -> price < 150)
39
+ - Use specific values to change a filter
40
+
41
+ Return a JSON object with:
42
+ - passed: boolean (true if results are satisfactory)
43
+ - confidence: float (0.0-1.0, your confidence in the assessment)
44
+ - feedback: string (brief explanation of issues, if any)
45
+ - suggested_filter_relaxation: object (filter changes for retry, e.g., {{"brand_name": "REMOVE"}})
46
+ - unmet_constraints: array of strings (list of constraints not satisfied)
@@ -397,6 +397,8 @@ class DatabricksProvider(ServiceProvider):
397
397
 
398
398
  pip_requirements += get_installed_packages()
399
399
 
400
+ code_paths = list(dict.fromkeys(code_paths))
401
+
400
402
  logger.trace("Pip requirements prepared", count=len(pip_requirements))
401
403
  logger.trace("Code paths prepared", count=len(code_paths))
402
404
 
@@ -434,19 +436,38 @@ class DatabricksProvider(ServiceProvider):
434
436
  pip_packages_count=len(pip_requirements),
435
437
  )
436
438
 
437
- with mlflow.start_run(run_name=run_name):
438
- mlflow.set_tag("type", "agent")
439
- mlflow.set_tag("dao_ai", dao_ai_version())
440
- logged_agent_info: ModelInfo = mlflow.pyfunc.log_model(
441
- python_model=model_path.as_posix(),
442
- code_paths=code_paths,
443
- model_config=config.model_dump(mode="json", by_alias=True),
444
- name="agent",
445
- conda_env=conda_env,
446
- input_example=input_example,
447
- # resources=all_resources,
448
- auth_policy=auth_policy,
439
+ # End any stale runs before starting to ensure clean state on retry
440
+ if mlflow.active_run():
441
+ logger.warning(
442
+ "Ending stale MLflow run before creating new agent",
443
+ run_id=mlflow.active_run().info.run_id,
444
+ )
445
+ mlflow.end_run()
446
+
447
+ try:
448
+ with mlflow.start_run(run_name=run_name):
449
+ mlflow.set_tag("type", "agent")
450
+ mlflow.set_tag("dao_ai", dao_ai_version())
451
+ logged_agent_info: ModelInfo = mlflow.pyfunc.log_model(
452
+ python_model=model_path.as_posix(),
453
+ code_paths=code_paths,
454
+ model_config=config.model_dump(mode="json", by_alias=True),
455
+ name="agent",
456
+ conda_env=conda_env,
457
+ input_example=input_example,
458
+ # resources=all_resources,
459
+ auth_policy=auth_policy,
460
+ )
461
+ except Exception as e:
462
+ # Ensure run is ended on failure to prevent stale state on retry
463
+ if mlflow.active_run():
464
+ mlflow.end_run(status="FAILED")
465
+ logger.error(
466
+ "Failed to log model",
467
+ run_name=run_name,
468
+ error=str(e),
449
469
  )
470
+ raise
450
471
 
451
472
  registered_model_name: str = config.app.registered_model.full_name
452
473