dao-ai 0.0.20__py3-none-any.whl → 0.0.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dao_ai/config.py CHANGED
@@ -23,6 +23,7 @@ from databricks.sdk.credentials_provider import (
23
23
  ModelServingUserCredentials,
24
24
  )
25
25
  from databricks.sdk.service.catalog import FunctionInfo, TableInfo
26
+ from databricks.sdk.service.database import DatabaseInstance
26
27
  from databricks.vector_search.client import VectorSearchClient
27
28
  from databricks.vector_search.index import VectorSearchIndex
28
29
  from databricks_langchain import (
@@ -427,7 +428,8 @@ class GenieRoomModel(BaseModel, IsDatabricksResource):
427
428
  def as_resources(self) -> Sequence[DatabricksResource]:
428
429
  return [
429
430
  DatabricksGenieSpace(
430
- genie_space_id=self.space_id, on_behalf_of_user=self.on_behalf_of_user
431
+ genie_space_id=value_of(self.space_id),
432
+ on_behalf_of_user=self.on_behalf_of_user,
431
433
  )
432
434
  ]
433
435
 
@@ -437,7 +439,7 @@ class GenieRoomModel(BaseModel, IsDatabricksResource):
437
439
  return self
438
440
 
439
441
 
440
- class VolumeModel(BaseModel, HasFullName):
442
+ class VolumeModel(BaseModel, HasFullName, IsDatabricksResource):
441
443
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
442
444
  schema_model: Optional[SchemaModel] = Field(default=None, alias="schema")
443
445
  name: str
@@ -455,6 +457,13 @@ class VolumeModel(BaseModel, HasFullName):
455
457
  provider: ServiceProvider = DatabricksProvider(w=w)
456
458
  provider.create_volume(self)
457
459
 
460
+ @property
461
+ def api_scopes(self) -> Sequence[str]:
462
+ return ["files.files", "catalog.volumes"]
463
+
464
+ def as_resources(self) -> Sequence[DatabricksResource]:
465
+ return []
466
+
458
467
 
459
468
  class VolumePathModel(BaseModel, HasFullName):
460
469
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
@@ -683,7 +692,8 @@ class WarehouseModel(BaseModel, IsDatabricksResource):
683
692
  def as_resources(self) -> Sequence[DatabricksResource]:
684
693
  return [
685
694
  DatabricksSQLWarehouse(
686
- warehouse_id=self.warehouse_id, on_behalf_of_user=self.on_behalf_of_user
695
+ warehouse_id=value_of(self.warehouse_id),
696
+ on_behalf_of_user=self.on_behalf_of_user,
687
697
  )
688
698
  ]
689
699
 
@@ -694,15 +704,18 @@ class WarehouseModel(BaseModel, IsDatabricksResource):
694
704
 
695
705
 
696
706
  class DatabaseModel(BaseModel, IsDatabricksResource):
697
- model_config = ConfigDict(frozen=True)
707
+ model_config = ConfigDict(use_enum_values=True, extra="forbid")
698
708
  name: str
709
+ instance_name: Optional[str] = None
699
710
  description: Optional[str] = None
700
- host: Optional[AnyVariable]
711
+ host: Optional[AnyVariable] = None
701
712
  database: Optional[AnyVariable] = "databricks_postgres"
702
713
  port: Optional[AnyVariable] = 5432
703
714
  connection_kwargs: Optional[dict[str, Any]] = Field(default_factory=dict)
704
715
  max_pool_size: Optional[int] = 10
705
- timeout_seconds: Optional[int] = 5
716
+ timeout_seconds: Optional[int] = 10
717
+ capacity: Optional[Literal["CU_1", "CU_2"]] = "CU_2"
718
+ node_count: Optional[int] = None
706
719
  user: Optional[AnyVariable] = None
707
720
  password: Optional[AnyVariable] = None
708
721
  client_id: Optional[AnyVariable] = None
@@ -716,11 +729,44 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
716
729
  def as_resources(self) -> Sequence[DatabricksResource]:
717
730
  return [
718
731
  DatabricksLakebase(
719
- database_instance_name=self.name,
732
+ database_instance_name=self.instance_name,
720
733
  on_behalf_of_user=self.on_behalf_of_user,
721
734
  )
722
735
  ]
723
736
 
737
+ @model_validator(mode="after")
738
+ def update_instance_name(self):
739
+ if self.instance_name is None:
740
+ self.instance_name = self.name
741
+
742
+ return self
743
+
744
+ @model_validator(mode="after")
745
+ def update_user(self):
746
+ if self.client_id or self.user:
747
+ return self
748
+
749
+ self.user = self.workspace_client.current_user.me().user_name
750
+ if not self.user:
751
+ raise ValueError(
752
+ "Unable to determine current user. Please provide a user name or OAuth credentials."
753
+ )
754
+
755
+ return self
756
+
757
+ @model_validator(mode="after")
758
+ def update_host(self):
759
+ if self.host is not None:
760
+ return self
761
+
762
+ existing_instance: DatabaseInstance = (
763
+ self.workspace_client.database.get_database_instance(
764
+ name=self.instance_name
765
+ )
766
+ )
767
+ self.host = existing_instance.read_write_dns
768
+ return self
769
+
724
770
  @model_validator(mode="after")
725
771
  def validate_auth_methods(self):
726
772
  oauth_fields: Sequence[Any] = [
@@ -730,7 +776,7 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
730
776
  ]
731
777
  has_oauth: bool = all(field is not None for field in oauth_fields)
732
778
 
733
- pat_fields: Sequence[Any] = [self.user, self.password]
779
+ pat_fields: Sequence[Any] = [self.user]
734
780
  has_user_auth: bool = all(field is not None for field in pat_fields)
735
781
 
736
782
  if has_oauth and has_user_auth:
@@ -749,7 +795,14 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
749
795
  return self
750
796
 
751
797
  @property
752
- def connection_url(self) -> str:
798
+ def connection_params(self) -> dict[str, Any]:
799
+ """
800
+ Get database connection parameters as a dictionary.
801
+
802
+ Returns a dict with connection parameters suitable for psycopg ConnectionPool.
803
+ If username is configured, it will be included; otherwise it will be omitted
804
+ to allow Lakebase to authenticate using the token's identity.
805
+ """
753
806
  from dao_ai.providers.base import ServiceProvider
754
807
  from dao_ai.providers.databricks import DatabricksProvider
755
808
 
@@ -757,7 +810,7 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
757
810
 
758
811
  if self.client_id and self.client_secret and self.workspace_host:
759
812
  username = value_of(self.client_id)
760
- else:
813
+ elif self.user:
761
814
  username = value_of(self.user)
762
815
 
763
816
  host: str = value_of(self.host)
@@ -770,11 +823,48 @@ class DatabaseModel(BaseModel, IsDatabricksResource):
770
823
  workspace_host=value_of(self.workspace_host),
771
824
  pat=value_of(self.password),
772
825
  )
773
- token: str = provider.create_token()
774
826
 
775
- return (
776
- f"postgresql://{username}:{token}@{host}:{port}/{database}?sslmode=require"
777
- )
827
+ token: str = provider.lakebase_password_provider(self.instance_name)
828
+
829
+ # Build connection parameters dictionary
830
+ params: dict[str, Any] = {
831
+ "dbname": database,
832
+ "host": host,
833
+ "port": port,
834
+ "password": token,
835
+ "sslmode": "require",
836
+ }
837
+
838
+ # Only include user if explicitly configured
839
+ if username:
840
+ params["user"] = username
841
+ logger.debug(
842
+ f"Connection params: dbname={database} user={username} host={host} port={port} password=******** sslmode=require"
843
+ )
844
+ else:
845
+ logger.debug(
846
+ f"Connection params: dbname={database} host={host} port={port} password=******** sslmode=require (using token identity)"
847
+ )
848
+
849
+ return params
850
+
851
+ @property
852
+ def connection_url(self) -> str:
853
+ """
854
+ Get database connection URL as a string (for backwards compatibility).
855
+
856
+ Note: It's recommended to use connection_params instead for better flexibility.
857
+ """
858
+ params = self.connection_params
859
+ parts = [f"{k}={v}" for k, v in params.items()]
860
+ return " ".join(parts)
861
+
862
+ def create(self, w: WorkspaceClient | None = None) -> None:
863
+ from dao_ai.providers.databricks import DatabricksProvider
864
+
865
+ provider: DatabricksProvider = DatabricksProvider()
866
+ provider.create_lakebase(self)
867
+ provider.create_lakebase_instance_role(self)
778
868
 
779
869
 
780
870
  class SearchParametersModel(BaseModel):
@@ -867,7 +957,7 @@ class PythonFunctionModel(BaseFunctionModel, HasFullName):
867
957
 
868
958
  class FactoryFunctionModel(BaseFunctionModel, HasFullName):
869
959
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
870
- args: Optional[dict[str, AnyVariable]] = Field(default_factory=dict)
960
+ args: Optional[dict[str, Any]] = Field(default_factory=dict)
871
961
  type: Literal[FunctionType.FACTORY] = FunctionType.FACTORY
872
962
 
873
963
  @property
@@ -879,6 +969,12 @@ class FactoryFunctionModel(BaseFunctionModel, HasFullName):
879
969
 
880
970
  return [create_factory_tool(self, **kwargs)]
881
971
 
972
+ @model_validator(mode="after")
973
+ def update_args(self):
974
+ for key, value in self.args.items():
975
+ self.args[key] = value_of(value)
976
+ return self
977
+
882
978
 
883
979
  class TransportType(str, Enum):
884
980
  STREAMABLE_HTTP = "streamable_http"
@@ -1078,6 +1174,7 @@ class AgentModel(BaseModel):
1078
1174
  class SupervisorModel(BaseModel):
1079
1175
  model_config = ConfigDict(use_enum_values=True, extra="forbid")
1080
1176
  model: LLMModel
1177
+ tools: list[ToolModel] = Field(default_factory=list)
1081
1178
  prompt: Optional[str] = None
1082
1179
 
1083
1180
 
@@ -1262,12 +1359,12 @@ class AppModel(BaseModel):
1262
1359
  if len(self.agents) > 1:
1263
1360
  default_agent: AgentModel = self.agents[0]
1264
1361
  self.orchestration = OrchestrationModel(
1265
- swarm=SupervisorModel(model=default_agent.model)
1362
+ supervisor=SupervisorModel(model=default_agent.model)
1266
1363
  )
1267
1364
  elif len(self.agents) == 1:
1268
1365
  default_agent: AgentModel = self.agents[0]
1269
1366
  self.orchestration = OrchestrationModel(
1270
- supervisor=SwarmModel(
1367
+ swarm=SwarmModel(
1271
1368
  model=default_agent.model, default_agent=default_agent
1272
1369
  )
1273
1370
  )
dao_ai/graph.py CHANGED
@@ -27,6 +27,7 @@ from dao_ai.nodes import (
27
27
  )
28
28
  from dao_ai.prompts import make_prompt
29
29
  from dao_ai.state import Context, IncomingState, OutgoingState, SharedState
30
+ from dao_ai.tools import create_tools
30
31
 
31
32
 
32
33
  def route_message(state: SharedState) -> str:
@@ -91,6 +92,8 @@ def _create_supervisor_graph(config: AppConfig) -> CompiledStateGraph:
91
92
  orchestration: OrchestrationModel = config.app.orchestration
92
93
  supervisor: SupervisorModel = orchestration.supervisor
93
94
 
95
+ tools += create_tools(orchestration.supervisor.tools)
96
+
94
97
  store: BaseStore = None
95
98
  if orchestration.memory and orchestration.memory.store:
96
99
  store = orchestration.memory.store.as_store()
dao_ai/memory/core.py CHANGED
@@ -70,10 +70,14 @@ class StoreManager:
70
70
  case StorageType.POSTGRES:
71
71
  from dao_ai.memory.postgres import PostgresStoreManager
72
72
 
73
- store_manager = cls.store_managers.get(store_model.database.name)
73
+ store_manager = cls.store_managers.get(
74
+ store_model.database.instance_name
75
+ )
74
76
  if store_manager is None:
75
77
  store_manager = PostgresStoreManager(store_model)
76
- cls.store_managers[store_model.database.name] = store_manager
78
+ cls.store_managers[store_model.database.instance_name] = (
79
+ store_manager
80
+ )
77
81
  case _:
78
82
  raise ValueError(f"Unknown store type: {store_model.type}")
79
83
 
@@ -102,15 +106,15 @@ class CheckpointManager:
102
106
  from dao_ai.memory.postgres import AsyncPostgresCheckpointerManager
103
107
 
104
108
  checkpointer_manager = cls.checkpoint_managers.get(
105
- checkpointer_model.database.name
109
+ checkpointer_model.database.instance_name
106
110
  )
107
111
  if checkpointer_manager is None:
108
112
  checkpointer_manager = AsyncPostgresCheckpointerManager(
109
113
  checkpointer_model
110
114
  )
111
- cls.checkpoint_managers[checkpointer_model.database.name] = (
112
- checkpointer_manager
113
- )
115
+ cls.checkpoint_managers[
116
+ checkpointer_model.database.instance_name
117
+ ] = checkpointer_manager
114
118
  case _:
115
119
  raise ValueError(f"Unknown store type: {checkpointer_model.type}")
116
120
 
dao_ai/memory/postgres.py CHANGED
@@ -20,6 +20,59 @@ from dao_ai.memory.base import (
20
20
  )
21
21
 
22
22
 
23
+ def _create_pool(
24
+ connection_params: dict[str, Any],
25
+ database_name: str,
26
+ max_pool_size: int,
27
+ timeout_seconds: int,
28
+ kwargs: dict,
29
+ ) -> ConnectionPool:
30
+ """Create a connection pool using the provided connection parameters."""
31
+ logger.debug(
32
+ f"Connection params for {database_name}: {', '.join(k + '=' + (str(v) if k != 'password' else '***') for k, v in connection_params.items())}"
33
+ )
34
+
35
+ # Merge connection_params into kwargs for psycopg
36
+ connection_kwargs = kwargs | connection_params
37
+ pool = ConnectionPool(
38
+ conninfo="", # Empty conninfo, params come from kwargs
39
+ min_size=1,
40
+ max_size=max_pool_size,
41
+ open=False,
42
+ timeout=timeout_seconds,
43
+ kwargs=connection_kwargs,
44
+ )
45
+ pool.open(wait=True, timeout=timeout_seconds)
46
+ logger.info(f"Successfully connected to {database_name}")
47
+ return pool
48
+
49
+
50
+ async def _create_async_pool(
51
+ connection_params: dict[str, Any],
52
+ database_name: str,
53
+ max_pool_size: int,
54
+ timeout_seconds: int,
55
+ kwargs: dict,
56
+ ) -> AsyncConnectionPool:
57
+ """Create an async connection pool using the provided connection parameters."""
58
+ logger.debug(
59
+ f"Connection params for {database_name}: {', '.join(k + '=' + (str(v) if k != 'password' else '***') for k, v in connection_params.items())}"
60
+ )
61
+
62
+ # Merge connection_params into kwargs for psycopg
63
+ connection_kwargs = kwargs | connection_params
64
+ pool = AsyncConnectionPool(
65
+ conninfo="", # Empty conninfo, params come from kwargs
66
+ max_size=max_pool_size,
67
+ open=False,
68
+ timeout=timeout_seconds,
69
+ kwargs=connection_kwargs,
70
+ )
71
+ await pool.open(wait=True, timeout=timeout_seconds)
72
+ logger.info(f"Successfully connected to {database_name}")
73
+ return pool
74
+
75
+
23
76
  class AsyncPostgresPoolManager:
24
77
  _pools: dict[str, AsyncConnectionPool] = {}
25
78
  _lock: asyncio.Lock = asyncio.Lock()
@@ -27,7 +80,7 @@ class AsyncPostgresPoolManager:
27
80
  @classmethod
28
81
  async def get_pool(cls, database: DatabaseModel) -> AsyncConnectionPool:
29
82
  connection_key: str = database.name
30
- connection_url: str = database.connection_url
83
+ connection_params: dict[str, Any] = database.connection_params
31
84
 
32
85
  async with cls._lock:
33
86
  if connection_key in cls._pools:
@@ -41,23 +94,17 @@ class AsyncPostgresPoolManager:
41
94
  "autocommit": True,
42
95
  } | database.connection_kwargs or {}
43
96
 
44
- pool: AsyncConnectionPool = AsyncConnectionPool(
45
- conninfo=connection_url,
46
- max_size=database.max_pool_size,
47
- open=False,
48
- timeout=database.timeout_seconds,
97
+ # Create connection pool
98
+ pool: AsyncConnectionPool = await _create_async_pool(
99
+ connection_params=connection_params,
100
+ database_name=database.name,
101
+ max_pool_size=database.max_pool_size,
102
+ timeout_seconds=database.timeout_seconds,
49
103
  kwargs=kwargs,
50
104
  )
51
105
 
52
- try:
53
- await pool.open(wait=True, timeout=database.timeout_seconds)
54
- cls._pools[connection_key] = pool
55
- return pool
56
- except Exception as e:
57
- logger.error(
58
- f"Failed to create PostgreSQL pool for {database.name}: {e}"
59
- )
60
- raise e
106
+ cls._pools[connection_key] = pool
107
+ return pool
61
108
 
62
109
  @classmethod
63
110
  async def close_pool(cls, database: DatabaseModel):
@@ -74,8 +121,17 @@ class AsyncPostgresPoolManager:
74
121
  async with cls._lock:
75
122
  for connection_key, pool in cls._pools.items():
76
123
  try:
77
- await pool.close()
124
+ # Use a short timeout to avoid blocking on pool closure
125
+ await asyncio.wait_for(pool.close(), timeout=2.0)
78
126
  logger.debug(f"Closed PostgreSQL pool: {connection_key}")
127
+ except asyncio.TimeoutError:
128
+ logger.warning(
129
+ f"Timeout closing pool {connection_key}, forcing closure"
130
+ )
131
+ except asyncio.CancelledError:
132
+ logger.warning(
133
+ f"Pool closure cancelled for {connection_key} (shutdown in progress)"
134
+ )
79
135
  except Exception as e:
80
136
  logger.error(f"Error closing pool {connection_key}: {e}")
81
137
  cls._pools.clear()
@@ -209,7 +265,7 @@ class PostgresPoolManager:
209
265
  @classmethod
210
266
  def get_pool(cls, database: DatabaseModel) -> ConnectionPool:
211
267
  connection_key: str = str(database.name)
212
- connection_url: str = database.connection_url
268
+ connection_params: dict[str, Any] = database.connection_params
213
269
 
214
270
  with cls._lock:
215
271
  if connection_key in cls._pools:
@@ -223,23 +279,17 @@ class PostgresPoolManager:
223
279
  "autocommit": True,
224
280
  } | database.connection_kwargs or {}
225
281
 
226
- pool: ConnectionPool = ConnectionPool(
227
- conninfo=connection_url,
228
- max_size=database.max_pool_size,
229
- open=False,
230
- timeout=database.timeout_seconds,
282
+ # Create connection pool
283
+ pool: ConnectionPool = _create_pool(
284
+ connection_params=connection_params,
285
+ database_name=database.name,
286
+ max_pool_size=database.max_pool_size,
287
+ timeout_seconds=database.timeout_seconds,
231
288
  kwargs=kwargs,
232
289
  )
233
290
 
234
- try:
235
- pool.open(wait=True, timeout=database.timeout_seconds)
236
- cls._pools[connection_key] = pool
237
- return pool
238
- except Exception as e:
239
- logger.error(
240
- f"Failed to create PostgreSQL pool for {database.name}: {e}"
241
- )
242
- raise e
291
+ cls._pools[connection_key] = pool
292
+ return pool
243
293
 
244
294
  @classmethod
245
295
  def close_pool(cls, database: DatabaseModel):
@@ -369,8 +419,27 @@ def _shutdown_pools():
369
419
 
370
420
  def _shutdown_async_pools():
371
421
  try:
372
- asyncio.run(AsyncPostgresPoolManager.close_all_pools())
373
- logger.debug("Successfully closed all asynchronous PostgreSQL pools")
422
+ # Try to get the current event loop first
423
+ try:
424
+ loop = asyncio.get_running_loop()
425
+ # If we're already in an event loop, create a task
426
+ loop.create_task(AsyncPostgresPoolManager.close_all_pools())
427
+ logger.debug("Scheduled async pool closure in running event loop")
428
+ except RuntimeError:
429
+ # No running loop, try to get or create one
430
+ try:
431
+ loop = asyncio.get_event_loop()
432
+ if loop.is_closed():
433
+ # Loop is closed, create a new one
434
+ loop = asyncio.new_event_loop()
435
+ asyncio.set_event_loop(loop)
436
+ loop.run_until_complete(AsyncPostgresPoolManager.close_all_pools())
437
+ logger.debug("Successfully closed all asynchronous PostgreSQL pools")
438
+ except Exception as inner_e:
439
+ # If all else fails, just log the error
440
+ logger.warning(
441
+ f"Could not close async pools cleanly during shutdown: {inner_e}"
442
+ )
374
443
  except Exception as e:
375
444
  logger.error(
376
445
  f"Error closing asynchronous PostgreSQL pools during shutdown: {e}"
dao_ai/models.py CHANGED
@@ -5,6 +5,7 @@ from typing import Any, Generator, Optional, Sequence, Union
5
5
 
6
6
  from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
7
7
  from langgraph.graph.state import CompiledStateGraph
8
+ from langgraph.types import StateSnapshot
8
9
  from loguru import logger
9
10
  from mlflow import MlflowClient
10
11
  from mlflow.pyfunc import ChatAgent, ChatModel, ResponsesAgent
@@ -59,6 +60,113 @@ def get_latest_model_version(model_name: str) -> int:
59
60
  return latest_version
60
61
 
61
62
 
63
+ async def get_state_snapshot_async(
64
+ graph: CompiledStateGraph, thread_id: str
65
+ ) -> Optional[StateSnapshot]:
66
+ """
67
+ Retrieve the state snapshot from the graph for a given thread_id asynchronously.
68
+
69
+ This utility function accesses the graph's checkpointer to retrieve the current
70
+ state snapshot, which contains the full state values and metadata.
71
+
72
+ Args:
73
+ graph: The compiled LangGraph state machine
74
+ thread_id: The thread/conversation ID to retrieve state for
75
+
76
+ Returns:
77
+ StateSnapshot if found, None otherwise
78
+ """
79
+ logger.debug(f"Retrieving state snapshot for thread_id: {thread_id}")
80
+ try:
81
+ # Check if graph has a checkpointer
82
+ if graph.checkpointer is None:
83
+ logger.debug("No checkpointer available in graph")
84
+ return None
85
+
86
+ # Get the current state from the checkpointer (use async version)
87
+ config: dict[str, Any] = {"configurable": {"thread_id": thread_id}}
88
+ state_snapshot: Optional[StateSnapshot] = await graph.aget_state(config)
89
+
90
+ if state_snapshot is None:
91
+ logger.debug(f"No state found for thread_id: {thread_id}")
92
+ return None
93
+
94
+ return state_snapshot
95
+
96
+ except Exception as e:
97
+ logger.warning(f"Error retrieving state snapshot for thread {thread_id}: {e}")
98
+ return None
99
+
100
+
101
+ def get_state_snapshot(
102
+ graph: CompiledStateGraph, thread_id: str
103
+ ) -> Optional[StateSnapshot]:
104
+ """
105
+ Retrieve the state snapshot from the graph for a given thread_id.
106
+
107
+ This is a synchronous wrapper around get_state_snapshot_async.
108
+ Use this for backward compatibility in synchronous contexts.
109
+
110
+ Args:
111
+ graph: The compiled LangGraph state machine
112
+ thread_id: The thread/conversation ID to retrieve state for
113
+
114
+ Returns:
115
+ StateSnapshot if found, None otherwise
116
+ """
117
+ import asyncio
118
+
119
+ try:
120
+ loop = asyncio.get_event_loop()
121
+ except RuntimeError:
122
+ loop = asyncio.new_event_loop()
123
+ asyncio.set_event_loop(loop)
124
+
125
+ try:
126
+ return loop.run_until_complete(get_state_snapshot_async(graph, thread_id))
127
+ except Exception as e:
128
+ logger.warning(f"Error in synchronous state snapshot retrieval: {e}")
129
+ return None
130
+
131
+
132
+ def get_genie_conversation_ids_from_state(
133
+ state_snapshot: Optional[StateSnapshot],
134
+ ) -> dict[str, str]:
135
+ """
136
+ Extract genie_conversation_ids from a state snapshot.
137
+
138
+ This function extracts the genie_conversation_ids dictionary from the state
139
+ snapshot values if present.
140
+
141
+ Args:
142
+ state_snapshot: The state snapshot to extract conversation IDs from
143
+
144
+ Returns:
145
+ A dictionary mapping genie space_id to conversation_id, or empty dict if not found
146
+ """
147
+ if state_snapshot is None:
148
+ return {}
149
+
150
+ try:
151
+ # Extract state values - these contain the actual state data
152
+ state_values: dict[str, Any] = state_snapshot.values
153
+
154
+ # Extract genie_conversation_ids from state values
155
+ genie_conversation_ids: dict[str, str] = state_values.get(
156
+ "genie_conversation_ids", {}
157
+ )
158
+
159
+ if genie_conversation_ids:
160
+ logger.debug(f"Retrieved genie_conversation_ids: {genie_conversation_ids}")
161
+ return genie_conversation_ids
162
+
163
+ return {}
164
+
165
+ except Exception as e:
166
+ logger.warning(f"Error extracting genie_conversation_ids from state: {e}")
167
+ return {}
168
+
169
+
62
170
  class LanggraphChatModel(ChatModel):
63
171
  """
64
172
  ChatModel that delegates requests to a LangGraph CompiledStateGraph.
@@ -257,7 +365,19 @@ class LanggraphResponsesAgent(ResponsesAgent):
257
365
  text=last_message.content, id=f"msg_{uuid.uuid4().hex[:8]}"
258
366
  )
259
367
 
260
- custom_outputs = custom_inputs
368
+ # Retrieve genie_conversation_ids from state if available
369
+ custom_outputs: dict[str, Any] = custom_inputs.copy()
370
+ thread_id: Optional[str] = context.thread_id
371
+ if thread_id:
372
+ state_snapshot: Optional[StateSnapshot] = loop.run_until_complete(
373
+ get_state_snapshot_async(self.graph, thread_id)
374
+ )
375
+ genie_conversation_ids: dict[str, str] = (
376
+ get_genie_conversation_ids_from_state(state_snapshot)
377
+ )
378
+ if genie_conversation_ids:
379
+ custom_outputs["genie_conversation_ids"] = genie_conversation_ids
380
+
261
381
  return ResponsesAgentResponse(
262
382
  output=[output_item], custom_outputs=custom_outputs
263
383
  )
@@ -318,7 +438,22 @@ class LanggraphResponsesAgent(ResponsesAgent):
318
438
  **self.create_text_delta(delta=content, item_id=item_id)
319
439
  )
320
440
 
321
- custom_outputs = custom_inputs
441
+ # Retrieve genie_conversation_ids from state if available
442
+ custom_outputs: dict[str, Any] = custom_inputs.copy()
443
+ thread_id: Optional[str] = context.thread_id
444
+
445
+ if thread_id:
446
+ state_snapshot: Optional[
447
+ StateSnapshot
448
+ ] = await get_state_snapshot_async(self.graph, thread_id)
449
+ genie_conversation_ids: dict[str, str] = (
450
+ get_genie_conversation_ids_from_state(state_snapshot)
451
+ )
452
+ if genie_conversation_ids:
453
+ custom_outputs["genie_conversation_ids"] = (
454
+ genie_conversation_ids
455
+ )
456
+
322
457
  # Yield final output item
323
458
  yield ResponsesAgentStreamEvent(
324
459
  type="response.output_item.done",