uipath-langchain 0.1.34__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- uipath_langchain/_cli/_templates/langgraph.json.template +2 -4
- uipath_langchain/_cli/cli_new.py +1 -2
- uipath_langchain/agent/guardrails/actions/escalate_action.py +252 -108
- uipath_langchain/agent/guardrails/actions/filter_action.py +247 -12
- uipath_langchain/agent/guardrails/guardrail_nodes.py +47 -12
- uipath_langchain/agent/guardrails/guardrails_factory.py +40 -15
- uipath_langchain/agent/guardrails/utils.py +64 -33
- uipath_langchain/agent/react/agent.py +4 -2
- uipath_langchain/agent/react/file_type_handler.py +123 -0
- uipath_langchain/agent/react/guardrails/guardrails_subgraph.py +67 -12
- uipath_langchain/agent/react/init_node.py +16 -1
- uipath_langchain/agent/react/job_attachments.py +125 -0
- uipath_langchain/agent/react/json_utils.py +183 -0
- uipath_langchain/agent/react/jsonschema_pydantic_converter.py +76 -0
- uipath_langchain/agent/react/llm_with_files.py +76 -0
- uipath_langchain/agent/react/types.py +4 -0
- uipath_langchain/agent/react/utils.py +29 -3
- uipath_langchain/agent/tools/__init__.py +5 -1
- uipath_langchain/agent/tools/context_tool.py +151 -1
- uipath_langchain/agent/tools/escalation_tool.py +46 -15
- uipath_langchain/agent/tools/integration_tool.py +20 -16
- uipath_langchain/agent/tools/internal_tools/__init__.py +5 -0
- uipath_langchain/agent/tools/internal_tools/analyze_files_tool.py +113 -0
- uipath_langchain/agent/tools/internal_tools/internal_tool_factory.py +54 -0
- uipath_langchain/agent/tools/process_tool.py +8 -1
- uipath_langchain/agent/tools/static_args.py +18 -40
- uipath_langchain/agent/tools/tool_factory.py +13 -5
- uipath_langchain/agent/tools/tool_node.py +133 -4
- uipath_langchain/agent/tools/utils.py +31 -0
- uipath_langchain/agent/wrappers/__init__.py +6 -0
- uipath_langchain/agent/wrappers/job_attachment_wrapper.py +62 -0
- uipath_langchain/agent/wrappers/static_args_wrapper.py +34 -0
- uipath_langchain/chat/mapper.py +60 -42
- uipath_langchain/runtime/factory.py +10 -5
- uipath_langchain/runtime/runtime.py +38 -35
- uipath_langchain/runtime/storage.py +178 -71
- {uipath_langchain-0.1.34.dist-info → uipath_langchain-0.3.1.dist-info}/METADATA +5 -4
- {uipath_langchain-0.1.34.dist-info → uipath_langchain-0.3.1.dist-info}/RECORD +41 -30
- {uipath_langchain-0.1.34.dist-info → uipath_langchain-0.3.1.dist-info}/WHEEL +0 -0
- {uipath_langchain-0.1.34.dist-info → uipath_langchain-0.3.1.dist-info}/entry_points.txt +0 -0
- {uipath_langchain-0.1.34.dist-info → uipath_langchain-0.3.1.dist-info}/licenses/LICENSE +0 -0
uipath_langchain/chat/mapper.py
CHANGED
|
@@ -41,6 +41,7 @@ class UiPathChatMessagesMapper:
|
|
|
41
41
|
def __init__(self):
|
|
42
42
|
"""Initialize the mapper with empty state."""
|
|
43
43
|
self.tool_call_to_ai_message: dict[str, str] = {}
|
|
44
|
+
self.current_message: AIMessageChunk
|
|
44
45
|
self.seen_message_ids: set[str] = set()
|
|
45
46
|
|
|
46
47
|
def _extract_text(self, content: Any) -> str:
|
|
@@ -141,7 +142,7 @@ class UiPathChatMessagesMapper:
|
|
|
141
142
|
def map_event(
|
|
142
143
|
self,
|
|
143
144
|
message: BaseMessage,
|
|
144
|
-
) -> UiPathConversationMessageEvent | None:
|
|
145
|
+
) -> list[UiPathConversationMessageEvent] | None:
|
|
145
146
|
"""Convert LangGraph BaseMessage (chunk or full) into a UiPathConversationMessageEvent.
|
|
146
147
|
|
|
147
148
|
Args:
|
|
@@ -168,16 +169,45 @@ class UiPathChatMessagesMapper:
|
|
|
168
169
|
|
|
169
170
|
# Check if this is the last chunk by examining chunk_position
|
|
170
171
|
if message.chunk_position == "last":
|
|
172
|
+
events: list[UiPathConversationMessageEvent] = []
|
|
173
|
+
|
|
174
|
+
# Loop through all content_blocks in current_message and create toolCallStart events for each tool_call_chunk
|
|
175
|
+
if self.current_message and self.current_message.content_blocks:
|
|
176
|
+
for block in self.current_message.content_blocks:
|
|
177
|
+
if block.get("type") == "tool_call_chunk":
|
|
178
|
+
tool_chunk_block = cast(ToolCallChunk, block)
|
|
179
|
+
tool_call_id = tool_chunk_block.get("id")
|
|
180
|
+
tool_name = tool_chunk_block.get("name")
|
|
181
|
+
tool_args = tool_chunk_block.get("args")
|
|
182
|
+
|
|
183
|
+
if tool_call_id:
|
|
184
|
+
tool_event = UiPathConversationMessageEvent(
|
|
185
|
+
message_id=message.id,
|
|
186
|
+
tool_call=UiPathConversationToolCallEvent(
|
|
187
|
+
tool_call_id=tool_call_id,
|
|
188
|
+
start=UiPathConversationToolCallStartEvent(
|
|
189
|
+
tool_name=tool_name,
|
|
190
|
+
timestamp=timestamp,
|
|
191
|
+
input=UiPathInlineValue(inline=tool_args),
|
|
192
|
+
),
|
|
193
|
+
),
|
|
194
|
+
)
|
|
195
|
+
events.append(tool_event)
|
|
196
|
+
|
|
197
|
+
# Create the final event for the message
|
|
171
198
|
msg_event.end = UiPathConversationMessageEndEvent(timestamp=timestamp)
|
|
172
199
|
msg_event.content_part = UiPathConversationContentPartEvent(
|
|
173
200
|
content_part_id=f"chunk-{message.id}-0",
|
|
174
201
|
end=UiPathConversationContentPartEndEvent(),
|
|
175
202
|
)
|
|
176
|
-
|
|
203
|
+
events.append(msg_event)
|
|
204
|
+
|
|
205
|
+
return events
|
|
177
206
|
|
|
178
207
|
# For every new message_id, start a new message
|
|
179
208
|
if message.id not in self.seen_message_ids:
|
|
180
209
|
self.seen_message_ids.add(message.id)
|
|
210
|
+
self.current_message = message
|
|
181
211
|
msg_event.start = UiPathConversationMessageStartEvent(
|
|
182
212
|
role="assistant", timestamp=timestamp
|
|
183
213
|
)
|
|
@@ -200,7 +230,6 @@ class UiPathChatMessagesMapper:
|
|
|
200
230
|
content_part_id=f"chunk-{message.id}-0",
|
|
201
231
|
chunk=UiPathConversationContentPartChunkEvent(
|
|
202
232
|
data=text,
|
|
203
|
-
content_part_sequence=0,
|
|
204
233
|
),
|
|
205
234
|
)
|
|
206
235
|
|
|
@@ -210,19 +239,10 @@ class UiPathChatMessagesMapper:
|
|
|
210
239
|
tool_call_id = tool_chunk_block.get("id")
|
|
211
240
|
if tool_call_id:
|
|
212
241
|
# Track tool_call_id -> ai_message_id mapping
|
|
213
|
-
self.tool_call_to_ai_message[
|
|
214
|
-
|
|
215
|
-
args = tool_chunk_block.get("args") or ""
|
|
242
|
+
self.tool_call_to_ai_message[tool_call_id] = message.id
|
|
216
243
|
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
chunk=UiPathConversationContentPartChunkEvent(
|
|
220
|
-
data=args,
|
|
221
|
-
content_part_sequence=0,
|
|
222
|
-
),
|
|
223
|
-
)
|
|
224
|
-
# Continue so that multiple tool_call_chunks in the same block list
|
|
225
|
-
# are handled correctly
|
|
244
|
+
# Accumulate the message chunk
|
|
245
|
+
self.current_message = self.current_message + message
|
|
226
246
|
continue
|
|
227
247
|
|
|
228
248
|
# Fallback: raw string content on the chunk (rare when using content_blocks)
|
|
@@ -231,7 +251,6 @@ class UiPathChatMessagesMapper:
|
|
|
231
251
|
content_part_id=f"content-{message.id}",
|
|
232
252
|
chunk=UiPathConversationContentPartChunkEvent(
|
|
233
253
|
data=message.content,
|
|
234
|
-
content_part_sequence=0,
|
|
235
254
|
),
|
|
236
255
|
)
|
|
237
256
|
|
|
@@ -241,7 +260,7 @@ class UiPathChatMessagesMapper:
|
|
|
241
260
|
or msg_event.tool_call
|
|
242
261
|
or msg_event.end
|
|
243
262
|
):
|
|
244
|
-
return msg_event
|
|
263
|
+
return [msg_event]
|
|
245
264
|
|
|
246
265
|
return None
|
|
247
266
|
|
|
@@ -275,35 +294,34 @@ class UiPathChatMessagesMapper:
|
|
|
275
294
|
# Keep as string if not valid JSON
|
|
276
295
|
pass
|
|
277
296
|
|
|
278
|
-
return
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
297
|
+
return [
|
|
298
|
+
UiPathConversationMessageEvent(
|
|
299
|
+
message_id=result_message_id or str(uuid4()),
|
|
300
|
+
tool_call=UiPathConversationToolCallEvent(
|
|
301
|
+
tool_call_id=message.tool_call_id,
|
|
302
|
+
end=UiPathConversationToolCallEndEvent(
|
|
303
|
+
timestamp=timestamp,
|
|
304
|
+
output=UiPathInlineValue(inline=content_value),
|
|
305
|
+
),
|
|
286
306
|
),
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
output=UiPathInlineValue(inline=content_value),
|
|
290
|
-
),
|
|
291
|
-
),
|
|
292
|
-
)
|
|
307
|
+
)
|
|
308
|
+
]
|
|
293
309
|
|
|
294
310
|
# --- Fallback for other BaseMessage types ---
|
|
295
311
|
text_content = self._extract_text(message.content)
|
|
296
|
-
return
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
312
|
+
return [
|
|
313
|
+
UiPathConversationMessageEvent(
|
|
314
|
+
message_id=message.id,
|
|
315
|
+
start=UiPathConversationMessageStartEvent(
|
|
316
|
+
role="assistant", timestamp=timestamp
|
|
317
|
+
),
|
|
318
|
+
content_part=UiPathConversationContentPartEvent(
|
|
319
|
+
content_part_id=f"cp-{message.id}",
|
|
320
|
+
chunk=UiPathConversationContentPartChunkEvent(data=text_content),
|
|
321
|
+
),
|
|
322
|
+
end=UiPathConversationMessageEndEvent(),
|
|
323
|
+
)
|
|
324
|
+
]
|
|
307
325
|
|
|
308
326
|
|
|
309
327
|
__all__ = ["UiPathChatMessagesMapper"]
|
|
@@ -92,7 +92,7 @@ class UiPathLangGraphRuntimeFactory:
|
|
|
92
92
|
return self._config
|
|
93
93
|
|
|
94
94
|
async def _load_graph(
|
|
95
|
-
self, entrypoint: str
|
|
95
|
+
self, entrypoint: str, **kwargs
|
|
96
96
|
) -> StateGraph[Any, Any, Any] | CompiledStateGraph[Any, Any, Any, Any]:
|
|
97
97
|
"""
|
|
98
98
|
Load a graph for the given entrypoint.
|
|
@@ -181,7 +181,7 @@ class UiPathLangGraphRuntimeFactory:
|
|
|
181
181
|
return builder.compile(checkpointer=memory)
|
|
182
182
|
|
|
183
183
|
async def _resolve_and_compile_graph(
|
|
184
|
-
self, entrypoint: str, memory: AsyncSqliteSaver
|
|
184
|
+
self, entrypoint: str, memory: AsyncSqliteSaver, **kwargs
|
|
185
185
|
) -> CompiledStateGraph[Any, Any, Any, Any]:
|
|
186
186
|
"""
|
|
187
187
|
Resolve a graph from configuration and compile it.
|
|
@@ -201,7 +201,7 @@ class UiPathLangGraphRuntimeFactory:
|
|
|
201
201
|
if entrypoint in self._graph_cache:
|
|
202
202
|
return self._graph_cache[entrypoint]
|
|
203
203
|
|
|
204
|
-
loaded_graph = await self._load_graph(entrypoint)
|
|
204
|
+
loaded_graph = await self._load_graph(entrypoint, **kwargs)
|
|
205
205
|
|
|
206
206
|
compiled_graph = await self._compile_graph(loaded_graph, memory)
|
|
207
207
|
|
|
@@ -249,6 +249,7 @@ class UiPathLangGraphRuntimeFactory:
|
|
|
249
249
|
compiled_graph: CompiledStateGraph[Any, Any, Any, Any],
|
|
250
250
|
runtime_id: str,
|
|
251
251
|
entrypoint: str,
|
|
252
|
+
**kwargs,
|
|
252
253
|
) -> UiPathRuntimeProtocol:
|
|
253
254
|
"""
|
|
254
255
|
Create a runtime instance from a compiled graph.
|
|
@@ -275,10 +276,11 @@ class UiPathLangGraphRuntimeFactory:
|
|
|
275
276
|
delegate=base_runtime,
|
|
276
277
|
storage=storage,
|
|
277
278
|
trigger_manager=trigger_manager,
|
|
279
|
+
runtime_id=runtime_id,
|
|
278
280
|
)
|
|
279
281
|
|
|
280
282
|
async def new_runtime(
|
|
281
|
-
self, entrypoint: str, runtime_id: str
|
|
283
|
+
self, entrypoint: str, runtime_id: str, **kwargs
|
|
282
284
|
) -> UiPathRuntimeProtocol:
|
|
283
285
|
"""
|
|
284
286
|
Create a new LangGraph runtime instance.
|
|
@@ -293,12 +295,15 @@ class UiPathLangGraphRuntimeFactory:
|
|
|
293
295
|
# Get shared memory instance
|
|
294
296
|
memory = await self._get_memory()
|
|
295
297
|
|
|
296
|
-
compiled_graph = await self._resolve_and_compile_graph(
|
|
298
|
+
compiled_graph = await self._resolve_and_compile_graph(
|
|
299
|
+
entrypoint, memory, **kwargs
|
|
300
|
+
)
|
|
297
301
|
|
|
298
302
|
return await self._create_runtime_instance(
|
|
299
303
|
compiled_graph=compiled_graph,
|
|
300
304
|
runtime_id=runtime_id,
|
|
301
305
|
entrypoint=entrypoint,
|
|
306
|
+
**kwargs,
|
|
302
307
|
)
|
|
303
308
|
|
|
304
309
|
async def dispose(self) -> None:
|
|
@@ -3,6 +3,7 @@ import os
|
|
|
3
3
|
from typing import Any, AsyncGenerator
|
|
4
4
|
from uuid import uuid4
|
|
5
5
|
|
|
6
|
+
from langchain_core.callbacks import BaseCallbackHandler
|
|
6
7
|
from langchain_core.runnables.config import RunnableConfig
|
|
7
8
|
from langgraph.errors import EmptyInputError, GraphRecursionError, InvalidUpdateError
|
|
8
9
|
from langgraph.graph.state import CompiledStateGraph
|
|
@@ -41,6 +42,7 @@ class UiPathLangGraphRuntime:
|
|
|
41
42
|
graph: CompiledStateGraph[Any, Any, Any, Any],
|
|
42
43
|
runtime_id: str | None = None,
|
|
43
44
|
entrypoint: str | None = None,
|
|
45
|
+
callbacks: list[BaseCallbackHandler] | None = None,
|
|
44
46
|
):
|
|
45
47
|
"""
|
|
46
48
|
Initialize the runtime.
|
|
@@ -53,6 +55,7 @@ class UiPathLangGraphRuntime:
|
|
|
53
55
|
self.graph: CompiledStateGraph[Any, Any, Any, Any] = graph
|
|
54
56
|
self.runtime_id: str = runtime_id or "default"
|
|
55
57
|
self.entrypoint: str | None = entrypoint
|
|
58
|
+
self.callbacks: list[BaseCallbackHandler] = callbacks or []
|
|
56
59
|
self.chat = UiPathChatMessagesMapper()
|
|
57
60
|
self._middleware_node_names: set[str] = self._detect_middleware_nodes()
|
|
58
61
|
|
|
@@ -135,10 +138,17 @@ class UiPathLangGraphRuntime:
|
|
|
135
138
|
if chunk_type == "messages":
|
|
136
139
|
if isinstance(data, tuple):
|
|
137
140
|
message, _ = data
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
141
|
+
try:
|
|
142
|
+
events = self.chat.map_event(message)
|
|
143
|
+
except Exception as e:
|
|
144
|
+
logger.warning(f"Error mapping message event: {e}")
|
|
145
|
+
events = None
|
|
146
|
+
if events:
|
|
147
|
+
for mapped_event in events:
|
|
148
|
+
event = UiPathRuntimeMessageEvent(
|
|
149
|
+
payload=mapped_event,
|
|
150
|
+
)
|
|
151
|
+
yield event
|
|
142
152
|
|
|
143
153
|
# Emit UiPathRuntimeStateEvent for state updates
|
|
144
154
|
elif chunk_type == "updates":
|
|
@@ -153,6 +163,8 @@ class UiPathLangGraphRuntime:
|
|
|
153
163
|
|
|
154
164
|
# Emit state update event for each node
|
|
155
165
|
for node_name, agent_data in data.items():
|
|
166
|
+
if node_name in ("__metadata__",):
|
|
167
|
+
continue
|
|
156
168
|
if isinstance(agent_data, dict):
|
|
157
169
|
state_event = UiPathRuntimeStateEvent(
|
|
158
170
|
payload=serialize_output(agent_data),
|
|
@@ -189,7 +201,7 @@ class UiPathLangGraphRuntime:
|
|
|
189
201
|
"""Build graph execution configuration."""
|
|
190
202
|
graph_config: RunnableConfig = {
|
|
191
203
|
"configurable": {"thread_id": self.runtime_id},
|
|
192
|
-
"callbacks":
|
|
204
|
+
"callbacks": self.callbacks,
|
|
193
205
|
}
|
|
194
206
|
|
|
195
207
|
# Add optional config from environment
|
|
@@ -283,29 +295,9 @@ class UiPathLangGraphRuntime:
|
|
|
283
295
|
|
|
284
296
|
def _is_interrupted(self, state: StateSnapshot) -> bool:
|
|
285
297
|
"""Check if execution was interrupted (static or dynamic)."""
|
|
286
|
-
#
|
|
287
|
-
if
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
# Check for dynamic interrupts (interrupt() inside node)
|
|
291
|
-
if hasattr(state, "tasks"):
|
|
292
|
-
for task in state.tasks:
|
|
293
|
-
if hasattr(task, "interrupts") and task.interrupts:
|
|
294
|
-
return True
|
|
295
|
-
|
|
296
|
-
return False
|
|
297
|
-
|
|
298
|
-
def _get_dynamic_interrupt(self, state: StateSnapshot) -> Interrupt | None:
|
|
299
|
-
"""Get the first dynamic interrupt if any."""
|
|
300
|
-
if not hasattr(state, "tasks"):
|
|
301
|
-
return None
|
|
302
|
-
|
|
303
|
-
for task in state.tasks:
|
|
304
|
-
if hasattr(task, "interrupts") and task.interrupts:
|
|
305
|
-
for interrupt in task.interrupts:
|
|
306
|
-
if isinstance(interrupt, Interrupt):
|
|
307
|
-
return interrupt
|
|
308
|
-
return None
|
|
298
|
+
# An execution is considered interrupted if there are any next nodes (static interrupt)
|
|
299
|
+
# or if there are any dynamic interrupts present
|
|
300
|
+
return bool(state.next) or bool(state.interrupts)
|
|
309
301
|
|
|
310
302
|
async def _create_runtime_result(
|
|
311
303
|
self,
|
|
@@ -334,13 +326,24 @@ class UiPathLangGraphRuntime:
|
|
|
334
326
|
graph_state: StateSnapshot,
|
|
335
327
|
) -> UiPathRuntimeResult:
|
|
336
328
|
"""Create result for suspended execution."""
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
329
|
+
interrupt_map: dict[str, Any] = {}
|
|
330
|
+
|
|
331
|
+
if graph_state.interrupts:
|
|
332
|
+
for interrupt in graph_state.interrupts:
|
|
333
|
+
if isinstance(interrupt, Interrupt):
|
|
334
|
+
# Find which task this interrupt belongs to
|
|
335
|
+
for task in graph_state.tasks:
|
|
336
|
+
if task.interrupts and interrupt in task.interrupts:
|
|
337
|
+
# Only include if this task is still waiting for interrupt resolution
|
|
338
|
+
if task.interrupts and not task.result:
|
|
339
|
+
interrupt_map[interrupt.id] = interrupt.value
|
|
340
|
+
break
|
|
341
|
+
|
|
342
|
+
# If we have dynamic interrupts, return suspended with interrupt map
|
|
343
|
+
# The output is used to create the resume triggers
|
|
344
|
+
if interrupt_map:
|
|
342
345
|
return UiPathRuntimeResult(
|
|
343
|
-
output=
|
|
346
|
+
output=interrupt_map,
|
|
344
347
|
status=UiPathRuntimeStatus.SUSPENDED,
|
|
345
348
|
)
|
|
346
349
|
else:
|
|
@@ -360,7 +363,7 @@ class UiPathLangGraphRuntime:
|
|
|
360
363
|
if next_nodes:
|
|
361
364
|
# Breakpoint is BEFORE these nodes (interrupt_before)
|
|
362
365
|
breakpoint_type = "before"
|
|
363
|
-
breakpoint_node = next_nodes
|
|
366
|
+
breakpoint_node = ", ".join(next_nodes)
|
|
364
367
|
else:
|
|
365
368
|
# Breakpoint is AFTER the last executed node (interrupt_after)
|
|
366
369
|
# Get the last executed node from tasks
|
|
@@ -1,115 +1,222 @@
|
|
|
1
1
|
"""SQLite implementation of UiPathResumableStorageProtocol."""
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
-
from typing import cast
|
|
4
|
+
from typing import Any, cast
|
|
5
5
|
|
|
6
6
|
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
|
7
7
|
from pydantic import BaseModel
|
|
8
|
-
from uipath.runtime import
|
|
9
|
-
UiPathApiTrigger,
|
|
10
|
-
UiPathResumeTrigger,
|
|
11
|
-
UiPathResumeTriggerName,
|
|
12
|
-
UiPathResumeTriggerType,
|
|
13
|
-
)
|
|
8
|
+
from uipath.runtime import UiPathResumeTrigger
|
|
14
9
|
|
|
15
10
|
|
|
16
11
|
class SqliteResumableStorage:
|
|
17
|
-
"""SQLite storage for resume triggers."""
|
|
12
|
+
"""SQLite storage for resume triggers and arbitrary kv pairs."""
|
|
18
13
|
|
|
19
14
|
def __init__(
|
|
20
|
-
self,
|
|
15
|
+
self,
|
|
16
|
+
memory: AsyncSqliteSaver,
|
|
21
17
|
):
|
|
22
18
|
self.memory = memory
|
|
23
|
-
self.
|
|
19
|
+
self.rs_table_name = "__uipath_resume_triggers"
|
|
20
|
+
self.kv_table_name = "__uipath_runtime_kv"
|
|
24
21
|
self._initialized = False
|
|
25
22
|
|
|
26
23
|
async def _ensure_table(self) -> None:
|
|
27
|
-
"""Create
|
|
24
|
+
"""Create tables if needed."""
|
|
28
25
|
if self._initialized:
|
|
29
26
|
return
|
|
30
27
|
|
|
31
28
|
await self.memory.setup()
|
|
32
29
|
async with self.memory.lock, self.memory.conn.cursor() as cur:
|
|
33
|
-
|
|
34
|
-
|
|
30
|
+
# Enable WAL mode for high concurrency
|
|
31
|
+
await cur.execute("PRAGMA journal_mode=WAL")
|
|
32
|
+
|
|
33
|
+
await cur.execute(
|
|
34
|
+
f"""
|
|
35
|
+
CREATE TABLE IF NOT EXISTS {self.rs_table_name} (
|
|
35
36
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
folder_key TEXT,
|
|
40
|
-
folder_path TEXT,
|
|
41
|
-
payload TEXT,
|
|
37
|
+
runtime_id TEXT NOT NULL,
|
|
38
|
+
interrupt_id TEXT NOT NULL,
|
|
39
|
+
data TEXT NOT NULL,
|
|
42
40
|
timestamp DATETIME DEFAULT (strftime('%Y-%m-%d %H:%M:%S', 'now', 'utc'))
|
|
43
41
|
)
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
self._initialized = True
|
|
42
|
+
"""
|
|
43
|
+
)
|
|
47
44
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
45
|
+
await cur.execute(
|
|
46
|
+
f"""
|
|
47
|
+
CREATE INDEX IF NOT EXISTS idx_{self.rs_table_name}_runtime_id
|
|
48
|
+
ON {self.rs_table_name}(runtime_id)
|
|
49
|
+
"""
|
|
50
|
+
)
|
|
51
51
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
else json.dumps(payload)
|
|
52
|
+
await cur.execute(
|
|
53
|
+
f"""
|
|
54
|
+
CREATE TABLE IF NOT EXISTS {self.kv_table_name} (
|
|
55
|
+
runtime_id TEXT NOT NULL,
|
|
56
|
+
namespace TEXT NOT NULL,
|
|
57
|
+
key TEXT NOT NULL,
|
|
58
|
+
value TEXT,
|
|
59
|
+
timestamp DATETIME DEFAULT (strftime('%Y-%m-%d %H:%M:%S', 'now', 'utc')),
|
|
60
|
+
PRIMARY KEY (runtime_id, namespace, key)
|
|
62
61
|
)
|
|
63
|
-
|
|
64
|
-
else str(payload)
|
|
62
|
+
"""
|
|
65
63
|
)
|
|
66
64
|
|
|
65
|
+
await self.memory.conn.commit()
|
|
66
|
+
|
|
67
|
+
self._initialized = True
|
|
68
|
+
|
|
69
|
+
async def save_triggers(
|
|
70
|
+
self, runtime_id: str, triggers: list[UiPathResumeTrigger]
|
|
71
|
+
) -> None:
|
|
72
|
+
"""Save resume triggers to database, replacing all existing triggers for this runtime_id."""
|
|
73
|
+
await self._ensure_table()
|
|
74
|
+
|
|
67
75
|
async with self.memory.lock, self.memory.conn.cursor() as cur:
|
|
76
|
+
# Delete all existing triggers for this runtime_id
|
|
68
77
|
await cur.execute(
|
|
69
|
-
f"
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
payload,
|
|
75
|
-
trigger.folder_path,
|
|
76
|
-
trigger.folder_key,
|
|
77
|
-
),
|
|
78
|
+
f"""
|
|
79
|
+
DELETE FROM {self.rs_table_name}
|
|
80
|
+
WHERE runtime_id = ?
|
|
81
|
+
""",
|
|
82
|
+
(runtime_id,),
|
|
78
83
|
)
|
|
84
|
+
|
|
85
|
+
# Insert new triggers
|
|
86
|
+
for trigger in triggers:
|
|
87
|
+
trigger_data = trigger.model_dump()
|
|
88
|
+
trigger_data["payload"] = trigger.payload
|
|
89
|
+
trigger_data["trigger_name"] = trigger.trigger_name
|
|
90
|
+
|
|
91
|
+
await cur.execute(
|
|
92
|
+
f"""
|
|
93
|
+
INSERT INTO {self.rs_table_name}
|
|
94
|
+
(runtime_id, interrupt_id, data)
|
|
95
|
+
VALUES (?, ?, ?)
|
|
96
|
+
""",
|
|
97
|
+
(
|
|
98
|
+
runtime_id,
|
|
99
|
+
trigger.interrupt_id,
|
|
100
|
+
json.dumps(trigger_data),
|
|
101
|
+
),
|
|
102
|
+
)
|
|
79
103
|
await self.memory.conn.commit()
|
|
80
104
|
|
|
81
|
-
async def
|
|
82
|
-
"""Get
|
|
105
|
+
async def get_triggers(self, runtime_id: str) -> list[UiPathResumeTrigger] | None:
|
|
106
|
+
"""Get all triggers for runtime_id from database."""
|
|
83
107
|
await self._ensure_table()
|
|
84
108
|
|
|
85
109
|
async with self.memory.lock, self.memory.conn.cursor() as cur:
|
|
86
|
-
await cur.execute(
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
110
|
+
await cur.execute(
|
|
111
|
+
f"""
|
|
112
|
+
SELECT data
|
|
113
|
+
FROM {self.rs_table_name}
|
|
114
|
+
WHERE runtime_id = ?
|
|
115
|
+
ORDER BY timestamp ASC
|
|
116
|
+
""",
|
|
117
|
+
(runtime_id,),
|
|
118
|
+
)
|
|
119
|
+
results = await cur.fetchall()
|
|
120
|
+
|
|
121
|
+
if not results:
|
|
122
|
+
return None
|
|
93
123
|
|
|
94
|
-
|
|
95
|
-
|
|
124
|
+
triggers = []
|
|
125
|
+
for result in results:
|
|
126
|
+
data_text = cast(str, result[0])
|
|
127
|
+
trigger = UiPathResumeTrigger.model_validate_json(data_text)
|
|
128
|
+
triggers.append(trigger)
|
|
96
129
|
|
|
97
|
-
|
|
98
|
-
|
|
130
|
+
return triggers
|
|
131
|
+
|
|
132
|
+
async def delete_trigger(
|
|
133
|
+
self, runtime_id: str, trigger: UiPathResumeTrigger
|
|
134
|
+
) -> None:
|
|
135
|
+
"""Delete resume trigger from storage."""
|
|
136
|
+
await self._ensure_table()
|
|
137
|
+
|
|
138
|
+
async with self.memory.lock, self.memory.conn.cursor() as cur:
|
|
139
|
+
await cur.execute(
|
|
140
|
+
f"""
|
|
141
|
+
DELETE FROM {self.rs_table_name}
|
|
142
|
+
WHERE runtime_id = ? AND interrupt_id = ?
|
|
143
|
+
""",
|
|
144
|
+
(
|
|
145
|
+
runtime_id,
|
|
146
|
+
trigger.interrupt_id,
|
|
147
|
+
),
|
|
99
148
|
)
|
|
149
|
+
await self.memory.conn.commit()
|
|
150
|
+
|
|
151
|
+
async def set_value(
|
|
152
|
+
self,
|
|
153
|
+
runtime_id: str,
|
|
154
|
+
namespace: str,
|
|
155
|
+
key: str,
|
|
156
|
+
value: Any,
|
|
157
|
+
) -> None:
|
|
158
|
+
"""Save arbitrary key-value pair to database."""
|
|
159
|
+
if not (
|
|
160
|
+
isinstance(value, str)
|
|
161
|
+
or isinstance(value, dict)
|
|
162
|
+
or isinstance(value, BaseModel)
|
|
163
|
+
or value is None
|
|
164
|
+
):
|
|
165
|
+
raise TypeError("Value must be str, dict, BaseModel or None.")
|
|
166
|
+
|
|
167
|
+
await self._ensure_table()
|
|
100
168
|
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
169
|
+
value_text = self._dump_value(value)
|
|
170
|
+
|
|
171
|
+
async with self.memory.lock, self.memory.conn.cursor() as cur:
|
|
172
|
+
await cur.execute(
|
|
173
|
+
f"""
|
|
174
|
+
INSERT INTO {self.kv_table_name} (runtime_id, namespace, key, value)
|
|
175
|
+
VALUES (?, ?, ?, ?)
|
|
176
|
+
ON CONFLICT(runtime_id, namespace, key)
|
|
177
|
+
DO UPDATE SET
|
|
178
|
+
value = excluded.value,
|
|
179
|
+
timestamp = (strftime('%Y-%m-%d %H:%M:%S', 'now', 'utc'))
|
|
180
|
+
""",
|
|
181
|
+
(runtime_id, namespace, key, value_text),
|
|
108
182
|
)
|
|
183
|
+
await self.memory.conn.commit()
|
|
109
184
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
)
|
|
185
|
+
async def get_value(self, runtime_id: str, namespace: str, key: str) -> Any:
|
|
186
|
+
"""Get arbitrary key-value pair from database (scoped by runtime_id + namespace)."""
|
|
187
|
+
await self._ensure_table()
|
|
114
188
|
|
|
115
|
-
|
|
189
|
+
async with self.memory.lock, self.memory.conn.cursor() as cur:
|
|
190
|
+
await cur.execute(
|
|
191
|
+
f"""
|
|
192
|
+
SELECT value
|
|
193
|
+
FROM {self.kv_table_name}
|
|
194
|
+
WHERE runtime_id = ? AND namespace = ? AND key = ?
|
|
195
|
+
LIMIT 1
|
|
196
|
+
""",
|
|
197
|
+
(runtime_id, namespace, key),
|
|
198
|
+
)
|
|
199
|
+
row = await cur.fetchone()
|
|
200
|
+
|
|
201
|
+
if not row:
|
|
202
|
+
return None
|
|
203
|
+
|
|
204
|
+
return self._load_value(cast(str | None, row[0]))
|
|
205
|
+
|
|
206
|
+
def _dump_value(self, value: str | dict[str, Any] | BaseModel | None) -> str | None:
|
|
207
|
+
if value is None:
|
|
208
|
+
return None
|
|
209
|
+
if isinstance(value, BaseModel):
|
|
210
|
+
return "j:" + json.dumps(value.model_dump())
|
|
211
|
+
if isinstance(value, dict):
|
|
212
|
+
return "j:" + json.dumps(value)
|
|
213
|
+
return "s:" + value
|
|
214
|
+
|
|
215
|
+
def _load_value(self, raw: str | None) -> Any:
|
|
216
|
+
if raw is None:
|
|
217
|
+
return None
|
|
218
|
+
if raw.startswith("s:"):
|
|
219
|
+
return raw[2:]
|
|
220
|
+
if raw.startswith("j:"):
|
|
221
|
+
return json.loads(raw[2:])
|
|
222
|
+
return raw
|