dao-ai 0.0.16__py3-none-any.whl → 0.0.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/agent_as_code.py +2 -2
- dao_ai/config.py +172 -39
- dao_ai/graph.py +1 -1
- dao_ai/messages.py +7 -0
- dao_ai/models.py +445 -81
- dao_ai/nodes.py +7 -3
- dao_ai/providers/databricks.py +8 -4
- dao_ai/tools/agent.py +3 -3
- dao_ai/tools/genie.py +4 -2
- dao_ai/tools/human_in_the_loop.py +3 -3
- dao_ai/tools/mcp.py +11 -12
- {dao_ai-0.0.16.dist-info → dao_ai-0.0.18.dist-info}/METADATA +1 -1
- {dao_ai-0.0.16.dist-info → dao_ai-0.0.18.dist-info}/RECORD +16 -16
- {dao_ai-0.0.16.dist-info → dao_ai-0.0.18.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.16.dist-info → dao_ai-0.0.18.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.16.dist-info → dao_ai-0.0.18.dist-info}/licenses/LICENSE +0 -0
dao_ai/agent_as_code.py
CHANGED
|
@@ -3,7 +3,7 @@ import sys
|
|
|
3
3
|
import mlflow
|
|
4
4
|
from loguru import logger
|
|
5
5
|
from mlflow.models import ModelConfig
|
|
6
|
-
from mlflow.pyfunc import
|
|
6
|
+
from mlflow.pyfunc import ResponsesAgent
|
|
7
7
|
|
|
8
8
|
from dao_ai.config import AppConfig
|
|
9
9
|
|
|
@@ -17,6 +17,6 @@ log_level: str = config.app.log_level
|
|
|
17
17
|
logger.remove()
|
|
18
18
|
logger.add(sys.stderr, level=log_level)
|
|
19
19
|
|
|
20
|
-
app:
|
|
20
|
+
app: ResponsesAgent = config.as_responses_agent()
|
|
21
21
|
|
|
22
22
|
mlflow.models.set_model(app)
|
dao_ai/config.py
CHANGED
|
@@ -9,6 +9,7 @@ from pathlib import Path
|
|
|
9
9
|
from typing import (
|
|
10
10
|
Any,
|
|
11
11
|
Callable,
|
|
12
|
+
Iterator,
|
|
12
13
|
Literal,
|
|
13
14
|
Optional,
|
|
14
15
|
Sequence,
|
|
@@ -21,6 +22,7 @@ from databricks.sdk.credentials_provider import (
|
|
|
21
22
|
CredentialsStrategy,
|
|
22
23
|
ModelServingUserCredentials,
|
|
23
24
|
)
|
|
25
|
+
from databricks.sdk.service.catalog import FunctionInfo, TableInfo
|
|
24
26
|
from databricks.vector_search.client import VectorSearchClient
|
|
25
27
|
from databricks.vector_search.index import VectorSearchIndex
|
|
26
28
|
from databricks_langchain import (
|
|
@@ -44,8 +46,14 @@ from mlflow.models.resources import (
|
|
|
44
46
|
DatabricksUCConnection,
|
|
45
47
|
DatabricksVectorSearchIndex,
|
|
46
48
|
)
|
|
47
|
-
from mlflow.pyfunc import ChatModel
|
|
48
|
-
from pydantic import
|
|
49
|
+
from mlflow.pyfunc import ChatModel, ResponsesAgent
|
|
50
|
+
from pydantic import (
|
|
51
|
+
BaseModel,
|
|
52
|
+
ConfigDict,
|
|
53
|
+
Field,
|
|
54
|
+
field_serializer,
|
|
55
|
+
model_validator,
|
|
56
|
+
)
|
|
49
57
|
|
|
50
58
|
|
|
51
59
|
class HasValue(ABC):
|
|
@@ -69,7 +77,7 @@ class IsDatabricksResource(ABC):
|
|
|
69
77
|
on_behalf_of_user: Optional[bool] = False
|
|
70
78
|
|
|
71
79
|
@abstractmethod
|
|
72
|
-
def
|
|
80
|
+
def as_resources(self) -> Sequence[DatabricksResource]: ...
|
|
73
81
|
|
|
74
82
|
@property
|
|
75
83
|
@abstractmethod
|
|
@@ -235,22 +243,68 @@ class SchemaModel(BaseModel, HasFullName):
|
|
|
235
243
|
class TableModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
236
244
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
237
245
|
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
238
|
-
name: str
|
|
246
|
+
name: Optional[str] = None
|
|
247
|
+
|
|
248
|
+
@model_validator(mode="after")
|
|
249
|
+
def validate_name_or_schema_required(self) -> "TableModel":
|
|
250
|
+
if not self.name and not self.schema_model:
|
|
251
|
+
raise ValueError(
|
|
252
|
+
"Either 'name' or 'schema_model' must be provided for TableModel"
|
|
253
|
+
)
|
|
254
|
+
return self
|
|
239
255
|
|
|
240
256
|
@property
|
|
241
257
|
def full_name(self) -> str:
|
|
242
258
|
if self.schema_model:
|
|
243
|
-
|
|
259
|
+
name: str = ""
|
|
260
|
+
if self.name:
|
|
261
|
+
name = f".{self.name}"
|
|
262
|
+
return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}{name}"
|
|
244
263
|
return self.name
|
|
245
264
|
|
|
246
265
|
@property
|
|
247
266
|
def api_scopes(self) -> Sequence[str]:
|
|
248
267
|
return []
|
|
249
268
|
|
|
250
|
-
def
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
269
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
270
|
+
resources: list[DatabricksResource] = []
|
|
271
|
+
|
|
272
|
+
excluded_suffixes: Sequence[str] = [
|
|
273
|
+
"_payload",
|
|
274
|
+
"_assessment_logs",
|
|
275
|
+
"_request_logs",
|
|
276
|
+
]
|
|
277
|
+
|
|
278
|
+
excluded_prefixes: Sequence[str] = [
|
|
279
|
+
"trace_logs_"
|
|
280
|
+
]
|
|
281
|
+
|
|
282
|
+
if self.name:
|
|
283
|
+
resources.append(
|
|
284
|
+
DatabricksTable(
|
|
285
|
+
table_name=self.full_name, on_behalf_of_user=self.on_behalf_of_user
|
|
286
|
+
)
|
|
287
|
+
)
|
|
288
|
+
else:
|
|
289
|
+
w: WorkspaceClient = self.workspace_client
|
|
290
|
+
schema_full_name: str = self.schema_model.full_name
|
|
291
|
+
tables: Iterator[TableInfo] = w.tables.list(
|
|
292
|
+
catalog_name=self.schema_model.catalog_name,
|
|
293
|
+
schema_name=self.schema_model.schema_name,
|
|
294
|
+
)
|
|
295
|
+
resources.extend(
|
|
296
|
+
[
|
|
297
|
+
DatabricksTable(
|
|
298
|
+
table_name=f"{schema_full_name}.{table.name}",
|
|
299
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
300
|
+
)
|
|
301
|
+
for table in tables
|
|
302
|
+
if not any(table.name.endswith(suffix) for suffix in excluded_suffixes)
|
|
303
|
+
and not any(table.name.startswith(prefix) for prefix in excluded_prefixes)
|
|
304
|
+
]
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
return resources
|
|
254
308
|
|
|
255
309
|
|
|
256
310
|
class LLMModel(BaseModel, IsDatabricksResource):
|
|
@@ -266,10 +320,12 @@ class LLMModel(BaseModel, IsDatabricksResource):
|
|
|
266
320
|
"serving.serving-endpoints",
|
|
267
321
|
]
|
|
268
322
|
|
|
269
|
-
def
|
|
270
|
-
return
|
|
271
|
-
|
|
272
|
-
|
|
323
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
324
|
+
return [
|
|
325
|
+
DatabricksServingEndpoint(
|
|
326
|
+
endpoint_name=self.name, on_behalf_of_user=self.on_behalf_of_user
|
|
327
|
+
)
|
|
328
|
+
]
|
|
273
329
|
|
|
274
330
|
def as_chat_model(self) -> LanguageModelLike:
|
|
275
331
|
# Retrieve langchain chat client from workspace client to enable OBO
|
|
@@ -345,10 +401,12 @@ class IndexModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
|
345
401
|
return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}.{self.name}"
|
|
346
402
|
return self.name
|
|
347
403
|
|
|
348
|
-
def
|
|
349
|
-
return
|
|
350
|
-
|
|
351
|
-
|
|
404
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
405
|
+
return [
|
|
406
|
+
DatabricksVectorSearchIndex(
|
|
407
|
+
index_name=self.full_name, on_behalf_of_user=self.on_behalf_of_user
|
|
408
|
+
)
|
|
409
|
+
]
|
|
352
410
|
|
|
353
411
|
|
|
354
412
|
class GenieRoomModel(BaseModel, IsDatabricksResource):
|
|
@@ -363,10 +421,12 @@ class GenieRoomModel(BaseModel, IsDatabricksResource):
|
|
|
363
421
|
"dashboards.genie",
|
|
364
422
|
]
|
|
365
423
|
|
|
366
|
-
def
|
|
367
|
-
return
|
|
368
|
-
|
|
369
|
-
|
|
424
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
425
|
+
return [
|
|
426
|
+
DatabricksGenieSpace(
|
|
427
|
+
genie_space_id=self.space_id, on_behalf_of_user=self.on_behalf_of_user
|
|
428
|
+
)
|
|
429
|
+
]
|
|
370
430
|
|
|
371
431
|
|
|
372
432
|
class VolumeModel(BaseModel, HasFullName):
|
|
@@ -394,7 +454,7 @@ class VolumePathModel(BaseModel, HasFullName):
|
|
|
394
454
|
path: Optional[str] = None
|
|
395
455
|
|
|
396
456
|
@model_validator(mode="after")
|
|
397
|
-
def validate_path_or_volume(self):
|
|
457
|
+
def validate_path_or_volume(self) -> "VolumePathModel":
|
|
398
458
|
if not self.volume and not self.path:
|
|
399
459
|
raise ValueError("Either 'volume' or 'path' must be provided")
|
|
400
460
|
return self
|
|
@@ -502,8 +562,8 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
|
|
|
502
562
|
"serving.serving-endpoints",
|
|
503
563
|
] + self.index.api_scopes
|
|
504
564
|
|
|
505
|
-
def
|
|
506
|
-
return self.index.
|
|
565
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
566
|
+
return self.index.as_resources()
|
|
507
567
|
|
|
508
568
|
def as_index(self, vsc: VectorSearchClient | None = None) -> VectorSearchIndex:
|
|
509
569
|
from dao_ai.providers.base import ServiceProvider
|
|
@@ -524,18 +584,52 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
|
|
|
524
584
|
class FunctionModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
525
585
|
model_config = ConfigDict()
|
|
526
586
|
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
527
|
-
name: str
|
|
587
|
+
name: Optional[str] = None
|
|
588
|
+
|
|
589
|
+
@model_validator(mode="after")
|
|
590
|
+
def validate_name_or_schema_required(self) -> "FunctionModel":
|
|
591
|
+
if not self.name and not self.schema_model:
|
|
592
|
+
raise ValueError(
|
|
593
|
+
"Either 'name' or 'schema_model' must be provided for FunctionModel"
|
|
594
|
+
)
|
|
595
|
+
return self
|
|
528
596
|
|
|
529
597
|
@property
|
|
530
598
|
def full_name(self) -> str:
|
|
531
599
|
if self.schema_model:
|
|
532
|
-
|
|
600
|
+
name: str = ""
|
|
601
|
+
if self.name:
|
|
602
|
+
name = f".{self.name}"
|
|
603
|
+
return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}{name}"
|
|
533
604
|
return self.name
|
|
534
605
|
|
|
535
|
-
def
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
606
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
607
|
+
resources: list[DatabricksResource] = []
|
|
608
|
+
if self.name:
|
|
609
|
+
resources.append(
|
|
610
|
+
DatabricksFunction(
|
|
611
|
+
function_name=self.full_name,
|
|
612
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
613
|
+
)
|
|
614
|
+
)
|
|
615
|
+
else:
|
|
616
|
+
w: WorkspaceClient = self.workspace_client
|
|
617
|
+
schema_full_name: str = self.schema_model.full_name
|
|
618
|
+
functions: Iterator[FunctionInfo] = w.functions.list(
|
|
619
|
+
catalog_name=self.schema_model.catalog_name,
|
|
620
|
+
schema_name=self.schema_model.schema_name,
|
|
621
|
+
)
|
|
622
|
+
resources.extend(
|
|
623
|
+
[
|
|
624
|
+
DatabricksFunction(
|
|
625
|
+
function_name=f"{schema_full_name}.{function.name}",
|
|
626
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
627
|
+
)
|
|
628
|
+
for function in functions
|
|
629
|
+
]
|
|
630
|
+
)
|
|
631
|
+
|
|
632
|
+
return resources
|
|
539
633
|
|
|
540
634
|
@property
|
|
541
635
|
def api_scopes(self) -> Sequence[str]:
|
|
@@ -556,10 +650,12 @@ class ConnectionModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
|
556
650
|
"catalog.connections",
|
|
557
651
|
]
|
|
558
652
|
|
|
559
|
-
def
|
|
560
|
-
return
|
|
561
|
-
|
|
562
|
-
|
|
653
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
654
|
+
return [
|
|
655
|
+
DatabricksUCConnection(
|
|
656
|
+
connection_name=self.name, on_behalf_of_user=self.on_behalf_of_user
|
|
657
|
+
)
|
|
658
|
+
]
|
|
563
659
|
|
|
564
660
|
|
|
565
661
|
class WarehouseModel(BaseModel, IsDatabricksResource):
|
|
@@ -575,10 +671,12 @@ class WarehouseModel(BaseModel, IsDatabricksResource):
|
|
|
575
671
|
"sql.statement-execution",
|
|
576
672
|
]
|
|
577
673
|
|
|
578
|
-
def
|
|
579
|
-
return
|
|
580
|
-
|
|
581
|
-
|
|
674
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
675
|
+
return [
|
|
676
|
+
DatabricksSQLWarehouse(
|
|
677
|
+
warehouse_id=self.warehouse_id, on_behalf_of_user=self.on_behalf_of_user
|
|
678
|
+
)
|
|
679
|
+
]
|
|
582
680
|
|
|
583
681
|
|
|
584
682
|
class DatabaseModel(BaseModel):
|
|
@@ -1034,9 +1132,37 @@ class Message(BaseModel):
|
|
|
1034
1132
|
|
|
1035
1133
|
class ChatPayload(BaseModel):
|
|
1036
1134
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1037
|
-
|
|
1135
|
+
input: Optional[list[Message]] = None
|
|
1136
|
+
messages: Optional[list[Message]] = None
|
|
1038
1137
|
custom_inputs: dict
|
|
1039
1138
|
|
|
1139
|
+
@model_validator(mode="after")
|
|
1140
|
+
def validate_mutual_exclusion_and_alias(self) -> "ChatPayload":
|
|
1141
|
+
"""Handle dual field support with automatic aliasing."""
|
|
1142
|
+
# If both fields are provided and they're the same, that's okay (redundant but valid)
|
|
1143
|
+
if self.input is not None and self.messages is not None:
|
|
1144
|
+
# Allow if they're identical (redundant specification)
|
|
1145
|
+
if self.input == self.messages:
|
|
1146
|
+
return self
|
|
1147
|
+
# If they're different, prefer input and copy to messages
|
|
1148
|
+
else:
|
|
1149
|
+
self.messages = self.input
|
|
1150
|
+
return self
|
|
1151
|
+
|
|
1152
|
+
# If neither field is provided, that's an error
|
|
1153
|
+
if self.input is None and self.messages is None:
|
|
1154
|
+
raise ValueError("Must specify either 'input' or 'messages' field.")
|
|
1155
|
+
|
|
1156
|
+
# Create alias: copy messages to input if input is None
|
|
1157
|
+
if self.input is None and self.messages is not None:
|
|
1158
|
+
self.input = self.messages
|
|
1159
|
+
|
|
1160
|
+
# Create alias: copy input to messages if messages is None
|
|
1161
|
+
elif self.messages is None and self.input is not None:
|
|
1162
|
+
self.messages = self.input
|
|
1163
|
+
|
|
1164
|
+
return self
|
|
1165
|
+
|
|
1040
1166
|
|
|
1041
1167
|
class ChatHistoryModel(BaseModel):
|
|
1042
1168
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
@@ -1374,3 +1500,10 @@ class AppConfig(BaseModel):
|
|
|
1374
1500
|
graph: CompiledStateGraph = self.as_graph()
|
|
1375
1501
|
app: ChatModel = create_agent(graph)
|
|
1376
1502
|
return app
|
|
1503
|
+
|
|
1504
|
+
def as_responses_agent(self) -> ResponsesAgent:
|
|
1505
|
+
from dao_ai.models import create_responses_agent
|
|
1506
|
+
|
|
1507
|
+
graph: CompiledStateGraph = self.as_graph()
|
|
1508
|
+
app: ResponsesAgent = create_responses_agent(graph)
|
|
1509
|
+
return app
|
dao_ai/graph.py
CHANGED
|
@@ -221,7 +221,7 @@ def _create_swarm_graph(config: AppConfig) -> CompiledStateGraph:
|
|
|
221
221
|
|
|
222
222
|
return swarm_node
|
|
223
223
|
|
|
224
|
-
#return workflow.compile(checkpointer=checkpointer, store=store)
|
|
224
|
+
# return workflow.compile(checkpointer=checkpointer, store=store)
|
|
225
225
|
|
|
226
226
|
|
|
227
227
|
def create_dao_ai_graph(config: AppConfig) -> CompiledStateGraph:
|
dao_ai/messages.py
CHANGED
|
@@ -11,6 +11,9 @@ from langchain_core.messages import (
|
|
|
11
11
|
)
|
|
12
12
|
from langchain_core.messages.modifier import RemoveMessage
|
|
13
13
|
from mlflow.types.llm import ChatMessage
|
|
14
|
+
from mlflow.types.responses import (
|
|
15
|
+
ResponsesAgentRequest,
|
|
16
|
+
)
|
|
14
17
|
|
|
15
18
|
|
|
16
19
|
def remove_messages(
|
|
@@ -96,6 +99,10 @@ def has_mlflow_messages(messages: ChatMessage | Sequence[ChatMessage]) -> bool:
|
|
|
96
99
|
return any(isinstance(m, ChatMessage) for m in messages)
|
|
97
100
|
|
|
98
101
|
|
|
102
|
+
def has_mlflow_responses_messages(messages: ResponsesAgentRequest) -> bool:
|
|
103
|
+
return isinstance(messages, ResponsesAgentRequest)
|
|
104
|
+
|
|
105
|
+
|
|
99
106
|
def has_image(messages: BaseMessage | Sequence[BaseMessage]) -> bool:
|
|
100
107
|
"""
|
|
101
108
|
Check if a message contains an image.
|
dao_ai/models.py
CHANGED
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
import uuid
|
|
2
2
|
from os import PathLike
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Any, Generator, Optional, Sequence
|
|
4
|
+
from typing import Any, Generator, Optional, Sequence, Union
|
|
5
5
|
|
|
6
6
|
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
|
7
7
|
from langgraph.graph.state import CompiledStateGraph
|
|
8
8
|
from loguru import logger
|
|
9
9
|
from mlflow import MlflowClient
|
|
10
|
-
from mlflow.pyfunc import ChatAgent, ChatModel
|
|
10
|
+
from mlflow.pyfunc import ChatAgent, ChatModel, ResponsesAgent
|
|
11
|
+
from mlflow.types.agent import ChatContext
|
|
11
12
|
from mlflow.types.llm import (
|
|
12
13
|
ChatChoice,
|
|
13
14
|
ChatChoiceDelta,
|
|
@@ -17,8 +18,21 @@ from mlflow.types.llm import (
|
|
|
17
18
|
ChatMessage,
|
|
18
19
|
ChatParams,
|
|
19
20
|
)
|
|
21
|
+
from mlflow.types.responses import (
|
|
22
|
+
ResponsesAgentRequest,
|
|
23
|
+
ResponsesAgentResponse,
|
|
24
|
+
ResponsesAgentStreamEvent,
|
|
25
|
+
)
|
|
26
|
+
from mlflow.types.responses_helpers import (
|
|
27
|
+
Message,
|
|
28
|
+
ResponseInputTextParam,
|
|
29
|
+
)
|
|
20
30
|
|
|
21
|
-
from dao_ai.messages import
|
|
31
|
+
from dao_ai.messages import (
|
|
32
|
+
has_langchain_messages,
|
|
33
|
+
has_mlflow_messages,
|
|
34
|
+
has_mlflow_responses_messages,
|
|
35
|
+
)
|
|
22
36
|
from dao_ai.state import Context
|
|
23
37
|
|
|
24
38
|
|
|
@@ -65,9 +79,19 @@ class LanggraphChatModel(ChatModel):
|
|
|
65
79
|
context: Context = self._convert_to_context(params)
|
|
66
80
|
custom_inputs: dict[str, Any] = {"configurable": context.model_dump()}
|
|
67
81
|
|
|
68
|
-
|
|
69
|
-
|
|
82
|
+
# Use async ainvoke internally for parallel execution
|
|
83
|
+
import asyncio
|
|
84
|
+
|
|
85
|
+
async def _async_invoke():
|
|
86
|
+
return await self.graph.ainvoke(
|
|
87
|
+
request, context=context, config=custom_inputs
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
loop = asyncio.get_event_loop()
|
|
91
|
+
response: dict[str, Sequence[BaseMessage]] = loop.run_until_complete(
|
|
92
|
+
_async_invoke()
|
|
70
93
|
)
|
|
94
|
+
|
|
71
95
|
logger.trace(f"response: {response}")
|
|
72
96
|
|
|
73
97
|
last_message: BaseMessage = response["messages"][-1]
|
|
@@ -114,33 +138,51 @@ class LanggraphChatModel(ChatModel):
|
|
|
114
138
|
context: Context = self._convert_to_context(params)
|
|
115
139
|
custom_inputs: dict[str, Any] = {"configurable": context.model_dump()}
|
|
116
140
|
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
stream_mode
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
141
|
+
# Use async astream internally for parallel execution
|
|
142
|
+
import asyncio
|
|
143
|
+
|
|
144
|
+
async def _async_stream():
|
|
145
|
+
async for nodes, stream_mode, messages_batch in self.graph.astream(
|
|
146
|
+
request,
|
|
147
|
+
context=context,
|
|
148
|
+
config=custom_inputs,
|
|
149
|
+
stream_mode=["messages", "custom"],
|
|
150
|
+
subgraphs=True,
|
|
151
|
+
):
|
|
152
|
+
nodes: tuple[str, ...]
|
|
153
|
+
stream_mode: str
|
|
154
|
+
messages_batch: Sequence[BaseMessage]
|
|
155
|
+
logger.trace(
|
|
156
|
+
f"nodes: {nodes}, stream_mode: {stream_mode}, messages: {messages_batch}"
|
|
157
|
+
)
|
|
158
|
+
for message in messages_batch:
|
|
159
|
+
if (
|
|
160
|
+
isinstance(
|
|
161
|
+
message,
|
|
162
|
+
(
|
|
163
|
+
AIMessageChunk,
|
|
164
|
+
AIMessage,
|
|
165
|
+
),
|
|
166
|
+
)
|
|
167
|
+
and message.content
|
|
168
|
+
and "summarization" not in nodes
|
|
169
|
+
):
|
|
170
|
+
content = message.content
|
|
171
|
+
yield self._create_chat_completion_chunk(content)
|
|
172
|
+
|
|
173
|
+
# Convert async generator to sync generator
|
|
174
|
+
loop = asyncio.get_event_loop()
|
|
175
|
+
async_gen = _async_stream()
|
|
176
|
+
|
|
177
|
+
try:
|
|
178
|
+
while True:
|
|
179
|
+
try:
|
|
180
|
+
item = loop.run_until_complete(async_gen.__anext__())
|
|
181
|
+
yield item
|
|
182
|
+
except StopAsyncIteration:
|
|
183
|
+
break
|
|
184
|
+
finally:
|
|
185
|
+
loop.run_until_complete(async_gen.aclose())
|
|
144
186
|
|
|
145
187
|
def _create_chat_completion_chunk(self, content: str) -> ChatCompletionChunk:
|
|
146
188
|
return ChatCompletionChunk(
|
|
@@ -157,6 +199,232 @@ class LanggraphChatModel(ChatModel):
|
|
|
157
199
|
return [m.to_dict() for m in messages]
|
|
158
200
|
|
|
159
201
|
|
|
202
|
+
class LanggraphResponsesAgent(ResponsesAgent):
|
|
203
|
+
"""
|
|
204
|
+
ResponsesAgent that delegates requests to a LangGraph CompiledStateGraph.
|
|
205
|
+
|
|
206
|
+
This is the modern replacement for LanggraphChatModel, providing better
|
|
207
|
+
support for streaming, tool calling, and async execution.
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
def __init__(self, graph: CompiledStateGraph) -> None:
|
|
211
|
+
self.graph = graph
|
|
212
|
+
|
|
213
|
+
def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
|
|
214
|
+
"""
|
|
215
|
+
Process a ResponsesAgentRequest and return a ResponsesAgentResponse.
|
|
216
|
+
"""
|
|
217
|
+
logger.debug(f"ResponsesAgent request: {request}")
|
|
218
|
+
|
|
219
|
+
# Convert ResponsesAgent input to LangChain messages
|
|
220
|
+
messages = self._convert_request_to_langchain_messages(request)
|
|
221
|
+
|
|
222
|
+
# Prepare context
|
|
223
|
+
context: Context = self._convert_request_to_context(request)
|
|
224
|
+
custom_inputs: dict[str, Any] = {"configurable": context.model_dump()}
|
|
225
|
+
|
|
226
|
+
# Use async ainvoke internally for parallel execution
|
|
227
|
+
import asyncio
|
|
228
|
+
|
|
229
|
+
async def _async_invoke():
|
|
230
|
+
return await self.graph.ainvoke(
|
|
231
|
+
{"messages": messages}, context=context, config=custom_inputs
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
loop = asyncio.get_event_loop()
|
|
235
|
+
response: dict[str, Sequence[BaseMessage]] = loop.run_until_complete(
|
|
236
|
+
_async_invoke()
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
# Convert response to ResponsesAgent format
|
|
240
|
+
last_message: BaseMessage = response["messages"][-1]
|
|
241
|
+
|
|
242
|
+
output_item = self.create_text_output_item(
|
|
243
|
+
text=last_message.content, id=f"msg_{uuid.uuid4().hex[:8]}"
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
return ResponsesAgentResponse(
|
|
247
|
+
output=[output_item], custom_outputs=request.custom_inputs
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
def predict_stream(
|
|
251
|
+
self, request: ResponsesAgentRequest
|
|
252
|
+
) -> Generator[ResponsesAgentStreamEvent, None, None]:
|
|
253
|
+
"""
|
|
254
|
+
Process a ResponsesAgentRequest and yield ResponsesAgentStreamEvent objects.
|
|
255
|
+
"""
|
|
256
|
+
logger.debug(f"ResponsesAgent stream request: {request}")
|
|
257
|
+
|
|
258
|
+
# Convert ResponsesAgent input to LangChain messages
|
|
259
|
+
messages: list[BaseMessage] = self._convert_request_to_langchain_messages(
|
|
260
|
+
request
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
# Prepare context
|
|
264
|
+
context: Context = self._convert_request_to_context(request)
|
|
265
|
+
custom_inputs: dict[str, Any] = {"configurable": context.model_dump()}
|
|
266
|
+
|
|
267
|
+
# Use async astream internally for parallel execution
|
|
268
|
+
import asyncio
|
|
269
|
+
|
|
270
|
+
async def _async_stream():
|
|
271
|
+
item_id = f"msg_{uuid.uuid4().hex[:8]}"
|
|
272
|
+
accumulated_content = ""
|
|
273
|
+
|
|
274
|
+
async for nodes, stream_mode, messages_batch in self.graph.astream(
|
|
275
|
+
{"messages": messages},
|
|
276
|
+
context=context,
|
|
277
|
+
config=custom_inputs,
|
|
278
|
+
stream_mode=["messages", "custom"],
|
|
279
|
+
subgraphs=True,
|
|
280
|
+
):
|
|
281
|
+
nodes: tuple[str, ...]
|
|
282
|
+
stream_mode: str
|
|
283
|
+
messages_batch: Sequence[BaseMessage]
|
|
284
|
+
|
|
285
|
+
for message in messages_batch:
|
|
286
|
+
if (
|
|
287
|
+
isinstance(
|
|
288
|
+
message,
|
|
289
|
+
(
|
|
290
|
+
AIMessageChunk,
|
|
291
|
+
AIMessage,
|
|
292
|
+
),
|
|
293
|
+
)
|
|
294
|
+
and message.content
|
|
295
|
+
and "summarization" not in nodes
|
|
296
|
+
):
|
|
297
|
+
content = message.content
|
|
298
|
+
accumulated_content += content
|
|
299
|
+
|
|
300
|
+
# Yield streaming delta
|
|
301
|
+
yield ResponsesAgentStreamEvent(
|
|
302
|
+
**self.create_text_delta(delta=content, item_id=item_id)
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
# Yield final output item
|
|
306
|
+
yield ResponsesAgentStreamEvent(
|
|
307
|
+
type="response.output_item.done",
|
|
308
|
+
item=self.create_text_output_item(text=accumulated_content, id=item_id),
|
|
309
|
+
custom_outputs=request.custom_inputs,
|
|
310
|
+
)
|
|
311
|
+
|
|
312
|
+
# Convert async generator to sync generator
|
|
313
|
+
loop = asyncio.get_event_loop()
|
|
314
|
+
async_gen = _async_stream()
|
|
315
|
+
|
|
316
|
+
try:
|
|
317
|
+
while True:
|
|
318
|
+
try:
|
|
319
|
+
item = loop.run_until_complete(async_gen.__anext__())
|
|
320
|
+
yield item
|
|
321
|
+
except StopAsyncIteration:
|
|
322
|
+
break
|
|
323
|
+
finally:
|
|
324
|
+
loop.run_until_complete(async_gen.aclose())
|
|
325
|
+
|
|
326
|
+
def _extract_text_from_content(
|
|
327
|
+
self,
|
|
328
|
+
content: Union[str, list[Union[ResponseInputTextParam, str, dict[str, Any]]]],
|
|
329
|
+
) -> str:
|
|
330
|
+
"""Extract text content from various MLflow content formats.
|
|
331
|
+
|
|
332
|
+
MLflow ResponsesAgent supports multiple content formats:
|
|
333
|
+
- str: Simple text content
|
|
334
|
+
- list[ResponseInputTextParam]: Structured text objects with .text attribute
|
|
335
|
+
- list[dict]: Dictionaries with "text" key
|
|
336
|
+
- Mixed lists of the above types
|
|
337
|
+
|
|
338
|
+
This method normalizes all formats to a single concatenated string.
|
|
339
|
+
|
|
340
|
+
Args:
|
|
341
|
+
content: The content to extract text from
|
|
342
|
+
|
|
343
|
+
Returns:
|
|
344
|
+
Concatenated text string from all content items
|
|
345
|
+
"""
|
|
346
|
+
if isinstance(content, str):
|
|
347
|
+
return content
|
|
348
|
+
elif isinstance(content, list):
|
|
349
|
+
text_parts = []
|
|
350
|
+
for content_item in content:
|
|
351
|
+
if isinstance(content_item, ResponseInputTextParam):
|
|
352
|
+
text_parts.append(content_item.text)
|
|
353
|
+
elif isinstance(content_item, str):
|
|
354
|
+
text_parts.append(content_item)
|
|
355
|
+
elif isinstance(content_item, dict) and "text" in content_item:
|
|
356
|
+
text_parts.append(content_item["text"])
|
|
357
|
+
return "".join(text_parts)
|
|
358
|
+
else:
|
|
359
|
+
# Fallback for unknown types - try to extract text attribute
|
|
360
|
+
return getattr(content, "text", str(content))
|
|
361
|
+
|
|
362
|
+
def _convert_request_to_langchain_messages(
|
|
363
|
+
self, request: ResponsesAgentRequest
|
|
364
|
+
) -> list[dict[str, Any]]:
|
|
365
|
+
"""Convert ResponsesAgent input to LangChain message format."""
|
|
366
|
+
messages = []
|
|
367
|
+
|
|
368
|
+
for input_item in request.input:
|
|
369
|
+
if isinstance(input_item, Message):
|
|
370
|
+
# Handle MLflow Message objects
|
|
371
|
+
content = self._extract_text_from_content(input_item.content)
|
|
372
|
+
messages.append({"role": input_item.role, "content": content})
|
|
373
|
+
elif isinstance(input_item, dict):
|
|
374
|
+
# Handle dict format
|
|
375
|
+
if "role" in input_item and "content" in input_item:
|
|
376
|
+
content = self._extract_text_from_content(input_item["content"])
|
|
377
|
+
messages.append({"role": input_item["role"], "content": content})
|
|
378
|
+
else:
|
|
379
|
+
# Fallback for other object types with role/content attributes
|
|
380
|
+
role = getattr(input_item, "role", "user")
|
|
381
|
+
content = self._extract_text_from_content(
|
|
382
|
+
getattr(input_item, "content", "")
|
|
383
|
+
)
|
|
384
|
+
messages.append({"role": role, "content": content})
|
|
385
|
+
|
|
386
|
+
return messages
|
|
387
|
+
|
|
388
|
+
def _convert_request_to_context(self, request: ResponsesAgentRequest) -> Context:
|
|
389
|
+
"""Convert ResponsesAgent context to internal Context."""
|
|
390
|
+
|
|
391
|
+
logger.debug(f"request.context: {request.context}")
|
|
392
|
+
logger.debug(f"request.custom_inputs: {request.custom_inputs}")
|
|
393
|
+
|
|
394
|
+
configurable: dict[str, Any] = {}
|
|
395
|
+
|
|
396
|
+
# Process context values first (lower priority)
|
|
397
|
+
# Use strong typing with forward-declared type hints instead of hasattr checks
|
|
398
|
+
chat_context: Optional[ChatContext] = request.context
|
|
399
|
+
if chat_context is not None:
|
|
400
|
+
conversation_id: Optional[str] = chat_context.conversation_id
|
|
401
|
+
user_id: Optional[str] = chat_context.user_id
|
|
402
|
+
|
|
403
|
+
if conversation_id is not None:
|
|
404
|
+
configurable["conversation_id"] = conversation_id
|
|
405
|
+
configurable["thread_id"] = conversation_id
|
|
406
|
+
|
|
407
|
+
if user_id is not None:
|
|
408
|
+
configurable["user_id"] = user_id
|
|
409
|
+
|
|
410
|
+
# Process custom_inputs after context so they can override context values (higher priority)
|
|
411
|
+
if request.custom_inputs:
|
|
412
|
+
if "configurable" in request.custom_inputs:
|
|
413
|
+
configurable.update(request.custom_inputs.pop("configurable"))
|
|
414
|
+
|
|
415
|
+
configurable.update(request.custom_inputs)
|
|
416
|
+
|
|
417
|
+
if "user_id" in configurable:
|
|
418
|
+
configurable["user_id"] = configurable["user_id"].replace(".", "_")
|
|
419
|
+
|
|
420
|
+
if "thread_id" not in configurable:
|
|
421
|
+
configurable["thread_id"] = str(uuid.uuid4())
|
|
422
|
+
|
|
423
|
+
logger.debug(f"Creating context from: {configurable}")
|
|
424
|
+
|
|
425
|
+
return Context(**configurable)
|
|
426
|
+
|
|
427
|
+
|
|
160
428
|
def create_agent(graph: CompiledStateGraph) -> ChatAgent:
|
|
161
429
|
"""
|
|
162
430
|
Create an MLflow-compatible ChatAgent from a LangGraph state machine.
|
|
@@ -173,14 +441,39 @@ def create_agent(graph: CompiledStateGraph) -> ChatAgent:
|
|
|
173
441
|
return LanggraphChatModel(graph)
|
|
174
442
|
|
|
175
443
|
|
|
444
|
+
def create_responses_agent(graph: CompiledStateGraph) -> ResponsesAgent:
|
|
445
|
+
"""
|
|
446
|
+
Create an MLflow-compatible ResponsesAgent from a LangGraph state machine.
|
|
447
|
+
|
|
448
|
+
Factory function that wraps a compiled LangGraph in the LanggraphResponsesAgent
|
|
449
|
+
class to make it deployable through MLflow.
|
|
450
|
+
|
|
451
|
+
Args:
|
|
452
|
+
graph: A compiled LangGraph state machine
|
|
453
|
+
|
|
454
|
+
Returns:
|
|
455
|
+
An MLflow-compatible ResponsesAgent instance
|
|
456
|
+
"""
|
|
457
|
+
return LanggraphResponsesAgent(graph)
|
|
458
|
+
|
|
459
|
+
|
|
176
460
|
def _process_langchain_messages(
|
|
177
461
|
app: LanggraphChatModel | CompiledStateGraph,
|
|
178
462
|
messages: Sequence[BaseMessage],
|
|
179
463
|
custom_inputs: Optional[dict[str, Any]] = None,
|
|
180
464
|
) -> dict[str, Any] | Any:
|
|
465
|
+
"""Process LangChain messages using async LangGraph calls internally."""
|
|
466
|
+
import asyncio
|
|
467
|
+
|
|
181
468
|
if isinstance(app, LanggraphChatModel):
|
|
182
469
|
app = app.graph
|
|
183
|
-
|
|
470
|
+
|
|
471
|
+
# Use async ainvoke internally for parallel execution
|
|
472
|
+
async def _async_invoke():
|
|
473
|
+
return await app.ainvoke({"messages": messages}, config=custom_inputs)
|
|
474
|
+
|
|
475
|
+
loop = asyncio.get_event_loop()
|
|
476
|
+
return loop.run_until_complete(_async_invoke())
|
|
184
477
|
|
|
185
478
|
|
|
186
479
|
def _process_langchain_messages_stream(
|
|
@@ -188,6 +481,9 @@ def _process_langchain_messages_stream(
|
|
|
188
481
|
messages: Sequence[BaseMessage],
|
|
189
482
|
custom_inputs: Optional[dict[str, Any]] = None,
|
|
190
483
|
) -> Generator[AIMessageChunk, None, None]:
|
|
484
|
+
"""Process LangChain messages in streaming mode using async LangGraph calls internally."""
|
|
485
|
+
import asyncio
|
|
486
|
+
|
|
191
487
|
if isinstance(app, LanggraphChatModel):
|
|
192
488
|
app = app.graph
|
|
193
489
|
|
|
@@ -196,32 +492,48 @@ def _process_langchain_messages_stream(
|
|
|
196
492
|
custom_inputs = custom_inputs.get("configurable", custom_inputs or {})
|
|
197
493
|
context: Context = Context(**custom_inputs)
|
|
198
494
|
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
(
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
495
|
+
# Use async astream internally for parallel execution
|
|
496
|
+
async def _async_stream():
|
|
497
|
+
async for nodes, stream_mode, stream_messages in app.astream(
|
|
498
|
+
{"messages": messages},
|
|
499
|
+
context=context,
|
|
500
|
+
config=custom_inputs,
|
|
501
|
+
stream_mode=["messages", "custom"],
|
|
502
|
+
subgraphs=True,
|
|
503
|
+
):
|
|
504
|
+
nodes: tuple[str, ...]
|
|
505
|
+
stream_mode: str
|
|
506
|
+
stream_messages: Sequence[BaseMessage]
|
|
507
|
+
logger.trace(
|
|
508
|
+
f"nodes: {nodes}, stream_mode: {stream_mode}, messages: {stream_messages}"
|
|
509
|
+
)
|
|
510
|
+
for message in stream_messages:
|
|
511
|
+
if (
|
|
512
|
+
isinstance(
|
|
513
|
+
message,
|
|
514
|
+
(
|
|
515
|
+
AIMessageChunk,
|
|
516
|
+
AIMessage,
|
|
517
|
+
),
|
|
518
|
+
)
|
|
519
|
+
and message.content
|
|
520
|
+
and "summarization" not in nodes
|
|
521
|
+
):
|
|
522
|
+
yield message
|
|
523
|
+
|
|
524
|
+
# Convert async generator to sync generator
|
|
525
|
+
loop = asyncio.get_event_loop()
|
|
526
|
+
async_gen = _async_stream()
|
|
527
|
+
|
|
528
|
+
try:
|
|
529
|
+
while True:
|
|
530
|
+
try:
|
|
531
|
+
item = loop.run_until_complete(async_gen.__anext__())
|
|
532
|
+
yield item
|
|
533
|
+
except StopAsyncIteration:
|
|
534
|
+
break
|
|
535
|
+
finally:
|
|
536
|
+
loop.run_until_complete(async_gen.aclose())
|
|
225
537
|
|
|
226
538
|
|
|
227
539
|
def _process_mlflow_messages(
|
|
@@ -232,6 +544,14 @@ def _process_mlflow_messages(
|
|
|
232
544
|
return app.predict(None, messages, custom_inputs)
|
|
233
545
|
|
|
234
546
|
|
|
547
|
+
def _process_mlflow_response_messages(
|
|
548
|
+
app: ResponsesAgent,
|
|
549
|
+
messages: ResponsesAgentRequest,
|
|
550
|
+
) -> ResponsesAgentResponse:
|
|
551
|
+
"""Process MLflow ResponsesAgent request in batch mode."""
|
|
552
|
+
return app.predict(messages)
|
|
553
|
+
|
|
554
|
+
|
|
235
555
|
def _process_mlflow_messages_stream(
|
|
236
556
|
app: ChatModel,
|
|
237
557
|
messages: Sequence[ChatMessage],
|
|
@@ -242,37 +562,73 @@ def _process_mlflow_messages_stream(
|
|
|
242
562
|
yield event
|
|
243
563
|
|
|
244
564
|
|
|
565
|
+
def _process_mlflow_response_messages_stream(
|
|
566
|
+
app: ResponsesAgent,
|
|
567
|
+
messages: ResponsesAgentRequest,
|
|
568
|
+
) -> Generator[ResponsesAgentStreamEvent, None, None]:
|
|
569
|
+
"""Process MLflow ResponsesAgent request in streaming mode."""
|
|
570
|
+
for event in app.predict_stream(messages):
|
|
571
|
+
event: ResponsesAgentStreamEvent
|
|
572
|
+
yield event
|
|
573
|
+
|
|
574
|
+
|
|
245
575
|
def _process_config_messages(
|
|
246
|
-
app:
|
|
576
|
+
app: LanggraphChatModel | LanggraphResponsesAgent,
|
|
247
577
|
messages: dict[str, Any],
|
|
248
578
|
custom_inputs: Optional[dict[str, Any]] = None,
|
|
249
|
-
) -> ChatCompletionResponse:
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
579
|
+
) -> ChatCompletionResponse | ResponsesAgentResponse:
|
|
580
|
+
if isinstance(app, LanggraphChatModel):
|
|
581
|
+
messages: Sequence[ChatMessage] = [ChatMessage(**m) for m in messages]
|
|
582
|
+
params: ChatParams = ChatParams(**{"custom_inputs": custom_inputs})
|
|
583
|
+
return _process_mlflow_messages(app, messages, params)
|
|
584
|
+
|
|
585
|
+
elif isinstance(app, LanggraphResponsesAgent):
|
|
586
|
+
input_messages: list[Message] = [Message(**m) for m in messages]
|
|
587
|
+
request = ResponsesAgentRequest(
|
|
588
|
+
input=input_messages, custom_inputs=custom_inputs
|
|
589
|
+
)
|
|
590
|
+
return _process_mlflow_response_messages(app, request)
|
|
254
591
|
|
|
255
592
|
|
|
256
593
|
def _process_config_messages_stream(
|
|
257
|
-
app:
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
594
|
+
app: LanggraphChatModel | LanggraphResponsesAgent,
|
|
595
|
+
messages: dict[str, Any],
|
|
596
|
+
custom_inputs: dict[str, Any],
|
|
597
|
+
) -> Generator[ChatCompletionChunk | ResponsesAgentStreamEvent, None, None]:
|
|
598
|
+
if isinstance(app, LanggraphChatModel):
|
|
599
|
+
messages: Sequence[ChatMessage] = [ChatMessage(**m) for m in messages]
|
|
600
|
+
params: ChatParams = ChatParams(**{"custom_inputs": custom_inputs})
|
|
261
601
|
|
|
262
|
-
|
|
263
|
-
|
|
602
|
+
for event in _process_mlflow_messages_stream(
|
|
603
|
+
app, messages, custom_inputs=params
|
|
604
|
+
):
|
|
605
|
+
yield event
|
|
606
|
+
|
|
607
|
+
elif isinstance(app, LanggraphResponsesAgent):
|
|
608
|
+
input_messages: list[Message] = [Message(**m) for m in messages]
|
|
609
|
+
request = ResponsesAgentRequest(
|
|
610
|
+
input=input_messages, custom_inputs=custom_inputs
|
|
611
|
+
)
|
|
612
|
+
|
|
613
|
+
for event in _process_mlflow_response_messages_stream(app, request):
|
|
614
|
+
yield event
|
|
264
615
|
|
|
265
616
|
|
|
266
617
|
def process_messages_stream(
|
|
267
|
-
app: LanggraphChatModel,
|
|
268
|
-
messages: Sequence[BaseMessage]
|
|
618
|
+
app: LanggraphChatModel | LanggraphResponsesAgent,
|
|
619
|
+
messages: Sequence[BaseMessage]
|
|
620
|
+
| Sequence[ChatMessage]
|
|
621
|
+
| ResponsesAgentRequest
|
|
622
|
+
| dict[str, Any],
|
|
269
623
|
custom_inputs: Optional[dict[str, Any]] = None,
|
|
270
|
-
) -> Generator[
|
|
624
|
+
) -> Generator[
|
|
625
|
+
ChatCompletionChunk | ResponsesAgentStreamEvent | AIMessageChunk, None, None
|
|
626
|
+
]:
|
|
271
627
|
"""
|
|
272
628
|
Process messages through a ChatAgent in streaming mode.
|
|
273
629
|
|
|
274
630
|
Utility function that normalizes message input formats and
|
|
275
|
-
streams the agent's responses as they're generated.
|
|
631
|
+
streams the agent's responses as they're generated using async LangGraph calls internally.
|
|
276
632
|
|
|
277
633
|
Args:
|
|
278
634
|
app: The ChatAgent to process messages with
|
|
@@ -282,7 +638,10 @@ def process_messages_stream(
|
|
|
282
638
|
Individual message chunks from the streaming response
|
|
283
639
|
"""
|
|
284
640
|
|
|
285
|
-
if
|
|
641
|
+
if has_mlflow_responses_messages(messages):
|
|
642
|
+
for event in _process_mlflow_response_messages_stream(app, messages):
|
|
643
|
+
yield event
|
|
644
|
+
elif has_mlflow_messages(messages):
|
|
286
645
|
for event in _process_mlflow_messages_stream(app, messages, custom_inputs):
|
|
287
646
|
yield event
|
|
288
647
|
elif has_langchain_messages(messages):
|
|
@@ -294,15 +653,18 @@ def process_messages_stream(
|
|
|
294
653
|
|
|
295
654
|
|
|
296
655
|
def process_messages(
|
|
297
|
-
app: LanggraphChatModel,
|
|
298
|
-
messages: Sequence[BaseMessage]
|
|
656
|
+
app: LanggraphChatModel | LanggraphResponsesAgent,
|
|
657
|
+
messages: Sequence[BaseMessage]
|
|
658
|
+
| Sequence[ChatMessage]
|
|
659
|
+
| ResponsesAgentRequest
|
|
660
|
+
| dict[str, Any],
|
|
299
661
|
custom_inputs: Optional[dict[str, Any]] = None,
|
|
300
|
-
) -> ChatCompletionResponse | dict[str, Any] | Any:
|
|
662
|
+
) -> ChatCompletionResponse | ResponsesAgentResponse | dict[str, Any] | Any:
|
|
301
663
|
"""
|
|
302
664
|
Process messages through a ChatAgent in batch mode.
|
|
303
665
|
|
|
304
666
|
Utility function that normalizes message input formats and
|
|
305
|
-
returns the complete response from the agent.
|
|
667
|
+
returns the complete response from the agent using async LangGraph calls internally.
|
|
306
668
|
|
|
307
669
|
Args:
|
|
308
670
|
app: The ChatAgent to process messages with
|
|
@@ -312,7 +674,9 @@ def process_messages(
|
|
|
312
674
|
Complete response from the agent
|
|
313
675
|
"""
|
|
314
676
|
|
|
315
|
-
if
|
|
677
|
+
if has_mlflow_responses_messages(messages):
|
|
678
|
+
return _process_mlflow_response_messages(app, messages)
|
|
679
|
+
elif has_mlflow_messages(messages):
|
|
316
680
|
return _process_mlflow_messages(app, messages, custom_inputs)
|
|
317
681
|
elif has_langchain_messages(messages):
|
|
318
682
|
return _process_langchain_messages(app, messages, custom_inputs)
|
dao_ai/nodes.py
CHANGED
|
@@ -69,7 +69,7 @@ def summarization_node(app_model: AppModel) -> RunnableLike:
|
|
|
69
69
|
|
|
70
70
|
|
|
71
71
|
def call_agent_with_summarized_messages(agent: CompiledStateGraph) -> RunnableLike:
|
|
72
|
-
def call_agent(state: SharedState, runtime: Runtime[Context]) -> SharedState:
|
|
72
|
+
async def call_agent(state: SharedState, runtime: Runtime[Context]) -> SharedState:
|
|
73
73
|
logger.debug(f"Calling agent {agent.name} with summarized messages")
|
|
74
74
|
|
|
75
75
|
# Get the summarized messages from the summarization node
|
|
@@ -81,7 +81,9 @@ def call_agent_with_summarized_messages(agent: CompiledStateGraph) -> RunnableLi
|
|
|
81
81
|
"messages": messages,
|
|
82
82
|
}
|
|
83
83
|
|
|
84
|
-
response: dict[str, Any] = agent.
|
|
84
|
+
response: dict[str, Any] = await agent.ainvoke(
|
|
85
|
+
input=input, context=runtime.context
|
|
86
|
+
)
|
|
85
87
|
response_messages = response.get("messages", [])
|
|
86
88
|
logger.debug(f"Agent returned {len(response_messages)} messages")
|
|
87
89
|
|
|
@@ -193,7 +195,9 @@ def message_hook_node(config: AppConfig) -> RunnableLike:
|
|
|
193
195
|
message_hooks: Sequence[Callable[..., Any]] = create_hooks(config.app.message_hooks)
|
|
194
196
|
|
|
195
197
|
@mlflow.trace()
|
|
196
|
-
def message_hook(
|
|
198
|
+
async def message_hook(
|
|
199
|
+
state: IncomingState, runtime: Runtime[Context]
|
|
200
|
+
) -> SharedState:
|
|
197
201
|
logger.debug("Running message validation")
|
|
198
202
|
response: dict[str, Any] = {"is_valid": True, "message_error": None}
|
|
199
203
|
|
dao_ai/providers/databricks.py
CHANGED
|
@@ -235,12 +235,16 @@ class DatabricksProvider(ServiceProvider):
|
|
|
235
235
|
+ connections
|
|
236
236
|
)
|
|
237
237
|
|
|
238
|
-
#
|
|
239
|
-
|
|
240
|
-
|
|
238
|
+
# Flatten all resources from all models into a single list
|
|
239
|
+
all_resources: list[DatabricksResource] = []
|
|
240
|
+
for r in resources:
|
|
241
|
+
all_resources.extend(r.as_resources())
|
|
241
242
|
|
|
242
243
|
system_resources: Sequence[DatabricksResource] = [
|
|
243
|
-
|
|
244
|
+
resource
|
|
245
|
+
for r in resources
|
|
246
|
+
for resource in r.as_resources()
|
|
247
|
+
if not r.on_behalf_of_user
|
|
244
248
|
]
|
|
245
249
|
logger.debug(f"system_resources: {[r.name for r in system_resources]}")
|
|
246
250
|
|
dao_ai/tools/agent.py
CHANGED
|
@@ -40,16 +40,16 @@ def create_agent_endpoint_tool(
|
|
|
40
40
|
|
|
41
41
|
doc: str = description + "\n" + doc_signature
|
|
42
42
|
|
|
43
|
-
def agent_endpoint(prompt: str) -> AIMessage:
|
|
43
|
+
async def agent_endpoint(prompt: str) -> AIMessage:
|
|
44
44
|
model: LanguageModelLike = llm.as_chat_model()
|
|
45
45
|
messages: Sequence[BaseMessage] = [HumanMessage(content=prompt)]
|
|
46
|
-
response: AIMessage = model.
|
|
46
|
+
response: AIMessage = await model.ainvoke(messages)
|
|
47
47
|
return response
|
|
48
48
|
|
|
49
49
|
name: str = name if name else agent_endpoint.__name__
|
|
50
50
|
|
|
51
51
|
structured_tool: StructuredTool = StructuredTool.from_function(
|
|
52
|
-
|
|
52
|
+
coroutine=agent_endpoint, name=name, description=doc, parse_docstring=False
|
|
53
53
|
)
|
|
54
54
|
|
|
55
55
|
return structured_tool
|
dao_ai/tools/genie.py
CHANGED
|
@@ -62,14 +62,16 @@ def create_genie_tool(
|
|
|
62
62
|
|
|
63
63
|
doc: str = description + "\n" + doc_signature
|
|
64
64
|
|
|
65
|
-
def genie_tool(question: str) -> GenieResponse:
|
|
65
|
+
async def genie_tool(question: str) -> GenieResponse:
|
|
66
|
+
# Use sync API for now since Genie doesn't support async yet
|
|
67
|
+
# Can be easily updated to await when Genie gets async support
|
|
66
68
|
response: GenieResponse = genie.ask_question(question)
|
|
67
69
|
return response
|
|
68
70
|
|
|
69
71
|
name: str = name if name else genie_tool.__name__
|
|
70
72
|
|
|
71
73
|
structured_tool: StructuredTool = StructuredTool.from_function(
|
|
72
|
-
|
|
74
|
+
coroutine=genie_tool, name=name, description=doc, parse_docstring=False
|
|
73
75
|
)
|
|
74
76
|
|
|
75
77
|
return structured_tool
|
|
@@ -50,7 +50,7 @@ def add_human_in_the_loop(
|
|
|
50
50
|
logger.debug(f"Wrapping tool {tool} with human-in-the-loop functionality")
|
|
51
51
|
|
|
52
52
|
@create_tool(tool.name, description=tool.description, args_schema=tool.args_schema)
|
|
53
|
-
def call_tool_with_interrupt(config: RunnableConfig, **tool_input) -> Any:
|
|
53
|
+
async def call_tool_with_interrupt(config: RunnableConfig, **tool_input) -> Any:
|
|
54
54
|
logger.debug(f"call_tool_with_interrupt: {tool.name} with input: {tool_input}")
|
|
55
55
|
request: HumanInterrupt = {
|
|
56
56
|
"action_request": {
|
|
@@ -66,10 +66,10 @@ def add_human_in_the_loop(
|
|
|
66
66
|
logger.debug(f"Human interrupt response: {response}")
|
|
67
67
|
|
|
68
68
|
if response["type"] == "accept":
|
|
69
|
-
tool_response = tool.
|
|
69
|
+
tool_response = await tool.ainvoke(tool_input, config=config)
|
|
70
70
|
elif response["type"] == "edit":
|
|
71
71
|
tool_input = response["args"]["args"]
|
|
72
|
-
tool_response = tool.
|
|
72
|
+
tool_response = await tool.ainvoke(tool_input, config=config)
|
|
73
73
|
elif response["type"] == "response":
|
|
74
74
|
user_feedback = response["args"]
|
|
75
75
|
tool_response = user_feedback
|
dao_ai/tools/mcp.py
CHANGED
|
@@ -77,6 +77,8 @@ def create_mcp_tools(
|
|
|
77
77
|
logger.error(f"Failed to list MCP tools: {e}")
|
|
78
78
|
return []
|
|
79
79
|
|
|
80
|
+
# Note: This still needs to run sync during tool creation/registration
|
|
81
|
+
# The actual tool execution will be async
|
|
80
82
|
try:
|
|
81
83
|
mcp_tools: list | ListToolsResult = asyncio.run(_list_mcp_tools())
|
|
82
84
|
if isinstance(mcp_tools, ListToolsResult):
|
|
@@ -96,22 +98,19 @@ def create_mcp_tools(
|
|
|
96
98
|
description=mcp_tool.description or f"MCP tool: {mcp_tool.name}",
|
|
97
99
|
args_schema=mcp_tool.inputSchema,
|
|
98
100
|
)
|
|
99
|
-
def tool_wrapper(**kwargs):
|
|
101
|
+
async def tool_wrapper(**kwargs):
|
|
100
102
|
"""Execute MCP tool with fresh session and authentication."""
|
|
101
103
|
logger.debug(f"Invoking MCP tool {mcp_tool.name} with fresh session")
|
|
102
104
|
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
client = MultiServerMCPClient({function.name: connection})
|
|
105
|
+
connection = _create_fresh_connection()
|
|
106
|
+
client = MultiServerMCPClient({function.name: connection})
|
|
106
107
|
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
return asyncio.run(_invoke())
|
|
108
|
+
try:
|
|
109
|
+
async with client.session(function.name) as session:
|
|
110
|
+
return await session.call_tool(mcp_tool.name, kwargs)
|
|
111
|
+
except Exception as e:
|
|
112
|
+
logger.error(f"MCP tool {mcp_tool.name} failed: {e}")
|
|
113
|
+
raise
|
|
115
114
|
|
|
116
115
|
return as_human_in_the_loop(tool_wrapper, function)
|
|
117
116
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dao-ai
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.18
|
|
4
4
|
Summary: DAO AI: A modular, multi-agent orchestration framework for complex AI workflows. Supports agent handoff, tool integration, and dynamic configuration via YAML.
|
|
5
5
|
Project-URL: Homepage, https://github.com/natefleming/dao-ai
|
|
6
6
|
Project-URL: Documentation, https://natefleming.github.io/dao-ai
|
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
dao_ai/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
dao_ai/agent_as_code.py,sha256=
|
|
2
|
+
dao_ai/agent_as_code.py,sha256=kPSeDz2-1jRaed1TMs4LA3VECoyqe9_Ed2beRLB9gXQ,472
|
|
3
3
|
dao_ai/catalog.py,sha256=sPZpHTD3lPx4EZUtIWeQV7VQM89WJ6YH__wluk1v2lE,4947
|
|
4
4
|
dao_ai/chat_models.py,sha256=uhwwOTeLyHWqoTTgHrs4n5iSyTwe4EQcLKnh3jRxPWI,8626
|
|
5
5
|
dao_ai/cli.py,sha256=Aez2TQW3Q8Ho1IaIkRggt0NevDxAAVPjXkePC5GPJF0,20429
|
|
6
|
-
dao_ai/config.py,sha256=
|
|
7
|
-
dao_ai/graph.py,sha256=
|
|
6
|
+
dao_ai/config.py,sha256=JlYC8N_7UL8VVkdSepiCUnR9NA5OsCVAigLjse7dMFM,49922
|
|
7
|
+
dao_ai/graph.py,sha256=kXaGLGFVekDWqm-AHzti6LmrXnyi99VQ-AdCGuNb_xM,7831
|
|
8
8
|
dao_ai/guardrails.py,sha256=-Qh0f_2Db9t4Nbrrx9FM7tnpqShjMoyxepZ0HByItfU,4027
|
|
9
|
-
dao_ai/messages.py,sha256=
|
|
10
|
-
dao_ai/models.py,sha256=
|
|
11
|
-
dao_ai/nodes.py,sha256=
|
|
9
|
+
dao_ai/messages.py,sha256=xl_3-WcFqZKCFCiov8sZOPljTdM3gX3fCHhxq-xFg2U,7005
|
|
10
|
+
dao_ai/models.py,sha256=h_xFMK5FHQwPApEAYhvrt69y7ZUljmqThHTjp-yde_o,25368
|
|
11
|
+
dao_ai/nodes.py,sha256=SSuFNTXOdFaKg_aX-yUkQO7fM9wvNGu14lPXKDapU1U,8461
|
|
12
12
|
dao_ai/prompts.py,sha256=vpmIbWs_szXUgNNDs5Gh2LcxKZti5pHDKSfoClUcgX0,1289
|
|
13
13
|
dao_ai/state.py,sha256=GwbMbd1TWZx1T5iQrEOX6_rpxOitlmyeJ8dMr2o_pag,1031
|
|
14
14
|
dao_ai/types.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -22,19 +22,19 @@ dao_ai/memory/core.py,sha256=K45iCEFbqJCVxMi4m3vmBJi4c6TQ-UtKGzyugDTkPP0,4141
|
|
|
22
22
|
dao_ai/memory/postgres.py,sha256=YILzA7xtqawPAOLFaGG_i17zW7cQxXTzTD8yd-ipe8k,12480
|
|
23
23
|
dao_ai/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
24
24
|
dao_ai/providers/base.py,sha256=-fjKypCOk28h6vioPfMj9YZSw_3Kcbi2nMuAyY7vX9k,1383
|
|
25
|
-
dao_ai/providers/databricks.py,sha256=
|
|
25
|
+
dao_ai/providers/databricks.py,sha256=KLYrLccOA3Uws9nWJcJUZTbMz-MdR_onhlQeztbplCM,28073
|
|
26
26
|
dao_ai/tools/__init__.py,sha256=ye6MHaJY7tUnJ8336YJiLxuZr55zDPNdOw6gm7j5jlc,1103
|
|
27
|
-
dao_ai/tools/agent.py,sha256=
|
|
27
|
+
dao_ai/tools/agent.py,sha256=WbQnyziiT12TLMrA7xK0VuOU029tdmUBXbUl-R1VZ0Q,1886
|
|
28
28
|
dao_ai/tools/core.py,sha256=Kei33S8vrmvPOAyrFNekaWmV2jqZ-IPS1QDSvU7RZF0,1984
|
|
29
|
-
dao_ai/tools/genie.py,sha256=
|
|
30
|
-
dao_ai/tools/human_in_the_loop.py,sha256=
|
|
31
|
-
dao_ai/tools/mcp.py,sha256=
|
|
29
|
+
dao_ai/tools/genie.py,sha256=GzV5lfDYKmzW_lSLxAsPaTwnzX6GxQOB1UcLaTDqpfY,2787
|
|
30
|
+
dao_ai/tools/human_in_the_loop.py,sha256=IBmQJmpxkdDxnBNyABc_-dZhhsQlTNTkPyUXgkHKIgY,3466
|
|
31
|
+
dao_ai/tools/mcp.py,sha256=auEt_dwv4J26fr5AgLmwmnAsI894-cyuvkvjItzAUxs,4419
|
|
32
32
|
dao_ai/tools/python.py,sha256=XcQiTMshZyLUTVR5peB3vqsoUoAAy8gol9_pcrhddfI,1831
|
|
33
33
|
dao_ai/tools/time.py,sha256=Y-23qdnNHzwjvnfkWvYsE7PoWS1hfeKy44tA7sCnNac,8759
|
|
34
34
|
dao_ai/tools/unity_catalog.py,sha256=PXfLj2EgyQgaXq4Qq3t25AmTC4KyVCF_-sCtg6enens,1404
|
|
35
35
|
dao_ai/tools/vector_search.py,sha256=EDYQs51zIPaAP0ma1D81wJT77GQ-v-cjb2XrFVWfWdg,2621
|
|
36
|
-
dao_ai-0.0.
|
|
37
|
-
dao_ai-0.0.
|
|
38
|
-
dao_ai-0.0.
|
|
39
|
-
dao_ai-0.0.
|
|
40
|
-
dao_ai-0.0.
|
|
36
|
+
dao_ai-0.0.18.dist-info/METADATA,sha256=9lTAXjEqQHxl6dmRMyiqUnYT1Nh_wJpSeJXRG8bGZGg,41378
|
|
37
|
+
dao_ai-0.0.18.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
38
|
+
dao_ai-0.0.18.dist-info/entry_points.txt,sha256=Xa-UFyc6gWGwMqMJOt06ZOog2vAfygV_DSwg1AiP46g,43
|
|
39
|
+
dao_ai-0.0.18.dist-info/licenses/LICENSE,sha256=YZt3W32LtPYruuvHE9lGk2bw6ZPMMJD8yLrjgHybyz4,1069
|
|
40
|
+
dao_ai-0.0.18.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|