dao-ai 0.0.17__py3-none-any.whl → 0.0.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dao_ai/agent_as_code.py CHANGED
@@ -3,7 +3,7 @@ import sys
3
3
  import mlflow
4
4
  from loguru import logger
5
5
  from mlflow.models import ModelConfig
6
- from mlflow.pyfunc import ChatModel
6
+ from mlflow.pyfunc import ResponsesAgent
7
7
 
8
8
  from dao_ai.config import AppConfig
9
9
 
@@ -17,6 +17,6 @@ log_level: str = config.app.log_level
17
17
  logger.remove()
18
18
  logger.add(sys.stderr, level=log_level)
19
19
 
20
- app: ChatModel = config.as_chat_model()
20
+ app: ResponsesAgent = config.as_responses_agent()
21
21
 
22
22
  mlflow.models.set_model(app)
dao_ai/config.py CHANGED
@@ -9,6 +9,7 @@ from pathlib import Path
9
9
  from typing import (
10
10
  Any,
11
11
  Callable,
12
+ Iterator,
12
13
  Literal,
13
14
  Optional,
14
15
  Sequence,
@@ -21,6 +22,7 @@ from databricks.sdk.credentials_provider import (
21
22
  CredentialsStrategy,
22
23
  ModelServingUserCredentials,
23
24
  )
25
+ from databricks.sdk.service.catalog import FunctionInfo, TableInfo
24
26
  from databricks.vector_search.client import VectorSearchClient
25
27
  from databricks.vector_search.index import VectorSearchIndex
26
28
  from databricks_langchain import (
@@ -44,8 +46,14 @@ from mlflow.models.resources import (
44
46
  DatabricksUCConnection,
45
47
  DatabricksVectorSearchIndex,
46
48
  )
47
- from mlflow.pyfunc import ChatModel
48
- from pydantic import BaseModel, ConfigDict, Field, field_serializer, model_validator
49
+ from mlflow.pyfunc import ChatModel, ResponsesAgent
50
+ from pydantic import (
51
+ BaseModel,
52
+ ConfigDict,
53
+ Field,
54
+ field_serializer,
55
+ model_validator,
56
+ )
49
57
 
50
58
 
51
59
  class HasValue(ABC):
@@ -69,7 +77,7 @@ class IsDatabricksResource(ABC):
69
77
  on_behalf_of_user: Optional[bool] = False
70
78
 
71
79
  @abstractmethod
72
- def as_resource(self) -> DatabricksResource: ...
80
+ def as_resources(self) -> Sequence[DatabricksResource]: ...
73
81
 
74
82
  @property
75
83
  @abstractmethod
@@ -235,22 +243,68 @@ class SchemaModel(BaseModel, HasFullName):
235
243
  class TableModel(BaseModel, HasFullName, IsDatabricksResource):
236
244
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
237
245
  schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
238
- name: str
246
+ name: Optional[str] = None
247
+
248
+ @model_validator(mode="after")
249
+ def validate_name_or_schema_required(self) -> "TableModel":
250
+ if not self.name and not self.schema_model:
251
+ raise ValueError(
252
+ "Either 'name' or 'schema_model' must be provided for TableModel"
253
+ )
254
+ return self
239
255
 
240
256
  @property
241
257
  def full_name(self) -> str:
242
258
  if self.schema_model:
243
- return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}.{self.name}"
259
+ name: str = ""
260
+ if self.name:
261
+ name = f".{self.name}"
262
+ return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}{name}"
244
263
  return self.name
245
264
 
246
265
  @property
247
266
  def api_scopes(self) -> Sequence[str]:
248
267
  return []
249
268
 
250
- def as_resource(self) -> DatabricksResource:
251
- return DatabricksTable(
252
- table_name=self.full_name, on_behalf_of_user=self.on_behalf_of_user
253
- )
269
+ def as_resources(self) -> Sequence[DatabricksResource]:
270
+ resources: list[DatabricksResource] = []
271
+
272
+ excluded_suffixes: Sequence[str] = [
273
+ "_payload",
274
+ "_assessment_logs",
275
+ "_request_logs",
276
+ ]
277
+
278
+ excluded_prefixes: Sequence[str] = [
279
+ "trace_logs_"
280
+ ]
281
+
282
+ if self.name:
283
+ resources.append(
284
+ DatabricksTable(
285
+ table_name=self.full_name, on_behalf_of_user=self.on_behalf_of_user
286
+ )
287
+ )
288
+ else:
289
+ w: WorkspaceClient = self.workspace_client
290
+ schema_full_name: str = self.schema_model.full_name
291
+ tables: Iterator[TableInfo] = w.tables.list(
292
+ catalog_name=self.schema_model.catalog_name,
293
+ schema_name=self.schema_model.schema_name,
294
+ )
295
+ resources.extend(
296
+ [
297
+ DatabricksTable(
298
+ table_name=f"{schema_full_name}.{table.name}",
299
+ on_behalf_of_user=self.on_behalf_of_user,
300
+ )
301
+ for table in tables
302
+ if not any(table.name.endswith(suffix) for suffix in excluded_suffixes)
303
+ and not any(table.name.startswith(prefix) for prefix in excluded_prefixes)
304
+ ]
305
+ )
306
+
307
+ return resources
254
308
 
255
309
 
256
310
  class LLMModel(BaseModel, IsDatabricksResource):
@@ -266,10 +320,12 @@ class LLMModel(BaseModel, IsDatabricksResource):
266
320
  "serving.serving-endpoints",
267
321
  ]
268
322
 
269
- def as_resource(self) -> DatabricksResource:
270
- return DatabricksServingEndpoint(
271
- endpoint_name=self.name, on_behalf_of_user=self.on_behalf_of_user
272
- )
323
+ def as_resources(self) -> Sequence[DatabricksResource]:
324
+ return [
325
+ DatabricksServingEndpoint(
326
+ endpoint_name=self.name, on_behalf_of_user=self.on_behalf_of_user
327
+ )
328
+ ]
273
329
 
274
330
  def as_chat_model(self) -> LanguageModelLike:
275
331
  # Retrieve langchain chat client from workspace client to enable OBO
@@ -345,10 +401,12 @@ class IndexModel(BaseModel, HasFullName, IsDatabricksResource):
345
401
  return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}.{self.name}"
346
402
  return self.name
347
403
 
348
- def as_resource(self) -> DatabricksResource:
349
- return DatabricksVectorSearchIndex(
350
- index_name=self.full_name, on_behalf_of_user=self.on_behalf_of_user
351
- )
404
+ def as_resources(self) -> Sequence[DatabricksResource]:
405
+ return [
406
+ DatabricksVectorSearchIndex(
407
+ index_name=self.full_name, on_behalf_of_user=self.on_behalf_of_user
408
+ )
409
+ ]
352
410
 
353
411
 
354
412
  class GenieRoomModel(BaseModel, IsDatabricksResource):
@@ -363,10 +421,12 @@ class GenieRoomModel(BaseModel, IsDatabricksResource):
363
421
  "dashboards.genie",
364
422
  ]
365
423
 
366
- def as_resource(self) -> DatabricksResource:
367
- return DatabricksGenieSpace(
368
- genie_space_id=self.space_id, on_behalf_of_user=self.on_behalf_of_user
369
- )
424
+ def as_resources(self) -> Sequence[DatabricksResource]:
425
+ return [
426
+ DatabricksGenieSpace(
427
+ genie_space_id=self.space_id, on_behalf_of_user=self.on_behalf_of_user
428
+ )
429
+ ]
370
430
 
371
431
 
372
432
  class VolumeModel(BaseModel, HasFullName):
@@ -394,7 +454,7 @@ class VolumePathModel(BaseModel, HasFullName):
394
454
  path: Optional[str] = None
395
455
 
396
456
  @model_validator(mode="after")
397
- def validate_path_or_volume(self):
457
+ def validate_path_or_volume(self) -> "VolumePathModel":
398
458
  if not self.volume and not self.path:
399
459
  raise ValueError("Either 'volume' or 'path' must be provided")
400
460
  return self
@@ -502,8 +562,8 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
502
562
  "serving.serving-endpoints",
503
563
  ] + self.index.api_scopes
504
564
 
505
- def as_resource(self) -> DatabricksResource:
506
- return self.index.as_resource()
565
+ def as_resources(self) -> Sequence[DatabricksResource]:
566
+ return self.index.as_resources()
507
567
 
508
568
  def as_index(self, vsc: VectorSearchClient | None = None) -> VectorSearchIndex:
509
569
  from dao_ai.providers.base import ServiceProvider
@@ -524,18 +584,52 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
524
584
  class FunctionModel(BaseModel, HasFullName, IsDatabricksResource):
525
585
  model_config = ConfigDict()
526
586
  schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
527
- name: str
587
+ name: Optional[str] = None
588
+
589
+ @model_validator(mode="after")
590
+ def validate_name_or_schema_required(self) -> "FunctionModel":
591
+ if not self.name and not self.schema_model:
592
+ raise ValueError(
593
+ "Either 'name' or 'schema_model' must be provided for FunctionModel"
594
+ )
595
+ return self
528
596
 
529
597
  @property
530
598
  def full_name(self) -> str:
531
599
  if self.schema_model:
532
- return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}.{self.name}"
600
+ name: str = ""
601
+ if self.name:
602
+ name = f".{self.name}"
603
+ return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}{name}"
533
604
  return self.name
534
605
 
535
- def as_resource(self) -> DatabricksResource:
536
- return DatabricksFunction(
537
- function_name=self.full_name, on_behalf_of_user=self.on_behalf_of_user
538
- )
606
+ def as_resources(self) -> Sequence[DatabricksResource]:
607
+ resources: list[DatabricksResource] = []
608
+ if self.name:
609
+ resources.append(
610
+ DatabricksFunction(
611
+ function_name=self.full_name,
612
+ on_behalf_of_user=self.on_behalf_of_user,
613
+ )
614
+ )
615
+ else:
616
+ w: WorkspaceClient = self.workspace_client
617
+ schema_full_name: str = self.schema_model.full_name
618
+ functions: Iterator[FunctionInfo] = w.functions.list(
619
+ catalog_name=self.schema_model.catalog_name,
620
+ schema_name=self.schema_model.schema_name,
621
+ )
622
+ resources.extend(
623
+ [
624
+ DatabricksFunction(
625
+ function_name=f"{schema_full_name}.{function.name}",
626
+ on_behalf_of_user=self.on_behalf_of_user,
627
+ )
628
+ for function in functions
629
+ ]
630
+ )
631
+
632
+ return resources
539
633
 
540
634
  @property
541
635
  def api_scopes(self) -> Sequence[str]:
@@ -556,10 +650,12 @@ class ConnectionModel(BaseModel, HasFullName, IsDatabricksResource):
556
650
  "catalog.connections",
557
651
  ]
558
652
 
559
- def as_resource(self) -> DatabricksResource:
560
- return DatabricksUCConnection(
561
- connection_name=self.name, on_behalf_of_user=self.on_behalf_of_user
562
- )
653
+ def as_resources(self) -> Sequence[DatabricksResource]:
654
+ return [
655
+ DatabricksUCConnection(
656
+ connection_name=self.name, on_behalf_of_user=self.on_behalf_of_user
657
+ )
658
+ ]
563
659
 
564
660
 
565
661
  class WarehouseModel(BaseModel, IsDatabricksResource):
@@ -575,10 +671,12 @@ class WarehouseModel(BaseModel, IsDatabricksResource):
575
671
  "sql.statement-execution",
576
672
  ]
577
673
 
578
- def as_resource(self) -> DatabricksResource:
579
- return DatabricksSQLWarehouse(
580
- warehouse_id=self.warehouse_id, on_behalf_of_user=self.on_behalf_of_user
581
- )
674
+ def as_resources(self) -> Sequence[DatabricksResource]:
675
+ return [
676
+ DatabricksSQLWarehouse(
677
+ warehouse_id=self.warehouse_id, on_behalf_of_user=self.on_behalf_of_user
678
+ )
679
+ ]
582
680
 
583
681
 
584
682
  class DatabaseModel(BaseModel):
@@ -1034,9 +1132,37 @@ class Message(BaseModel):
1034
1132
 
1035
1133
  class ChatPayload(BaseModel):
1036
1134
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1037
- messages: list[Message]
1135
+ input: Optional[list[Message]] = None
1136
+ messages: Optional[list[Message]] = None
1038
1137
  custom_inputs: dict
1039
1138
 
1139
+ @model_validator(mode="after")
1140
+ def validate_mutual_exclusion_and_alias(self) -> "ChatPayload":
1141
+ """Handle dual field support with automatic aliasing."""
1142
+ # If both fields are provided and they're the same, that's okay (redundant but valid)
1143
+ if self.input is not None and self.messages is not None:
1144
+ # Allow if they're identical (redundant specification)
1145
+ if self.input == self.messages:
1146
+ return self
1147
+ # If they're different, prefer input and copy to messages
1148
+ else:
1149
+ self.messages = self.input
1150
+ return self
1151
+
1152
+ # If neither field is provided, that's an error
1153
+ if self.input is None and self.messages is None:
1154
+ raise ValueError("Must specify either 'input' or 'messages' field.")
1155
+
1156
+ # Create alias: copy messages to input if input is None
1157
+ if self.input is None and self.messages is not None:
1158
+ self.input = self.messages
1159
+
1160
+ # Create alias: copy input to messages if messages is None
1161
+ elif self.messages is None and self.input is not None:
1162
+ self.messages = self.input
1163
+
1164
+ return self
1165
+
1040
1166
 
1041
1167
  class ChatHistoryModel(BaseModel):
1042
1168
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
@@ -1374,3 +1500,10 @@ class AppConfig(BaseModel):
1374
1500
  graph: CompiledStateGraph = self.as_graph()
1375
1501
  app: ChatModel = create_agent(graph)
1376
1502
  return app
1503
+
1504
+ def as_responses_agent(self) -> ResponsesAgent:
1505
+ from dao_ai.models import create_responses_agent
1506
+
1507
+ graph: CompiledStateGraph = self.as_graph()
1508
+ app: ResponsesAgent = create_responses_agent(graph)
1509
+ return app
dao_ai/messages.py CHANGED
@@ -11,6 +11,9 @@ from langchain_core.messages import (
11
11
  )
12
12
  from langchain_core.messages.modifier import RemoveMessage
13
13
  from mlflow.types.llm import ChatMessage
14
+ from mlflow.types.responses import (
15
+ ResponsesAgentRequest,
16
+ )
14
17
 
15
18
 
16
19
  def remove_messages(
@@ -96,6 +99,10 @@ def has_mlflow_messages(messages: ChatMessage | Sequence[ChatMessage]) -> bool:
96
99
  return any(isinstance(m, ChatMessage) for m in messages)
97
100
 
98
101
 
102
+ def has_mlflow_responses_messages(messages: ResponsesAgentRequest) -> bool:
103
+ return isinstance(messages, ResponsesAgentRequest)
104
+
105
+
99
106
  def has_image(messages: BaseMessage | Sequence[BaseMessage]) -> bool:
100
107
  """
101
108
  Check if a message contains an image.
dao_ai/models.py CHANGED
@@ -1,13 +1,14 @@
1
1
  import uuid
2
2
  from os import PathLike
3
3
  from pathlib import Path
4
- from typing import Any, Generator, Optional, Sequence
4
+ from typing import Any, Generator, Optional, Sequence, Union
5
5
 
6
6
  from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
7
7
  from langgraph.graph.state import CompiledStateGraph
8
8
  from loguru import logger
9
9
  from mlflow import MlflowClient
10
- from mlflow.pyfunc import ChatAgent, ChatModel
10
+ from mlflow.pyfunc import ChatAgent, ChatModel, ResponsesAgent
11
+ from mlflow.types.agent import ChatContext
11
12
  from mlflow.types.llm import (
12
13
  ChatChoice,
13
14
  ChatChoiceDelta,
@@ -17,8 +18,21 @@ from mlflow.types.llm import (
17
18
  ChatMessage,
18
19
  ChatParams,
19
20
  )
21
+ from mlflow.types.responses import (
22
+ ResponsesAgentRequest,
23
+ ResponsesAgentResponse,
24
+ ResponsesAgentStreamEvent,
25
+ )
26
+ from mlflow.types.responses_helpers import (
27
+ Message,
28
+ ResponseInputTextParam,
29
+ )
20
30
 
21
- from dao_ai.messages import has_langchain_messages, has_mlflow_messages
31
+ from dao_ai.messages import (
32
+ has_langchain_messages,
33
+ has_mlflow_messages,
34
+ has_mlflow_responses_messages,
35
+ )
22
36
  from dao_ai.state import Context
23
37
 
24
38
 
@@ -185,6 +199,232 @@ class LanggraphChatModel(ChatModel):
185
199
  return [m.to_dict() for m in messages]
186
200
 
187
201
 
202
+ class LanggraphResponsesAgent(ResponsesAgent):
203
+ """
204
+ ResponsesAgent that delegates requests to a LangGraph CompiledStateGraph.
205
+
206
+ This is the modern replacement for LanggraphChatModel, providing better
207
+ support for streaming, tool calling, and async execution.
208
+ """
209
+
210
+ def __init__(self, graph: CompiledStateGraph) -> None:
211
+ self.graph = graph
212
+
213
+ def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
214
+ """
215
+ Process a ResponsesAgentRequest and return a ResponsesAgentResponse.
216
+ """
217
+ logger.debug(f"ResponsesAgent request: {request}")
218
+
219
+ # Convert ResponsesAgent input to LangChain messages
220
+ messages = self._convert_request_to_langchain_messages(request)
221
+
222
+ # Prepare context
223
+ context: Context = self._convert_request_to_context(request)
224
+ custom_inputs: dict[str, Any] = {"configurable": context.model_dump()}
225
+
226
+ # Use async ainvoke internally for parallel execution
227
+ import asyncio
228
+
229
+ async def _async_invoke():
230
+ return await self.graph.ainvoke(
231
+ {"messages": messages}, context=context, config=custom_inputs
232
+ )
233
+
234
+ loop = asyncio.get_event_loop()
235
+ response: dict[str, Sequence[BaseMessage]] = loop.run_until_complete(
236
+ _async_invoke()
237
+ )
238
+
239
+ # Convert response to ResponsesAgent format
240
+ last_message: BaseMessage = response["messages"][-1]
241
+
242
+ output_item = self.create_text_output_item(
243
+ text=last_message.content, id=f"msg_{uuid.uuid4().hex[:8]}"
244
+ )
245
+
246
+ return ResponsesAgentResponse(
247
+ output=[output_item], custom_outputs=request.custom_inputs
248
+ )
249
+
250
+ def predict_stream(
251
+ self, request: ResponsesAgentRequest
252
+ ) -> Generator[ResponsesAgentStreamEvent, None, None]:
253
+ """
254
+ Process a ResponsesAgentRequest and yield ResponsesAgentStreamEvent objects.
255
+ """
256
+ logger.debug(f"ResponsesAgent stream request: {request}")
257
+
258
+ # Convert ResponsesAgent input to LangChain messages
259
+ messages: list[BaseMessage] = self._convert_request_to_langchain_messages(
260
+ request
261
+ )
262
+
263
+ # Prepare context
264
+ context: Context = self._convert_request_to_context(request)
265
+ custom_inputs: dict[str, Any] = {"configurable": context.model_dump()}
266
+
267
+ # Use async astream internally for parallel execution
268
+ import asyncio
269
+
270
+ async def _async_stream():
271
+ item_id = f"msg_{uuid.uuid4().hex[:8]}"
272
+ accumulated_content = ""
273
+
274
+ async for nodes, stream_mode, messages_batch in self.graph.astream(
275
+ {"messages": messages},
276
+ context=context,
277
+ config=custom_inputs,
278
+ stream_mode=["messages", "custom"],
279
+ subgraphs=True,
280
+ ):
281
+ nodes: tuple[str, ...]
282
+ stream_mode: str
283
+ messages_batch: Sequence[BaseMessage]
284
+
285
+ for message in messages_batch:
286
+ if (
287
+ isinstance(
288
+ message,
289
+ (
290
+ AIMessageChunk,
291
+ AIMessage,
292
+ ),
293
+ )
294
+ and message.content
295
+ and "summarization" not in nodes
296
+ ):
297
+ content = message.content
298
+ accumulated_content += content
299
+
300
+ # Yield streaming delta
301
+ yield ResponsesAgentStreamEvent(
302
+ **self.create_text_delta(delta=content, item_id=item_id)
303
+ )
304
+
305
+ # Yield final output item
306
+ yield ResponsesAgentStreamEvent(
307
+ type="response.output_item.done",
308
+ item=self.create_text_output_item(text=accumulated_content, id=item_id),
309
+ custom_outputs=request.custom_inputs,
310
+ )
311
+
312
+ # Convert async generator to sync generator
313
+ loop = asyncio.get_event_loop()
314
+ async_gen = _async_stream()
315
+
316
+ try:
317
+ while True:
318
+ try:
319
+ item = loop.run_until_complete(async_gen.__anext__())
320
+ yield item
321
+ except StopAsyncIteration:
322
+ break
323
+ finally:
324
+ loop.run_until_complete(async_gen.aclose())
325
+
326
+ def _extract_text_from_content(
327
+ self,
328
+ content: Union[str, list[Union[ResponseInputTextParam, str, dict[str, Any]]]],
329
+ ) -> str:
330
+ """Extract text content from various MLflow content formats.
331
+
332
+ MLflow ResponsesAgent supports multiple content formats:
333
+ - str: Simple text content
334
+ - list[ResponseInputTextParam]: Structured text objects with .text attribute
335
+ - list[dict]: Dictionaries with "text" key
336
+ - Mixed lists of the above types
337
+
338
+ This method normalizes all formats to a single concatenated string.
339
+
340
+ Args:
341
+ content: The content to extract text from
342
+
343
+ Returns:
344
+ Concatenated text string from all content items
345
+ """
346
+ if isinstance(content, str):
347
+ return content
348
+ elif isinstance(content, list):
349
+ text_parts = []
350
+ for content_item in content:
351
+ if isinstance(content_item, ResponseInputTextParam):
352
+ text_parts.append(content_item.text)
353
+ elif isinstance(content_item, str):
354
+ text_parts.append(content_item)
355
+ elif isinstance(content_item, dict) and "text" in content_item:
356
+ text_parts.append(content_item["text"])
357
+ return "".join(text_parts)
358
+ else:
359
+ # Fallback for unknown types - try to extract text attribute
360
+ return getattr(content, "text", str(content))
361
+
362
+ def _convert_request_to_langchain_messages(
363
+ self, request: ResponsesAgentRequest
364
+ ) -> list[dict[str, Any]]:
365
+ """Convert ResponsesAgent input to LangChain message format."""
366
+ messages = []
367
+
368
+ for input_item in request.input:
369
+ if isinstance(input_item, Message):
370
+ # Handle MLflow Message objects
371
+ content = self._extract_text_from_content(input_item.content)
372
+ messages.append({"role": input_item.role, "content": content})
373
+ elif isinstance(input_item, dict):
374
+ # Handle dict format
375
+ if "role" in input_item and "content" in input_item:
376
+ content = self._extract_text_from_content(input_item["content"])
377
+ messages.append({"role": input_item["role"], "content": content})
378
+ else:
379
+ # Fallback for other object types with role/content attributes
380
+ role = getattr(input_item, "role", "user")
381
+ content = self._extract_text_from_content(
382
+ getattr(input_item, "content", "")
383
+ )
384
+ messages.append({"role": role, "content": content})
385
+
386
+ return messages
387
+
388
+ def _convert_request_to_context(self, request: ResponsesAgentRequest) -> Context:
389
+ """Convert ResponsesAgent context to internal Context."""
390
+
391
+ logger.debug(f"request.context: {request.context}")
392
+ logger.debug(f"request.custom_inputs: {request.custom_inputs}")
393
+
394
+ configurable: dict[str, Any] = {}
395
+
396
+ # Process context values first (lower priority)
397
+ # Use strong typing with forward-declared type hints instead of hasattr checks
398
+ chat_context: Optional[ChatContext] = request.context
399
+ if chat_context is not None:
400
+ conversation_id: Optional[str] = chat_context.conversation_id
401
+ user_id: Optional[str] = chat_context.user_id
402
+
403
+ if conversation_id is not None:
404
+ configurable["conversation_id"] = conversation_id
405
+ configurable["thread_id"] = conversation_id
406
+
407
+ if user_id is not None:
408
+ configurable["user_id"] = user_id
409
+
410
+ # Process custom_inputs after context so they can override context values (higher priority)
411
+ if request.custom_inputs:
412
+ if "configurable" in request.custom_inputs:
413
+ configurable.update(request.custom_inputs.pop("configurable"))
414
+
415
+ configurable.update(request.custom_inputs)
416
+
417
+ if "user_id" in configurable:
418
+ configurable["user_id"] = configurable["user_id"].replace(".", "_")
419
+
420
+ if "thread_id" not in configurable:
421
+ configurable["thread_id"] = str(uuid.uuid4())
422
+
423
+ logger.debug(f"Creating context from: {configurable}")
424
+
425
+ return Context(**configurable)
426
+
427
+
188
428
  def create_agent(graph: CompiledStateGraph) -> ChatAgent:
189
429
  """
190
430
  Create an MLflow-compatible ChatAgent from a LangGraph state machine.
@@ -201,6 +441,22 @@ def create_agent(graph: CompiledStateGraph) -> ChatAgent:
201
441
  return LanggraphChatModel(graph)
202
442
 
203
443
 
444
+ def create_responses_agent(graph: CompiledStateGraph) -> ResponsesAgent:
445
+ """
446
+ Create an MLflow-compatible ResponsesAgent from a LangGraph state machine.
447
+
448
+ Factory function that wraps a compiled LangGraph in the LanggraphResponsesAgent
449
+ class to make it deployable through MLflow.
450
+
451
+ Args:
452
+ graph: A compiled LangGraph state machine
453
+
454
+ Returns:
455
+ An MLflow-compatible ResponsesAgent instance
456
+ """
457
+ return LanggraphResponsesAgent(graph)
458
+
459
+
204
460
  def _process_langchain_messages(
205
461
  app: LanggraphChatModel | CompiledStateGraph,
206
462
  messages: Sequence[BaseMessage],
@@ -288,6 +544,14 @@ def _process_mlflow_messages(
288
544
  return app.predict(None, messages, custom_inputs)
289
545
 
290
546
 
547
+ def _process_mlflow_response_messages(
548
+ app: ResponsesAgent,
549
+ messages: ResponsesAgentRequest,
550
+ ) -> ResponsesAgentResponse:
551
+ """Process MLflow ResponsesAgent request in batch mode."""
552
+ return app.predict(messages)
553
+
554
+
291
555
  def _process_mlflow_messages_stream(
292
556
  app: ChatModel,
293
557
  messages: Sequence[ChatMessage],
@@ -298,32 +562,68 @@ def _process_mlflow_messages_stream(
298
562
  yield event
299
563
 
300
564
 
565
+ def _process_mlflow_response_messages_stream(
566
+ app: ResponsesAgent,
567
+ messages: ResponsesAgentRequest,
568
+ ) -> Generator[ResponsesAgentStreamEvent, None, None]:
569
+ """Process MLflow ResponsesAgent request in streaming mode."""
570
+ for event in app.predict_stream(messages):
571
+ event: ResponsesAgentStreamEvent
572
+ yield event
573
+
574
+
301
575
  def _process_config_messages(
302
- app: ChatModel,
576
+ app: LanggraphChatModel | LanggraphResponsesAgent,
303
577
  messages: dict[str, Any],
304
578
  custom_inputs: Optional[dict[str, Any]] = None,
305
- ) -> ChatCompletionResponse:
306
- messages: Sequence[ChatMessage] = [ChatMessage(**m) for m in messages]
307
- params: ChatParams = ChatParams(**{"custom_inputs": custom_inputs})
308
-
309
- return _process_mlflow_messages(app, messages, params)
579
+ ) -> ChatCompletionResponse | ResponsesAgentResponse:
580
+ if isinstance(app, LanggraphChatModel):
581
+ messages: Sequence[ChatMessage] = [ChatMessage(**m) for m in messages]
582
+ params: ChatParams = ChatParams(**{"custom_inputs": custom_inputs})
583
+ return _process_mlflow_messages(app, messages, params)
584
+
585
+ elif isinstance(app, LanggraphResponsesAgent):
586
+ input_messages: list[Message] = [Message(**m) for m in messages]
587
+ request = ResponsesAgentRequest(
588
+ input=input_messages, custom_inputs=custom_inputs
589
+ )
590
+ return _process_mlflow_response_messages(app, request)
310
591
 
311
592
 
312
593
  def _process_config_messages_stream(
313
- app: ChatModel, messages: dict[str, Any], custom_inputs: dict[str, Any]
314
- ) -> Generator[ChatCompletionChunk, None, None]:
315
- messages: Sequence[ChatMessage] = [ChatMessage(**m) for m in messages]
316
- params: ChatParams = ChatParams(**{"custom_inputs": custom_inputs})
594
+ app: LanggraphChatModel | LanggraphResponsesAgent,
595
+ messages: dict[str, Any],
596
+ custom_inputs: dict[str, Any],
597
+ ) -> Generator[ChatCompletionChunk | ResponsesAgentStreamEvent, None, None]:
598
+ if isinstance(app, LanggraphChatModel):
599
+ messages: Sequence[ChatMessage] = [ChatMessage(**m) for m in messages]
600
+ params: ChatParams = ChatParams(**{"custom_inputs": custom_inputs})
317
601
 
318
- for event in _process_mlflow_messages_stream(app, messages, custom_inputs=params):
319
- yield event
602
+ for event in _process_mlflow_messages_stream(
603
+ app, messages, custom_inputs=params
604
+ ):
605
+ yield event
606
+
607
+ elif isinstance(app, LanggraphResponsesAgent):
608
+ input_messages: list[Message] = [Message(**m) for m in messages]
609
+ request = ResponsesAgentRequest(
610
+ input=input_messages, custom_inputs=custom_inputs
611
+ )
612
+
613
+ for event in _process_mlflow_response_messages_stream(app, request):
614
+ yield event
320
615
 
321
616
 
322
617
  def process_messages_stream(
323
- app: LanggraphChatModel,
324
- messages: Sequence[BaseMessage] | Sequence[ChatMessage] | dict[str, Any],
618
+ app: LanggraphChatModel | LanggraphResponsesAgent,
619
+ messages: Sequence[BaseMessage]
620
+ | Sequence[ChatMessage]
621
+ | ResponsesAgentRequest
622
+ | dict[str, Any],
325
623
  custom_inputs: Optional[dict[str, Any]] = None,
326
- ) -> Generator[ChatCompletionChunk | AIMessageChunk, None, None]:
624
+ ) -> Generator[
625
+ ChatCompletionChunk | ResponsesAgentStreamEvent | AIMessageChunk, None, None
626
+ ]:
327
627
  """
328
628
  Process messages through a ChatAgent in streaming mode.
329
629
 
@@ -338,7 +638,10 @@ def process_messages_stream(
338
638
  Individual message chunks from the streaming response
339
639
  """
340
640
 
341
- if has_mlflow_messages(messages):
641
+ if has_mlflow_responses_messages(messages):
642
+ for event in _process_mlflow_response_messages_stream(app, messages):
643
+ yield event
644
+ elif has_mlflow_messages(messages):
342
645
  for event in _process_mlflow_messages_stream(app, messages, custom_inputs):
343
646
  yield event
344
647
  elif has_langchain_messages(messages):
@@ -350,10 +653,13 @@ def process_messages_stream(
350
653
 
351
654
 
352
655
  def process_messages(
353
- app: LanggraphChatModel,
354
- messages: Sequence[BaseMessage] | Sequence[ChatMessage] | dict[str, Any],
656
+ app: LanggraphChatModel | LanggraphResponsesAgent,
657
+ messages: Sequence[BaseMessage]
658
+ | Sequence[ChatMessage]
659
+ | ResponsesAgentRequest
660
+ | dict[str, Any],
355
661
  custom_inputs: Optional[dict[str, Any]] = None,
356
- ) -> ChatCompletionResponse | dict[str, Any] | Any:
662
+ ) -> ChatCompletionResponse | ResponsesAgentResponse | dict[str, Any] | Any:
357
663
  """
358
664
  Process messages through a ChatAgent in batch mode.
359
665
 
@@ -368,7 +674,9 @@ def process_messages(
368
674
  Complete response from the agent
369
675
  """
370
676
 
371
- if has_mlflow_messages(messages):
677
+ if has_mlflow_responses_messages(messages):
678
+ return _process_mlflow_response_messages(app, messages)
679
+ elif has_mlflow_messages(messages):
372
680
  return _process_mlflow_messages(app, messages, custom_inputs)
373
681
  elif has_langchain_messages(messages):
374
682
  return _process_langchain_messages(app, messages, custom_inputs)
@@ -235,12 +235,16 @@ class DatabricksProvider(ServiceProvider):
235
235
  + connections
236
236
  )
237
237
 
238
- # all_resources: Sequence[DatabricksResource] = [
239
- # r.as_resource() for r in resources
240
- # ]
238
+ # Flatten all resources from all models into a single list
239
+ all_resources: list[DatabricksResource] = []
240
+ for r in resources:
241
+ all_resources.extend(r.as_resources())
241
242
 
242
243
  system_resources: Sequence[DatabricksResource] = [
243
- r.as_resource() for r in resources if not r.on_behalf_of_user
244
+ resource
245
+ for r in resources
246
+ for resource in r.as_resources()
247
+ if not r.on_behalf_of_user
244
248
  ]
245
249
  logger.debug(f"system_resources: {[r.name for r in system_resources]}")
246
250
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dao-ai
3
- Version: 0.0.17
3
+ Version: 0.0.18
4
4
  Summary: DAO AI: A modular, multi-agent orchestration framework for complex AI workflows. Supports agent handoff, tool integration, and dynamic configuration via YAML.
5
5
  Project-URL: Homepage, https://github.com/natefleming/dao-ai
6
6
  Project-URL: Documentation, https://natefleming.github.io/dao-ai
@@ -1,13 +1,13 @@
1
1
  dao_ai/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- dao_ai/agent_as_code.py,sha256=rMWMC0nabtABHPD5H-Yy9ej7QNebLnXfvlZPiwrthoE,457
2
+ dao_ai/agent_as_code.py,sha256=kPSeDz2-1jRaed1TMs4LA3VECoyqe9_Ed2beRLB9gXQ,472
3
3
  dao_ai/catalog.py,sha256=sPZpHTD3lPx4EZUtIWeQV7VQM89WJ6YH__wluk1v2lE,4947
4
4
  dao_ai/chat_models.py,sha256=uhwwOTeLyHWqoTTgHrs4n5iSyTwe4EQcLKnh3jRxPWI,8626
5
5
  dao_ai/cli.py,sha256=Aez2TQW3Q8Ho1IaIkRggt0NevDxAAVPjXkePC5GPJF0,20429
6
- dao_ai/config.py,sha256=VyHqkW1UMQJ0fzyme1fV_3fi_6wDmKRQeCrx881fDQ4,45173
6
+ dao_ai/config.py,sha256=JlYC8N_7UL8VVkdSepiCUnR9NA5OsCVAigLjse7dMFM,49922
7
7
  dao_ai/graph.py,sha256=kXaGLGFVekDWqm-AHzti6LmrXnyi99VQ-AdCGuNb_xM,7831
8
8
  dao_ai/guardrails.py,sha256=-Qh0f_2Db9t4Nbrrx9FM7tnpqShjMoyxepZ0HByItfU,4027
9
- dao_ai/messages.py,sha256=tRZQTeb5YFKu8cm1xeaCkKhidq-0tdzncNEzVePvits,6806
10
- dao_ai/models.py,sha256=Zf5Rqus5xcdpxSvuLlDau4JM1j1fF9v_MnQ7HW4BXU4,13862
9
+ dao_ai/messages.py,sha256=xl_3-WcFqZKCFCiov8sZOPljTdM3gX3fCHhxq-xFg2U,7005
10
+ dao_ai/models.py,sha256=h_xFMK5FHQwPApEAYhvrt69y7ZUljmqThHTjp-yde_o,25368
11
11
  dao_ai/nodes.py,sha256=SSuFNTXOdFaKg_aX-yUkQO7fM9wvNGu14lPXKDapU1U,8461
12
12
  dao_ai/prompts.py,sha256=vpmIbWs_szXUgNNDs5Gh2LcxKZti5pHDKSfoClUcgX0,1289
13
13
  dao_ai/state.py,sha256=GwbMbd1TWZx1T5iQrEOX6_rpxOitlmyeJ8dMr2o_pag,1031
@@ -22,7 +22,7 @@ dao_ai/memory/core.py,sha256=K45iCEFbqJCVxMi4m3vmBJi4c6TQ-UtKGzyugDTkPP0,4141
22
22
  dao_ai/memory/postgres.py,sha256=YILzA7xtqawPAOLFaGG_i17zW7cQxXTzTD8yd-ipe8k,12480
23
23
  dao_ai/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
24
24
  dao_ai/providers/base.py,sha256=-fjKypCOk28h6vioPfMj9YZSw_3Kcbi2nMuAyY7vX9k,1383
25
- dao_ai/providers/databricks.py,sha256=XRPOqwF5SeA9rPAOWMg2gSMC7lw31BI5VI_4K0KIOqo,27931
25
+ dao_ai/providers/databricks.py,sha256=KLYrLccOA3Uws9nWJcJUZTbMz-MdR_onhlQeztbplCM,28073
26
26
  dao_ai/tools/__init__.py,sha256=ye6MHaJY7tUnJ8336YJiLxuZr55zDPNdOw6gm7j5jlc,1103
27
27
  dao_ai/tools/agent.py,sha256=WbQnyziiT12TLMrA7xK0VuOU029tdmUBXbUl-R1VZ0Q,1886
28
28
  dao_ai/tools/core.py,sha256=Kei33S8vrmvPOAyrFNekaWmV2jqZ-IPS1QDSvU7RZF0,1984
@@ -33,8 +33,8 @@ dao_ai/tools/python.py,sha256=XcQiTMshZyLUTVR5peB3vqsoUoAAy8gol9_pcrhddfI,1831
33
33
  dao_ai/tools/time.py,sha256=Y-23qdnNHzwjvnfkWvYsE7PoWS1hfeKy44tA7sCnNac,8759
34
34
  dao_ai/tools/unity_catalog.py,sha256=PXfLj2EgyQgaXq4Qq3t25AmTC4KyVCF_-sCtg6enens,1404
35
35
  dao_ai/tools/vector_search.py,sha256=EDYQs51zIPaAP0ma1D81wJT77GQ-v-cjb2XrFVWfWdg,2621
36
- dao_ai-0.0.17.dist-info/METADATA,sha256=G50pidDVsQt5j4T2NhygeeYJyMOdL7fSI_zyMaWIGUo,41378
37
- dao_ai-0.0.17.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
38
- dao_ai-0.0.17.dist-info/entry_points.txt,sha256=Xa-UFyc6gWGwMqMJOt06ZOog2vAfygV_DSwg1AiP46g,43
39
- dao_ai-0.0.17.dist-info/licenses/LICENSE,sha256=YZt3W32LtPYruuvHE9lGk2bw6ZPMMJD8yLrjgHybyz4,1069
40
- dao_ai-0.0.17.dist-info/RECORD,,
36
+ dao_ai-0.0.18.dist-info/METADATA,sha256=9lTAXjEqQHxl6dmRMyiqUnYT1Nh_wJpSeJXRG8bGZGg,41378
37
+ dao_ai-0.0.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
38
+ dao_ai-0.0.18.dist-info/entry_points.txt,sha256=Xa-UFyc6gWGwMqMJOt06ZOog2vAfygV_DSwg1AiP46g,43
39
+ dao_ai-0.0.18.dist-info/licenses/LICENSE,sha256=YZt3W32LtPYruuvHE9lGk2bw6ZPMMJD8yLrjgHybyz4,1069
40
+ dao_ai-0.0.18.dist-info/RECORD,,