dao-ai 0.0.32__py3-none-any.whl → 0.0.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dao_ai/config.py CHANGED
@@ -12,6 +12,7 @@ from typing import (
12
12
  Iterator,
13
13
  Literal,
14
14
  Optional,
15
+ Self,
15
16
  Sequence,
16
17
  TypeAlias,
17
18
  Union,
@@ -200,6 +201,15 @@ AnyVariable: TypeAlias = (
200
201
  )
201
202
 
202
203
 
204
+ class ServicePrincipalModel(BaseModel):
205
+ model_config = ConfigDict(
206
+ frozen=True,
207
+ use_enum_values=True,
208
+ )
209
+ client_id: AnyVariable
210
+ client_secret: AnyVariable
211
+
212
+
203
213
  class Privilege(str, Enum):
204
214
  ALL_PRIVILEGES = "ALL_PRIVILEGES"
205
215
  USE_CATALOG = "USE_CATALOG"
@@ -226,9 +236,21 @@ class Privilege(str, Enum):
226
236
 
227
237
  class PermissionModel(BaseModel):
228
238
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
229
- principals: list[str] = Field(default_factory=list)
239
+ principals: list[ServicePrincipalModel | str] = Field(default_factory=list)
230
240
  privileges: list[Privilege]
231
241
 
242
+ @model_validator(mode="after")
243
+ def resolve_principals(self) -> Self:
244
+ """Resolve ServicePrincipalModel objects to their client_id."""
245
+ resolved: list[str] = []
246
+ for principal in self.principals:
247
+ if isinstance(principal, ServicePrincipalModel):
248
+ resolved.append(value_of(principal.client_id))
249
+ else:
250
+ resolved.append(principal)
251
+ self.principals = resolved
252
+ return self
253
+
232
254
 
233
255
  class SchemaModel(BaseModel, HasFullName):
234
256
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
@@ -451,7 +473,7 @@ class GenieRoomModel(BaseModel, IsDatabricksResource):
451
473
  ]
452
474
 
453
475
  @model_validator(mode="after")
454
- def update_space_id(self):
476
+ def update_space_id(self) -> Self:
455
477
  self.space_id = value_of(self.space_id)
456
478
  return self
457
479
 
@@ -530,13 +552,13 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
530
552
  embedding_source_column: str
531
553
 
532
554
  @model_validator(mode="after")
533
- def set_default_embedding_model(self):
555
+ def set_default_embedding_model(self) -> Self:
534
556
  if not self.embedding_model:
535
557
  self.embedding_model = LLMModel(name="databricks-gte-large-en")
536
558
  return self
537
559
 
538
560
  @model_validator(mode="after")
539
- def set_default_primary_key(self):
561
+ def set_default_primary_key(self) -> Self:
540
562
  if self.primary_key is None:
541
563
  from dao_ai.providers.databricks import DatabricksProvider
542
564
 
@@ -557,14 +579,14 @@ class VectorStoreModel(BaseModel, IsDatabricksResource):
557
579
  return self
558
580
 
559
581
  @model_validator(mode="after")
560
- def set_default_index(self):
582
+ def set_default_index(self) -> Self:
561
583
  if self.index is None:
562
584
  name: str = f"{self.source_table.name}_index"
563
585
  self.index = IndexModel(schema=self.source_table.schema_model, name=name)
564
586
  return self
565
587
 
566
588
  @model_validator(mode="after")
567
- def set_default_endpoint(self):
589
+ def set_default_endpoint(self) -> Self:
568
590
  if self.endpoint is None:
569
591
  from dao_ai.providers.databricks import (
570
592
  DatabricksProvider,
@@ -719,7 +741,7 @@ class WarehouseModel(BaseModel, IsDatabricksResource):
719
741
  ]
720
742
 
721
743
  @model_validator(mode="after")
722
- def update_warehouse_id(self):
744
+ def update_warehouse_id(self) -> Self:
723
745
  self.warehouse_id = value_of(self.warehouse_id)
724
746
  return self
725
747
 
@@ -737,12 +759,28 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
737
759
  - Used for: discovering database instance, getting host DNS, checking instance status
738
760
  - Controlled by: DATABRICKS_HOST, DATABRICKS_TOKEN env vars, or SDK default config
739
761
 
740
- 2. **Database Connection Authentication** (configured via client_id/client_secret OR user):
762
+ 2. **Database Connection Authentication** (configured via service_principal, client_id/client_secret, OR user):
741
763
  - Used for: connecting to the PostgreSQL database as a specific identity
764
+ - Service Principal: Set service_principal with workspace_host to connect as a service principal
742
765
  - OAuth M2M: Set client_id, client_secret, workspace_host to connect as a service principal
743
766
  - User Auth: Set user (and optionally password) to connect as a user identity
744
767
 
745
- Example OAuth M2M Configuration:
768
+ Example Service Principal Configuration:
769
+ ```yaml
770
+ databases:
771
+ my_lakebase:
772
+ name: my-database
773
+ service_principal:
774
+ client_id:
775
+ env: SERVICE_PRINCIPAL_CLIENT_ID
776
+ client_secret:
777
+ scope: my-scope
778
+ secret: sp-client-secret
779
+ workspace_host:
780
+ env: DATABRICKS_HOST
781
+ ```
782
+
783
+ Example OAuth M2M Configuration (alternative):
746
784
  ```yaml
747
785
  databases:
748
786
  my_lakebase:
@@ -779,6 +817,7 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
779
817
  node_count: Optional[int] = None
780
818
  user: Optional[AnyVariable] = None
781
819
  password: Optional[AnyVariable] = None
820
+ service_principal: Optional[ServicePrincipalModel] = None
782
821
  client_id: Optional[AnyVariable] = None
783
822
  client_secret: Optional[AnyVariable] = None
784
823
  workspace_host: Optional[AnyVariable] = None
@@ -796,14 +835,24 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
796
835
  ]
797
836
 
798
837
  @model_validator(mode="after")
799
- def update_instance_name(self):
838
+ def update_instance_name(self) -> Self:
800
839
  if self.instance_name is None:
801
840
  self.instance_name = self.name
802
841
 
803
842
  return self
804
843
 
805
844
  @model_validator(mode="after")
806
- def update_user(self):
845
+ def expand_service_principal(self) -> Self:
846
+ """Expand service_principal into client_id and client_secret if provided."""
847
+ if self.service_principal is not None:
848
+ if self.client_id is None:
849
+ self.client_id = self.service_principal.client_id
850
+ if self.client_secret is None:
851
+ self.client_secret = self.service_principal.client_secret
852
+ return self
853
+
854
+ @model_validator(mode="after")
855
+ def update_user(self) -> Self:
807
856
  if self.client_id or self.user:
808
857
  return self
809
858
 
@@ -816,7 +865,7 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
816
865
  return self
817
866
 
818
867
  @model_validator(mode="after")
819
- def update_host(self):
868
+ def update_host(self) -> Self:
820
869
  if self.host is not None:
821
870
  return self
822
871
 
@@ -829,7 +878,7 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
829
878
  return self
830
879
 
831
880
  @model_validator(mode="after")
832
- def validate_auth_methods(self):
881
+ def validate_auth_methods(self) -> Self:
833
882
  oauth_fields: Sequence[Any] = [
834
883
  self.workspace_host,
835
884
  self.client_id,
@@ -849,8 +898,8 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
849
898
  if not has_oauth and not has_user_auth:
850
899
  raise ValueError(
851
900
  "At least one authentication method must be provided: "
852
- "either OAuth credentials (workspace_host, client_id, client_secret) "
853
- "or user credentials (user, password)."
901
+ "either OAuth credentials (workspace_host, client_id, client_secret), "
902
+ "service_principal with workspace_host, or user credentials (user, password)."
854
903
  )
855
904
 
856
905
  return self
@@ -997,14 +1046,14 @@ class RetrieverModel(BaseModel):
997
1046
  )
998
1047
 
999
1048
  @model_validator(mode="after")
1000
- def set_default_columns(self):
1049
+ def set_default_columns(self) -> Self:
1001
1050
  if not self.columns:
1002
1051
  columns: Sequence[str] = self.vector_store.columns
1003
1052
  self.columns = columns
1004
1053
  return self
1005
1054
 
1006
1055
  @model_validator(mode="after")
1007
- def set_default_reranker(self):
1056
+ def set_default_reranker(self) -> Self:
1008
1057
  """Convert bool to ReRankParametersModel with defaults."""
1009
1058
  if isinstance(self.rerank, bool) and self.rerank:
1010
1059
  self.rerank = RerankParametersModel()
@@ -1091,7 +1140,7 @@ class FactoryFunctionModel(BaseFunctionModel, HasFullName):
1091
1140
  return [create_factory_tool(self, **kwargs)]
1092
1141
 
1093
1142
  @model_validator(mode="after")
1094
- def update_args(self):
1143
+ def update_args(self) -> Self:
1095
1144
  for key, value in self.args.items():
1096
1145
  self.args[key] = value_of(value)
1097
1146
  return self
@@ -1111,6 +1160,7 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
1111
1160
  headers: dict[str, AnyVariable] = Field(default_factory=dict)
1112
1161
  args: list[str] = Field(default_factory=list)
1113
1162
  pat: Optional[AnyVariable] = None
1163
+ service_principal: Optional[ServicePrincipalModel] = None
1114
1164
  client_id: Optional[AnyVariable] = None
1115
1165
  client_secret: Optional[AnyVariable] = None
1116
1166
  workspace_host: Optional[AnyVariable] = None
@@ -1120,6 +1170,16 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
1120
1170
  sql: Optional[bool] = None
1121
1171
  vector_search: Optional[VectorStoreModel] = None
1122
1172
 
1173
+ @model_validator(mode="after")
1174
+ def expand_service_principal(self) -> Self:
1175
+ """Expand service_principal into client_id and client_secret if provided."""
1176
+ if self.service_principal is not None:
1177
+ if self.client_id is None:
1178
+ self.client_id = self.service_principal.client_id
1179
+ if self.client_secret is None:
1180
+ self.client_secret = self.service_principal.client_secret
1181
+ return self
1182
+
1123
1183
  @property
1124
1184
  def full_name(self) -> str:
1125
1185
  return self.name
@@ -1129,12 +1189,12 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
1129
1189
  Get the workspace host, either from config or from workspace client.
1130
1190
 
1131
1191
  If connection is provided, uses its workspace client.
1132
- Otherwise, falls back to creating a new workspace client.
1192
+ Otherwise, falls back to the default Databricks host.
1133
1193
 
1134
1194
  Returns:
1135
1195
  str: The workspace host URL without trailing slash
1136
1196
  """
1137
- from databricks.sdk import WorkspaceClient
1197
+ from dao_ai.utils import get_default_databricks_host
1138
1198
 
1139
1199
  # Try to get workspace_host from config
1140
1200
  workspace_host: str | None = (
@@ -1147,9 +1207,13 @@ class McpFunctionModel(BaseFunctionModel, HasFullName):
1147
1207
  if self.connection:
1148
1208
  workspace_host = self.connection.workspace_client.config.host
1149
1209
  else:
1150
- # Create a default workspace client
1151
- w: WorkspaceClient = WorkspaceClient()
1152
- workspace_host = w.config.host
1210
+ workspace_host = get_default_databricks_host()
1211
+
1212
+ if not workspace_host:
1213
+ raise ValueError(
1214
+ "Could not determine workspace host. "
1215
+ "Please set workspace_host in config or DATABRICKS_HOST environment variable."
1216
+ )
1153
1217
 
1154
1218
  # Remove trailing slash
1155
1219
  return workspace_host.rstrip("/")
@@ -1356,7 +1420,7 @@ class CheckpointerModel(BaseModel):
1356
1420
  database: Optional[DatabaseModel] = None
1357
1421
 
1358
1422
  @model_validator(mode="after")
1359
- def validate_postgres_requires_database(self):
1423
+ def validate_postgres_requires_database(self) -> Self:
1360
1424
  if self.type == StorageType.POSTGRES and not self.database:
1361
1425
  raise ValueError("Database must be provided when storage type is POSTGRES")
1362
1426
  return self
@@ -1381,7 +1445,7 @@ class StoreModel(BaseModel):
1381
1445
  namespace: Optional[str] = None
1382
1446
 
1383
1447
  @model_validator(mode="after")
1384
- def validate_postgres_requires_database(self):
1448
+ def validate_postgres_requires_database(self) -> Self:
1385
1449
  if self.type == StorageType.POSTGRES and not self.database:
1386
1450
  raise ValueError("Database must be provided when storage type is POSTGRES")
1387
1451
  return self
@@ -1445,7 +1509,7 @@ class PromptModel(BaseModel, HasFullName):
1445
1509
  return prompt_version
1446
1510
 
1447
1511
  @model_validator(mode="after")
1448
- def validate_mutually_exclusive(self):
1512
+ def validate_mutually_exclusive(self) -> Self:
1449
1513
  if self.alias and self.version:
1450
1514
  raise ValueError("Cannot specify both alias and version")
1451
1515
  return self
@@ -1499,7 +1563,7 @@ class OrchestrationModel(BaseModel):
1499
1563
  memory: Optional[MemoryModel] = None
1500
1564
 
1501
1565
  @model_validator(mode="after")
1502
- def validate_mutually_exclusive(self):
1566
+ def validate_mutually_exclusive(self) -> Self:
1503
1567
  if self.supervisor is not None and self.swarm is not None:
1504
1568
  raise ValueError("Cannot specify both supervisor and swarm")
1505
1569
  if self.supervisor is None and self.swarm is None:
@@ -1529,9 +1593,21 @@ class Entitlement(str, Enum):
1529
1593
 
1530
1594
  class AppPermissionModel(BaseModel):
1531
1595
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1532
- principals: list[str] = Field(default_factory=list)
1596
+ principals: list[ServicePrincipalModel | str] = Field(default_factory=list)
1533
1597
  entitlements: list[Entitlement]
1534
1598
 
1599
+ @model_validator(mode="after")
1600
+ def resolve_principals(self) -> Self:
1601
+ """Resolve ServicePrincipalModel objects to their client_id."""
1602
+ resolved: list[str] = []
1603
+ for principal in self.principals:
1604
+ if isinstance(principal, ServicePrincipalModel):
1605
+ resolved.append(value_of(principal.client_id))
1606
+ else:
1607
+ resolved.append(principal)
1608
+ self.principals = resolved
1609
+ return self
1610
+
1535
1611
 
1536
1612
  class LogLevel(str, Enum):
1537
1613
  TRACE = "TRACE"
@@ -1626,6 +1702,7 @@ class ChatHistoryModel(BaseModel):
1626
1702
  class AppModel(BaseModel):
1627
1703
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1628
1704
  name: str
1705
+ service_principal: Optional[ServicePrincipalModel] = None
1629
1706
  description: Optional[str] = None
1630
1707
  log_level: Optional[LogLevel] = "WARNING"
1631
1708
  registered_model: RegisteredModelModel
@@ -1661,14 +1738,42 @@ class AppModel(BaseModel):
1661
1738
  )
1662
1739
 
1663
1740
  @model_validator(mode="after")
1664
- def validate_agents_not_empty(self):
1741
+ def set_databricks_env_vars(self) -> Self:
1742
+ """Set Databricks environment variables for Model Serving.
1743
+
1744
+ Sets DATABRICKS_HOST, DATABRICKS_CLIENT_ID, and DATABRICKS_CLIENT_SECRET.
1745
+ Values explicitly provided in environment_vars take precedence.
1746
+ """
1747
+ from dao_ai.utils import get_default_databricks_host
1748
+
1749
+ # Set DATABRICKS_HOST if not already provided
1750
+ if "DATABRICKS_HOST" not in self.environment_vars:
1751
+ host: str | None = get_default_databricks_host()
1752
+ if host:
1753
+ self.environment_vars["DATABRICKS_HOST"] = host
1754
+
1755
+ # Set service principal credentials if provided
1756
+ if self.service_principal is not None:
1757
+ if "DATABRICKS_CLIENT_ID" not in self.environment_vars:
1758
+ self.environment_vars["DATABRICKS_CLIENT_ID"] = (
1759
+ self.service_principal.client_id
1760
+ )
1761
+ if "DATABRICKS_CLIENT_SECRET" not in self.environment_vars:
1762
+ self.environment_vars["DATABRICKS_CLIENT_SECRET"] = (
1763
+ self.service_principal.client_secret
1764
+ )
1765
+ return self
1766
+
1767
+ @model_validator(mode="after")
1768
+ def validate_agents_not_empty(self) -> Self:
1665
1769
  if not self.agents:
1666
1770
  raise ValueError("At least one agent must be specified")
1667
1771
  return self
1668
1772
 
1669
1773
  @model_validator(mode="after")
1670
- def update_environment_vars(self):
1774
+ def resolve_environment_vars(self) -> Self:
1671
1775
  for key, value in self.environment_vars.items():
1776
+ updated_value: str
1672
1777
  if isinstance(value, SecretVariableModel):
1673
1778
  updated_value = str(value)
1674
1779
  else:
@@ -1678,7 +1783,7 @@ class AppModel(BaseModel):
1678
1783
  return self
1679
1784
 
1680
1785
  @model_validator(mode="after")
1681
- def set_default_orchestration(self):
1786
+ def set_default_orchestration(self) -> Self:
1682
1787
  if self.orchestration is None:
1683
1788
  if len(self.agents) > 1:
1684
1789
  default_agent: AgentModel = self.agents[0]
@@ -1698,14 +1803,14 @@ class AppModel(BaseModel):
1698
1803
  return self
1699
1804
 
1700
1805
  @model_validator(mode="after")
1701
- def set_default_endpoint_name(self):
1806
+ def set_default_endpoint_name(self) -> Self:
1702
1807
  if self.endpoint_name is None:
1703
1808
  self.endpoint_name = self.name
1704
1809
  return self
1705
1810
 
1706
1811
  @model_validator(mode="after")
1707
- def set_default_agent(self):
1708
- default_agent_name = self.agents[0].name
1812
+ def set_default_agent(self) -> Self:
1813
+ default_agent_name: str = self.agents[0].name
1709
1814
 
1710
1815
  if self.orchestration.swarm and not self.orchestration.swarm.default_agent:
1711
1816
  self.orchestration.swarm.default_agent = default_agent_name
@@ -1713,7 +1818,7 @@ class AppModel(BaseModel):
1713
1818
  return self
1714
1819
 
1715
1820
  @model_validator(mode="after")
1716
- def add_code_paths_to_sys_path(self):
1821
+ def add_code_paths_to_sys_path(self) -> Self:
1717
1822
  for code_path in self.code_paths:
1718
1823
  parent_path: str = str(Path(code_path).parent)
1719
1824
  if parent_path not in sys.path:
@@ -1746,7 +1851,7 @@ class EvaluationDatasetExpectationsModel(BaseModel):
1746
1851
  expected_facts: Optional[list[str]] = None
1747
1852
 
1748
1853
  @model_validator(mode="after")
1749
- def validate_mutually_exclusive(self):
1854
+ def validate_mutually_exclusive(self) -> Self:
1750
1855
  if self.expected_response is not None and self.expected_facts is not None:
1751
1856
  raise ValueError("Cannot specify both expected_response and expected_facts")
1752
1857
  return self
@@ -1854,7 +1959,7 @@ class PromptOptimizationModel(BaseModel):
1854
1959
  return optimized_prompt
1855
1960
 
1856
1961
  @model_validator(mode="after")
1857
- def set_defaults(self):
1962
+ def set_defaults(self) -> Self:
1858
1963
  # If no prompt is specified, try to use the agent's prompt
1859
1964
  if self.prompt is None:
1860
1965
  if isinstance(self.agent.prompt, PromptModel):
@@ -1976,6 +2081,7 @@ class ResourcesModel(BaseModel):
1976
2081
  class AppConfig(BaseModel):
1977
2082
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1978
2083
  variables: dict[str, AnyVariable] = Field(default_factory=dict)
2084
+ service_principals: dict[str, ServicePrincipalModel] = Field(default_factory=dict)
1979
2085
  schemas: dict[str, SchemaModel] = Field(default_factory=dict)
1980
2086
  resources: Optional[ResourcesModel] = None
1981
2087
  retrievers: dict[str, RetrieverModel] = Field(default_factory=dict)
dao_ai/models.py CHANGED
@@ -331,13 +331,23 @@ class LanggraphResponsesAgent(ResponsesAgent):
331
331
  context: Context = self._convert_request_to_context(request)
332
332
  custom_inputs: dict[str, Any] = {"configurable": context.model_dump()}
333
333
 
334
+ # Build the graph input state, including genie_conversation_ids if provided
335
+ graph_input: dict[str, Any] = {"messages": messages}
336
+ if request.custom_inputs and "genie_conversation_ids" in request.custom_inputs:
337
+ graph_input["genie_conversation_ids"] = request.custom_inputs[
338
+ "genie_conversation_ids"
339
+ ]
340
+ logger.debug(
341
+ f"Including genie_conversation_ids in graph input: {graph_input['genie_conversation_ids']}"
342
+ )
343
+
334
344
  # Use async ainvoke internally for parallel execution
335
345
  import asyncio
336
346
 
337
347
  async def _async_invoke():
338
348
  try:
339
349
  return await self.graph.ainvoke(
340
- {"messages": messages}, context=context, config=custom_inputs
350
+ graph_input, context=context, config=custom_inputs
341
351
  )
342
352
  except Exception as e:
343
353
  logger.error(f"Error in graph.ainvoke: {e}")
@@ -399,6 +409,16 @@ class LanggraphResponsesAgent(ResponsesAgent):
399
409
  context: Context = self._convert_request_to_context(request)
400
410
  custom_inputs: dict[str, Any] = {"configurable": context.model_dump()}
401
411
 
412
+ # Build the graph input state, including genie_conversation_ids if provided
413
+ graph_input: dict[str, Any] = {"messages": messages}
414
+ if request.custom_inputs and "genie_conversation_ids" in request.custom_inputs:
415
+ graph_input["genie_conversation_ids"] = request.custom_inputs[
416
+ "genie_conversation_ids"
417
+ ]
418
+ logger.debug(
419
+ f"Including genie_conversation_ids in graph input: {graph_input['genie_conversation_ids']}"
420
+ )
421
+
402
422
  # Use async astream internally for parallel execution
403
423
  import asyncio
404
424
 
@@ -408,7 +428,7 @@ class LanggraphResponsesAgent(ResponsesAgent):
408
428
 
409
429
  try:
410
430
  async for nodes, stream_mode, messages_batch in self.graph.astream(
411
- {"messages": messages},
431
+ graph_input,
412
432
  context=context,
413
433
  config=custom_inputs,
414
434
  stream_mode=["messages", "custom"],
@@ -1151,7 +1151,7 @@ class DatabricksProvider(ServiceProvider):
1151
1151
  If an explicit version or alias is specified in the prompt_model, uses that directly.
1152
1152
  Otherwise, tries to load prompts in this order:
1153
1153
  1. champion alias
1154
- 2. latest version (max version number from search_prompt_versions)
1154
+ 2. latest alias
1155
1155
  3. default alias
1156
1156
  4. Register default_template if provided
1157
1157
 
@@ -1166,7 +1166,6 @@ class DatabricksProvider(ServiceProvider):
1166
1166
  """
1167
1167
 
1168
1168
  prompt_name: str = prompt_model.full_name
1169
- mlflow_client: MlflowClient = MlflowClient()
1170
1169
 
1171
1170
  # If explicit version or alias is specified, use it directly
1172
1171
  if prompt_model.version or prompt_model.alias:
@@ -1197,19 +1196,13 @@ class DatabricksProvider(ServiceProvider):
1197
1196
  except Exception as e:
1198
1197
  logger.debug(f"Champion alias not found for '{prompt_name}': {e}")
1199
1198
 
1200
- # 2. Try to get latest version by finding the max version number
1199
+ # 2. Try latest alias
1201
1200
  try:
1202
- versions = mlflow_client.search_prompt_versions(
1203
- prompt_name, max_results=100
1204
- )
1205
- if versions:
1206
- latest = max(versions, key=lambda v: int(v.version))
1207
- logger.info(
1208
- f"Loaded prompt '{prompt_name}' version {latest.version} (latest by max version)"
1209
- )
1210
- return latest
1201
+ prompt_version = load_prompt(f"prompts:/{prompt_name}@latest")
1202
+ logger.info(f"Loaded prompt '{prompt_name}' from latest alias")
1203
+ return prompt_version
1211
1204
  except Exception as e:
1212
- logger.debug(f"Failed to find latest version for '{prompt_name}': {e}")
1205
+ logger.debug(f"Latest alias not found for '{prompt_name}': {e}")
1213
1206
 
1214
1207
  # 3. Try default alias
1215
1208
  try:
@@ -1225,7 +1218,7 @@ class DatabricksProvider(ServiceProvider):
1225
1218
  f"No existing prompt found for '{prompt_name}', "
1226
1219
  "attempting to register default_template"
1227
1220
  )
1228
- return self._sync_default_template_to_registry(
1221
+ return self._register_default_template(
1229
1222
  prompt_name, prompt_model.default_template, prompt_model.description
1230
1223
  )
1231
1224
 
@@ -1235,49 +1228,17 @@ class DatabricksProvider(ServiceProvider):
1235
1228
  "and no default_template provided"
1236
1229
  )
1237
1230
 
1238
- def _sync_default_template_to_registry(
1231
+ def _register_default_template(
1239
1232
  self, prompt_name: str, default_template: str, description: str | None = None
1240
1233
  ) -> PromptVersion:
1241
- """Get the best available prompt version, or register default_template if possible.
1242
-
1243
- Tries to load prompts in order: champion → latest (max version) → default.
1244
- If none found and we have write permissions, registers the default_template.
1245
- If registration fails (e.g., in Model Serving), logs the error and raises.
1246
- """
1247
- mlflow_client: MlflowClient = MlflowClient()
1248
-
1249
- # Try to find an existing prompt version in priority order
1250
- # 1. Try champion alias
1251
- try:
1252
- champion = mlflow.genai.load_prompt(f"prompts:/{prompt_name}@champion")
1253
- logger.info(f"Loaded prompt '{prompt_name}' from champion alias")
1254
- return champion
1255
- except Exception as e:
1256
- logger.debug(f"Champion alias not found for '{prompt_name}': {e}")
1234
+ """Register default_template as a new prompt version.
1257
1235
 
1258
- # 2. Try to get the latest version by finding the max version number
1259
- try:
1260
- versions = mlflow_client.search_prompt_versions(
1261
- prompt_name, max_results=100
1262
- )
1263
- if versions:
1264
- latest = max(versions, key=lambda v: int(v.version))
1265
- logger.info(
1266
- f"Loaded prompt '{prompt_name}' version {latest.version} (latest by max version)"
1267
- )
1268
- return latest
1269
- except Exception as e:
1270
- logger.debug(f"Failed to search versions for '{prompt_name}': {e}")
1271
-
1272
- # 3. Try default alias
1273
- try:
1274
- default = mlflow.genai.load_prompt(f"prompts:/{prompt_name}@default")
1275
- logger.info(f"Loaded prompt '{prompt_name}' from default alias")
1276
- return default
1277
- except Exception as e:
1278
- logger.debug(f"Default alias not found for '{prompt_name}': {e}")
1236
+ Called when no existing prompt version is found (champion, latest, default all failed).
1237
+ Registers the template and sets both 'default' and 'champion' aliases.
1279
1238
 
1280
- # No existing prompt found - try to register if we have a template
1239
+ If registration fails (e.g., in Model Serving with restricted permissions),
1240
+ logs the error and raises.
1241
+ """
1281
1242
  logger.info(
1282
1243
  f"No existing prompt found for '{prompt_name}', attempting to register default_template"
1283
1244
  )
dao_ai/tools/core.py CHANGED
@@ -35,7 +35,7 @@ def create_tools(tool_models: Sequence[ToolModel]) -> Sequence[RunnableLike]:
35
35
  if name in tools:
36
36
  logger.warning(f"Tools already registered for: {name}, skipping creation.")
37
37
  continue
38
- registered_tools: Sequence[RunnableLike] = tool_registry.get(name)
38
+ registered_tools: Sequence[RunnableLike] | None = tool_registry.get(name)
39
39
  if registered_tools is None:
40
40
  logger.debug(f"Creating tools for: {name}...")
41
41
  function: AnyTool = tool_config.function
dao_ai/tools/genie.py CHANGED
@@ -5,8 +5,9 @@ from typing import Annotated, Any, Callable
5
5
 
6
6
  import pandas as pd
7
7
  from databricks_ai_bridge.genie import Genie, GenieResponse
8
+ from langchain.tools import tool
8
9
  from langchain_core.messages import ToolMessage
9
- from langchain_core.tools import InjectedToolCallId, tool
10
+ from langchain_core.tools import InjectedToolCallId
10
11
  from langgraph.prebuilt import InjectedState
11
12
  from langgraph.types import Command
12
13
  from loguru import logger
@@ -43,7 +44,7 @@ def create_genie_tool(
43
44
  genie_room: GenieRoomModel | dict[str, Any],
44
45
  name: str | None = None,
45
46
  description: str | None = None,
46
- persist_conversation: bool = False,
47
+ persist_conversation: bool = True,
47
48
  truncate_results: bool = False,
48
49
  ) -> Callable[..., Command]:
49
50
  """
@@ -64,6 +65,16 @@ def create_genie_tool(
64
65
  Returns:
65
66
  A LangGraph tool that processes natural language queries through Genie
66
67
  """
68
+ logger.debug("create_genie_tool")
69
+ logger.debug(f"genie_room type: {type(genie_room)}")
70
+ logger.debug(f"genie_room: {genie_room}")
71
+ logger.debug(f"persist_conversation: {persist_conversation}")
72
+ logger.debug(f"truncate_results: {truncate_results}")
73
+ logger.debug(f"name: {name}")
74
+ logger.debug(f"description: {description}")
75
+ logger.debug(f"genie_room: {genie_room}")
76
+ logger.debug(f"persist_conversation: {persist_conversation}")
77
+ logger.debug(f"truncate_results: {truncate_results}")
67
78
 
68
79
  if isinstance(genie_room, dict):
69
80
  genie_room = GenieRoomModel(**genie_room)
@@ -106,14 +117,13 @@ GenieResponse: A response object containing the conversation ID and result from
106
117
  state: Annotated[dict, InjectedState],
107
118
  tool_call_id: Annotated[str, InjectedToolCallId],
108
119
  ) -> Command:
109
- """Process a natural language question through Databricks Genie."""
110
- # Create Genie instance using databricks_langchain implementation
111
120
  genie: Genie = Genie(
112
121
  space_id=space_id,
113
122
  client=genie_room.workspace_client,
114
123
  truncate_results=truncate_results,
115
124
  )
116
125
 
126
+ """Process a natural language question through Databricks Genie."""
117
127
  # Get existing conversation mapping and retrieve conversation ID for this space
118
128
  conversation_ids: dict[str, str] = state.get("genie_conversation_ids", {})
119
129
  existing_conversation_id: str | None = conversation_ids.get(space_id)
@@ -131,6 +141,7 @@ GenieResponse: A response object containing the conversation ID and result from
131
141
  )
132
142
 
133
143
  # Update the conversation mapping with the new conversation ID for this space
144
+
134
145
  update: dict[str, Any] = {
135
146
  "messages": [
136
147
  ToolMessage(_response_to_json(response), tool_call_id=tool_call_id)
@@ -265,23 +265,52 @@ def with_partial_args(
265
265
 
266
266
  Args:
267
267
  tool: ToolModel containing the Unity Catalog function configuration
268
- partial_args: Dictionary of arguments to pre-fill in the tool
268
+ partial_args: Dictionary of arguments to pre-fill in the tool.
269
+ Supports:
270
+ - client_id, client_secret: OAuth credentials directly
271
+ - service_principal: ServicePrincipalModel with client_id and client_secret
272
+ - host or workspace_host: Databricks workspace host
269
273
 
270
274
  Returns:
271
275
  StructuredTool: A LangChain tool with partial arguments pre-filled
272
276
  """
273
277
  from unitycatalog.ai.langchain.toolkit import generate_function_input_params_schema
274
278
 
279
+ from dao_ai.config import ServicePrincipalModel
280
+
275
281
  logger.debug(f"with_partial_args: {tool}")
276
282
 
277
283
  # Convert dict-based variables to CompositeVariableModel and resolve their values
278
- resolved_args = {}
284
+ resolved_args: dict[str, Any] = {}
279
285
  for k, v in partial_args.items():
280
286
  if isinstance(v, dict):
281
287
  resolved_args[k] = value_of(CompositeVariableModel(**v))
282
288
  else:
283
289
  resolved_args[k] = value_of(v)
284
290
 
291
+ # Handle service_principal - expand into client_id and client_secret
292
+ if "service_principal" in resolved_args:
293
+ sp = resolved_args.pop("service_principal")
294
+ if isinstance(sp, dict):
295
+ sp = ServicePrincipalModel(**sp)
296
+ if isinstance(sp, ServicePrincipalModel):
297
+ if "client_id" not in resolved_args:
298
+ resolved_args["client_id"] = value_of(sp.client_id)
299
+ if "client_secret" not in resolved_args:
300
+ resolved_args["client_secret"] = value_of(sp.client_secret)
301
+
302
+ # Normalize host/workspace_host - accept either key
303
+ if "workspace_host" in resolved_args and "host" not in resolved_args:
304
+ resolved_args["host"] = resolved_args.pop("workspace_host")
305
+
306
+ # Default host from WorkspaceClient if not provided
307
+ if "host" not in resolved_args:
308
+ from dao_ai.utils import get_default_databricks_host
309
+
310
+ host: str | None = get_default_databricks_host()
311
+ if host:
312
+ resolved_args["host"] = host
313
+
285
314
  logger.debug(f"Resolved partial args: {resolved_args.keys()}")
286
315
 
287
316
  if isinstance(tool, dict):
dao_ai/utils.py CHANGED
@@ -38,6 +38,32 @@ def normalize_name(name: str) -> str:
38
38
  return normalized.strip("_")
39
39
 
40
40
 
41
+ def get_default_databricks_host() -> str | None:
42
+ """Get the default Databricks workspace host.
43
+
44
+ Attempts to get the host from:
45
+ 1. DATABRICKS_HOST environment variable
46
+ 2. WorkspaceClient ambient authentication (e.g., from ~/.databrickscfg)
47
+
48
+ Returns:
49
+ The Databricks workspace host URL, or None if not available.
50
+ """
51
+ # Try environment variable first
52
+ host: str | None = os.environ.get("DATABRICKS_HOST")
53
+ if host:
54
+ return host
55
+
56
+ # Fall back to WorkspaceClient
57
+ try:
58
+ from databricks.sdk import WorkspaceClient
59
+
60
+ w: WorkspaceClient = WorkspaceClient()
61
+ return w.config.host
62
+ except Exception:
63
+ logger.debug("Could not get default Databricks host from WorkspaceClient")
64
+ return None
65
+
66
+
41
67
  def dao_ai_version() -> str:
42
68
  """
43
69
  Get the dao-ai package version, with fallback for source installations.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dao-ai
3
- Version: 0.0.32
3
+ Version: 0.0.34
4
4
  Summary: DAO AI: A modular, multi-agent orchestration framework for complex AI workflows. Supports agent handoff, tool integration, and dynamic configuration via YAML.
5
5
  Project-URL: Homepage, https://github.com/natefleming/dao-ai
6
6
  Project-URL: Documentation, https://natefleming.github.io/dao-ai
@@ -25,7 +25,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
25
25
  Classifier: Topic :: Software Development :: Libraries :: Python Modules
26
26
  Classifier: Topic :: System :: Distributed Computing
27
27
  Requires-Python: >=3.11
28
- Requires-Dist: databricks-agents>=1.7.0
28
+ Requires-Dist: databricks-agents>=1.8.2
29
29
  Requires-Dist: databricks-langchain>=0.11.0
30
30
  Requires-Dist: databricks-mcp>=0.3.0
31
31
  Requires-Dist: databricks-sdk[openai]>=0.67.0
@@ -33,21 +33,21 @@ Requires-Dist: ddgs>=9.9.3
33
33
  Requires-Dist: flashrank>=0.2.8
34
34
  Requires-Dist: gepa>=0.0.17
35
35
  Requires-Dist: grandalf>=0.8
36
- Requires-Dist: langchain-mcp-adapters>=0.1.10
36
+ Requires-Dist: langchain-mcp-adapters>=0.2.1
37
37
  Requires-Dist: langchain-tavily>=0.2.11
38
38
  Requires-Dist: langchain>=1.1.3
39
- Requires-Dist: langgraph-checkpoint-postgres>=2.0.25
39
+ Requires-Dist: langgraph-checkpoint-postgres>=3.0.2
40
40
  Requires-Dist: langgraph-supervisor>=0.0.31
41
41
  Requires-Dist: langgraph-swarm>=0.1.0
42
42
  Requires-Dist: langgraph>=1.0.4
43
- Requires-Dist: langmem>=0.0.29
43
+ Requires-Dist: langmem>=0.0.30
44
44
  Requires-Dist: loguru>=0.7.3
45
- Requires-Dist: mcp>=1.17.0
45
+ Requires-Dist: mcp>=1.23.3
46
46
  Requires-Dist: mlflow>=3.7.0
47
47
  Requires-Dist: nest-asyncio>=1.6.0
48
48
  Requires-Dist: openevals>=0.0.19
49
49
  Requires-Dist: openpyxl>=3.1.5
50
- Requires-Dist: psycopg[binary,pool]>=3.2.9
50
+ Requires-Dist: psycopg[binary,pool]>=3.3.2
51
51
  Requires-Dist: pydantic>=2.12.0
52
52
  Requires-Dist: python-dotenv>=1.1.0
53
53
  Requires-Dist: pyyaml>=6.0.2
@@ -3,16 +3,16 @@ dao_ai/agent_as_code.py,sha256=sviZQV7ZPxE5zkZ9jAbfegI681nra5i8yYxw05e3X7U,552
3
3
  dao_ai/catalog.py,sha256=sPZpHTD3lPx4EZUtIWeQV7VQM89WJ6YH__wluk1v2lE,4947
4
4
  dao_ai/chat_models.py,sha256=uhwwOTeLyHWqoTTgHrs4n5iSyTwe4EQcLKnh3jRxPWI,8626
5
5
  dao_ai/cli.py,sha256=gq-nsapWxDA1M6Jua3vajBvIwf0Oa6YLcB58lEtMKUo,22503
6
- dao_ai/config.py,sha256=sc9iYPui5tHitG5kmOTd9LVjzgLJ2Dn0M6s-Zu3dw04,75022
6
+ dao_ai/config.py,sha256=Jzb0ePrt2TM2WuXI_LtmTafbseKBlJ8J8J2ExyBowbM,79491
7
7
  dao_ai/graph.py,sha256=9kjJx0oFZKq5J9-Kpri4-0VCJILHYdYyhqQnj0_noxQ,8913
8
8
  dao_ai/guardrails.py,sha256=4TKArDONRy8RwHzOT1plZ1rhy3x9GF_aeGpPCRl6wYA,4016
9
9
  dao_ai/messages.py,sha256=xl_3-WcFqZKCFCiov8sZOPljTdM3gX3fCHhxq-xFg2U,7005
10
- dao_ai/models.py,sha256=8r8GIG3EGxtVyWsRNI56lVaBjiNrPkzh4HdwMZRq8iw,31689
10
+ dao_ai/models.py,sha256=hvEZO2N0EC2sQoMgjJ9mbKmDWcdxnnAb2NqzpXh4Wgk,32691
11
11
  dao_ai/nodes.py,sha256=iQ_5vL6mt1UcRnhwgz-l1D8Ww4CMQrSMVnP_Lu7fFjU,8781
12
12
  dao_ai/prompts.py,sha256=iA2Iaky7yzjwWT5cxg0cUIgwo1z1UVQua__8WPnvV6g,1633
13
13
  dao_ai/state.py,sha256=_lF9krAYYjvFDMUwZzVKOn0ZnXKcOrbjWKdre0C5B54,1137
14
14
  dao_ai/types.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
- dao_ai/utils.py,sha256=FLXbiUaCeBva4vJ-czs-sRP7QSxjoKjyDt1Q4yeI7sU,7727
15
+ dao_ai/utils.py,sha256=oIPmz02kZ3LMntbqxUajFXh4nswOhbvEjOTi4e5_cvI,8500
16
16
  dao_ai/vector_search.py,sha256=jlaFS_iizJ55wblgzZmswMM3UOL-qOp2BGJc0JqXYSg,2839
17
17
  dao_ai/hooks/__init__.py,sha256=LlHGIuiZt6vGW8K5AQo1XJEkBP5vDVtMhq0IdjcLrD4,417
18
18
  dao_ai/hooks/core.py,sha256=ZShHctUSoauhBgdf1cecy9-D7J6-sGn-pKjuRMumW5U,6663
@@ -22,20 +22,20 @@ dao_ai/memory/core.py,sha256=DnEjQO3S7hXr3CDDd7C2eE7fQUmcCS_8q9BXEgjPH3U,4271
22
22
  dao_ai/memory/postgres.py,sha256=vvI3osjx1EoU5GBA6SCUstTBKillcmLl12hVgDMjfJY,15346
23
23
  dao_ai/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
24
24
  dao_ai/providers/base.py,sha256=-fjKypCOk28h6vioPfMj9YZSw_3Kcbi2nMuAyY7vX9k,1383
25
- dao_ai/providers/databricks.py,sha256=rPBMdGcJvdGBRK9FZeBxkLfcTpXyxU1cs14YllyZKbY,67857
25
+ dao_ai/providers/databricks.py,sha256=WEigNPGRTlIPVjwp97My8o1zOHn5ftuMsMrpqrBeaLg,66012
26
26
  dao_ai/tools/__init__.py,sha256=G5-5Yi6zpQOH53b5IzLdtsC6g0Ep6leI5GxgxOmgw7Q,1203
27
27
  dao_ai/tools/agent.py,sha256=WbQnyziiT12TLMrA7xK0VuOU029tdmUBXbUl-R1VZ0Q,1886
28
- dao_ai/tools/core.py,sha256=Kei33S8vrmvPOAyrFNekaWmV2jqZ-IPS1QDSvU7RZF0,1984
29
- dao_ai/tools/genie.py,sha256=BPM_1Sk5bf7QSCFPPboWWkZKYwBwDwbGhMVp5-QDd10,5956
28
+ dao_ai/tools/core.py,sha256=kN77fWOzVY7qOs4NiW72cUxCsSTC0DnPp73s6VJEZOQ,1991
29
+ dao_ai/tools/genie.py,sha256=hWDLLGUNz1wgwOb69pXnMiLJnMbG_1YmMdfVKt1Qe8o,6426
30
30
  dao_ai/tools/human_in_the_loop.py,sha256=yk35MO9eNETnYFH-sqlgR-G24TrEgXpJlnZUustsLkI,3681
31
31
  dao_ai/tools/mcp.py,sha256=5aQoRtx2z4xm6zgRslc78rSfEQe-mfhqov2NsiybYfc,8416
32
32
  dao_ai/tools/python.py,sha256=XcQiTMshZyLUTVR5peB3vqsoUoAAy8gol9_pcrhddfI,1831
33
33
  dao_ai/tools/slack.py,sha256=SCvyVcD9Pv_XXPXePE_fSU1Pd8VLTEkKDLvoGTZWy2Y,4775
34
34
  dao_ai/tools/time.py,sha256=Y-23qdnNHzwjvnfkWvYsE7PoWS1hfeKy44tA7sCnNac,8759
35
- dao_ai/tools/unity_catalog.py,sha256=uX_h52BuBAr4c9UeqSMI7DNz3BPRLeai5tBVW4sJqRI,13113
35
+ dao_ai/tools/unity_catalog.py,sha256=K9t8M4spsbxbecWmV5yEZy16s_AG7AfaoxT-7IDW43I,14438
36
36
  dao_ai/tools/vector_search.py,sha256=3cdiUaFpox25GSRNec7FKceY3DuLp7dLVH8FRA0BgeY,12624
37
- dao_ai-0.0.32.dist-info/METADATA,sha256=1_BlILYdzDHCILhIxFNeWdM6CRg4uKqBNPiP_hjbXtE,42763
38
- dao_ai-0.0.32.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
39
- dao_ai-0.0.32.dist-info/entry_points.txt,sha256=Xa-UFyc6gWGwMqMJOt06ZOog2vAfygV_DSwg1AiP46g,43
40
- dao_ai-0.0.32.dist-info/licenses/LICENSE,sha256=YZt3W32LtPYruuvHE9lGk2bw6ZPMMJD8yLrjgHybyz4,1069
41
- dao_ai-0.0.32.dist-info/RECORD,,
37
+ dao_ai-0.0.34.dist-info/METADATA,sha256=vq51NEV-pg7WTOD5z56jyOrC5_6Q-nUIL51RI5lL-Hg,42761
38
+ dao_ai-0.0.34.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
39
+ dao_ai-0.0.34.dist-info/entry_points.txt,sha256=Xa-UFyc6gWGwMqMJOt06ZOog2vAfygV_DSwg1AiP46g,43
40
+ dao_ai-0.0.34.dist-info/licenses/LICENSE,sha256=YZt3W32LtPYruuvHE9lGk2bw6ZPMMJD8yLrjgHybyz4,1069
41
+ dao_ai-0.0.34.dist-info/RECORD,,