dao-ai 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/agent_as_code.py +2 -2
- dao_ai/config.py +235 -45
- dao_ai/graph.py +4 -4
- dao_ai/memory/core.py +2 -2
- dao_ai/memory/postgres.py +134 -2
- dao_ai/messages.py +7 -0
- dao_ai/models.py +373 -24
- dao_ai/providers/databricks.py +11 -4
- {dao_ai-0.0.17.dist-info → dao_ai-0.0.19.dist-info}/METADATA +8 -8
- {dao_ai-0.0.17.dist-info → dao_ai-0.0.19.dist-info}/RECORD +13 -13
- {dao_ai-0.0.17.dist-info → dao_ai-0.0.19.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.17.dist-info → dao_ai-0.0.19.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.17.dist-info → dao_ai-0.0.19.dist-info}/licenses/LICENSE +0 -0
dao_ai/agent_as_code.py
CHANGED
|
@@ -3,7 +3,7 @@ import sys
|
|
|
3
3
|
import mlflow
|
|
4
4
|
from loguru import logger
|
|
5
5
|
from mlflow.models import ModelConfig
|
|
6
|
-
from mlflow.pyfunc import
|
|
6
|
+
from mlflow.pyfunc import ResponsesAgent
|
|
7
7
|
|
|
8
8
|
from dao_ai.config import AppConfig
|
|
9
9
|
|
|
@@ -17,6 +17,6 @@ log_level: str = config.app.log_level
|
|
|
17
17
|
logger.remove()
|
|
18
18
|
logger.add(sys.stderr, level=log_level)
|
|
19
19
|
|
|
20
|
-
app:
|
|
20
|
+
app: ResponsesAgent = config.as_responses_agent()
|
|
21
21
|
|
|
22
22
|
mlflow.models.set_model(app)
|
dao_ai/config.py
CHANGED
|
@@ -9,6 +9,7 @@ from pathlib import Path
|
|
|
9
9
|
from typing import (
|
|
10
10
|
Any,
|
|
11
11
|
Callable,
|
|
12
|
+
Iterator,
|
|
12
13
|
Literal,
|
|
13
14
|
Optional,
|
|
14
15
|
Sequence,
|
|
@@ -21,6 +22,7 @@ from databricks.sdk.credentials_provider import (
|
|
|
21
22
|
CredentialsStrategy,
|
|
22
23
|
ModelServingUserCredentials,
|
|
23
24
|
)
|
|
25
|
+
from databricks.sdk.service.catalog import FunctionInfo, TableInfo
|
|
24
26
|
from databricks.vector_search.client import VectorSearchClient
|
|
25
27
|
from databricks.vector_search.index import VectorSearchIndex
|
|
26
28
|
from databricks_langchain import (
|
|
@@ -37,6 +39,7 @@ from mlflow.models import ModelConfig
|
|
|
37
39
|
from mlflow.models.resources import (
|
|
38
40
|
DatabricksFunction,
|
|
39
41
|
DatabricksGenieSpace,
|
|
42
|
+
DatabricksLakebase,
|
|
40
43
|
DatabricksResource,
|
|
41
44
|
DatabricksServingEndpoint,
|
|
42
45
|
DatabricksSQLWarehouse,
|
|
@@ -44,8 +47,14 @@ from mlflow.models.resources import (
|
|
|
44
47
|
DatabricksUCConnection,
|
|
45
48
|
DatabricksVectorSearchIndex,
|
|
46
49
|
)
|
|
47
|
-
from mlflow.pyfunc import ChatModel
|
|
48
|
-
from pydantic import
|
|
50
|
+
from mlflow.pyfunc import ChatModel, ResponsesAgent
|
|
51
|
+
from pydantic import (
|
|
52
|
+
BaseModel,
|
|
53
|
+
ConfigDict,
|
|
54
|
+
Field,
|
|
55
|
+
field_serializer,
|
|
56
|
+
model_validator,
|
|
57
|
+
)
|
|
49
58
|
|
|
50
59
|
|
|
51
60
|
class HasValue(ABC):
|
|
@@ -69,7 +78,7 @@ class IsDatabricksResource(ABC):
|
|
|
69
78
|
on_behalf_of_user: Optional[bool] = False
|
|
70
79
|
|
|
71
80
|
@abstractmethod
|
|
72
|
-
def
|
|
81
|
+
def as_resources(self) -> Sequence[DatabricksResource]: ...
|
|
73
82
|
|
|
74
83
|
@property
|
|
75
84
|
@abstractmethod
|
|
@@ -235,22 +244,70 @@ class SchemaModel(BaseModel, HasFullName):
|
|
|
235
244
|
class TableModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
236
245
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
237
246
|
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
238
|
-
name: str
|
|
247
|
+
name: Optional[str] = None
|
|
248
|
+
|
|
249
|
+
@model_validator(mode="after")
|
|
250
|
+
def validate_name_or_schema_required(self) -> "TableModel":
|
|
251
|
+
if not self.name and not self.schema_model:
|
|
252
|
+
raise ValueError(
|
|
253
|
+
"Either 'name' or 'schema_model' must be provided for TableModel"
|
|
254
|
+
)
|
|
255
|
+
return self
|
|
239
256
|
|
|
240
257
|
@property
|
|
241
258
|
def full_name(self) -> str:
|
|
242
259
|
if self.schema_model:
|
|
243
|
-
|
|
260
|
+
name: str = ""
|
|
261
|
+
if self.name:
|
|
262
|
+
name = f".{self.name}"
|
|
263
|
+
return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}{name}"
|
|
244
264
|
return self.name
|
|
245
265
|
|
|
246
266
|
@property
|
|
247
267
|
def api_scopes(self) -> Sequence[str]:
|
|
248
268
|
return []
|
|
249
269
|
|
|
250
|
-
def
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
270
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
271
|
+
resources: list[DatabricksResource] = []
|
|
272
|
+
|
|
273
|
+
excluded_suffixes: Sequence[str] = [
|
|
274
|
+
"_payload",
|
|
275
|
+
"_assessment_logs",
|
|
276
|
+
"_request_logs",
|
|
277
|
+
]
|
|
278
|
+
|
|
279
|
+
excluded_prefixes: Sequence[str] = ["trace_logs_"]
|
|
280
|
+
|
|
281
|
+
if self.name:
|
|
282
|
+
resources.append(
|
|
283
|
+
DatabricksTable(
|
|
284
|
+
table_name=self.full_name, on_behalf_of_user=self.on_behalf_of_user
|
|
285
|
+
)
|
|
286
|
+
)
|
|
287
|
+
else:
|
|
288
|
+
w: WorkspaceClient = self.workspace_client
|
|
289
|
+
schema_full_name: str = self.schema_model.full_name
|
|
290
|
+
tables: Iterator[TableInfo] = w.tables.list(
|
|
291
|
+
catalog_name=self.schema_model.catalog_name,
|
|
292
|
+
schema_name=self.schema_model.schema_name,
|
|
293
|
+
)
|
|
294
|
+
resources.extend(
|
|
295
|
+
[
|
|
296
|
+
DatabricksTable(
|
|
297
|
+
table_name=f"{schema_full_name}.{table.name}",
|
|
298
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
299
|
+
)
|
|
300
|
+
for table in tables
|
|
301
|
+
if not any(
|
|
302
|
+
table.name.endswith(suffix) for suffix in excluded_suffixes
|
|
303
|
+
)
|
|
304
|
+
and not any(
|
|
305
|
+
table.name.startswith(prefix) for prefix in excluded_prefixes
|
|
306
|
+
)
|
|
307
|
+
]
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
return resources
|
|
254
311
|
|
|
255
312
|
|
|
256
313
|
class LLMModel(BaseModel, IsDatabricksResource):
|
|
@@ -266,10 +323,12 @@ class LLMModel(BaseModel, IsDatabricksResource):
|
|
|
266
323
|
"serving.serving-endpoints",
|
|
267
324
|
]
|
|
268
325
|
|
|
269
|
-
def
|
|
270
|
-
return
|
|
271
|
-
|
|
272
|
-
|
|
326
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
327
|
+
return [
|
|
328
|
+
DatabricksServingEndpoint(
|
|
329
|
+
endpoint_name=self.name, on_behalf_of_user=self.on_behalf_of_user
|
|
330
|
+
)
|
|
331
|
+
]
|
|
273
332
|
|
|
274
333
|
def as_chat_model(self) -> LanguageModelLike:
|
|
275
334
|
# Retrieve langchain chat client from workspace client to enable OBO
|
|
@@ -345,17 +404,19 @@ class IndexModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
|
345
404
|
return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}.{self.name}"
|
|
346
405
|
return self.name
|
|
347
406
|
|
|
348
|
-
def
|
|
349
|
-
return
|
|
350
|
-
|
|
351
|
-
|
|
407
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
408
|
+
return [
|
|
409
|
+
DatabricksVectorSearchIndex(
|
|
410
|
+
index_name=self.full_name, on_behalf_of_user=self.on_behalf_of_user
|
|
411
|
+
)
|
|
412
|
+
]
|
|
352
413
|
|
|
353
414
|
|
|
354
415
|
class GenieRoomModel(BaseModel, IsDatabricksResource):
|
|
355
416
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
356
417
|
name: str
|
|
357
418
|
description: Optional[str] = None
|
|
358
|
-
space_id:
|
|
419
|
+
space_id: AnyVariable
|
|
359
420
|
|
|
360
421
|
@property
|
|
361
422
|
def api_scopes(self) -> Sequence[str]:
|
|
@@ -363,10 +424,17 @@ class GenieRoomModel(BaseModel, IsDatabricksResource):
|
|
|
363
424
|
"dashboards.genie",
|
|
364
425
|
]
|
|
365
426
|
|
|
366
|
-
def
|
|
367
|
-
return
|
|
368
|
-
|
|
369
|
-
|
|
427
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
428
|
+
return [
|
|
429
|
+
DatabricksGenieSpace(
|
|
430
|
+
genie_space_id=self.space_id, on_behalf_of_user=self.on_behalf_of_user
|
|
431
|
+
)
|
|
432
|
+
]
|
|
433
|
+
|
|
434
|
+
@model_validator(mode="after")
|
|
435
|
+
def update_space_id(self):
|
|
436
|
+
self.space_id = value_of(self.space_id)
|
|
437
|
+
return self
|
|
370
438
|
|
|
371
439
|
|
|
372
440
|
class VolumeModel(BaseModel, HasFullName):
|
|
@@ -394,7 +462,7 @@ class VolumePathModel(BaseModel, HasFullName):
|
|
|
394
462
|
path: Optional[str] = None
|
|
395
463
|
|
|
396
464
|
@model_validator(mode="after")
|
|
397
|
-
def validate_path_or_volume(self):
|
|
465
|
+
def validate_path_or_volume(self) -> "VolumePathModel":
|
|
398
466
|
if not self.volume and not self.path:
|
|
399
467
|
raise ValueError("Either 'volume' or 'path' must be provided")
|
|
400
468
|
return self
|
|
@@ -502,8 +570,8 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
|
|
|
502
570
|
"serving.serving-endpoints",
|
|
503
571
|
] + self.index.api_scopes
|
|
504
572
|
|
|
505
|
-
def
|
|
506
|
-
return self.index.
|
|
573
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
574
|
+
return self.index.as_resources()
|
|
507
575
|
|
|
508
576
|
def as_index(self, vsc: VectorSearchClient | None = None) -> VectorSearchIndex:
|
|
509
577
|
from dao_ai.providers.base import ServiceProvider
|
|
@@ -524,18 +592,52 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
|
|
|
524
592
|
class FunctionModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
525
593
|
model_config = ConfigDict()
|
|
526
594
|
schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
|
|
527
|
-
name: str
|
|
595
|
+
name: Optional[str] = None
|
|
596
|
+
|
|
597
|
+
@model_validator(mode="after")
|
|
598
|
+
def validate_name_or_schema_required(self) -> "FunctionModel":
|
|
599
|
+
if not self.name and not self.schema_model:
|
|
600
|
+
raise ValueError(
|
|
601
|
+
"Either 'name' or 'schema_model' must be provided for FunctionModel"
|
|
602
|
+
)
|
|
603
|
+
return self
|
|
528
604
|
|
|
529
605
|
@property
|
|
530
606
|
def full_name(self) -> str:
|
|
531
607
|
if self.schema_model:
|
|
532
|
-
|
|
608
|
+
name: str = ""
|
|
609
|
+
if self.name:
|
|
610
|
+
name = f".{self.name}"
|
|
611
|
+
return f"{self.schema_model.catalog_name}.{self.schema_model.schema_name}{name}"
|
|
533
612
|
return self.name
|
|
534
613
|
|
|
535
|
-
def
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
614
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
615
|
+
resources: list[DatabricksResource] = []
|
|
616
|
+
if self.name:
|
|
617
|
+
resources.append(
|
|
618
|
+
DatabricksFunction(
|
|
619
|
+
function_name=self.full_name,
|
|
620
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
621
|
+
)
|
|
622
|
+
)
|
|
623
|
+
else:
|
|
624
|
+
w: WorkspaceClient = self.workspace_client
|
|
625
|
+
schema_full_name: str = self.schema_model.full_name
|
|
626
|
+
functions: Iterator[FunctionInfo] = w.functions.list(
|
|
627
|
+
catalog_name=self.schema_model.catalog_name,
|
|
628
|
+
schema_name=self.schema_model.schema_name,
|
|
629
|
+
)
|
|
630
|
+
resources.extend(
|
|
631
|
+
[
|
|
632
|
+
DatabricksFunction(
|
|
633
|
+
function_name=f"{schema_full_name}.{function.name}",
|
|
634
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
635
|
+
)
|
|
636
|
+
for function in functions
|
|
637
|
+
]
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
return resources
|
|
539
641
|
|
|
540
642
|
@property
|
|
541
643
|
def api_scopes(self) -> Sequence[str]:
|
|
@@ -554,19 +656,22 @@ class ConnectionModel(BaseModel, HasFullName, IsDatabricksResource):
|
|
|
554
656
|
def api_scopes(self) -> Sequence[str]:
|
|
555
657
|
return [
|
|
556
658
|
"catalog.connections",
|
|
659
|
+
"serving.serving-endpoints",
|
|
557
660
|
]
|
|
558
661
|
|
|
559
|
-
def
|
|
560
|
-
return
|
|
561
|
-
|
|
562
|
-
|
|
662
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
663
|
+
return [
|
|
664
|
+
DatabricksUCConnection(
|
|
665
|
+
connection_name=self.name, on_behalf_of_user=self.on_behalf_of_user
|
|
666
|
+
)
|
|
667
|
+
]
|
|
563
668
|
|
|
564
669
|
|
|
565
670
|
class WarehouseModel(BaseModel, IsDatabricksResource):
|
|
566
671
|
model_config = ConfigDict()
|
|
567
672
|
name: str
|
|
568
673
|
description: Optional[str] = None
|
|
569
|
-
warehouse_id:
|
|
674
|
+
warehouse_id: AnyVariable
|
|
570
675
|
|
|
571
676
|
@property
|
|
572
677
|
def api_scopes(self) -> Sequence[str]:
|
|
@@ -575,13 +680,20 @@ class WarehouseModel(BaseModel, IsDatabricksResource):
|
|
|
575
680
|
"sql.statement-execution",
|
|
576
681
|
]
|
|
577
682
|
|
|
578
|
-
def
|
|
579
|
-
return
|
|
580
|
-
|
|
581
|
-
|
|
683
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
684
|
+
return [
|
|
685
|
+
DatabricksSQLWarehouse(
|
|
686
|
+
warehouse_id=self.warehouse_id, on_behalf_of_user=self.on_behalf_of_user
|
|
687
|
+
)
|
|
688
|
+
]
|
|
689
|
+
|
|
690
|
+
@model_validator(mode="after")
|
|
691
|
+
def update_warehouse_id(self):
|
|
692
|
+
self.warehouse_id = value_of(self.warehouse_id)
|
|
693
|
+
return self
|
|
582
694
|
|
|
583
695
|
|
|
584
|
-
class DatabaseModel(BaseModel):
|
|
696
|
+
class DatabaseModel(BaseModel, IsDatabricksResource):
|
|
585
697
|
model_config = ConfigDict(frozen=True)
|
|
586
698
|
name: str
|
|
587
699
|
description: Optional[str] = None
|
|
@@ -597,6 +709,18 @@ class DatabaseModel(BaseModel):
|
|
|
597
709
|
client_secret: Optional[AnyVariable] = None
|
|
598
710
|
workspace_host: Optional[AnyVariable] = None
|
|
599
711
|
|
|
712
|
+
@property
|
|
713
|
+
def api_scopes(self) -> Sequence[str]:
|
|
714
|
+
return []
|
|
715
|
+
|
|
716
|
+
def as_resources(self) -> Sequence[DatabricksResource]:
|
|
717
|
+
return [
|
|
718
|
+
DatabricksLakebase(
|
|
719
|
+
database_instance_name=self.name,
|
|
720
|
+
on_behalf_of_user=self.on_behalf_of_user,
|
|
721
|
+
)
|
|
722
|
+
]
|
|
723
|
+
|
|
600
724
|
@model_validator(mode="after")
|
|
601
725
|
def validate_auth_methods(self):
|
|
602
726
|
oauth_fields: Sequence[Any] = [
|
|
@@ -1034,9 +1158,37 @@ class Message(BaseModel):
|
|
|
1034
1158
|
|
|
1035
1159
|
class ChatPayload(BaseModel):
|
|
1036
1160
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
1037
|
-
|
|
1161
|
+
input: Optional[list[Message]] = None
|
|
1162
|
+
messages: Optional[list[Message]] = None
|
|
1038
1163
|
custom_inputs: dict
|
|
1039
1164
|
|
|
1165
|
+
@model_validator(mode="after")
|
|
1166
|
+
def validate_mutual_exclusion_and_alias(self) -> "ChatPayload":
|
|
1167
|
+
"""Handle dual field support with automatic aliasing."""
|
|
1168
|
+
# If both fields are provided and they're the same, that's okay (redundant but valid)
|
|
1169
|
+
if self.input is not None and self.messages is not None:
|
|
1170
|
+
# Allow if they're identical (redundant specification)
|
|
1171
|
+
if self.input == self.messages:
|
|
1172
|
+
return self
|
|
1173
|
+
# If they're different, prefer input and copy to messages
|
|
1174
|
+
else:
|
|
1175
|
+
self.messages = self.input
|
|
1176
|
+
return self
|
|
1177
|
+
|
|
1178
|
+
# If neither field is provided, that's an error
|
|
1179
|
+
if self.input is None and self.messages is None:
|
|
1180
|
+
raise ValueError("Must specify either 'input' or 'messages' field.")
|
|
1181
|
+
|
|
1182
|
+
# Create alias: copy messages to input if input is None
|
|
1183
|
+
if self.input is None and self.messages is not None:
|
|
1184
|
+
self.input = self.messages
|
|
1185
|
+
|
|
1186
|
+
# Create alias: copy input to messages if messages is None
|
|
1187
|
+
elif self.messages is None and self.input is not None:
|
|
1188
|
+
self.messages = self.input
|
|
1189
|
+
|
|
1190
|
+
return self
|
|
1191
|
+
|
|
1040
1192
|
|
|
1041
1193
|
class ChatHistoryModel(BaseModel):
|
|
1042
1194
|
model_config = ConfigDict(use_enum_values=True, extra="forbid")
|
|
@@ -1064,13 +1216,13 @@ class AppModel(BaseModel):
|
|
|
1064
1216
|
endpoint_name: Optional[str] = None
|
|
1065
1217
|
tags: Optional[dict[str, Any]] = Field(default_factory=dict)
|
|
1066
1218
|
scale_to_zero: Optional[bool] = True
|
|
1067
|
-
environment_vars: Optional[dict[str,
|
|
1219
|
+
environment_vars: Optional[dict[str, AnyVariable]] = Field(default_factory=dict)
|
|
1068
1220
|
budget_policy_id: Optional[str] = None
|
|
1069
1221
|
workload_size: Optional[WorkloadSize] = "Small"
|
|
1070
1222
|
permissions: Optional[list[AppPermissionModel]] = Field(default_factory=list)
|
|
1071
1223
|
agents: list[AgentModel] = Field(default_factory=list)
|
|
1072
1224
|
|
|
1073
|
-
orchestration: OrchestrationModel
|
|
1225
|
+
orchestration: Optional[OrchestrationModel] = None
|
|
1074
1226
|
alias: Optional[str] = None
|
|
1075
1227
|
initialization_hooks: Optional[FunctionHook | list[FunctionHook]] = Field(
|
|
1076
1228
|
default_factory=list
|
|
@@ -1089,7 +1241,38 @@ class AppModel(BaseModel):
|
|
|
1089
1241
|
@model_validator(mode="after")
|
|
1090
1242
|
def validate_agents_not_empty(self):
|
|
1091
1243
|
if not self.agents:
|
|
1092
|
-
raise ValueError("
|
|
1244
|
+
raise ValueError("At least one agent must be specified")
|
|
1245
|
+
return self
|
|
1246
|
+
|
|
1247
|
+
@model_validator(mode="after")
|
|
1248
|
+
def update_environment_vars(self):
|
|
1249
|
+
for key, value in self.environment_vars.items():
|
|
1250
|
+
if isinstance(value, SecretVariableModel):
|
|
1251
|
+
updated_value = str(value)
|
|
1252
|
+
else:
|
|
1253
|
+
updated_value = value_of(value)
|
|
1254
|
+
|
|
1255
|
+
self.environment_vars[key] = updated_value
|
|
1256
|
+
return self
|
|
1257
|
+
|
|
1258
|
+
@model_validator(mode="after")
|
|
1259
|
+
def set_default_orchestration(self):
|
|
1260
|
+
if self.orchestration is None:
|
|
1261
|
+
if len(self.agents) > 1:
|
|
1262
|
+
default_agent: AgentModel = self.agents[0]
|
|
1263
|
+
self.orchestration = OrchestrationModel(
|
|
1264
|
+
swarm=SupervisorModel(model=default_agent.model)
|
|
1265
|
+
)
|
|
1266
|
+
elif len(self.agents) == 1:
|
|
1267
|
+
default_agent: AgentModel = self.agents[0]
|
|
1268
|
+
self.orchestration = OrchestrationModel(
|
|
1269
|
+
supervisor=SwarmModel(
|
|
1270
|
+
model=default_agent.model, default_agent=default_agent
|
|
1271
|
+
)
|
|
1272
|
+
)
|
|
1273
|
+
else:
|
|
1274
|
+
raise ValueError("At least one agent must be specified")
|
|
1275
|
+
|
|
1093
1276
|
return self
|
|
1094
1277
|
|
|
1095
1278
|
@model_validator(mode="after")
|
|
@@ -1374,3 +1557,10 @@ class AppConfig(BaseModel):
|
|
|
1374
1557
|
graph: CompiledStateGraph = self.as_graph()
|
|
1375
1558
|
app: ChatModel = create_agent(graph)
|
|
1376
1559
|
return app
|
|
1560
|
+
|
|
1561
|
+
def as_responses_agent(self) -> ResponsesAgent:
|
|
1562
|
+
from dao_ai.models import create_responses_agent
|
|
1563
|
+
|
|
1564
|
+
graph: CompiledStateGraph = self.as_graph()
|
|
1565
|
+
app: ResponsesAgent = create_responses_agent(graph)
|
|
1566
|
+
return app
|
dao_ai/graph.py
CHANGED
|
@@ -136,8 +136,8 @@ def _create_supervisor_graph(config: AppConfig) -> CompiledStateGraph:
|
|
|
136
136
|
|
|
137
137
|
workflow: StateGraph = StateGraph(
|
|
138
138
|
SharedState,
|
|
139
|
-
|
|
140
|
-
|
|
139
|
+
input_schema=IncomingState,
|
|
140
|
+
output_schema=OutgoingState,
|
|
141
141
|
context_schema=Context,
|
|
142
142
|
)
|
|
143
143
|
|
|
@@ -200,8 +200,8 @@ def _create_swarm_graph(config: AppConfig) -> CompiledStateGraph:
|
|
|
200
200
|
|
|
201
201
|
workflow: StateGraph = StateGraph(
|
|
202
202
|
SharedState,
|
|
203
|
-
|
|
204
|
-
|
|
203
|
+
input_schema=IncomingState,
|
|
204
|
+
output_schema=OutgoingState,
|
|
205
205
|
context_schema=Context,
|
|
206
206
|
)
|
|
207
207
|
|
dao_ai/memory/core.py
CHANGED
|
@@ -99,13 +99,13 @@ class CheckpointManager:
|
|
|
99
99
|
checkpointer_manager
|
|
100
100
|
)
|
|
101
101
|
case StorageType.POSTGRES:
|
|
102
|
-
from dao_ai.memory.postgres import
|
|
102
|
+
from dao_ai.memory.postgres import AsyncPostgresCheckpointerManager
|
|
103
103
|
|
|
104
104
|
checkpointer_manager = cls.checkpoint_managers.get(
|
|
105
105
|
checkpointer_model.database.name
|
|
106
106
|
)
|
|
107
107
|
if checkpointer_manager is None:
|
|
108
|
-
checkpointer_manager =
|
|
108
|
+
checkpointer_manager = AsyncPostgresCheckpointerManager(
|
|
109
109
|
checkpointer_model
|
|
110
110
|
)
|
|
111
111
|
cls.checkpoint_managers[checkpointer_model.database.name] = (
|
dao_ai/memory/postgres.py
CHANGED
|
@@ -20,6 +20,137 @@ from dao_ai.memory.base import (
|
|
|
20
20
|
)
|
|
21
21
|
|
|
22
22
|
|
|
23
|
+
class PatchedAsyncPostgresStore(AsyncPostgresStore):
|
|
24
|
+
"""
|
|
25
|
+
Patched version of AsyncPostgresStore that properly handles event loop initialization
|
|
26
|
+
and task lifecycle management.
|
|
27
|
+
|
|
28
|
+
The issues occur because:
|
|
29
|
+
1. AsyncBatchedBaseStore.__init__ calls asyncio.get_running_loop() and fails if no event loop is running
|
|
30
|
+
2. The background _task can complete/fail, causing assertions in asearch/other methods to fail
|
|
31
|
+
3. Destructor tries to access _task even when it doesn't exist
|
|
32
|
+
|
|
33
|
+
This patch ensures proper initialization and handles task lifecycle robustly.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, *args, **kwargs):
|
|
37
|
+
# Ensure we have a running event loop before calling super().__init__()
|
|
38
|
+
loop = None
|
|
39
|
+
try:
|
|
40
|
+
loop = asyncio.get_running_loop()
|
|
41
|
+
except RuntimeError:
|
|
42
|
+
# No running loop - create one temporarily for initialization
|
|
43
|
+
loop = asyncio.new_event_loop()
|
|
44
|
+
asyncio.set_event_loop(loop)
|
|
45
|
+
|
|
46
|
+
try:
|
|
47
|
+
super().__init__(*args, **kwargs)
|
|
48
|
+
except Exception as e:
|
|
49
|
+
# If parent initialization fails, ensure _task is at least defined
|
|
50
|
+
if not hasattr(self, "_task"):
|
|
51
|
+
self._task = None
|
|
52
|
+
logger.warning(f"AsyncPostgresStore initialization failed: {e}")
|
|
53
|
+
raise
|
|
54
|
+
|
|
55
|
+
def _ensure_task_running(self):
|
|
56
|
+
"""
|
|
57
|
+
Ensure the background task is running. Recreate it if necessary.
|
|
58
|
+
"""
|
|
59
|
+
if not hasattr(self, "_task") or self._task is None:
|
|
60
|
+
logger.error("AsyncPostgresStore task not initialized")
|
|
61
|
+
raise RuntimeError("Store task not properly initialized")
|
|
62
|
+
|
|
63
|
+
if self._task.done():
|
|
64
|
+
logger.warning(
|
|
65
|
+
"AsyncPostgresStore background task completed, attempting to restart"
|
|
66
|
+
)
|
|
67
|
+
# Try to get the task exception for debugging
|
|
68
|
+
try:
|
|
69
|
+
exception = self._task.exception()
|
|
70
|
+
if exception:
|
|
71
|
+
logger.error(f"Background task failed with: {exception}")
|
|
72
|
+
else:
|
|
73
|
+
logger.info("Background task completed normally")
|
|
74
|
+
except Exception as e:
|
|
75
|
+
logger.warning(f"Could not determine task completion reason: {e}")
|
|
76
|
+
|
|
77
|
+
# Try to restart the task
|
|
78
|
+
try:
|
|
79
|
+
import weakref
|
|
80
|
+
|
|
81
|
+
from langgraph.store.base.batch import _run
|
|
82
|
+
|
|
83
|
+
self._task = self._loop.create_task(
|
|
84
|
+
_run(self._aqueue, weakref.ref(self))
|
|
85
|
+
)
|
|
86
|
+
logger.info("Successfully restarted AsyncPostgresStore background task")
|
|
87
|
+
except Exception as e:
|
|
88
|
+
logger.error(f"Failed to restart background task: {e}")
|
|
89
|
+
raise RuntimeError(
|
|
90
|
+
f"Store background task failed and could not be restarted: {e}"
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
async def asearch(
|
|
94
|
+
self,
|
|
95
|
+
namespace_prefix,
|
|
96
|
+
/,
|
|
97
|
+
*,
|
|
98
|
+
query=None,
|
|
99
|
+
filter=None,
|
|
100
|
+
limit=10,
|
|
101
|
+
offset=0,
|
|
102
|
+
refresh_ttl=None,
|
|
103
|
+
):
|
|
104
|
+
"""
|
|
105
|
+
Override asearch to handle task lifecycle issues gracefully.
|
|
106
|
+
"""
|
|
107
|
+
self._ensure_task_running()
|
|
108
|
+
|
|
109
|
+
# Call parent implementation if task is healthy
|
|
110
|
+
return await super().asearch(
|
|
111
|
+
namespace_prefix,
|
|
112
|
+
query=query,
|
|
113
|
+
filter=filter,
|
|
114
|
+
limit=limit,
|
|
115
|
+
offset=offset,
|
|
116
|
+
refresh_ttl=refresh_ttl,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
async def aget(self, namespace, key, /, *, refresh_ttl=None):
|
|
120
|
+
"""Override aget with task lifecycle management."""
|
|
121
|
+
self._ensure_task_running()
|
|
122
|
+
return await super().aget(namespace, key, refresh_ttl=refresh_ttl)
|
|
123
|
+
|
|
124
|
+
async def aput(self, namespace, key, value, /, *, refresh_ttl=None):
|
|
125
|
+
"""Override aput with task lifecycle management."""
|
|
126
|
+
self._ensure_task_running()
|
|
127
|
+
return await super().aput(namespace, key, value, refresh_ttl=refresh_ttl)
|
|
128
|
+
|
|
129
|
+
async def adelete(self, namespace, key):
|
|
130
|
+
"""Override adelete with task lifecycle management."""
|
|
131
|
+
self._ensure_task_running()
|
|
132
|
+
return await super().adelete(namespace, key)
|
|
133
|
+
|
|
134
|
+
async def alist_namespaces(self, *, prefix=None):
|
|
135
|
+
"""Override alist_namespaces with task lifecycle management."""
|
|
136
|
+
self._ensure_task_running()
|
|
137
|
+
return await super().alist_namespaces(prefix=prefix)
|
|
138
|
+
|
|
139
|
+
def __del__(self):
|
|
140
|
+
"""
|
|
141
|
+
Override destructor to handle missing _task attribute gracefully.
|
|
142
|
+
"""
|
|
143
|
+
try:
|
|
144
|
+
# Only try to cancel if _task exists and is not None
|
|
145
|
+
if hasattr(self, "_task") and self._task is not None:
|
|
146
|
+
if not self._task.done():
|
|
147
|
+
self._task.cancel()
|
|
148
|
+
except Exception as e:
|
|
149
|
+
# Log but don't raise - destructors should not raise exceptions
|
|
150
|
+
logger.debug(f"AsyncPostgresStore destructor cleanup: {e}")
|
|
151
|
+
pass
|
|
152
|
+
|
|
153
|
+
|
|
23
154
|
class AsyncPostgresPoolManager:
|
|
24
155
|
_pools: dict[str, AsyncConnectionPool] = {}
|
|
25
156
|
_lock: asyncio.Lock = asyncio.Lock()
|
|
@@ -119,8 +250,9 @@ class AsyncPostgresStoreManager(StoreManagerBase):
|
|
|
119
250
|
self.store_model.database
|
|
120
251
|
)
|
|
121
252
|
|
|
122
|
-
# Create store with the shared pool
|
|
123
|
-
self._store =
|
|
253
|
+
# Create store with the shared pool (using patched version)
|
|
254
|
+
self._store = PatchedAsyncPostgresStore(conn=self.pool)
|
|
255
|
+
|
|
124
256
|
await self._store.setup()
|
|
125
257
|
|
|
126
258
|
self._setup_complete = True
|
dao_ai/messages.py
CHANGED
|
@@ -11,6 +11,9 @@ from langchain_core.messages import (
|
|
|
11
11
|
)
|
|
12
12
|
from langchain_core.messages.modifier import RemoveMessage
|
|
13
13
|
from mlflow.types.llm import ChatMessage
|
|
14
|
+
from mlflow.types.responses import (
|
|
15
|
+
ResponsesAgentRequest,
|
|
16
|
+
)
|
|
14
17
|
|
|
15
18
|
|
|
16
19
|
def remove_messages(
|
|
@@ -96,6 +99,10 @@ def has_mlflow_messages(messages: ChatMessage | Sequence[ChatMessage]) -> bool:
|
|
|
96
99
|
return any(isinstance(m, ChatMessage) for m in messages)
|
|
97
100
|
|
|
98
101
|
|
|
102
|
+
def has_mlflow_responses_messages(messages: ResponsesAgentRequest) -> bool:
|
|
103
|
+
return isinstance(messages, ResponsesAgentRequest)
|
|
104
|
+
|
|
105
|
+
|
|
99
106
|
def has_image(messages: BaseMessage | Sequence[BaseMessage]) -> bool:
|
|
100
107
|
"""
|
|
101
108
|
Check if a message contains an image.
|
dao_ai/models.py
CHANGED
|
@@ -1,13 +1,14 @@
|
|
|
1
1
|
import uuid
|
|
2
2
|
from os import PathLike
|
|
3
3
|
from pathlib import Path
|
|
4
|
-
from typing import Any, Generator, Optional, Sequence
|
|
4
|
+
from typing import Any, Generator, Optional, Sequence, Union
|
|
5
5
|
|
|
6
6
|
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
|
|
7
7
|
from langgraph.graph.state import CompiledStateGraph
|
|
8
8
|
from loguru import logger
|
|
9
9
|
from mlflow import MlflowClient
|
|
10
|
-
from mlflow.pyfunc import ChatAgent, ChatModel
|
|
10
|
+
from mlflow.pyfunc import ChatAgent, ChatModel, ResponsesAgent
|
|
11
|
+
from mlflow.types.agent import ChatContext
|
|
11
12
|
from mlflow.types.llm import (
|
|
12
13
|
ChatChoice,
|
|
13
14
|
ChatChoiceDelta,
|
|
@@ -17,8 +18,21 @@ from mlflow.types.llm import (
|
|
|
17
18
|
ChatMessage,
|
|
18
19
|
ChatParams,
|
|
19
20
|
)
|
|
21
|
+
from mlflow.types.responses import (
|
|
22
|
+
ResponsesAgentRequest,
|
|
23
|
+
ResponsesAgentResponse,
|
|
24
|
+
ResponsesAgentStreamEvent,
|
|
25
|
+
)
|
|
26
|
+
from mlflow.types.responses_helpers import (
|
|
27
|
+
Message,
|
|
28
|
+
ResponseInputTextParam,
|
|
29
|
+
)
|
|
20
30
|
|
|
21
|
-
from dao_ai.messages import
|
|
31
|
+
from dao_ai.messages import (
|
|
32
|
+
has_langchain_messages,
|
|
33
|
+
has_mlflow_messages,
|
|
34
|
+
has_mlflow_responses_messages,
|
|
35
|
+
)
|
|
22
36
|
from dao_ai.state import Context
|
|
23
37
|
|
|
24
38
|
|
|
@@ -185,6 +199,266 @@ class LanggraphChatModel(ChatModel):
|
|
|
185
199
|
return [m.to_dict() for m in messages]
|
|
186
200
|
|
|
187
201
|
|
|
202
|
+
class LanggraphResponsesAgent(ResponsesAgent):
|
|
203
|
+
"""
|
|
204
|
+
ResponsesAgent that delegates requests to a LangGraph CompiledStateGraph.
|
|
205
|
+
|
|
206
|
+
This is the modern replacement for LanggraphChatModel, providing better
|
|
207
|
+
support for streaming, tool calling, and async execution.
|
|
208
|
+
"""
|
|
209
|
+
|
|
210
|
+
def __init__(self, graph: CompiledStateGraph) -> None:
|
|
211
|
+
self.graph = graph
|
|
212
|
+
|
|
213
|
+
def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:
|
|
214
|
+
"""
|
|
215
|
+
Process a ResponsesAgentRequest and return a ResponsesAgentResponse.
|
|
216
|
+
"""
|
|
217
|
+
logger.debug(f"ResponsesAgent request: {request}")
|
|
218
|
+
|
|
219
|
+
# Convert ResponsesAgent input to LangChain messages
|
|
220
|
+
messages = self._convert_request_to_langchain_messages(request)
|
|
221
|
+
|
|
222
|
+
# Prepare context
|
|
223
|
+
context: Context = self._convert_request_to_context(request)
|
|
224
|
+
custom_inputs: dict[str, Any] = {"configurable": context.model_dump()}
|
|
225
|
+
|
|
226
|
+
# Use async ainvoke internally for parallel execution
|
|
227
|
+
import asyncio
|
|
228
|
+
|
|
229
|
+
async def _async_invoke():
|
|
230
|
+
try:
|
|
231
|
+
return await self.graph.ainvoke(
|
|
232
|
+
{"messages": messages}, context=context, config=custom_inputs
|
|
233
|
+
)
|
|
234
|
+
except Exception as e:
|
|
235
|
+
logger.error(f"Error in graph.ainvoke: {e}")
|
|
236
|
+
raise
|
|
237
|
+
|
|
238
|
+
try:
|
|
239
|
+
loop = asyncio.get_event_loop()
|
|
240
|
+
except RuntimeError:
|
|
241
|
+
# Handle case where no event loop exists (common in some deployment scenarios)
|
|
242
|
+
loop = asyncio.new_event_loop()
|
|
243
|
+
asyncio.set_event_loop(loop)
|
|
244
|
+
|
|
245
|
+
try:
|
|
246
|
+
response: dict[str, Sequence[BaseMessage]] = loop.run_until_complete(
|
|
247
|
+
_async_invoke()
|
|
248
|
+
)
|
|
249
|
+
except Exception as e:
|
|
250
|
+
logger.error(f"Error in async execution: {e}")
|
|
251
|
+
raise
|
|
252
|
+
|
|
253
|
+
# Convert response to ResponsesAgent format
|
|
254
|
+
last_message: BaseMessage = response["messages"][-1]
|
|
255
|
+
|
|
256
|
+
output_item = self.create_text_output_item(
|
|
257
|
+
text=last_message.content, id=f"msg_{uuid.uuid4().hex[:8]}"
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
custom_outputs = custom_inputs
|
|
261
|
+
return ResponsesAgentResponse(
|
|
262
|
+
output=[output_item], custom_outputs=custom_outputs
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
def predict_stream(
|
|
266
|
+
self, request: ResponsesAgentRequest
|
|
267
|
+
) -> Generator[ResponsesAgentStreamEvent, None, None]:
|
|
268
|
+
"""
|
|
269
|
+
Process a ResponsesAgentRequest and yield ResponsesAgentStreamEvent objects.
|
|
270
|
+
"""
|
|
271
|
+
logger.debug(f"ResponsesAgent stream request: {request}")
|
|
272
|
+
|
|
273
|
+
# Convert ResponsesAgent input to LangChain messages
|
|
274
|
+
messages: list[BaseMessage] = self._convert_request_to_langchain_messages(
|
|
275
|
+
request
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
# Prepare context
|
|
279
|
+
context: Context = self._convert_request_to_context(request)
|
|
280
|
+
custom_inputs: dict[str, Any] = {"configurable": context.model_dump()}
|
|
281
|
+
|
|
282
|
+
# Use async astream internally for parallel execution
|
|
283
|
+
import asyncio
|
|
284
|
+
|
|
285
|
+
async def _async_stream():
|
|
286
|
+
item_id = f"msg_{uuid.uuid4().hex[:8]}"
|
|
287
|
+
accumulated_content = ""
|
|
288
|
+
|
|
289
|
+
try:
|
|
290
|
+
async for nodes, stream_mode, messages_batch in self.graph.astream(
|
|
291
|
+
{"messages": messages},
|
|
292
|
+
context=context,
|
|
293
|
+
config=custom_inputs,
|
|
294
|
+
stream_mode=["messages", "custom"],
|
|
295
|
+
subgraphs=True,
|
|
296
|
+
):
|
|
297
|
+
nodes: tuple[str, ...]
|
|
298
|
+
stream_mode: str
|
|
299
|
+
messages_batch: Sequence[BaseMessage]
|
|
300
|
+
|
|
301
|
+
for message in messages_batch:
|
|
302
|
+
if (
|
|
303
|
+
isinstance(
|
|
304
|
+
message,
|
|
305
|
+
(
|
|
306
|
+
AIMessageChunk,
|
|
307
|
+
AIMessage,
|
|
308
|
+
),
|
|
309
|
+
)
|
|
310
|
+
and message.content
|
|
311
|
+
and "summarization" not in nodes
|
|
312
|
+
):
|
|
313
|
+
content = message.content
|
|
314
|
+
accumulated_content += content
|
|
315
|
+
|
|
316
|
+
# Yield streaming delta
|
|
317
|
+
yield ResponsesAgentStreamEvent(
|
|
318
|
+
**self.create_text_delta(delta=content, item_id=item_id)
|
|
319
|
+
)
|
|
320
|
+
|
|
321
|
+
custom_outputs = custom_inputs
|
|
322
|
+
# Yield final output item
|
|
323
|
+
yield ResponsesAgentStreamEvent(
|
|
324
|
+
type="response.output_item.done",
|
|
325
|
+
item=self.create_text_output_item(
|
|
326
|
+
text=accumulated_content, id=item_id
|
|
327
|
+
),
|
|
328
|
+
custom_outputs=custom_outputs,
|
|
329
|
+
)
|
|
330
|
+
except Exception as e:
|
|
331
|
+
logger.error(f"Error in graph.astream: {e}")
|
|
332
|
+
raise
|
|
333
|
+
|
|
334
|
+
# Convert async generator to sync generator
|
|
335
|
+
try:
|
|
336
|
+
loop = asyncio.get_event_loop()
|
|
337
|
+
except RuntimeError:
|
|
338
|
+
# Handle case where no event loop exists (common in some deployment scenarios)
|
|
339
|
+
loop = asyncio.new_event_loop()
|
|
340
|
+
asyncio.set_event_loop(loop)
|
|
341
|
+
|
|
342
|
+
async_gen = _async_stream()
|
|
343
|
+
|
|
344
|
+
try:
|
|
345
|
+
while True:
|
|
346
|
+
try:
|
|
347
|
+
item = loop.run_until_complete(async_gen.__anext__())
|
|
348
|
+
yield item
|
|
349
|
+
except StopAsyncIteration:
|
|
350
|
+
break
|
|
351
|
+
except Exception as e:
|
|
352
|
+
logger.error(f"Error in streaming: {e}")
|
|
353
|
+
raise
|
|
354
|
+
finally:
|
|
355
|
+
try:
|
|
356
|
+
loop.run_until_complete(async_gen.aclose())
|
|
357
|
+
except Exception as e:
|
|
358
|
+
logger.warning(f"Error closing async generator: {e}")
|
|
359
|
+
|
|
360
|
+
def _extract_text_from_content(
|
|
361
|
+
self,
|
|
362
|
+
content: Union[str, list[Union[ResponseInputTextParam, str, dict[str, Any]]]],
|
|
363
|
+
) -> str:
|
|
364
|
+
"""Extract text content from various MLflow content formats.
|
|
365
|
+
|
|
366
|
+
MLflow ResponsesAgent supports multiple content formats:
|
|
367
|
+
- str: Simple text content
|
|
368
|
+
- list[ResponseInputTextParam]: Structured text objects with .text attribute
|
|
369
|
+
- list[dict]: Dictionaries with "text" key
|
|
370
|
+
- Mixed lists of the above types
|
|
371
|
+
|
|
372
|
+
This method normalizes all formats to a single concatenated string.
|
|
373
|
+
|
|
374
|
+
Args:
|
|
375
|
+
content: The content to extract text from
|
|
376
|
+
|
|
377
|
+
Returns:
|
|
378
|
+
Concatenated text string from all content items
|
|
379
|
+
"""
|
|
380
|
+
if isinstance(content, str):
|
|
381
|
+
return content
|
|
382
|
+
elif isinstance(content, list):
|
|
383
|
+
text_parts = []
|
|
384
|
+
for content_item in content:
|
|
385
|
+
if isinstance(content_item, ResponseInputTextParam):
|
|
386
|
+
text_parts.append(content_item.text)
|
|
387
|
+
elif isinstance(content_item, str):
|
|
388
|
+
text_parts.append(content_item)
|
|
389
|
+
elif isinstance(content_item, dict) and "text" in content_item:
|
|
390
|
+
text_parts.append(content_item["text"])
|
|
391
|
+
return "".join(text_parts)
|
|
392
|
+
else:
|
|
393
|
+
# Fallback for unknown types - try to extract text attribute
|
|
394
|
+
return getattr(content, "text", str(content))
|
|
395
|
+
|
|
396
|
+
def _convert_request_to_langchain_messages(
|
|
397
|
+
self, request: ResponsesAgentRequest
|
|
398
|
+
) -> list[dict[str, Any]]:
|
|
399
|
+
"""Convert ResponsesAgent input to LangChain message format."""
|
|
400
|
+
messages = []
|
|
401
|
+
|
|
402
|
+
for input_item in request.input:
|
|
403
|
+
if isinstance(input_item, Message):
|
|
404
|
+
# Handle MLflow Message objects
|
|
405
|
+
content = self._extract_text_from_content(input_item.content)
|
|
406
|
+
messages.append({"role": input_item.role, "content": content})
|
|
407
|
+
elif isinstance(input_item, dict):
|
|
408
|
+
# Handle dict format
|
|
409
|
+
if "role" in input_item and "content" in input_item:
|
|
410
|
+
content = self._extract_text_from_content(input_item["content"])
|
|
411
|
+
messages.append({"role": input_item["role"], "content": content})
|
|
412
|
+
else:
|
|
413
|
+
# Fallback for other object types with role/content attributes
|
|
414
|
+
role = getattr(input_item, "role", "user")
|
|
415
|
+
content = self._extract_text_from_content(
|
|
416
|
+
getattr(input_item, "content", "")
|
|
417
|
+
)
|
|
418
|
+
messages.append({"role": role, "content": content})
|
|
419
|
+
|
|
420
|
+
return messages
|
|
421
|
+
|
|
422
|
+
def _convert_request_to_context(self, request: ResponsesAgentRequest) -> Context:
|
|
423
|
+
"""Convert ResponsesAgent context to internal Context."""
|
|
424
|
+
|
|
425
|
+
logger.debug(f"request.context: {request.context}")
|
|
426
|
+
logger.debug(f"request.custom_inputs: {request.custom_inputs}")
|
|
427
|
+
|
|
428
|
+
configurable: dict[str, Any] = {}
|
|
429
|
+
|
|
430
|
+
# Process context values first (lower priority)
|
|
431
|
+
# Use strong typing with forward-declared type hints instead of hasattr checks
|
|
432
|
+
chat_context: Optional[ChatContext] = request.context
|
|
433
|
+
if chat_context is not None:
|
|
434
|
+
conversation_id: Optional[str] = chat_context.conversation_id
|
|
435
|
+
user_id: Optional[str] = chat_context.user_id
|
|
436
|
+
|
|
437
|
+
if conversation_id is not None:
|
|
438
|
+
configurable["conversation_id"] = conversation_id
|
|
439
|
+
configurable["thread_id"] = conversation_id
|
|
440
|
+
|
|
441
|
+
if user_id is not None:
|
|
442
|
+
configurable["user_id"] = user_id
|
|
443
|
+
|
|
444
|
+
# Process custom_inputs after context so they can override context values (higher priority)
|
|
445
|
+
if request.custom_inputs:
|
|
446
|
+
if "configurable" in request.custom_inputs:
|
|
447
|
+
configurable.update(request.custom_inputs.pop("configurable"))
|
|
448
|
+
|
|
449
|
+
configurable.update(request.custom_inputs)
|
|
450
|
+
|
|
451
|
+
if "user_id" in configurable:
|
|
452
|
+
configurable["user_id"] = configurable["user_id"].replace(".", "_")
|
|
453
|
+
|
|
454
|
+
if "thread_id" not in configurable:
|
|
455
|
+
configurable["thread_id"] = str(uuid.uuid4())
|
|
456
|
+
|
|
457
|
+
logger.debug(f"Creating context from: {configurable}")
|
|
458
|
+
|
|
459
|
+
return Context(**configurable)
|
|
460
|
+
|
|
461
|
+
|
|
188
462
|
def create_agent(graph: CompiledStateGraph) -> ChatAgent:
|
|
189
463
|
"""
|
|
190
464
|
Create an MLflow-compatible ChatAgent from a LangGraph state machine.
|
|
@@ -201,6 +475,22 @@ def create_agent(graph: CompiledStateGraph) -> ChatAgent:
|
|
|
201
475
|
return LanggraphChatModel(graph)
|
|
202
476
|
|
|
203
477
|
|
|
478
|
+
def create_responses_agent(graph: CompiledStateGraph) -> ResponsesAgent:
|
|
479
|
+
"""
|
|
480
|
+
Create an MLflow-compatible ResponsesAgent from a LangGraph state machine.
|
|
481
|
+
|
|
482
|
+
Factory function that wraps a compiled LangGraph in the LanggraphResponsesAgent
|
|
483
|
+
class to make it deployable through MLflow.
|
|
484
|
+
|
|
485
|
+
Args:
|
|
486
|
+
graph: A compiled LangGraph state machine
|
|
487
|
+
|
|
488
|
+
Returns:
|
|
489
|
+
An MLflow-compatible ResponsesAgent instance
|
|
490
|
+
"""
|
|
491
|
+
return LanggraphResponsesAgent(graph)
|
|
492
|
+
|
|
493
|
+
|
|
204
494
|
def _process_langchain_messages(
|
|
205
495
|
app: LanggraphChatModel | CompiledStateGraph,
|
|
206
496
|
messages: Sequence[BaseMessage],
|
|
@@ -266,7 +556,14 @@ def _process_langchain_messages_stream(
|
|
|
266
556
|
yield message
|
|
267
557
|
|
|
268
558
|
# Convert async generator to sync generator
|
|
269
|
-
|
|
559
|
+
|
|
560
|
+
try:
|
|
561
|
+
loop = asyncio.get_event_loop()
|
|
562
|
+
except RuntimeError:
|
|
563
|
+
# Handle case where no event loop exists (common in some deployment scenarios)
|
|
564
|
+
loop = asyncio.new_event_loop()
|
|
565
|
+
asyncio.set_event_loop(loop)
|
|
566
|
+
|
|
270
567
|
async_gen = _async_stream()
|
|
271
568
|
|
|
272
569
|
try:
|
|
@@ -288,6 +585,14 @@ def _process_mlflow_messages(
|
|
|
288
585
|
return app.predict(None, messages, custom_inputs)
|
|
289
586
|
|
|
290
587
|
|
|
588
|
+
def _process_mlflow_response_messages(
|
|
589
|
+
app: ResponsesAgent,
|
|
590
|
+
messages: ResponsesAgentRequest,
|
|
591
|
+
) -> ResponsesAgentResponse:
|
|
592
|
+
"""Process MLflow ResponsesAgent request in batch mode."""
|
|
593
|
+
return app.predict(messages)
|
|
594
|
+
|
|
595
|
+
|
|
291
596
|
def _process_mlflow_messages_stream(
|
|
292
597
|
app: ChatModel,
|
|
293
598
|
messages: Sequence[ChatMessage],
|
|
@@ -298,32 +603,68 @@ def _process_mlflow_messages_stream(
|
|
|
298
603
|
yield event
|
|
299
604
|
|
|
300
605
|
|
|
606
|
+
def _process_mlflow_response_messages_stream(
|
|
607
|
+
app: ResponsesAgent,
|
|
608
|
+
messages: ResponsesAgentRequest,
|
|
609
|
+
) -> Generator[ResponsesAgentStreamEvent, None, None]:
|
|
610
|
+
"""Process MLflow ResponsesAgent request in streaming mode."""
|
|
611
|
+
for event in app.predict_stream(messages):
|
|
612
|
+
event: ResponsesAgentStreamEvent
|
|
613
|
+
yield event
|
|
614
|
+
|
|
615
|
+
|
|
301
616
|
def _process_config_messages(
|
|
302
|
-
app:
|
|
617
|
+
app: LanggraphChatModel | LanggraphResponsesAgent,
|
|
303
618
|
messages: dict[str, Any],
|
|
304
619
|
custom_inputs: Optional[dict[str, Any]] = None,
|
|
305
|
-
) -> ChatCompletionResponse:
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
620
|
+
) -> ChatCompletionResponse | ResponsesAgentResponse:
|
|
621
|
+
if isinstance(app, LanggraphChatModel):
|
|
622
|
+
messages: Sequence[ChatMessage] = [ChatMessage(**m) for m in messages]
|
|
623
|
+
params: ChatParams = ChatParams(**{"custom_inputs": custom_inputs})
|
|
624
|
+
return _process_mlflow_messages(app, messages, params)
|
|
625
|
+
|
|
626
|
+
elif isinstance(app, LanggraphResponsesAgent):
|
|
627
|
+
input_messages: list[Message] = [Message(**m) for m in messages]
|
|
628
|
+
request = ResponsesAgentRequest(
|
|
629
|
+
input=input_messages, custom_inputs=custom_inputs
|
|
630
|
+
)
|
|
631
|
+
return _process_mlflow_response_messages(app, request)
|
|
310
632
|
|
|
311
633
|
|
|
312
634
|
def _process_config_messages_stream(
|
|
313
|
-
app:
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
635
|
+
app: LanggraphChatModel | LanggraphResponsesAgent,
|
|
636
|
+
messages: dict[str, Any],
|
|
637
|
+
custom_inputs: dict[str, Any],
|
|
638
|
+
) -> Generator[ChatCompletionChunk | ResponsesAgentStreamEvent, None, None]:
|
|
639
|
+
if isinstance(app, LanggraphChatModel):
|
|
640
|
+
messages: Sequence[ChatMessage] = [ChatMessage(**m) for m in messages]
|
|
641
|
+
params: ChatParams = ChatParams(**{"custom_inputs": custom_inputs})
|
|
317
642
|
|
|
318
|
-
|
|
319
|
-
|
|
643
|
+
for event in _process_mlflow_messages_stream(
|
|
644
|
+
app, messages, custom_inputs=params
|
|
645
|
+
):
|
|
646
|
+
yield event
|
|
647
|
+
|
|
648
|
+
elif isinstance(app, LanggraphResponsesAgent):
|
|
649
|
+
input_messages: list[Message] = [Message(**m) for m in messages]
|
|
650
|
+
request = ResponsesAgentRequest(
|
|
651
|
+
input=input_messages, custom_inputs=custom_inputs
|
|
652
|
+
)
|
|
653
|
+
|
|
654
|
+
for event in _process_mlflow_response_messages_stream(app, request):
|
|
655
|
+
yield event
|
|
320
656
|
|
|
321
657
|
|
|
322
658
|
def process_messages_stream(
|
|
323
|
-
app: LanggraphChatModel,
|
|
324
|
-
messages: Sequence[BaseMessage]
|
|
659
|
+
app: LanggraphChatModel | LanggraphResponsesAgent,
|
|
660
|
+
messages: Sequence[BaseMessage]
|
|
661
|
+
| Sequence[ChatMessage]
|
|
662
|
+
| ResponsesAgentRequest
|
|
663
|
+
| dict[str, Any],
|
|
325
664
|
custom_inputs: Optional[dict[str, Any]] = None,
|
|
326
|
-
) -> Generator[
|
|
665
|
+
) -> Generator[
|
|
666
|
+
ChatCompletionChunk | ResponsesAgentStreamEvent | AIMessageChunk, None, None
|
|
667
|
+
]:
|
|
327
668
|
"""
|
|
328
669
|
Process messages through a ChatAgent in streaming mode.
|
|
329
670
|
|
|
@@ -338,7 +679,10 @@ def process_messages_stream(
|
|
|
338
679
|
Individual message chunks from the streaming response
|
|
339
680
|
"""
|
|
340
681
|
|
|
341
|
-
if
|
|
682
|
+
if has_mlflow_responses_messages(messages):
|
|
683
|
+
for event in _process_mlflow_response_messages_stream(app, messages):
|
|
684
|
+
yield event
|
|
685
|
+
elif has_mlflow_messages(messages):
|
|
342
686
|
for event in _process_mlflow_messages_stream(app, messages, custom_inputs):
|
|
343
687
|
yield event
|
|
344
688
|
elif has_langchain_messages(messages):
|
|
@@ -350,10 +694,13 @@ def process_messages_stream(
|
|
|
350
694
|
|
|
351
695
|
|
|
352
696
|
def process_messages(
|
|
353
|
-
app: LanggraphChatModel,
|
|
354
|
-
messages: Sequence[BaseMessage]
|
|
697
|
+
app: LanggraphChatModel | LanggraphResponsesAgent,
|
|
698
|
+
messages: Sequence[BaseMessage]
|
|
699
|
+
| Sequence[ChatMessage]
|
|
700
|
+
| ResponsesAgentRequest
|
|
701
|
+
| dict[str, Any],
|
|
355
702
|
custom_inputs: Optional[dict[str, Any]] = None,
|
|
356
|
-
) -> ChatCompletionResponse | dict[str, Any] | Any:
|
|
703
|
+
) -> ChatCompletionResponse | ResponsesAgentResponse | dict[str, Any] | Any:
|
|
357
704
|
"""
|
|
358
705
|
Process messages through a ChatAgent in batch mode.
|
|
359
706
|
|
|
@@ -368,7 +715,9 @@ def process_messages(
|
|
|
368
715
|
Complete response from the agent
|
|
369
716
|
"""
|
|
370
717
|
|
|
371
|
-
if
|
|
718
|
+
if has_mlflow_responses_messages(messages):
|
|
719
|
+
return _process_mlflow_response_messages(app, messages)
|
|
720
|
+
elif has_mlflow_messages(messages):
|
|
372
721
|
return _process_mlflow_messages(app, messages, custom_inputs)
|
|
373
722
|
elif has_langchain_messages(messages):
|
|
374
723
|
return _process_langchain_messages(app, messages, custom_inputs)
|
dao_ai/providers/databricks.py
CHANGED
|
@@ -42,6 +42,7 @@ import dao_ai
|
|
|
42
42
|
from dao_ai.config import (
|
|
43
43
|
AppConfig,
|
|
44
44
|
ConnectionModel,
|
|
45
|
+
DatabaseModel,
|
|
45
46
|
DatasetModel,
|
|
46
47
|
FunctionModel,
|
|
47
48
|
GenieRoomModel,
|
|
@@ -224,6 +225,7 @@ class DatabricksProvider(ServiceProvider):
|
|
|
224
225
|
connections: Sequence[ConnectionModel] = list(
|
|
225
226
|
config.resources.connections.values()
|
|
226
227
|
)
|
|
228
|
+
databases: Sequence[DatabaseModel] = list(config.resources.databases.values())
|
|
227
229
|
|
|
228
230
|
resources: Sequence[IsDatabricksResource] = (
|
|
229
231
|
llms
|
|
@@ -233,14 +235,19 @@ class DatabricksProvider(ServiceProvider):
|
|
|
233
235
|
+ functions
|
|
234
236
|
+ tables
|
|
235
237
|
+ connections
|
|
238
|
+
+ databases
|
|
236
239
|
)
|
|
237
240
|
|
|
238
|
-
#
|
|
239
|
-
|
|
240
|
-
|
|
241
|
+
# Flatten all resources from all models into a single list
|
|
242
|
+
all_resources: list[DatabricksResource] = []
|
|
243
|
+
for r in resources:
|
|
244
|
+
all_resources.extend(r.as_resources())
|
|
241
245
|
|
|
242
246
|
system_resources: Sequence[DatabricksResource] = [
|
|
243
|
-
|
|
247
|
+
resource
|
|
248
|
+
for r in resources
|
|
249
|
+
for resource in r.as_resources()
|
|
250
|
+
if not r.on_behalf_of_user
|
|
244
251
|
]
|
|
245
252
|
logger.debug(f"system_resources: {[r.name for r in system_resources]}")
|
|
246
253
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: dao-ai
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.19
|
|
4
4
|
Summary: DAO AI: A modular, multi-agent orchestration framework for complex AI workflows. Supports agent handoff, tool integration, and dynamic configuration via YAML.
|
|
5
5
|
Project-URL: Homepage, https://github.com/natefleming/dao-ai
|
|
6
6
|
Project-URL: Documentation, https://natefleming.github.io/dao-ai
|
|
@@ -24,22 +24,22 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
|
24
24
|
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
25
25
|
Classifier: Topic :: System :: Distributed Computing
|
|
26
26
|
Requires-Python: >=3.12
|
|
27
|
-
Requires-Dist: databricks-agents>=1.
|
|
28
|
-
Requires-Dist: databricks-langchain>=0.
|
|
29
|
-
Requires-Dist: databricks-sdk[openai]>=0.
|
|
27
|
+
Requires-Dist: databricks-agents>=1.6.0
|
|
28
|
+
Requires-Dist: databricks-langchain>=0.8.0
|
|
29
|
+
Requires-Dist: databricks-sdk[openai]>=0.66.0
|
|
30
30
|
Requires-Dist: duckduckgo-search>=8.0.2
|
|
31
31
|
Requires-Dist: grandalf>=0.8
|
|
32
|
-
Requires-Dist: langchain-mcp-adapters>=0.1.
|
|
32
|
+
Requires-Dist: langchain-mcp-adapters>=0.1.10
|
|
33
33
|
Requires-Dist: langchain-tavily>=0.2.11
|
|
34
34
|
Requires-Dist: langchain>=0.3.27
|
|
35
35
|
Requires-Dist: langgraph-checkpoint-postgres>=2.0.23
|
|
36
36
|
Requires-Dist: langgraph-supervisor>=0.0.29
|
|
37
37
|
Requires-Dist: langgraph-swarm>=0.0.14
|
|
38
|
-
Requires-Dist: langgraph>=0.6.
|
|
38
|
+
Requires-Dist: langgraph>=0.6.7
|
|
39
39
|
Requires-Dist: langmem>=0.0.29
|
|
40
40
|
Requires-Dist: loguru>=0.7.3
|
|
41
|
-
Requires-Dist: mcp>=1.
|
|
42
|
-
Requires-Dist: mlflow>=3.
|
|
41
|
+
Requires-Dist: mcp>=1.14.1
|
|
42
|
+
Requires-Dist: mlflow>=3.4.0
|
|
43
43
|
Requires-Dist: nest-asyncio>=1.6.0
|
|
44
44
|
Requires-Dist: openevals>=0.0.19
|
|
45
45
|
Requires-Dist: openpyxl>=3.1.5
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
dao_ai/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
dao_ai/agent_as_code.py,sha256=
|
|
2
|
+
dao_ai/agent_as_code.py,sha256=kPSeDz2-1jRaed1TMs4LA3VECoyqe9_Ed2beRLB9gXQ,472
|
|
3
3
|
dao_ai/catalog.py,sha256=sPZpHTD3lPx4EZUtIWeQV7VQM89WJ6YH__wluk1v2lE,4947
|
|
4
4
|
dao_ai/chat_models.py,sha256=uhwwOTeLyHWqoTTgHrs4n5iSyTwe4EQcLKnh3jRxPWI,8626
|
|
5
5
|
dao_ai/cli.py,sha256=Aez2TQW3Q8Ho1IaIkRggt0NevDxAAVPjXkePC5GPJF0,20429
|
|
6
|
-
dao_ai/config.py,sha256=
|
|
7
|
-
dao_ai/graph.py,sha256=
|
|
6
|
+
dao_ai/config.py,sha256=N_Vc-rJHvBzbia4TyAExGhCvZKXlk49bskrI_sbxwjg,51869
|
|
7
|
+
dao_ai/graph.py,sha256=gmD9mxODfXuvn9xWeBfewm1FiuVAWMLEdnZz7DNmSH0,7859
|
|
8
8
|
dao_ai/guardrails.py,sha256=-Qh0f_2Db9t4Nbrrx9FM7tnpqShjMoyxepZ0HByItfU,4027
|
|
9
|
-
dao_ai/messages.py,sha256=
|
|
10
|
-
dao_ai/models.py,sha256=
|
|
9
|
+
dao_ai/messages.py,sha256=xl_3-WcFqZKCFCiov8sZOPljTdM3gX3fCHhxq-xFg2U,7005
|
|
10
|
+
dao_ai/models.py,sha256=Xb23U-lhDG8KyNRIijcJ4InluadlaGNy4rrYx7Cjgfg,26939
|
|
11
11
|
dao_ai/nodes.py,sha256=SSuFNTXOdFaKg_aX-yUkQO7fM9wvNGu14lPXKDapU1U,8461
|
|
12
12
|
dao_ai/prompts.py,sha256=vpmIbWs_szXUgNNDs5Gh2LcxKZti5pHDKSfoClUcgX0,1289
|
|
13
13
|
dao_ai/state.py,sha256=GwbMbd1TWZx1T5iQrEOX6_rpxOitlmyeJ8dMr2o_pag,1031
|
|
@@ -18,11 +18,11 @@ dao_ai/hooks/__init__.py,sha256=LlHGIuiZt6vGW8K5AQo1XJEkBP5vDVtMhq0IdjcLrD4,417
|
|
|
18
18
|
dao_ai/hooks/core.py,sha256=ZShHctUSoauhBgdf1cecy9-D7J6-sGn-pKjuRMumW5U,6663
|
|
19
19
|
dao_ai/memory/__init__.py,sha256=1kHx_p9abKYFQ6EYD05nuc1GS5HXVEpufmjBGw_7Uho,260
|
|
20
20
|
dao_ai/memory/base.py,sha256=99nfr2UZJ4jmfTL_KrqUlRSCoRxzkZyWyx5WqeUoMdQ,338
|
|
21
|
-
dao_ai/memory/core.py,sha256=
|
|
22
|
-
dao_ai/memory/postgres.py,sha256=
|
|
21
|
+
dao_ai/memory/core.py,sha256=g7chjBgVgx3iKjR2hghl0QL1j3802uIM_e7mgszur9M,4151
|
|
22
|
+
dao_ai/memory/postgres.py,sha256=ncvEKFYX-ZjUDYVmuWBMcZnykcp2eK4TP-ojzqkwDsk,17433
|
|
23
23
|
dao_ai/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
24
24
|
dao_ai/providers/base.py,sha256=-fjKypCOk28h6vioPfMj9YZSw_3Kcbi2nMuAyY7vX9k,1383
|
|
25
|
-
dao_ai/providers/databricks.py,sha256=
|
|
25
|
+
dao_ai/providers/databricks.py,sha256=fZ8mGotfA3W3t5yUej2xGmGHSybjBFYr895mOctT418,28203
|
|
26
26
|
dao_ai/tools/__init__.py,sha256=ye6MHaJY7tUnJ8336YJiLxuZr55zDPNdOw6gm7j5jlc,1103
|
|
27
27
|
dao_ai/tools/agent.py,sha256=WbQnyziiT12TLMrA7xK0VuOU029tdmUBXbUl-R1VZ0Q,1886
|
|
28
28
|
dao_ai/tools/core.py,sha256=Kei33S8vrmvPOAyrFNekaWmV2jqZ-IPS1QDSvU7RZF0,1984
|
|
@@ -33,8 +33,8 @@ dao_ai/tools/python.py,sha256=XcQiTMshZyLUTVR5peB3vqsoUoAAy8gol9_pcrhddfI,1831
|
|
|
33
33
|
dao_ai/tools/time.py,sha256=Y-23qdnNHzwjvnfkWvYsE7PoWS1hfeKy44tA7sCnNac,8759
|
|
34
34
|
dao_ai/tools/unity_catalog.py,sha256=PXfLj2EgyQgaXq4Qq3t25AmTC4KyVCF_-sCtg6enens,1404
|
|
35
35
|
dao_ai/tools/vector_search.py,sha256=EDYQs51zIPaAP0ma1D81wJT77GQ-v-cjb2XrFVWfWdg,2621
|
|
36
|
-
dao_ai-0.0.
|
|
37
|
-
dao_ai-0.0.
|
|
38
|
-
dao_ai-0.0.
|
|
39
|
-
dao_ai-0.0.
|
|
40
|
-
dao_ai-0.0.
|
|
36
|
+
dao_ai-0.0.19.dist-info/METADATA,sha256=hus4RZHOCTgDR6Rs8zS9l0OusplrFzryWCLsXZpTxgw,41380
|
|
37
|
+
dao_ai-0.0.19.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
38
|
+
dao_ai-0.0.19.dist-info/entry_points.txt,sha256=Xa-UFyc6gWGwMqMJOt06ZOog2vAfygV_DSwg1AiP46g,43
|
|
39
|
+
dao_ai-0.0.19.dist-info/licenses/LICENSE,sha256=YZt3W32LtPYruuvHE9lGk2bw6ZPMMJD8yLrjgHybyz4,1069
|
|
40
|
+
dao_ai-0.0.19.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|