dao-ai 0.0.33__py3-none-any.whl → 0.0.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dao_ai/config.py CHANGED
@@ -236,9 +236,21 @@ class Privilege(str, Enum):
236
236
 
237
237
  class PermissionModel(BaseModel):
238
238
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
239
- principals: list[str] = Field(default_factory=list)
239
+ principals: list[ServicePrincipalModel | str] = Field(default_factory=list)
240
240
  privileges: list[Privilege]
241
241
 
242
+ @model_validator(mode="after")
243
+ def resolve_principals(self) -> Self:
244
+ """Resolve ServicePrincipalModel objects to their client_id."""
245
+ resolved: list[str] = []
246
+ for principal in self.principals:
247
+ if isinstance(principal, ServicePrincipalModel):
248
+ resolved.append(value_of(principal.client_id))
249
+ else:
250
+ resolved.append(principal)
251
+ self.principals = resolved
252
+ return self
253
+
242
254
 
243
255
  class SchemaModel(BaseModel, HasFullName):
244
256
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
@@ -1581,9 +1593,21 @@ class Entitlement(str, Enum):
1581
1593
 
1582
1594
  class AppPermissionModel(BaseModel):
1583
1595
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1584
- principals: list[str] = Field(default_factory=list)
1596
+ principals: list[ServicePrincipalModel | str] = Field(default_factory=list)
1585
1597
  entitlements: list[Entitlement]
1586
1598
 
1599
+ @model_validator(mode="after")
1600
+ def resolve_principals(self) -> Self:
1601
+ """Resolve ServicePrincipalModel objects to their client_id."""
1602
+ resolved: list[str] = []
1603
+ for principal in self.principals:
1604
+ if isinstance(principal, ServicePrincipalModel):
1605
+ resolved.append(value_of(principal.client_id))
1606
+ else:
1607
+ resolved.append(principal)
1608
+ self.principals = resolved
1609
+ return self
1610
+
1587
1611
 
1588
1612
  class LogLevel(str, Enum):
1589
1613
  TRACE = "TRACE"
dao_ai/models.py CHANGED
@@ -331,13 +331,23 @@ class LanggraphResponsesAgent(ResponsesAgent):
331
331
  context: Context = self._convert_request_to_context(request)
332
332
  custom_inputs: dict[str, Any] = {"configurable": context.model_dump()}
333
333
 
334
+ # Build the graph input state, including genie_conversation_ids if provided
335
+ graph_input: dict[str, Any] = {"messages": messages}
336
+ if request.custom_inputs and "genie_conversation_ids" in request.custom_inputs:
337
+ graph_input["genie_conversation_ids"] = request.custom_inputs[
338
+ "genie_conversation_ids"
339
+ ]
340
+ logger.debug(
341
+ f"Including genie_conversation_ids in graph input: {graph_input['genie_conversation_ids']}"
342
+ )
343
+
334
344
  # Use async ainvoke internally for parallel execution
335
345
  import asyncio
336
346
 
337
347
  async def _async_invoke():
338
348
  try:
339
349
  return await self.graph.ainvoke(
340
- {"messages": messages}, context=context, config=custom_inputs
350
+ graph_input, context=context, config=custom_inputs
341
351
  )
342
352
  except Exception as e:
343
353
  logger.error(f"Error in graph.ainvoke: {e}")
@@ -399,6 +409,16 @@ class LanggraphResponsesAgent(ResponsesAgent):
399
409
  context: Context = self._convert_request_to_context(request)
400
410
  custom_inputs: dict[str, Any] = {"configurable": context.model_dump()}
401
411
 
412
+ # Build the graph input state, including genie_conversation_ids if provided
413
+ graph_input: dict[str, Any] = {"messages": messages}
414
+ if request.custom_inputs and "genie_conversation_ids" in request.custom_inputs:
415
+ graph_input["genie_conversation_ids"] = request.custom_inputs[
416
+ "genie_conversation_ids"
417
+ ]
418
+ logger.debug(
419
+ f"Including genie_conversation_ids in graph input: {graph_input['genie_conversation_ids']}"
420
+ )
421
+
402
422
  # Use async astream internally for parallel execution
403
423
  import asyncio
404
424
 
@@ -408,7 +428,7 @@ class LanggraphResponsesAgent(ResponsesAgent):
408
428
 
409
429
  try:
410
430
  async for nodes, stream_mode, messages_batch in self.graph.astream(
411
- {"messages": messages},
431
+ graph_input,
412
432
  context=context,
413
433
  config=custom_inputs,
414
434
  stream_mode=["messages", "custom"],
@@ -1151,7 +1151,7 @@ class DatabricksProvider(ServiceProvider):
1151
1151
  If an explicit version or alias is specified in the prompt_model, uses that directly.
1152
1152
  Otherwise, tries to load prompts in this order:
1153
1153
  1. champion alias
1154
- 2. latest version (max version number from search_prompt_versions)
1154
+ 2. latest alias
1155
1155
  3. default alias
1156
1156
  4. Register default_template if provided
1157
1157
 
@@ -1166,7 +1166,6 @@ class DatabricksProvider(ServiceProvider):
1166
1166
  """
1167
1167
 
1168
1168
  prompt_name: str = prompt_model.full_name
1169
- mlflow_client: MlflowClient = MlflowClient()
1170
1169
 
1171
1170
  # If explicit version or alias is specified, use it directly
1172
1171
  if prompt_model.version or prompt_model.alias:
@@ -1197,19 +1196,13 @@ class DatabricksProvider(ServiceProvider):
1197
1196
  except Exception as e:
1198
1197
  logger.debug(f"Champion alias not found for '{prompt_name}': {e}")
1199
1198
 
1200
- # 2. Try to get latest version by finding the max version number
1199
+ # 2. Try latest alias
1201
1200
  try:
1202
- versions = mlflow_client.search_prompt_versions(
1203
- prompt_name, max_results=100
1204
- )
1205
- if versions:
1206
- latest = max(versions, key=lambda v: int(v.version))
1207
- logger.info(
1208
- f"Loaded prompt '{prompt_name}' version {latest.version} (latest by max version)"
1209
- )
1210
- return latest
1201
+ prompt_version = load_prompt(f"prompts:/{prompt_name}@latest")
1202
+ logger.info(f"Loaded prompt '{prompt_name}' from latest alias")
1203
+ return prompt_version
1211
1204
  except Exception as e:
1212
- logger.debug(f"Failed to find latest version for '{prompt_name}': {e}")
1205
+ logger.debug(f"Latest alias not found for '{prompt_name}': {e}")
1213
1206
 
1214
1207
  # 3. Try default alias
1215
1208
  try:
@@ -1225,7 +1218,7 @@ class DatabricksProvider(ServiceProvider):
1225
1218
  f"No existing prompt found for '{prompt_name}', "
1226
1219
  "attempting to register default_template"
1227
1220
  )
1228
- return self._sync_default_template_to_registry(
1221
+ return self._register_default_template(
1229
1222
  prompt_name, prompt_model.default_template, prompt_model.description
1230
1223
  )
1231
1224
 
@@ -1235,49 +1228,17 @@ class DatabricksProvider(ServiceProvider):
1235
1228
  "and no default_template provided"
1236
1229
  )
1237
1230
 
1238
- def _sync_default_template_to_registry(
1231
+ def _register_default_template(
1239
1232
  self, prompt_name: str, default_template: str, description: str | None = None
1240
1233
  ) -> PromptVersion:
1241
- """Get the best available prompt version, or register default_template if possible.
1242
-
1243
- Tries to load prompts in order: champion → latest (max version) → default.
1244
- If none found and we have write permissions, registers the default_template.
1245
- If registration fails (e.g., in Model Serving), logs the error and raises.
1246
- """
1247
- mlflow_client: MlflowClient = MlflowClient()
1248
-
1249
- # Try to find an existing prompt version in priority order
1250
- # 1. Try champion alias
1251
- try:
1252
- champion = mlflow.genai.load_prompt(f"prompts:/{prompt_name}@champion")
1253
- logger.info(f"Loaded prompt '{prompt_name}' from champion alias")
1254
- return champion
1255
- except Exception as e:
1256
- logger.debug(f"Champion alias not found for '{prompt_name}': {e}")
1234
+ """Register default_template as a new prompt version.
1257
1235
 
1258
- # 2. Try to get the latest version by finding the max version number
1259
- try:
1260
- versions = mlflow_client.search_prompt_versions(
1261
- prompt_name, max_results=100
1262
- )
1263
- if versions:
1264
- latest = max(versions, key=lambda v: int(v.version))
1265
- logger.info(
1266
- f"Loaded prompt '{prompt_name}' version {latest.version} (latest by max version)"
1267
- )
1268
- return latest
1269
- except Exception as e:
1270
- logger.debug(f"Failed to search versions for '{prompt_name}': {e}")
1271
-
1272
- # 3. Try default alias
1273
- try:
1274
- default = mlflow.genai.load_prompt(f"prompts:/{prompt_name}@default")
1275
- logger.info(f"Loaded prompt '{prompt_name}' from default alias")
1276
- return default
1277
- except Exception as e:
1278
- logger.debug(f"Default alias not found for '{prompt_name}': {e}")
1236
+ Called when no existing prompt version is found (champion, latest, default all failed).
1237
+ Registers the template and sets both 'default' and 'champion' aliases.
1279
1238
 
1280
- # No existing prompt found - try to register if we have a template
1239
+ If registration fails (e.g., in Model Serving with restricted permissions),
1240
+ logs the error and raises.
1241
+ """
1281
1242
  logger.info(
1282
1243
  f"No existing prompt found for '{prompt_name}', attempting to register default_template"
1283
1244
  )
dao_ai/tools/genie.py CHANGED
@@ -5,8 +5,9 @@ from typing import Annotated, Any, Callable
5
5
 
6
6
  import pandas as pd
7
7
  from databricks_ai_bridge.genie import Genie, GenieResponse
8
+ from langchain.tools import tool
8
9
  from langchain_core.messages import ToolMessage
9
- from langchain_core.tools import InjectedToolCallId, tool
10
+ from langchain_core.tools import InjectedToolCallId
10
11
  from langgraph.prebuilt import InjectedState
11
12
  from langgraph.types import Command
12
13
  from loguru import logger
@@ -43,7 +44,7 @@ def create_genie_tool(
43
44
  genie_room: GenieRoomModel | dict[str, Any],
44
45
  name: str | None = None,
45
46
  description: str | None = None,
46
- persist_conversation: bool = False,
47
+ persist_conversation: bool = True,
47
48
  truncate_results: bool = False,
48
49
  ) -> Callable[..., Command]:
49
50
  """
@@ -64,6 +65,16 @@ def create_genie_tool(
64
65
  Returns:
65
66
  A LangGraph tool that processes natural language queries through Genie
66
67
  """
68
+ logger.debug("create_genie_tool")
69
+ logger.debug(f"genie_room type: {type(genie_room)}")
70
+ logger.debug(f"genie_room: {genie_room}")
71
+ logger.debug(f"persist_conversation: {persist_conversation}")
72
+ logger.debug(f"truncate_results: {truncate_results}")
73
+ logger.debug(f"name: {name}")
74
+ logger.debug(f"description: {description}")
75
+ logger.debug(f"genie_room: {genie_room}")
76
+ logger.debug(f"persist_conversation: {persist_conversation}")
77
+ logger.debug(f"truncate_results: {truncate_results}")
67
78
 
68
79
  if isinstance(genie_room, dict):
69
80
  genie_room = GenieRoomModel(**genie_room)
@@ -106,14 +117,13 @@ GenieResponse: A response object containing the conversation ID and result from
106
117
  state: Annotated[dict, InjectedState],
107
118
  tool_call_id: Annotated[str, InjectedToolCallId],
108
119
  ) -> Command:
109
- """Process a natural language question through Databricks Genie."""
110
- # Create Genie instance using databricks_langchain implementation
111
120
  genie: Genie = Genie(
112
121
  space_id=space_id,
113
122
  client=genie_room.workspace_client,
114
123
  truncate_results=truncate_results,
115
124
  )
116
125
 
126
+ """Process a natural language question through Databricks Genie."""
117
127
  # Get existing conversation mapping and retrieve conversation ID for this space
118
128
  conversation_ids: dict[str, str] = state.get("genie_conversation_ids", {})
119
129
  existing_conversation_id: str | None = conversation_ids.get(space_id)
@@ -131,6 +141,7 @@ GenieResponse: A response object containing the conversation ID and result from
131
141
  )
132
142
 
133
143
  # Update the conversation mapping with the new conversation ID for this space
144
+
134
145
  update: dict[str, Any] = {
135
146
  "messages": [
136
147
  ToolMessage(_response_to_json(response), tool_call_id=tool_call_id)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dao-ai
3
- Version: 0.0.33
3
+ Version: 0.0.34
4
4
  Summary: DAO AI: A modular, multi-agent orchestration framework for complex AI workflows. Supports agent handoff, tool integration, and dynamic configuration via YAML.
5
5
  Project-URL: Homepage, https://github.com/natefleming/dao-ai
6
6
  Project-URL: Documentation, https://natefleming.github.io/dao-ai
@@ -3,11 +3,11 @@ dao_ai/agent_as_code.py,sha256=sviZQV7ZPxE5zkZ9jAbfegI681nra5i8yYxw05e3X7U,552
3
3
  dao_ai/catalog.py,sha256=sPZpHTD3lPx4EZUtIWeQV7VQM89WJ6YH__wluk1v2lE,4947
4
4
  dao_ai/chat_models.py,sha256=uhwwOTeLyHWqoTTgHrs4n5iSyTwe4EQcLKnh3jRxPWI,8626
5
5
  dao_ai/cli.py,sha256=gq-nsapWxDA1M6Jua3vajBvIwf0Oa6YLcB58lEtMKUo,22503
6
- dao_ai/config.py,sha256=Uj0FgOhjnYp0qEmY44mCnp3Ijafg-381FNXt8R_QuWw,78513
6
+ dao_ai/config.py,sha256=Jzb0ePrt2TM2WuXI_LtmTafbseKBlJ8J8J2ExyBowbM,79491
7
7
  dao_ai/graph.py,sha256=9kjJx0oFZKq5J9-Kpri4-0VCJILHYdYyhqQnj0_noxQ,8913
8
8
  dao_ai/guardrails.py,sha256=4TKArDONRy8RwHzOT1plZ1rhy3x9GF_aeGpPCRl6wYA,4016
9
9
  dao_ai/messages.py,sha256=xl_3-WcFqZKCFCiov8sZOPljTdM3gX3fCHhxq-xFg2U,7005
10
- dao_ai/models.py,sha256=8r8GIG3EGxtVyWsRNI56lVaBjiNrPkzh4HdwMZRq8iw,31689
10
+ dao_ai/models.py,sha256=hvEZO2N0EC2sQoMgjJ9mbKmDWcdxnnAb2NqzpXh4Wgk,32691
11
11
  dao_ai/nodes.py,sha256=iQ_5vL6mt1UcRnhwgz-l1D8Ww4CMQrSMVnP_Lu7fFjU,8781
12
12
  dao_ai/prompts.py,sha256=iA2Iaky7yzjwWT5cxg0cUIgwo1z1UVQua__8WPnvV6g,1633
13
13
  dao_ai/state.py,sha256=_lF9krAYYjvFDMUwZzVKOn0ZnXKcOrbjWKdre0C5B54,1137
@@ -22,11 +22,11 @@ dao_ai/memory/core.py,sha256=DnEjQO3S7hXr3CDDd7C2eE7fQUmcCS_8q9BXEgjPH3U,4271
22
22
  dao_ai/memory/postgres.py,sha256=vvI3osjx1EoU5GBA6SCUstTBKillcmLl12hVgDMjfJY,15346
23
23
  dao_ai/providers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
24
24
  dao_ai/providers/base.py,sha256=-fjKypCOk28h6vioPfMj9YZSw_3Kcbi2nMuAyY7vX9k,1383
25
- dao_ai/providers/databricks.py,sha256=rPBMdGcJvdGBRK9FZeBxkLfcTpXyxU1cs14YllyZKbY,67857
25
+ dao_ai/providers/databricks.py,sha256=WEigNPGRTlIPVjwp97My8o1zOHn5ftuMsMrpqrBeaLg,66012
26
26
  dao_ai/tools/__init__.py,sha256=G5-5Yi6zpQOH53b5IzLdtsC6g0Ep6leI5GxgxOmgw7Q,1203
27
27
  dao_ai/tools/agent.py,sha256=WbQnyziiT12TLMrA7xK0VuOU029tdmUBXbUl-R1VZ0Q,1886
28
28
  dao_ai/tools/core.py,sha256=kN77fWOzVY7qOs4NiW72cUxCsSTC0DnPp73s6VJEZOQ,1991
29
- dao_ai/tools/genie.py,sha256=BPM_1Sk5bf7QSCFPPboWWkZKYwBwDwbGhMVp5-QDd10,5956
29
+ dao_ai/tools/genie.py,sha256=hWDLLGUNz1wgwOb69pXnMiLJnMbG_1YmMdfVKt1Qe8o,6426
30
30
  dao_ai/tools/human_in_the_loop.py,sha256=yk35MO9eNETnYFH-sqlgR-G24TrEgXpJlnZUustsLkI,3681
31
31
  dao_ai/tools/mcp.py,sha256=5aQoRtx2z4xm6zgRslc78rSfEQe-mfhqov2NsiybYfc,8416
32
32
  dao_ai/tools/python.py,sha256=XcQiTMshZyLUTVR5peB3vqsoUoAAy8gol9_pcrhddfI,1831
@@ -34,8 +34,8 @@ dao_ai/tools/slack.py,sha256=SCvyVcD9Pv_XXPXePE_fSU1Pd8VLTEkKDLvoGTZWy2Y,4775
34
34
  dao_ai/tools/time.py,sha256=Y-23qdnNHzwjvnfkWvYsE7PoWS1hfeKy44tA7sCnNac,8759
35
35
  dao_ai/tools/unity_catalog.py,sha256=K9t8M4spsbxbecWmV5yEZy16s_AG7AfaoxT-7IDW43I,14438
36
36
  dao_ai/tools/vector_search.py,sha256=3cdiUaFpox25GSRNec7FKceY3DuLp7dLVH8FRA0BgeY,12624
37
- dao_ai-0.0.33.dist-info/METADATA,sha256=aa4BvkiG1dEvLorpgADosf1LCKRVBg-n8LtReVYJNxc,42761
38
- dao_ai-0.0.33.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
39
- dao_ai-0.0.33.dist-info/entry_points.txt,sha256=Xa-UFyc6gWGwMqMJOt06ZOog2vAfygV_DSwg1AiP46g,43
40
- dao_ai-0.0.33.dist-info/licenses/LICENSE,sha256=YZt3W32LtPYruuvHE9lGk2bw6ZPMMJD8yLrjgHybyz4,1069
41
- dao_ai-0.0.33.dist-info/RECORD,,
37
+ dao_ai-0.0.34.dist-info/METADATA,sha256=vq51NEV-pg7WTOD5z56jyOrC5_6Q-nUIL51RI5lL-Hg,42761
38
+ dao_ai-0.0.34.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
39
+ dao_ai-0.0.34.dist-info/entry_points.txt,sha256=Xa-UFyc6gWGwMqMJOt06ZOog2vAfygV_DSwg1AiP46g,43
40
+ dao_ai-0.0.34.dist-info/licenses/LICENSE,sha256=YZt3W32LtPYruuvHE9lGk2bw6ZPMMJD8yLrjgHybyz4,1069
41
+ dao_ai-0.0.34.dist-info/RECORD,,